[Git][ghc/ghc][wip/expand-do] do stmt expansion for Applicative Do

Apoorv Ingle (@ani) gitlab at gitlab.haskell.org
Thu Mar 23 19:51:43 UTC 2023



Apoorv Ingle pushed to branch wip/expand-do at Glasgow Haskell Compiler / GHC


Commits:
a0f73250 by Apoorv Ingle at 2023-03-23T14:51:32-05:00
do stmt expansion for Applicative Do

- - - - -


7 changed files:

- compiler/GHC/Hs/Expr.hs
- compiler/GHC/Rename/Expr.hs
- compiler/GHC/Tc/Gen/Match.hs
- testsuite/tests/rebindable/T18324.hs
- testsuite/tests/rebindable/all.T
- testsuite/tests/rebindable/pattern-fails.hs
- + testsuite/tests/rebindable/pattern-fails.stdout


Changes:

=====================================
compiler/GHC/Hs/Expr.hs
=====================================
@@ -1077,7 +1077,7 @@ instance (Outputable a, Outputable b) => Outputable (HsExpansion a b) where
   ppr (HsExpanded orig expanded)
     -- = ifPprDebug (vcat [ppr orig, braces (text "Expansion:" <+> ppr expanded)])
     --             (ppr orig)
-    = ppr orig <+> braces (text "Expansion:" <+> ppr expanded)
+    = braces (ppr orig) $$ braces (text "Expansion:" <+> ppr expanded)
 
 
 {-


=====================================
compiler/GHC/Rename/Expr.hs
=====================================
@@ -433,8 +433,7 @@ rnExpr (HsDo _ do_or_lc (L l stmts))
           rnStmtsWithFreeVars (HsDoStmt do_or_lc) rnExpr stmts
             (\ _ -> return ((), emptyFVs))
       ; (pp_stmts, fvs2) <- postProcessStmtsForApplicativeDo do_or_lc stmts1
-      ; return (HsDo noExtField do_or_lc (L l pp_stmts), fvs1 `plusFV` fvs2)
-      }
+      ; return ( HsDo noExtField do_or_lc (L l pp_stmts), fvs1 `plusFV` fvs2 ) }
 -- ExplicitList: see Note [Handling overloaded and rebindable constructs]
 rnExpr (ExplicitList _ exps)
   = do  { (exps', fvs) <- rnExprs exps
@@ -1071,10 +1070,8 @@ postProcessStmtsForApplicativeDo ctxt stmts
        ; in_th_bracket <- isBrackStage <$> getStage
        ; if ado_is_on && is_do_expr && not in_th_bracket
             then do { traceRn "ppsfa" (ppr stmts)
-                    ; ado_stmts_and_fvs <- rearrangeForApplicativeDo ctxt stmts
-                    ; return ado_stmts_and_fvs }
-            else do { do_stmts_and_fvs <- noPostProcessStmts (HsDoStmt ctxt) stmts
-                    ; return do_stmts_and_fvs } }
+                    ; rearrangeForApplicativeDo ctxt stmts }
+            else noPostProcessStmts (HsDoStmt ctxt) stmts }
 
 -- | strip the FreeVars annotations from statements
 noPostProcessStmts
@@ -1813,7 +1810,7 @@ independent and do something like this:
      (y,z) <- (,) <$> B x <*> C
      return (f x y z)
 
-But this isn't enough! A and C were also independent, and this
+But this isn't enough! If A and C were also independent, then this
 transformation loses the ability to do A and C in parallel.
 
 The algorithm works by first splitting the sequence of statements into


=====================================
compiler/GHC/Tc/Gen/Match.hs
=====================================
@@ -71,7 +71,8 @@ import GHC.Builtin.Names (bindMName, returnMName)
 import GHC.Utils.Outputable
 import GHC.Utils.Panic
 import GHC.Utils.Misc
