[Git][ghc/ghc][wip/T22428] Fix contification with stable unfoldings (#22428)

Sebastian Graf (@sgraf812) gitlab at gitlab.haskell.org
Mon Nov 14 16:50:03 UTC 2022



Sebastian Graf pushed to branch wip/T22428 at Glasgow Haskell Compiler / GHC


Commits:
6dd0b6cb by Sebastian Graf at 2022-11-14T17:49:57+01:00
Fix contification with stable unfoldings (#22428)

- - - - -


1 changed file:

- compiler/GHC/Core/Opt/OccurAnal.hs


Changes:

=====================================
compiler/GHC/Core/Opt/OccurAnal.hs
=====================================
@@ -754,12 +754,11 @@ occAnalNonRecBind !env lvl imp_rule_edges bndr rhs body_usage
   = WithUsageDetails body_usage []
 
   | otherwise                   -- It's mentioned in the body
-  = WithUsageDetails (body_usage' `andUDs` rhs_usage) [NonRec final_bndr rhs']
+  = WithUsageDetails (body_usage' `andUDs` rhs_usage) [NonRec final_bndr rhs2]
   where
     (body_usage', tagged_bndr) = tagNonRecBinder lvl body_usage bndr
     final_bndr = tagged_bndr `setIdUnfolding` unf'
                              `setIdSpecialisation` mkRuleInfo rules'
-    rhs_usage = rhs_uds `andUDs` unf_uds `andUDs` rule_uds
 
     -- Get the join info from the *new* decision
     -- See Note [Join points and unfoldings/rules]
@@ -773,13 +772,18 @@ occAnalNonRecBind !env lvl imp_rule_edges bndr rhs body_usage
 
     -- See Note [Sources of one-shot information]
     rhs_env = env1 { occ_one_shots = argOneShots dmd }
-    (WithUsageDetails rhs_uds rhs') = occAnalRhs rhs_env NonRecursive mb_join_arity rhs
+    (WithUsageDetails rhs_uds rhs1) = occAnalLam rhs_env rhs
+      -- corresponding call to adjustRhsUsage directly below
+    rhs2 = markNonRecJoinOneShots mb_join_arity rhs1
+    rhs_usage = adjustRhsUsage mb_join_arity rhs1 $
+                rhs_uds `andUDs` unf_uds `andUDs` rule_uds
 
     --------- Unfolding ---------
     -- See Note [Unfoldings and join points]
+    -- and Note [Join points and unfoldings/rules]
     unf | isId bndr = idUnfolding bndr
         | otherwise = NoUnfolding
-    (WithUsageDetails unf_uds unf') = occAnalUnfolding rhs_env NonRecursive mb_join_arity unf
+    (WithUsageDetails unf_uds unf') = occAnalUnfolding rhs_env mb_join_arity unf
 
     --------- Rules ---------
     -- See Note [Rules are extra RHSs] and Note [Rule dependency info]
@@ -848,10 +852,14 @@ occAnalRec !_ lvl (AcyclicSCC (ND { nd_bndr = bndr, nd_rhs = rhs
 
   | otherwise                   -- It's mentioned in the body
   = WithUsageDetails (body_uds' `andUDs` rhs_uds')
-                     (NonRec tagged_bndr rhs : binds)
+                     (NonRec tagged_bndr rhs' : binds)
   where
     (body_uds', tagged_bndr) = tagNonRecBinder lvl body_uds bndr
-    rhs_uds'      = adjustRhsUsage mb_join_arity rhs rhs_uds
+    rhs' = markNonRecJoinOneShots mb_join_arity rhs
+    rhs_uds' = adjustRhsUsage mb_join_arity rhs' rhs_uds
+      -- corresponding call to occAnalLam is in makeNode
+      -- rhs_uds is for a non-recursive join point; we should to do the same
+      -- as occAnalNonRecBind, so we do 'markNonRecJoinOneShots' before.
     mb_join_arity = willBeJoinId_maybe tagged_bndr
 
         -- The Rec case is the interesting one
@@ -1412,27 +1420,45 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs)
     -- and the unfolding together.
     -- See Note [inl_fvs]
 
-    mb_join_arity = isJoinId_maybe bndr
-    -- Get join point info from the *current* decision
-    -- We don't know what the new decision will be!
-    -- Using the old decision at least allows us to
-    -- preserve existing join point, even RULEs are added
+    mb_opt_join_arity = Just (manifestJoinArity rhs)
+    -- We don't know yet if bndr will become a join point.
+    -- But *if* it becomes one, we can guess its join arity: manifestJoinArity.
+    -- Here, we pretend that bndr will become a join point so that we preserve
+    -- tail calls (perhaps to other binders) in scope_uds. If it turns out
+    -- that bndr cannot become a join point, we take care of it later, in
+    -- occAnalRec and tagRecBinders, through calling adjustRhsUsage on the
+    -- merged scope_uds.
+    -- By pretending that we know the join arity we can already zap tail call
+    -- info for stable unfoldings and RULES with mismatching manifestJoinArity
+    -- without affecting tail call info from rhs_uds.
+    -- There's a special case worth thinking about: If bndr ends up as an
+    -- AcyclicSCC it might become a non-recursive join point with join arity
+    -- /less/ than mb_opt_join_arity (so 'manifestJoinArity' is really a
+    -- maximum). Well, then we'll still adjustRhsUsage with the lower arity, so
+    -- the only drawback is that we *might* have discarded useful tail call info
+    -- in inl_uds or rule_uds that we will see on the next iteration.
     -- See Note [Join points and unfoldings/rules]
 
     --------- Right hand side ---------
     -- Constructing the edges for the main Rec computation
     -- See Note [Forming Rec groups]
-    -- Do not use occAnalRhs because we don't yet know the final
-    -- answer for mb_join_arity; instead, do the occAnalLam call from
-    -- occAnalRhs, and postpone adjustRhsUsage until occAnalRec
+    -- Compared to occAnalRhs, this will not adjust the RHS because
+    --   (a) we don't yet know the final answer for mb_join_arity
+    --   (b) we don't even know whether it stays a recursive RHS after the SCC
+    --       analysis we are about to seed! So we can't markAllInsideLam in
+    --       advance, because if it ends up as a non-recursive join point we'll
+    --       consider it as one-shot and don't need to markAllInsideLam.
+    -- Instead, do the occAnalLam call from occAnalRhs, and postpone
+    -- adjustRhsUsage until occAnalRec.
     rhs_env                         = rhsCtxt env
     (WithUsageDetails rhs_uds rhs') = occAnalLam rhs_env rhs
+      -- corresponding call to adjustRhsUsage in occAnalRec and tagRecBinders
 
     --------- Unfolding ---------
     -- See Note [Unfoldings and join points]
     unf = realIdUnfolding bndr -- realIdUnfolding: Ignore loop-breaker-ness
                                -- here because that is what we are setting!
-    (WithUsageDetails unf_uds unf') = occAnalUnfolding rhs_env Recursive mb_join_arity unf
+    (WithUsageDetails unf_uds unf') = occAnalUnfolding rhs_env mb_opt_join_arity unf
 
     --------- IMP-RULES --------
     is_active     = occ_rule_act env :: Activation -> Bool
@@ -1442,7 +1468,7 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs)
 
     --------- All rules --------
     rules_w_uds :: [(CoreRule, UsageDetails, UsageDetails)]
-    rules_w_uds = occAnalRules rhs_env mb_join_arity bndr
+    rules_w_uds = occAnalRules rhs_env mb_opt_join_arity bndr
     rules'      = map fstOf3 rules_w_uds
 
     rule_uds = foldr add_rule_uds imp_rule_uds rules_w_uds
@@ -1815,7 +1841,11 @@ occAnalLam :: OccEnv -> CoreExpr -> (WithUsageDetails CoreExpr)
 -- This function does /not/ do
 --   markAllInsideLam or
 --   markAllNonTail
--- The caller does that, either in occAnal (Lam {}), or in adjustRhsUsage
+-- The caller does that, either in occAnal (Lam {}), or in adjustRhsUsage.
+-- Every call site links to its respective adjustRhsUsage call and vice versa.
+--
+-- In effect, the analysis result is for a non-recursive join point with
+-- manifest arity and adjustRhsUsage does the fixup.
 -- See Note [Adjusting right-hand sides]
 
 occAnalLam env (Lam bndr expr)
@@ -1895,51 +1925,29 @@ of a right hand side is handled by occAnalLam.
 *                                                                      *
 ********************************************************************* -}
 
-occAnalRhs :: OccEnv -> RecFlag -> Maybe JoinArity
+occAnalRhs :: OccEnv -> Maybe JoinArity
            -> CoreExpr   -- RHS
            -> WithUsageDetails CoreExpr
-occAnalRhs !env is_rec mb_join_arity rhs
-  = let (WithUsageDetails usage rhs1) = occAnalLam env rhs
-           -- We call occAnalLam here, not occAnalExpr, so that it doesn't
-           -- do the markAllInsideLam and markNonTailCall stuff before
-           -- we've had a chance to help with join points; that comes next
-        rhs2      = markJoinOneShots is_rec mb_join_arity rhs1
-        rhs_usage = adjustRhsUsage mb_join_arity rhs2 usage
-    in WithUsageDetails rhs_usage rhs2
-
-
-
-markJoinOneShots :: RecFlag -> Maybe JoinArity -> CoreExpr -> CoreExpr
--- For a /non-recursive/ join point we can mark all
--- its join-lambda as one-shot; and it's a good idea to do so
-markJoinOneShots NonRecursive (Just join_arity) rhs
-  = go join_arity rhs
+-- ^ This function immediately does adjustRhsUsage after the call to occAnalLam.
+-- It's useful for anonymous lambdas and unfoldings.
+occAnalRhs !env mb_join_arity rhs
+  = WithUsageDetails (adjustRhsUsage mb_join_arity rhs' usage) rhs'
   where
-    go 0 rhs         = rhs
-    go n (Lam b rhs) = Lam (if isId b then setOneShotLambda b else b)
-                           (go (n-1) rhs)
-    go _ rhs         = rhs  -- Not enough lambdas.  This can legitimately happen.
-                            -- e.g.    let j = case ... in j True
-                            -- This will become an arity-1 join point after the
-                            -- simplifier has eta-expanded it; but it may not have
-                            -- enough lambdas /yet/. (Lint checks that JoinIds do
-                            -- have enough lambdas.)
-markJoinOneShots _ _ rhs
-  = rhs
+    WithUsageDetails usage rhs' = occAnalLam env rhs
+
 
 occAnalUnfolding :: OccEnv
-                 -> RecFlag
                  -> Maybe JoinArity   -- See Note [Join points and unfoldings/rules]
                  -> Unfolding
                  -> WithUsageDetails Unfolding
 -- Occurrence-analyse a stable unfolding;
--- discard a non-stable one altogether.
-occAnalUnfolding !env is_rec mb_join_arity unf
+-- discard a non-stable one altogether and return empty usage details.
+occAnalUnfolding !env mb_join_arity unf
   = case unf of
       unf@(CoreUnfolding { uf_tmpl = rhs, uf_src = src })
         | isStableSource src ->
             let
-              (WithUsageDetails usage rhs') = occAnalRhs env is_rec mb_join_arity rhs
+              (WithUsageDetails usage rhs') = occAnalRhs env mb_join_arity rhs
 
               unf' | noBinderSwaps env = unf -- Note [Unfoldings and rules]
                    | otherwise         = unf { uf_tmpl = rhs' }
@@ -1958,9 +1966,7 @@ occAnalUnfolding !env is_rec mb_join_arity unf
         where
           env'            = env `addInScope` bndrs
           (WithUsageDetails usage args') = occAnalList env' args
-          final_usage     = markAllManyNonTail (delDetailsList usage bndrs)
-                            `addLamCoVarOccs` bndrs
-                            `delDetailsList` bndrs
+          final_usage     = usage `addLamCoVarOccs` bndrs `delDetailsList` bndrs
               -- delDetailsList; no need to use tagLamBinders because we
               -- never inline DFuns so the occ-info on binders doesn't matter
 
@@ -1989,8 +1995,8 @@ occAnalRules !env mb_join_arity bndr
         (WithUsageDetails rhs_uds rhs') = occAnal env' rhs
                             -- Note [Rules are extra RHSs]
                             -- Note [Rule dependency info]
-        rhs_uds' = markAllNonTailIf (not exact_join) $
-                   markAllMany                             $
+        rhs_uds' = markAllNonTailIf (not exact_join) $ -- Nearly adjustRhsUsage, but we don't want to
+                   markAllMany                       $ -- build `mkLams (map _ args) rhs` just for the call
                    rhs_uds `delDetailsList` bndrs
 
         exact_join = exactJoin mb_join_arity args
@@ -2208,10 +2214,7 @@ occAnal env app@(App _ _)
   = occAnalApp env (collectArgsTicks tickishFloatable app)
 
 occAnal env expr@(Lam {})
-  = let (WithUsageDetails usage expr') = occAnalLam env expr
-        final_usage = markAllInsideLamIf (not (isOneShotFun expr')) $
-                      markAllNonTail usage
-    in WithUsageDetails final_usage expr'
+  = occAnalRhs env Nothing expr -- mb_join_arity == Nothing <=> markAllNonTail
 
 occAnal env (Case scrut bndr ty alts)
   = let
@@ -2286,7 +2289,7 @@ occAnalApp !env (Var fun, args, ticks)
   --     This caused #18296
   | fun `hasKey` runRWKey
   , [t1, t2, arg]  <- args
-  , let (WithUsageDetails usage arg') = occAnalRhs env NonRecursive (Just 1) arg
+  , WithUsageDetails usage arg' <- occAnalRhs env (Just 1) arg
   = WithUsageDetails usage (mkTicks ticks $ mkApps (Var fun) [t1, t2, arg'])
 
 occAnalApp env (Var fun_id, args, ticks)
@@ -3133,7 +3136,7 @@ flattenUsageDetails ud@(UD { ud_env = env })
 -------------------
 -- See Note [Adjusting right-hand sides]
 adjustRhsUsage :: Maybe JoinArity
-               -> CoreExpr       -- Rhs, AFTER occ anal
+               -> CoreExpr       -- Rhs, AFTER occAnalLam
                -> UsageDetails   -- From body of lambda
                -> UsageDetails
 adjustRhsUsage mb_join_arity rhs usage
@@ -3146,6 +3149,28 @@ adjustRhsUsage mb_join_arity rhs usage
     exact_join = exactJoin mb_join_arity bndrs
     (bndrs,_)  = collectBinders rhs
 
+-- | IF rhs becomes a join point, then `manifestJoinArity rhs` returns the
+-- exact join arity of `rhs`.
+manifestJoinArity :: CoreExpr -> JoinArity
+manifestJoinArity rhs = length $ fst $ collectBinders rhs
+
+markNonRecJoinOneShots :: Maybe JoinArity -> CoreExpr -> CoreExpr
+-- For a /non-recursive/ join point we can mark all
+-- its join-lambda as one-shot; and it's a good idea to do so
+markNonRecJoinOneShots Nothing           rhs = rhs
+markNonRecJoinOneShots (Just join_arity) rhs
+  = go join_arity rhs
+  where
+    go 0 rhs         = rhs
+    go n (Lam b rhs) = Lam (if isId b then setOneShotLambda b else b)
+                           (go (n-1) rhs)
+    go _ rhs         = rhs  -- Not enough lambdas.  This can legitimately happen.
+                            -- e.g.    let j = case ... in j True
+                            -- This will become an arity-1 join point after the
+                            -- simplifier has eta-expanded it; but it may not have
+                            -- enough lambdas /yet/. (Lint checks that JoinIds do
+                            -- have enough lambdas.)
+
 exactJoin :: Maybe JoinArity -> [a] -> Bool
 exactJoin Nothing           _    = False
 exactJoin (Just join_arity) args = args `lengthIs` join_arity
@@ -3224,7 +3249,7 @@ tagRecBinders lvl body_uds details_s
 
      -- 2. Adjust usage details of each RHS, taking into account the
      --    join-point-hood decision
-     rhs_udss' = [ adjustRhsUsage (mb_join_arity bndr) rhs rhs_uds
+     rhs_udss' = [ adjustRhsUsage (mb_join_arity bndr) rhs rhs_uds -- matching occAnalLam in makeNode
                  | ND { nd_bndr = bndr, nd_uds = rhs_uds
                       , nd_rhs = rhs } <- details_s ]
 



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/6dd0b6cbc0e95999906363066f49e64279badf0a

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/6dd0b6cbc0e95999906363066f49e64279badf0a
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20221114/cedd4998/attachment-0001.html>


More information about the ghc-commits mailing list