[Git][ghc/ghc][wip/expand-do] enable mdo expansion

Apoorv Ingle (@ani) gitlab at gitlab.haskell.org
Mon Aug 21 02:21:23 UTC 2023



Apoorv Ingle pushed to branch wip/expand-do at Glasgow Haskell Compiler / GHC


Commits:
aea5417f by Apoorv Ingle at 2023-08-20T21:18:55-05:00
enable mdo expansion

- - - - -


9 changed files:

- compiler/GHC/Hs/Instances.hs
- compiler/GHC/HsToCore/Match.hs
- compiler/GHC/HsToCore/Pmc.hs
- compiler/GHC/Rename/Utils.hs
- compiler/GHC/Tc/Gen/Expr.hs
- compiler/GHC/Tc/Gen/Match.hs
- compiler/GHC/Types/Basic.hs
- compiler/Language/Haskell/Syntax/Expr.hs
- compiler/Language/Haskell/Syntax/Expr.hs-boot


Changes:

=====================================
compiler/GHC/Hs/Instances.hs
=====================================
@@ -385,7 +385,7 @@ deriving instance Data (HsStmtContext GhcTc)
 
 deriving instance Data HsArrowMatchContext
 
-deriving instance Data HsDoFlavour
+-- deriving instance Data HsDoFlavour
 
 deriving instance Data (HsMatchContext GhcPs)
 deriving instance Data (HsMatchContext GhcRn)