-import GHC.Driver.Session ( getDynFlags )
+import GHC.Driver.Session ( getDynFlags, DynFlags )
+import GHC.Driver.Ppr (showPpr)
 
 import GHC.Types.Fixity (LexicalFixity(..))
 import GHC.Types.Name
@@ -325,7 +326,7 @@ tcDoStmts doExpr@(DoExpr _) (L l stmts) res_ty
         ; let expand_do_expr = mkExpandedExpr (HsDo noExtField doExpr (L l stmts))
                                                (unLoc expand_expr)
                                         -- Do expansion on the fly
-        ; traceTc "tcDoStmts" (text "tcExpr:" <+> ppr expand_do_expr)
+        ; traceTc "tcDoStmts do" (text "tcExpr:" <+> ppr expand_do_expr)
         ; tcExpr expand_do_expr res_ty
         }
 
@@ -337,7 +338,7 @@ tcDoStmts mDoExpr@(MDoExpr _) (L l stmts) res_ty
         ; let expand_do_expr = mkExpandedExpr (HsDo noExtField mDoExpr (L l stmts))
                                               (unLoc expand_expr)
                                        -- Do expansion on the fly
-        ; traceTc "tcDoStmts" (text "tcExpr:" <+> ppr expand_do_expr)
+        ; traceTc "tcDoStmts mdo" (text "tcExpr:" <+> ppr expand_do_expr)
         ; tcExpr expand_do_expr res_ty
 
         }
@@ -1220,8 +1221,8 @@ expand_do_stmts do_flavour [L _ (LastStmt _ body _ ret_expr)]
 
 
 expand_do_stmts do_or_lc ((L _ (BindStmt xbsrn pat e)): lstmts)
-  | SyntaxExprRn bind_op        <- xbsrn_bindOp xbsrn
-  , Just (SyntaxExprRn fail_op) <- xbsrn_failOp xbsrn =
+  | SyntaxExprRn bind_op <- xbsrn_bindOp xbsrn
+  , fail_op              <- xbsrn_failOp xbsrn =
 -- the pattern binding x can fail
 --      stmts ~~> stmt'    let f pat = stmts'; f _ = fail ".."
 --    -------------------------------------------------------
@@ -1233,17 +1234,6 @@ expand_do_stmts do_or_lc ((L _ (BindStmt xbsrn pat e)): lstmts)
                                 , expr
                                 ])
 
