[Git][ghc/ghc][wip/T22428] Fix contification with stable unfoldings (#22428)

Sebastian Graf (@sgraf812) gitlab at gitlab.haskell.org
Thu Nov 17 17:34:29 UTC 2022



Sebastian Graf pushed to branch wip/T22428 at Glasgow Haskell Compiler / GHC


Commits:
12a81354 by Sebastian Graf at 2022-11-17T18:34:09+01:00
Fix contification with stable unfoldings (#22428)

... by predicting the join arity of a recursive RHS.
See `Note [Join arity prediction for recursive functions]`.
I also adjusted `Note [Join points and unfoldings/rules]` to account for the
usage of predicted join arity in `occAnalRules` and `occAnalUnfolding` which
now takes note of #22428.

I also renamed

  * `occAnalLam` to `occAnalLamTail`
  * `occAnalRhs` to `occAnalLam`
  * `adjustRhsUsage` to `adjustTailUsage`
  * a few other less important functions

and properly documented the that each call of `occAnalLamTail` must pair up with
`adjustTailUsage`.

I removed `Note [Unfoldings and join points]` because it was redundant with
`Note [Occurrences in stable unfoldings]`.

Fixes #22428.

- - - - -


4 changed files:

- compiler/GHC/Core/Opt/OccurAnal.hs
- + testsuite/tests/simplCore/should_compile/T22428.hs
- + testsuite/tests/simplCore/should_compile/T22428.stderr
- testsuite/tests/simplCore/should_compile/all.T


Changes:

=====================================
compiler/GHC/Core/Opt/OccurAnal.hs
=====================================
@@ -510,6 +510,13 @@ of the file.
     at any of the definitions.  This is done by Simplify.simplRecBind,
     when it calls addLetIdInfo.
 
+Note that the UsageDetails stored in the Details are as if the binding was a
+*non-recursive join point*, which is quite optimistic. The reason is that
+the binding might actually become a non-recursive join point after dependency
+analysis and in which case we can pretend that the whole RHS is only entered
+once. We do the delayed 'adjustTailUsage' in 'occAnalRec'/'tagRecBinders'.
+See Note [Join points and unfoldings/rules] for more details on the contract.
+
 Note [Stable unfoldings]
 ~~~~~~~~~~~~~~~~~~~~~~~~
 None of the above stuff about RULES applies to a stable unfolding
@@ -608,6 +615,24 @@ tail call with `n` arguments (counting both value and type arguments). Otherwise
 'occ_tail' will be 'NoTailCallInfo'. The tail call info flows bottom-up with the
 rest of 'OccInfo' until it goes on the binder.
 
+Note [Join arity prediction for recursive functions]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+For recursive functions, we can predict the only possible join arity:
+It's the manifest join arity, e.g., the number of leading lambda binders in the
+RHS. No different join arity would do:
+
+  * If join arity would be lower, then the leading lambda would spoil any
+    recursive tail calls in the RHS and hence join-point-hood of the whole letrec.
+  * If join arity would be higher, then we'd have to eta expand the RHS first.
+    It is the job of Arity Analysis, Call Arity and Demand Analysis to decide
+    whether that won't lose sharing and outside the scope of occurrence analysis
+    to check for.
+
+For non-recursive functions, we can make no such prediction; the join arity
+may be both higher than manifest join arity (in which case the Simplifier will
+eta-expand; it's simple to see that no sharing can be lost) or lower than the
+manifest join arity.
+
 Note [Join points and unfoldings/rules]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 Consider
@@ -618,7 +643,7 @@ Consider
 
 Before j is inlined, we'll have occurrences of j2 in
 both j's RHS and in its stable unfolding.  We want to discover
-j2 as a join point.  So we must do the adjustRhsUsage thing
+j2 as a join point.  So we must do the adjustTailUsage thing
 on j's RHS.  That's why we pass mb_join_arity to calcUnfolding.
 
 Same with rules. Suppose we have:
@@ -636,14 +661,18 @@ up.  So provided the join-point arity of k matches the args of the
 rule we can allow the tail-call info from the RHS of the rule to
 propagate.
 
-* Wrinkle for Rec case. In the recursive case we don't know the
-  join-point arity in advance, when calling occAnalUnfolding and
-  occAnalRules.  (See makeNode.)  We don't want to pass Nothing,
-  because then a recursive joinrec might lose its join-poin-hood
-  when SpecConstr adds a RULE.  So we just make do with the
-  *current* join-poin-hood, stored in the Id.
+* Wrinkle for Rec case. We need to know the potential join arity 'makeNode',
+  so that 'occAnalRules' and 'occAnalUnfolding' don't spoil potential
+  join-point-hood of the letrec. This is especially important given that otherwise
+    * RULEs added by SpecConstr might otherwise lose join-point-hood of
+      previously detected join points
+    * Stable unfoldings might prevent contification (#22428)
+  So as per Note [Join arity prediction for recursive functions] we pick
+  'manifestJoinArity' to do the job. Note that in case the letrec won't become
+  a join point, we'll 'adjustTailUsage' eventually in 'occAnalRec'.
 
-  In the non-recursive case things are simple: see occAnalNonRecBind
+  In the non-recursive case things are simple, because we know the join arity
+  from the body_usage: see occAnalNonRecBind
 
 * Wrinkle for RULES.  Suppose the example was a bit different:
       let j :: Int -> Int
@@ -669,13 +698,6 @@ propagate.
   This appears to be very rare in practice. TODO Perhaps we should gather
   statistics to be sure.
 
-Note [Unfoldings and join points]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-We assume that anything in an unfolding occurs multiple times, since
-unfoldings are often copied (that's the whole point!). But we still
-need to track tail calls for the purpose of finding join points.
-
-
 ------------------------------------------------------------
 Note [Adjusting right-hand sides]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -707,21 +729,24 @@ lambda) get marked:
 There are a few other caveats; most importantly, if we're marking a binding as
 'AlwaysTailCalled', it's *going* to be a join point, so we treat it as one so
 that the effect cascades properly. Consequently, at the time the RHS is
-analysed, we won't know what adjustments to make; thus 'occAnalLamOrRhs' must
-return the unadjusted 'UsageDetails', to be adjusted by 'adjustRhsUsage' once
-join-point-hood has been decided.
+analysed, we won't know what adjustments to make; thus 'occAnalLamTail' must
+return the unadjusted 'UsageDetails', to be adjusted by 'adjustTailUsage' once
+join-point-hood has been decided and eventual one-shot annotations have been
+added through 'markNonRecJoinOneShots'.
 
 Thus the overall sequence taking place in 'occAnalNonRecBind' and
 'occAnalRecBind' is as follows:
 
-  1. Call 'occAnalLamOrRhs' to find usage information for the RHS.
+  1. Call 'occAnalLamTail' to find usage information for the RHS.
   2. Call 'tagNonRecBinder' or 'tagRecBinders', which decides whether to make
      the binding a join point.
-  3. Call 'adjustRhsUsage' accordingly. (Done as part of 'tagRecBinders' when
+  3. Call 'markNonRecJoinOneShots' so that we recognise every non-recursive join
+     point as one-shot
+  4. Call 'adjustTailUsage' accordingly. (Done as part of 'tagRecBinders' when
      recursive.)
 
 (In the recursive case, this logic is spread between 'makeNode' and
-'occAnalRec'.)
+the 'AcyclicSCC' case of 'occAnalRec'.)
 -}
 
 
@@ -754,12 +779,11 @@ occAnalNonRecBind !env lvl imp_rule_edges bndr rhs body_usage
   = WithUsageDetails body_usage []
 
   | otherwise                   -- It's mentioned in the body
-  = WithUsageDetails (body_usage' `andUDs` rhs_usage) [NonRec final_bndr rhs']
+  = WithUsageDetails (body_usage' `andUDs` rhs_usage) [NonRec final_bndr rhs2]
   where
     (body_usage', tagged_bndr) = tagNonRecBinder lvl body_usage bndr
     final_bndr = tagged_bndr `setIdUnfolding` unf'
                              `setIdSpecialisation` mkRuleInfo rules'
-    rhs_usage = rhs_uds `andUDs` unf_uds `andUDs` rule_uds
 
     -- Get the join info from the *new* decision
     -- See Note [Join points and unfoldings/rules]
@@ -773,16 +797,21 @@ occAnalNonRecBind !env lvl imp_rule_edges bndr rhs body_usage
 
     -- See Note [Sources of one-shot information]
     rhs_env = env1 { occ_one_shots = argOneShots dmd }
-    (WithUsageDetails rhs_uds rhs') = occAnalRhs rhs_env NonRecursive mb_join_arity rhs
+    (WithUsageDetails rhs_uds rhs1) = occAnalLamTail rhs_env rhs
+      -- corresponding call to adjustTailUsage directly below
+    rhs2 = markNonRecJoinOneShots mb_join_arity rhs1
+    rhs_usage = adjustTailUsage mb_join_arity rhs1 $
+                rhs_uds `andUDs` unf_uds `andUDs` rule_uds
 
     --------- Unfolding ---------
-    -- See Note [Unfoldings and join points]
+    -- See Note [Join points and unfoldings/rules]
     unf | isId bndr = idUnfolding bndr
         | otherwise = NoUnfolding
-    (WithUsageDetails unf_uds unf') = occAnalUnfolding rhs_env NonRecursive mb_join_arity unf
+    (WithUsageDetails unf_uds unf') = occAnalUnfolding rhs_env mb_join_arity unf
 
     --------- Rules ---------
     -- See Note [Rules are extra RHSs] and Note [Rule dependency info]
+    -- and Note [Join points and unfoldings/rules]
     rules_w_uds  = occAnalRules rhs_env mb_join_arity bndr
     rules'       = map fstOf3 rules_w_uds
     imp_rule_uds = impRulesScopeUsage (lookupImpRules imp_rule_edges bndr)
@@ -848,10 +877,14 @@ occAnalRec !_ lvl (AcyclicSCC (ND { nd_bndr = bndr, nd_rhs = rhs
 
   | otherwise                   -- It's mentioned in the body
   = WithUsageDetails (body_uds' `andUDs` rhs_uds')
-                     (NonRec tagged_bndr rhs : binds)
+                     (NonRec tagged_bndr rhs' : binds)
   where
     (body_uds', tagged_bndr) = tagNonRecBinder lvl body_uds bndr
-    rhs_uds'      = adjustRhsUsage mb_join_arity rhs rhs_uds
+    rhs' = markNonRecJoinOneShots mb_join_arity rhs
+    rhs_uds' = adjustTailUsage mb_join_arity rhs' rhs_uds
+      -- corresponding call to occAnalLamTail is in makeNode
+      -- rhs_uds is for a non-recursive join point; we should to do the same
+      -- as occAnalNonRecBind, so we do 'markNonRecJoinOneShots' before.
     mb_join_arity = willBeJoinId_maybe tagged_bndr
 
         -- The Rec case is the interesting one
@@ -1331,6 +1364,8 @@ data Details
 
        , nd_uds  :: UsageDetails  -- Usage from RHS, and RULES, and stable unfoldings
                                   -- ignoring phase (ie assuming all are active)
+                                  -- NB: These UsageDetails optimistically assume
+                                  -- that this Node becomes a non-recursive join point!
                                   -- See Note [Forming Rec groups]
 
        , nd_inl  :: IdSet       -- Free variables of the stable unfolding and the RHS
@@ -1412,27 +1447,33 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs)
     -- and the unfolding together.
     -- See Note [inl_fvs]
 
-    mb_join_arity = isJoinId_maybe bndr
-    -- Get join point info from the *current* decision
-    -- We don't know what the new decision will be!
-    -- Using the old decision at least allows us to
-    -- preserve existing join point, even RULEs are added
-    -- See Note [Join points and unfoldings/rules]
+
+    -- See Note [Join arity prediction for recursive functions]
+    -- and Note [Join points and unfoldings/rules]
+    pred_join_arity = Just (manifestJoinArity rhs)
 
     --------- Right hand side ---------
     -- Constructing the edges for the main Rec computation
     -- See Note [Forming Rec groups]
-    -- Do not use occAnalRhs because we don't yet know the final
-    -- answer for mb_join_arity; instead, do the occAnalLam call from
-    -- occAnalRhs, and postpone adjustRhsUsage until occAnalRec
+    -- Compared to occAnalLam, we can't yet adjust the RHS because
+    --   (a) we don't yet know the final answer for pred_join_arity. It might be
+    --       Nothing!
+    --   (b) we don't even know whether it stays a recursive RHS after the SCC
+    --       analysis we are about to seed! So we can't markAllInsideLam in
+    --       advance, because if it ends up as a non-recursive join point we'll
+    --       consider it as one-shot and don't need to markAllInsideLam.
+    -- Instead, do the occAnalLamTail call from occAnalLam, and postpone
+    -- adjustTailUsage until occAnalRec. In effect, we pretend that the RHS
+    -- becomes a non-recursive join point and fix up later with adjustTailUsage.
     rhs_env                         = rhsCtxt env
-    (WithUsageDetails rhs_uds rhs') = occAnalLam rhs_env rhs
+    (WithUsageDetails rhs_uds rhs') = occAnalLamTail rhs_env rhs
+      -- corresponding call to adjustTailUsage in occAnalRec and tagRecBinders
 
     --------- Unfolding ---------
-    -- See Note [Unfoldings and join points]
+    -- See Note [Join points and unfoldings/rules]
     unf = realIdUnfolding bndr -- realIdUnfolding: Ignore loop-breaker-ness
                                -- here because that is what we are setting!
-    (WithUsageDetails unf_uds unf') = occAnalUnfolding rhs_env Recursive mb_join_arity unf
+    (WithUsageDetails unf_uds unf') = occAnalUnfolding rhs_env pred_join_arity unf
 
     --------- IMP-RULES --------
     is_active     = occ_rule_act env :: Activation -> Bool
@@ -1441,8 +1482,9 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs)
     imp_rule_fvs  = impRulesActiveFvs is_active bndr_set imp_rule_info
 
     --------- All rules --------
+    -- See Note [Join points and unfoldings/rules]
     rules_w_uds :: [(CoreRule, UsageDetails, UsageDetails)]
-    rules_w_uds = occAnalRules rhs_env mb_join_arity bndr
+    rules_w_uds = occAnalRules rhs_env pred_join_arity bndr
     rules'      = map fstOf3 rules_w_uds
 
     rule_uds = foldr add_rule_uds imp_rule_uds rules_w_uds
@@ -1748,7 +1790,7 @@ lambda and casts, e.g.
 
 * Occurrence analyser: we just mark each binder in the lambda-group
   (here: x,y,z) with its occurrence info in the *body* of the
-  lambda-group.  See occAnalLam.
+  lambda-group.  See occAnalLamTail.
 
 * Simplifier.  The simplifier is careful when partially applying
   lambda-groups. See the call to zapLambdaBndrs in
@@ -1804,7 +1846,7 @@ zapLambdaBndrs fun arg_count
     zap_bndr b | isTyVar b = b
                | otherwise = zapLamIdInfo b
 
-occAnalLam :: OccEnv -> CoreExpr -> (WithUsageDetails CoreExpr)
+occAnalLamTail :: OccEnv -> CoreExpr -> (WithUsageDetails CoreExpr)
 -- See Note [Occurrence analysis for lambda binders]
 -- It does the following:
 --   * Sets one-shot info on the lambda binder from the OccEnv, and
@@ -1815,12 +1857,16 @@ occAnalLam :: OccEnv -> CoreExpr -> (WithUsageDetails CoreExpr)
 -- This function does /not/ do
 --   markAllInsideLam or
 --   markAllNonTail
--- The caller does that, either in occAnal (Lam {}), or in adjustRhsUsage
+-- The caller does that, either in occAnalLam, or calling adjustTailUsage directly.
+-- Every call site links to its respective adjustTailUsage call and vice versa.
+--
+-- In effect, the analysis result is for a non-recursive join point with
+-- manifest arity and adjustTailUsage does the fixup.
 -- See Note [Adjusting right-hand sides]
 
-occAnalLam env (Lam bndr expr)
+occAnalLamTail env (Lam bndr expr)
   | isTyVar bndr
-  = let (WithUsageDetails usage expr') = occAnalLam env expr
+  = let (WithUsageDetails usage expr') = occAnalLamTail env expr
     in WithUsageDetails usage (Lam bndr expr')
        -- Important: Keep the 'env' unchanged so that with a RHS like
        --   \(@ x) -> K @x (f @x)
@@ -1839,14 +1885,14 @@ occAnalLam env (Lam bndr expr)
 
         env1 = env { occ_encl = OccVanilla, occ_one_shots = env_one_shots' }
         env2 = addOneInScope env1 bndr
-        (WithUsageDetails usage expr') = occAnalLam env2 expr
+        (WithUsageDetails usage expr') = occAnalLamTail env2 expr
         (usage', bndr2) = tagLamBinder usage bndr1
     in WithUsageDetails usage' (Lam bndr2 expr')
 
 -- For casts, keep going in the same lambda-group
 -- See Note [Occurrence analysis for lambda binders]
-occAnalLam env (Cast expr co)
-  = let  (WithUsageDetails usage expr') = occAnalLam env expr
+occAnalLamTail env (Cast expr co)
+  = let  (WithUsageDetails usage expr') = occAnalLamTail env expr
          -- usage1: see Note [Gather occurrences of coercion variables]
          usage1 = addManyOccs usage (coVarsOfCo co)
 
@@ -1856,15 +1902,15 @@ occAnalLam env (Cast expr co)
                     _ -> usage1
 
          -- usage3: you might think this was not necessary, because of
-         -- the markAllNonTail in adjustRhsUsage; but not so!  For a
-         -- join point, adjustRhsUsage doesn't do this; yet if there is
+         -- the markAllNonTail in adjustTailUsage; but not so!  For a
+         -- join point, adjustTailUsage doesn't do this; yet if there is
          -- a cast, we must!  Also: why markAllNonTail?  See
          -- GHC.Core.Lint: Note Note [Join points and casts]
          usage3 = markAllNonTail usage2
 
     in WithUsageDetails usage3 (Cast expr' co)
 
-occAnalLam env expr = occAnal env expr
+occAnalLamTail env expr = occAnal env expr
 
 {- Note [Occ-anal and cast worker/wrapper]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1884,8 +1930,8 @@ RHS. So it'll get a Many occ-info.  (Maybe Cast w/w should create a stable
 unfolding, which would obviate this Note; but that seems a bit of a
 heavyweight solution.)
 
-We only need to this in occAnalLam, not occAnal, because the top leve
-of a right hand side is handled by occAnalLam.
+We only need to this in occAnalLamTail, not occAnal, because the top leve
+of a right hand side is handled by occAnalLamTail.
 -}
 
 
@@ -1895,51 +1941,30 @@ of a right hand side is handled by occAnalLam.
 *                                                                      *
 ********************************************************************* -}
 
-occAnalRhs :: OccEnv -> RecFlag -> Maybe JoinArity
+occAnalLam :: OccEnv -> Maybe JoinArity
            -> CoreExpr   -- RHS
            -> WithUsageDetails CoreExpr
-occAnalRhs !env is_rec mb_join_arity rhs
-  = let (WithUsageDetails usage rhs1) = occAnalLam env rhs
-           -- We call occAnalLam here, not occAnalExpr, so that it doesn't
-           -- do the markAllInsideLam and markNonTailCall stuff before
-           -- we've had a chance to help with join points; that comes next
-        rhs2      = markJoinOneShots is_rec mb_join_arity rhs1
-        rhs_usage = adjustRhsUsage mb_join_arity rhs2 usage
-    in WithUsageDetails rhs_usage rhs2
-
-
-
-markJoinOneShots :: RecFlag -> Maybe JoinArity -> CoreExpr -> CoreExpr
--- For a /non-recursive/ join point we can mark all
--- its join-lambda as one-shot; and it's a good idea to do so
-markJoinOneShots NonRecursive (Just join_arity) rhs
-  = go join_arity rhs
+-- ^ This function immediately does adjustTailUsage with the fixed join arity
+-- after the call to occAnalLamTail.
+-- It's useful for anonymous lambdas and unfoldings.
+occAnalLam !env mb_join_arity rhs
+  = WithUsageDetails (adjustTailUsage mb_join_arity rhs' usage) rhs'
   where
-    go 0 rhs         = rhs
-    go n (Lam b rhs) = Lam (if isId b then setOneShotLambda b else b)
-                           (go (n-1) rhs)
-    go _ rhs         = rhs  -- Not enough lambdas.  This can legitimately happen.
-                            -- e.g.    let j = case ... in j True
-                            -- This will become an arity-1 join point after the
-                            -- simplifier has eta-expanded it; but it may not have
-                            -- enough lambdas /yet/. (Lint checks that JoinIds do
-                            -- have enough lambdas.)
-markJoinOneShots _ _ rhs
-  = rhs
+    WithUsageDetails usage rhs' = occAnalLamTail env rhs
+
 
 occAnalUnfolding :: OccEnv
-                 -> RecFlag
                  -> Maybe JoinArity   -- See Note [Join points and unfoldings/rules]
                  -> Unfolding
                  -> WithUsageDetails Unfolding
 -- Occurrence-analyse a stable unfolding;
--- discard a non-stable one altogether.
-occAnalUnfolding !env is_rec mb_join_arity unf
+-- discard a non-stable one altogether and return empty usage details.
+occAnalUnfolding !env mb_join_arity unf
   = case unf of
       unf@(CoreUnfolding { uf_tmpl = rhs, uf_src = src })
         | isStableSource src ->
             let
-              (WithUsageDetails usage rhs') = occAnalRhs env is_rec mb_join_arity rhs
+              (WithUsageDetails usage rhs') = occAnalLam env mb_join_arity rhs
 
               unf' | noBinderSwaps env = unf -- Note [Unfoldings and rules]
                    | otherwise         = unf { uf_tmpl = rhs' }
@@ -1958,9 +1983,7 @@ occAnalUnfolding !env is_rec mb_join_arity unf
         where
           env'            = env `addInScope` bndrs
           (WithUsageDetails usage args') = occAnalList env' args
-          final_usage     = markAllManyNonTail (delDetailsList usage bndrs)
-                            `addLamCoVarOccs` bndrs
-                            `delDetailsList` bndrs
+          final_usage     = usage `addLamCoVarOccs` bndrs `delDetailsList` bndrs
               -- delDetailsList; no need to use tagLamBinders because we
               -- never inline DFuns so the occ-info on binders doesn't matter
 
@@ -1989,8 +2012,8 @@ occAnalRules !env mb_join_arity bndr
         (WithUsageDetails rhs_uds rhs') = occAnal env' rhs
                             -- Note [Rules are extra RHSs]
                             -- Note [Rule dependency info]
-        rhs_uds' = markAllNonTailIf (not exact_join) $
-                   markAllMany                             $
+        rhs_uds' = markAllNonTailIf (not exact_join) $ -- Nearly adjustTailUsage, but we don't want to
+                   markAllMany                       $ -- build `mkLams (map _ args) rhs` just for the call
                    rhs_uds `delDetailsList` bndrs
 
         exact_join = exactJoin mb_join_arity args
@@ -2031,6 +2054,8 @@ Another way to think about it: if we inlined g as-is into multiple
 call sites, now there's be multiple calls to f.
 
 Bottom line: treat all occurrences in a stable unfolding as "Many".
+We still leave tail call information in tact, though, as to not spoil
+potential join points.
 
 Note [Unfoldings and rules]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -2208,10 +2233,7 @@ occAnal env app@(App _ _)
   = occAnalApp env (collectArgsTicks tickishFloatable app)
 
 occAnal env expr@(Lam {})
-  = let (WithUsageDetails usage expr') = occAnalLam env expr
-        final_usage = markAllInsideLamIf (not (isOneShotFun expr')) $
-                      markAllNonTail usage
-    in WithUsageDetails final_usage expr'
+  = occAnalLam env Nothing expr -- mb_join_arity == Nothing <=> markAllManyNonTail
 
 occAnal env (Case scrut bndr ty alts)
   = let
@@ -2286,7 +2308,7 @@ occAnalApp !env (Var fun, args, ticks)
   --     This caused #18296
   | fun `hasKey` runRWKey
   , [t1, t2, arg]  <- args
-  , let (WithUsageDetails usage arg') = occAnalRhs env NonRecursive (Just 1) arg
+  , WithUsageDetails usage arg' <- occAnalLam env (Just 1) arg
   = WithUsageDetails usage (mkTicks ticks $ mkApps (Var fun) [t1, t2, arg'])
 
 occAnalApp env (Var fun_id, args, ticks)
@@ -3132,11 +3154,11 @@ flattenUsageDetails ud@(UD { ud_env = env })
 
 -------------------
 -- See Note [Adjusting right-hand sides]
-adjustRhsUsage :: Maybe JoinArity
-               -> CoreExpr       -- Rhs, AFTER occ anal
+adjustTailUsage :: Maybe JoinArity
+               -> CoreExpr       -- Rhs, AFTER occAnalLamTail
                -> UsageDetails   -- From body of lambda
                -> UsageDetails
-adjustRhsUsage mb_join_arity rhs usage
+adjustTailUsage mb_join_arity rhs usage
   = -- c.f. occAnal (Lam {})
     markAllInsideLamIf (not one_shot) $
     markAllNonTailIf (not exact_join) $
@@ -3146,6 +3168,28 @@ adjustRhsUsage mb_join_arity rhs usage
     exact_join = exactJoin mb_join_arity bndrs
     (bndrs,_)  = collectBinders rhs
 
+-- | IF rhs becomes a join point, then `manifestJoinArity rhs` returns the
+-- exact join arity of `rhs`.
+manifestJoinArity :: CoreExpr -> JoinArity
+manifestJoinArity rhs = length $ fst $ collectBinders rhs
+
+markNonRecJoinOneShots :: Maybe JoinArity -> CoreExpr -> CoreExpr
+-- For a /non-recursive/ join point we can mark all
+-- its join-lambda as one-shot; and it's a good idea to do so
+markNonRecJoinOneShots Nothing           rhs = rhs
+markNonRecJoinOneShots (Just join_arity) rhs
+  = go join_arity rhs
+  where
+    go 0 rhs         = rhs
+    go n (Lam b rhs) = Lam (if isId b then setOneShotLambda b else b)
+                           (go (n-1) rhs)
+    go _ rhs         = rhs  -- Not enough lambdas.  This can legitimately happen.
+                            -- e.g.    let j = case ... in j True
+                            -- This will become an arity-1 join point after the
+                            -- simplifier has eta-expanded it; but it may not have
+                            -- enough lambdas /yet/. (Lint checks that JoinIds do
+                            -- have enough lambdas.)
+
 exactJoin :: Maybe JoinArity -> [a] -> Bool
 exactJoin Nothing           _    = False
 exactJoin (Just join_arity) args = args `lengthIs` join_arity
@@ -3224,7 +3268,7 @@ tagRecBinders lvl body_uds details_s
 
      -- 2. Adjust usage details of each RHS, taking into account the
      --    join-point-hood decision
-     rhs_udss' = [ adjustRhsUsage (mb_join_arity bndr) rhs rhs_uds
+     rhs_udss' = [ adjustTailUsage (mb_join_arity bndr) rhs rhs_uds -- matching occAnalLamTail in makeNode
                  | ND { nd_bndr = bndr, nd_uds = rhs_uds
                       , nd_rhs = rhs } <- details_s ]
 


=====================================
testsuite/tests/simplCore/should_compile/T22428.hs
=====================================
@@ -0,0 +1,9 @@
+module T22428 where
+
+f :: Integer -> Integer -> Integer
+f x y = go y
+  where
+    go :: Integer -> Integer
+    go 0 = x
+    go n = go (n-1)
+    {-# INLINE go #-}


=====================================
testsuite/tests/simplCore/should_compile/T22428.stderr
=====================================
@@ -0,0 +1,45 @@
+
+==================== Tidy Core ====================
+Result size of Tidy Core
+  = {terms: 32, types: 14, coercions: 0, joins: 1/1}
+
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
+T22428.f1 :: Integer
+[GblId,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+         WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 10}]
+T22428.f1 = GHC.Num.Integer.IS 1#
+
+-- RHS size: {terms: 28, types: 10, coercions: 0, joins: 1/1}
+f :: Integer -> Integer -> Integer
+[GblId,
+ Arity=2,
+ Str=<SL><1L>,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+         WorkFree=True, Expandable=True, Guidance=IF_ARGS [0 0] 156 0}]
+f = \ (x :: Integer) (y :: Integer) ->
+      joinrec {
+        go [InlPrag=INLINE (sat-args=1), Occ=LoopBreaker, Dmd=SC(S,L)]
+          :: Integer -> Integer
+        [LclId[JoinId(1)(Just [!])],
+         Arity=1,
+         Str=<1L>,
+         Unf=Unf{Src=StableUser, TopLvl=False, Value=True, ConLike=True,
+                 WorkFree=True, Expandable=True,
+                 Guidance=ALWAYS_IF(arity=1,unsat_ok=False,boring_ok=False)}]
+        go (ds :: Integer)
+          = case ds of wild {
+              GHC.Num.Integer.IS x1 ->
+                case x1 of {
+                  __DEFAULT -> jump go (GHC.Num.Integer.integerSub wild T22428.f1);
+                  0# -> x
+                };
+              GHC.Num.Integer.IP x1 ->
+                jump go (GHC.Num.Integer.integerSub wild T22428.f1);
+              GHC.Num.Integer.IN x1 ->
+                jump go (GHC.Num.Integer.integerSub wild T22428.f1)
+            }; } in
+      jump go y
+
+
+


=====================================
testsuite/tests/simplCore/should_compile/all.T
=====================================
@@ -447,3 +447,6 @@ test('T22375', normal, compile, ['-O -ddump-simpl -dsuppress-uniques -dno-typeab
 # One module, T21851_2.hs, has OPTIONS_GHC -ddump-simpl
 # Expecting to see $s$wwombat
 test('T21851_2', [grep_errmsg(r'wwombat') ], multimod_compile, ['T21851_2', '-O -dno-typeable-binds -dsuppress-uniques'])
+
+# go should become a join point
+test('T22428', [grep_errmsg(r'jump go') ], compile, ['-O -ddump-simpl -dsuppress-uniques -dno-typeable-binds -dsuppress-unfoldings'])



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/12a8135476bbc4cf48a13caf3d0583e87e50e762

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/12a8135476bbc4cf48a13caf3d0583e87e50e762
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20221117/1fe1b1aa/attachment-0001.html>


More information about the ghc-commits mailing list