=====================================
compiler/GHC/HsToCore/Match.hs
=====================================
@@ -845,12 +845,12 @@ matchWrapper ctxt scrs (MG { mg_alts = L _ matches'
       $ replicate (length (grhssGRHSs m)) ldi_nablas
 
     is_pat_syn_match :: Origin -> LMatch GhcTc (LHsExpr GhcTc) -> Bool
-    is_pat_syn_match origin (L _ (Match _ _ [l_pat] _)) | isDoExpansionGenerated origin = isPatSyn l_pat
+    is_pat_syn_match origin (L _ (Match _ _ [l_pat] _)) | isJust (isDoExpansionGenerated origin) = isPatSyn l_pat
     is_pat_syn_match _ _ = False
     -- generated match pattern that is not a wildcard
     non_gen_wc :: Origin -> LMatch GhcTc (LHsExpr GhcTc) -> Bool
     non_gen_wc origin (L _ (Match _ _ ([L _ (WildPat _)]) _))
-                   | isDoExpansionGenerated origin = False
+                   | isJust (isDoExpansionGenerated origin) = False
                    | otherwise = True
     non_gen_wc _ _ = True
 


=====================================
compiler/GHC/HsToCore/Pmc.hs
=====================================
@@ -69,6 +69,7 @@ import GHC.Data.Bag
 import GHC.Data.OrdList
 
 import Control.Monad (when, forM_)
+import Data.Maybe (isNothing)
 import qualified Data.Semigroup as Semi
 import Data.List.NonEmpty ( NonEmpty(..) )
 import qualified Data.List.NonEmpty as NE
@@ -191,7 +192,7 @@ pmcMatches origin ctxt vars matches = {-# SCC "pmcMatches" #-} do
       result  <- {-# SCC "checkMatchGroup" #-}
                  unCA (checkMatchGroup matches) missing
       tracePm "}: " (ppr (cr_uncov result))
-      when (not (isDoExpansionGenerated origin)) -- Generated code shouldn't emit overlapping warnings
+      when (isNothing (isDoExpansionGenerated origin)) -- Generated code shouldn't emit overlapping warnings
         ({-# SCC "formatReportWarnings" #-}
         formatReportWarnings ReportMatchGroup ctxt vars result)
       return (NE.toList (ldiMatchGroup (cr_ret result)))


=====================================
compiler/GHC/Rename/Utils.hs
=====================================
@@ -793,13 +793,14 @@ genHsLet :: HsLocalBindsLR GhcRn GhcRn -> LHsExpr GhcRn -> HsExpr GhcRn
 genHsLet bindings body = HsLet noExtField noHsTok bindings noHsTok body
 
 genHsLamDoExp :: (IsPass p, XMG (GhcPass p) (LHsExpr (GhcPass p)) ~ Origin)
-        => [LPat (GhcPass p)]
+        => HsDoFlavour
+        -> [LPat (GhcPass p)]
         -> LHsExpr (GhcPass p)
         -> LHsExpr (GhcPass p)
-genHsLamDoExp pats body = mkHsPar (wrapGenSpan $ HsLam noExtField matches)
+genHsLamDoExp doFlav pats body = mkHsPar (wrapGenSpan $ HsLam noExtField matches)
   where
-    matches = mkMatchGroup doExpansionOrigin
-                           (wrapGenSpan [genSimpleMatch (StmtCtxt (HsDoStmt (DoExpr Nothing))) pats' body])
+    matches = mkMatchGroup (doExpansionOrigin doFlav)
+                           (wrapGenSpan [genSimpleMatch (StmtCtxt (HsDoStmt doFlav)) pats' body])
     pats' = map (parenthesizePat appPrec) pats
 
 
@@ -807,10 +808,10 @@ genHsCaseAltDoExp :: (Anno (GRHS (GhcPass p) (LocatedA (body (GhcPass p))))
                      ~ SrcAnn NoEpAnns,
                  Anno (Match (GhcPass p) (LocatedA (body (GhcPass p))))
                         ~ SrcSpanAnnA)
-            => LPat (GhcPass p) -> (LocatedA (body (GhcPass p)))
+            => HsDoFlavour -> LPat (GhcPass p) -> (LocatedA (body (GhcPass p)))
             -> LMatch (GhcPass p) (LocatedA (body (GhcPass p)))
-genHsCaseAltDoExp pat expr
-  = genSimpleMatch (StmtCtxt (HsDoStmt (DoExpr Nothing)))  [pat] expr
+genHsCaseAltDoExp doFlav pat expr
+  = genSimpleMatch (StmtCtxt (HsDoStmt doFlav))  [pat] expr
 
 
 genSimpleMatch :: (Anno (Match (GhcPass p) (LocatedA (body (GhcPass p))))


=====================================
compiler/GHC/Tc/Gen/Expr.hs
=====================================
@@ -269,9 +269,9 @@ tcExpr (HsLam _ match) res_ty
         ; return (mkHsWrap wrap (HsLam noExtField match')) }
   where
     match_ctxt
-      | isDoExpansionGenerated (mg_ext match)
+      | Just f <- isDoExpansionGenerated (mg_ext match)
       -- See Part 3. of Note [Expanding HsDo with HsExpansion]
-      = MC { mc_what = StmtCtxt (HsDoStmt (DoExpr Nothing))
+      = MC { mc_what = StmtCtxt (HsDoStmt f)
            , mc_body = tcBodyNC -- NB: Do not add any error contexts
                                 -- It has already been done
            }
@@ -425,6 +425,15 @@ tcExpr hsDo@(HsDo _ do_or_lc@(DoExpr{}) ss@(L _  stmts)) res_ty
                  ; mkExpandedExprTc hsDo <$> tcExpr (unLoc expanded_expr) res_ty
                  }
        }
+tcExpr hsDo@(HsDo _ do_or_lc@(MDoExpr{}) (L _  stmts)) res_ty
+-- In the case of mdo expression.
+-- We expand the statements into explicit application of binds, thens and lets
+-- This helps in infering the right types for bind expressions when impredicativity is turned on
+-- See Note [Expanding HsDo with HsExpansion] in GHC.Tc.Gen.Match.hs
+  = do { expanded_expr <- expandDoStmts do_or_lc stmts
+                                               -- Do expansion on the fly
+       ; mkExpandedExprTc hsDo <$> tcExpr (unLoc expanded_expr) res_ty
+       }
 
 tcExpr (HsDo _ do_or_lc stmts) res_ty
   = tcDoStmts do_or_lc stmts res_ty


=====================================
compiler/GHC/Tc/Gen/Match.hs
=====================================
@@ -1277,7 +1277,7 @@ expand_do_stmts do_or_lc (stmt@(L loc (BindStmt xbsrn pat e)): lstmts)
 --    -------------------------------------------------------
 --       pat <- e ; stmts   ~~> (>>=) e f
   = do expand_stmts <- expand_do_stmts do_or_lc lstmts
-       failable_expr <- mk_failable_expr pat expand_stmts fail_op
+       failable_expr <- mk_failable_expr do_or_lc pat expand_stmts fail_op
        let expansion = genHsExpApps bind_op  -- (>>=)
                        [ e
                        , failable_expr ]
@@ -1321,8 +1321,8 @@ expand_do_stmts do_or_lc
      -- as we want to flatten the rec block statements into its parent do block anyway
      return $ mkHsApps (wrapGenSpan bind_fun)                           -- (>>=)
                       [ (wrapGenSpan mfix_fun) `mkHsApp` mfix_expr      -- (mfix (do block))
-                      , genHsLamDoExp [ mkBigLHsVarPatTup all_ids ]     --        (\ x ->
-                                       (expand_stmts)                   --               stmts')
+                      , genHsLamDoExp do_or_lc [ mkBigLHsVarPatTup all_ids ]     --        (\ x ->
+                                       expand_stmts                  --               stmts')
                       ]
   where
     local_only_ids = local_ids \\ later_ids -- get unique local rec ids;
@@ -1339,15 +1339,15 @@ expand_do_stmts do_or_lc
     do_block     :: LHsExpr GhcRn
     do_block     = L loc $ HsDo noExtField do_or_lc do_stmts
     mfix_expr    :: LHsExpr GhcRn
-    mfix_expr    = genHsLamDoExp [ wrapGenSpan (LazyPat noExtField $ mkBigLHsVarPatTup all_ids) ] $ do_block
+    mfix_expr    = genHsLamDoExp do_or_lc [ wrapGenSpan (LazyPat noExtField $ mkBigLHsVarPatTup all_ids) ] $ do_block
                              -- NB: LazyPat because we do not want to eagerly evaluate the pattern
                              -- and potentially loop forever
 
 expand_do_stmts _ stmts = pprPanic "expand_do_stmts: impossible happened" $ (ppr stmts)
 
 -- checks the pattern `pat`for irrefutability which decides if we need to decorate it with a fail block
-mk_failable_expr :: LPat GhcRn -> LHsExpr GhcRn -> FailOperator GhcRn -> TcM (LHsExpr GhcRn)
-mk_failable_expr pat@(L loc _) expr fail_op =
+mk_failable_expr :: HsDoFlavour -> LPat GhcRn -> LHsExpr GhcRn -> FailOperator GhcRn -> TcM (LHsExpr GhcRn)
+mk_failable_expr doFlav pat@(L loc _) expr fail_op =
   do { tc_env <- getGblEnv
      ; is_strict <- xoptM LangExt.Strict
      ; irrf_pat <- isIrrefutableHsPatRn' tc_env is_strict pat
@@ -1357,21 +1357,21 @@ mk_failable_expr pat@(L loc _) expr fail_op =
 
      ; if irrf_pat                        -- don't decorate with fail block if
                                           -- the pattern is irrefutable
-       then return $ genHsLamDoExp [pat] expr
-       else L loc <$> mk_fail_block pat expr fail_op
+       then return $ genHsLamDoExp doFlav [pat] expr
+       else L loc <$> mk_fail_block doFlav pat expr fail_op
      }
 
 -- makes the fail block with a given fail_op
-mk_fail_block :: LPat GhcRn -> LHsExpr GhcRn -> FailOperator GhcRn -> TcM (HsExpr GhcRn)
-mk_fail_block pat@(L ploc _) e (Just (SyntaxExprRn fail_op)) =
+mk_fail_block :: HsDoFlavour -> LPat GhcRn -> LHsExpr GhcRn -> FailOperator GhcRn -> TcM (HsExpr GhcRn)
+mk_fail_block doFlav pat@(L ploc _) e (Just (SyntaxExprRn fail_op)) =
   do  dflags <- getDynFlags
-      return $ HsLam noExtField $ mkMatchGroup doExpansionOrigin     -- \
-                (wrapGenSpan [ genHsCaseAltDoExp pat e               --  pat -> expr
+      return $ HsLam noExtField $ mkMatchGroup (doExpansionOrigin doFlav)     -- \
+                (wrapGenSpan [ genHsCaseAltDoExp doFlav pat e               --  pat -> expr
                              , fail_alt_case dflags pat fail_op      --  _   -> fail "fail pattern"
                              ])
         where
           fail_alt_case :: DynFlags -> LPat GhcRn -> HsExpr GhcRn -> LMatch GhcRn (LHsExpr GhcRn)
-          fail_alt_case dflags pat fail_op = genHsCaseAltDoExp genWildPat $
+          fail_alt_case dflags pat fail_op = genHsCaseAltDoExp doFlav genWildPat $
                                              L ploc (fail_op_expr dflags pat fail_op)
 
           fail_op_expr :: DynFlags -> LPat GhcRn -> HsExpr GhcRn -> HsExpr GhcRn
@@ -1386,7 +1386,7 @@ mk_fail_block pat@(L ploc _) e (Just (SyntaxExprRn fail_op)) =
                    <+> text "at" <+> ppr (getLocA pat)
 
 
-mk_fail_block _ _ _ = pprPanic "mk_fail_block: impossible happened" empty
+mk_fail_block _ _ _ _ = pprPanic "mk_fail_block: impossible happened" empty
 
 
 {- Note [Expanding HsDo with HsExpansion]
@@ -1432,11 +1432,11 @@ For example, the expansion of the do block
                  \ p -> e2
                  _   -> fail "failable pattern p at location")
 
-* Why an anonymous lambda?
-  We need a lambda for the types to match: this expression is a second
-  argument to bind so it needs to be of type `a -> m b`
-  It is anonymous because we do not want to introduce a new name that will
-  never be seen by the user anyway.
+Why an anonymous lambda?
+We need a lambda for the types to match: this expression is a second
+argument to bind so it needs to be of type `a -> m b`
+It is anonymous because we do not want to introduce a new name that will
+never be seen by the user anyway.
 
 * Wrinkle 1: For pattern synonyms (see testcase Typeable1.hs)
   We always decorate it with a fail block as the irrefutable pattern checker returns false


=====================================
compiler/GHC/Types/Basic.hs
=====================================
@@ -132,6 +132,7 @@ import GHC.Types.SourceText
 import qualified GHC.LanguageExtensions as LangExt
 import {-# SOURCE #-} Language.Haskell.Syntax.Type (PromotionFlag(..), isPromoted)
 import Language.Haskell.Syntax.Basic (Boxity(..), isBoxed, ConTag)
+import {-# SOURCE #-} Language.Haskell.Syntax.Expr (HsDoFlavour)
 
 import Control.DeepSeq ( NFData(..) )
 import Data.Data
@@ -599,22 +600,22 @@ isGenerated FromSource   = False
 -- | Why was the piece of code generated?
 --   It is useful for generating the right error context
 -- See Part 3 in Note [Expanding HsDo with HsExpansion]
-data GenReason = DoExpansion
+data GenReason = DoExpansion HsDoFlavour
                | OtherExpansion
                deriving (Eq, Data)
 
 instance Outputable GenReason where
-  ppr DoExpansion  = text "DoExpansion"
+  ppr (DoExpansion{})  = text "DoExpansion"
   ppr OtherExpansion  = text "OtherExpansion"
 
 -- See Part 3 in Note [Expanding HsDo with HsExpansion]
-isDoExpansionGenerated :: Origin -> Bool
-isDoExpansionGenerated (Generated DoExpansion _) = True
-isDoExpansionGenerated _ = False
+isDoExpansionGenerated :: Origin -> Maybe HsDoFlavour
+isDoExpansionGenerated (Generated (DoExpansion f) _) = Just f
+isDoExpansionGenerated _ = Nothing
 
 -- See Part 3 in Note [Expanding HsDo with HsExpansion]
-doExpansionOrigin :: Origin
-doExpansionOrigin = Generated DoExpansion DoPmc
+doExpansionOrigin :: HsDoFlavour -> Origin
+doExpansionOrigin f = Generated (DoExpansion f) DoPmc
                     -- It is important that we perfrom PMC
                     -- on the expansions of do statements
                     -- to get the right warnings


=====================================
compiler/Language/Haskell/Syntax/Expr.hs
=====================================
@@ -45,6 +45,7 @@ import Data.Maybe
 import Data.List.NonEmpty ( NonEmpty )
 import GHC.Types.Name.Reader
 
+
 {- Note [RecordDotSyntax field updates]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 The extensions @OverloadedRecordDot@ @OverloadedRecordUpdate@ together
@@ -1625,6 +1626,7 @@ data HsDoFlavour
   | GhciStmtCtxt                     -- ^A command-line Stmt in GHCi pat <- rhs
   | ListComp
   | MonadComp
+  deriving (Eq, Data)
 
 qualifiedDoModuleName_maybe :: HsStmtContext p -> Maybe ModuleName
 qualifiedDoModuleName_maybe ctxt = case ctxt of


=====================================
compiler/Language/Haskell/Syntax/Expr.hs-boot
=====================================
@@ -9,6 +9,9 @@ module Language.Haskell.Syntax.Expr where
 import Language.Haskell.Syntax.Extension ( XRec )
 import Data.Kind  ( Type )
 
+import GHC.Prelude (Eq)
+import Data.Data (Data)
+
 type role HsExpr nominal
 type role MatchGroup nominal nominal
 type role GRHSs nominal nominal
@@ -20,3 +23,7 @@ data GRHSs (a :: Type) (body :: Type)
 type family SyntaxExpr (i :: Type)
 
 type LHsExpr a = XRec a (HsExpr a)
+
+data HsDoFlavour
+instance Eq HsDoFlavour
+instance Data HsDoFlavour
\ No newline at end of file



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/aea5417f1670cbcc8466239696a00bb0dbb9981d

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/aea5417f1670cbcc8466239696a00bb0dbb9981d
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20230820/9658da58/attachment-0001.html>


More information about the ghc-commits mailing list