-  | SyntaxExprRn bind_op <- xbsrn_bindOp xbsrn
-  , Nothing          <- xbsrn_failOp xbsrn = -- irrefutable pattern so no failure
---                      stmts ~~> stmt'
---    ------------------------------------------------
---       x <- e ; stmts   ~~> (Prelude.>>=) e (\ x -> stmts')
-      do expand_stmts <- expand_do_stmts do_or_lc lstmts
-         return $ noLocA (foldl genHsApp bind_op -- (>>=)
-                          [ e
-                          , mkHsLam [pat] expand_stmts  -- (\ x -> stmts')
-                          ])
-
   | otherwise = -- just use the polymorhpic bindop. TODO: Necessary?
       do expand_stmts <- expand_do_stmts do_or_lc lstmts
          return $ noLocA (genHsApps bindMName -- (Prelude.>>=)
@@ -1251,33 +1241,6 @@ expand_do_stmts do_or_lc ((L _ (BindStmt xbsrn pat e)): lstmts)
                             , mkHsLam [pat] expand_stmts  -- (\ x -> stmts')
                             ])
 
-  where
-    mk_failable_lexpr_tcm :: LPat GhcRn -> LHsExpr GhcRn -> HsExpr GhcRn -> TcM (LHsExpr GhcRn)
-    -- checks the pattern pat and decides if we need to plug in the fail block
-    -- Type checking the pattern is necessary to decide if we need to generate the fail block
-    -- Renamer cannot always determine if a fail block is necessary, and its conservative behaviour would
-    -- generate a fail block even if it is not really needed. cf. GHC.Hs.isIrrefutableHsPat
-    -- Only Tuples are considered irrefutable in the renamer, while newtypes and TyCons with only one datacon
-    -- is not
-    mk_failable_lexpr_tcm pat lexpr fail_op =
-      do { ((tc_pat, _), _) <- tcInferPat (FRRBindStmt DoNotation)
-                               PatBindRhs pat $ return id -- whatever
-         ; dflags <- getDynFlags
-         ; if isIrrefutableHsPat dflags tc_pat
-           then return $ mkHsLam [pat] lexpr
-           else return $ mk_fail_lexpr pat lexpr fail_op
-         }
-    mk_fail_lexpr :: LPat GhcRn -> LHsExpr GhcRn -> HsExpr GhcRn -> LHsExpr GhcRn
-    -- makes the fail block
-    -- TODO: check the discussion around MonadFail.fail type signature.
-    -- Should we really say `mkHsString "fail pattern"`? if yes, maybe a better error message would help
-    mk_fail_lexpr pat lexpr fail_op =
-      noLocA (HsLam noExtField $ mkMatchGroup Generated                 -- let
-               (noLocA [ mkHsCaseAlt pat lexpr                          --   f pat = expr
-                       , mkHsCaseAlt nlWildPatName                      --   f _   = fail "fail pattern"
-                         (noLocA $ genHsApp fail_op
-                           (nlHsLit $ mkHsString "fail pattern")) ]))
-
 expand_do_stmts do_or_lc (L _ (LetStmt _ bnds) : lstmts) =
 --                      stmts ~~> stmts'
 --    ------------------------------------------------
@@ -1296,13 +1259,14 @@ expand_do_stmts do_or_lc ((L _ (BodyStmt _ e (SyntaxExprRn f) _)) : lstmts) =
                 [ e               -- e
                 , expand_stmts ]  -- stmts'
 
-expand_do_stmts do_or_lc ((L _ (RecStmt { recS_stmts = rec_stmts
-                                        , recS_later_ids = later_ids  -- forward referenced local ids
-                                        , recS_rec_ids = local_ids     -- ids referenced outside of the rec block
-                                        , recS_mfix_fn = SyntaxExprRn mfix_fun   -- the `mfix` expr
-                                        , recS_ret_fn  = SyntaxExprRn return_fun -- the `return` expr
-                                                                                 -- use it explicitly
-                                                                                 -- at the end of expanded rec block
+expand_do_stmts do_or_lc
+  ((L _ (RecStmt { recS_stmts = rec_stmts
+                 , recS_later_ids = later_ids  -- forward referenced local ids
+                 , recS_rec_ids = local_ids     -- ids referenced outside of the rec block
+                 , recS_mfix_fn = SyntaxExprRn mfix_fun   -- the `mfix` expr
+                 , recS_ret_fn  = SyntaxExprRn return_fun -- the `return` expr
+                                                          -- use it explicitly
+                                                          -- at the end of expanded rec block
                                       }))
                     : lstmts) =
 -- See Note [Typing a RecStmt]
@@ -1320,7 +1284,8 @@ expand_do_stmts do_or_lc ((L _ (RecStmt { recS_stmts = rec_stmts
                                        expand_stmts                       --         stmts')
                       ])
   where
-    local_only_ids = local_ids \\ later_ids -- get unique local rec ids; local rec ids and later ids overlap
+    local_only_ids = local_ids \\ later_ids -- get unique local rec ids;
+                                            --local rec ids and later ids can overlap
     all_ids = local_only_ids ++ later_ids   -- put local ids before return ids
 
     return_stmt  :: ExprLStmt GhcRn
@@ -1336,13 +1301,51 @@ expand_do_stmts do_or_lc ((L _ (RecStmt { recS_stmts = rec_stmts
     mfix_expr    :: LHsExpr GhcRn
     mfix_expr    = mkHsLam [ mkBigLHsVarPatTup all_ids ] $ do_block
 
-expand_do_stmts _ (stmt@(L _ (ApplicativeStmt _ appargs (Just join))):_) =
--- See Note [Applicative BodyStmt]
-  pprPanic "expand_do_stmts: impossible happened ApplicativeStmt" $ ppr stmt
-  
-expand_do_stmts _ (stmt@(L _ (ApplicativeStmt _ appargs Nothing)):_) =
+expand_do_stmts do_or_lc (stmt@(L _ (ApplicativeStmt _ args mb_join)): lstmts) =
 -- See Note [Applicative BodyStmt]
-  pprPanic "expand_do_stmts: impossible happened ApplicativeStmt" $ ppr stmt
+--
+--                  stmts ~~> stmts'
+--   -------------------------------------------------
+--      ; stmts  ~~> (\ x -> stmts') <$> e1 <*> e2 ...
+--
+-- Very similar to HsToCore.Expr.dsDo
+
+-- args are [(<$>, e1), (<*>, e2), .., ]
+-- mb_join is Maybe (join)
+  do { expr' <- expand_do_stmts do_or_lc lstmts
+     ; (pats_can_fail, rhss) <- unzip <$> mapM (do_arg . snd) args
+
+     ; body <- foldrM match_args expr' pats_can_fail -- add blocks for failable patterns
+
+     ; let expand_ado_expr = foldl mk_app_call body (zip (map fst args) rhss)
+     ; traceTc "expand_do_stmts: debug" $ (vcat [ text "stmt:" <+> ppr stmt
+                                                , text "(pats,rhss):" <+> ppr (pats_can_fail, rhss)
+                                                , text "expr':" <+> ppr expr'
+                                                , text "args" <+> ppr args
+                                                , text "final_ado" <+> ppr expand_ado_expr
+                                                ])
+
+
+             -- pprPanic "expand_do_stmts: impossible happened ApplicativeStmt" empty
+     ; case mb_join of
+         Nothing -> return expand_ado_expr
+         Just NoSyntaxExprRn -> return expand_ado_expr -- this is stupid
+         Just (SyntaxExprRn join_op) -> return $ mkHsApp (noLocA join_op) expand_ado_expr
+     }
+  where
+    do_arg :: ApplicativeArg GhcRn -> TcM ((LPat GhcRn, FailOperator GhcRn), LHsExpr GhcRn)
+    do_arg (ApplicativeArgOne mb_fail_op pat expr _) =
+      return ((pat, mb_fail_op), expr)
+    do_arg (ApplicativeArgMany _ stmts ret pat _) =
+      do { expr <- expand_do_stmts do_or_lc $ stmts ++ [noLocA $ mkLastStmt (noLocA ret)]
+         ; return ((pat, Nothing), expr) }
+
+    match_args :: (LPat GhcRn, FailOperator GhcRn) -> LHsExpr GhcRn -> TcM (LHsExpr GhcRn)
+    match_args (pat, fail_op) body = mk_failable_lexpr_tcm pat body fail_op
+
+    mk_app_call l (op, r) = case op of
+                              SyntaxExprRn op -> mkHsApps (noLocA op) [l, r]
+                              NoSyntaxExprRn -> pprPanic "expand_do_stmts: impossible happened first arg" (ppr op)
 
 expand_do_stmts _ (stmt@(L _ (TransStmt {})):_) =
   pprPanic "expand_do_stmts: impossible happened TransStmt" $ ppr stmt
@@ -1354,3 +1357,40 @@ expand_do_stmts _ (stmt@(L _ (ParStmt {})):_) =
 
 
 expand_do_stmts do_flavor stmts = pprPanic "expand_do_stmts: impossible happened" $ (ppr do_flavor $$ ppr stmts)
+
+
+
+mk_failable_lexpr_tcm :: LPat GhcRn -> LHsExpr GhcRn -> FailOperator GhcRn -> TcM (LHsExpr GhcRn)
+-- checks the pattern pat and decides if we need to plug in the fail block
+-- Type checking the pattern is necessary to decide if we need to generate the fail block
+-- Renamer cannot always determine if a fail block is necessary, and its conservative behaviour would
+-- generate a fail block even if it is not really needed. cf. GHC.Hs.isIrrefutableHsPat
+-- Only Tuples are considered irrefutable in the renamer, while newtypes and TyCons with only one datacon
+-- is not
+mk_failable_lexpr_tcm pat lexpr fail_op =
+  do { ((tc_pat, _), _) <- tcInferPat (FRRBindStmt DoNotation)
+                           PatBindRhs pat $ return id -- whatever
+     ; dflags <- getDynFlags
+     ; if isIrrefutableHsPat dflags tc_pat
+       then return $ mkHsLam [pat] lexpr
+       else mk_fail_lexpr pat lexpr fail_op
+     }
+
+-- makes the fail block
+-- TODO: check the discussion around MonadFail.fail type signature.
+-- Should we really say `mkHsString "fail pattern"`? if yes, maybe a better error message would help
+mk_fail_lexpr :: LPat GhcRn -> LHsExpr GhcRn -> FailOperator GhcRn -> TcM (LHsExpr GhcRn)
+mk_fail_lexpr pat lexpr (Just (SyntaxExprRn fail_op)) =
+  do  dflags <- getDynFlags
+      return $ noLocA (HsLam noExtField $ mkMatchGroup Generated   -- \
+                      (noLocA [ mkHsCaseAlt pat lexpr              --   pat -> expr
+                              , mkHsCaseAlt nlWildPatName          --   _   -> fail "fail pattern"
+                                (noLocA $ genHsApp fail_op
+                                 (mk_fail_msg_expr dflags (DoExpr Nothing) pat))
+                              ]))
+mk_fail_lexpr _ _ _ = pprPanic "mk_fail_lexpr: impossible happened" empty
+
+mk_fail_msg_expr :: DynFlags -> HsDoFlavour -> LPat GhcRn -> LHsExpr GhcRn
+mk_fail_msg_expr dflags ctx pat
+  = nlHsLit $ mkHsString $ showPpr dflags $ text "Pattern match failure in" <+> pprHsDoFlavour ctx
+                   <+> text "at" <+> ppr (getLocA pat)


=====================================
testsuite/tests/rebindable/T18324.hs
=====================================
@@ -19,3 +19,9 @@ foo2 = do { x <- t ; return (p x) }
 main = do x <- foo2
           putStrLn $ show x
           
+
+data D a b = D b b | E a a
+
+fffgg daa = case daa of
+              D b1 b2 -> let
+                x = do 


=====================================
testsuite/tests/rebindable/all.T
=====================================
@@ -45,4 +45,4 @@ test('T20126', normal, compile_fail, [''])
 # Tests for desugaring do before typechecking
 test('T18324', normal, compile, [''])
 test('T23147', normal, compile, [''])
-test('pattern-fails', normal, compile, [''])
+test('pattern-fails', normal, compile_and_run, [''])


=====================================
testsuite/tests/rebindable/pattern-fails.hs
=====================================
@@ -1,8 +1,8 @@
-module PF where
+module Main where
 
 
--- main :: IO ()
--- main = putStrLn . show $ qqq ['c']
+main :: IO ()
+main = putStrLn . show $ qqq ['c']
 
 qqq :: [a] -> Maybe (a, [a])
 qqq ts = do { (a:b:as) <- Just ts
@@ -16,3 +16,5 @@ emptyST = Just $ ST (0, 0)
 ppp :: Maybe (ST Int Int) -> Maybe (ST Int Int)
 ppp st = do { ST (x, y) <- st
             ; return $ ST (x+1, y+1)}
+
+


=====================================
testsuite/tests/rebindable/pattern-fails.stdout
=====================================
@@ -0,0 +1 @@
+Nothing



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/a0f732508aa4fd0fc23a6f9e51052b0413318154

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/a0f732508aa4fd0fc23a6f9e51052b0413318154
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20230323/c41ead73/attachment-0001.html>


More information about the ghc-commits mailing list