[Git][ghc/ghc][wip/T22428] Fix contification with stable unfoldings (#22428)

Sebastian Graf (@sgraf812) gitlab at gitlab.haskell.org
Thu Jan 12 10:13:17 UTC 2023



Sebastian Graf pushed to branch wip/T22428 at Glasgow Haskell Compiler / GHC


Commits:
18f1be99 by Sebastian Graf at 2023-01-12T11:13:00+01:00
Fix contification with stable unfoldings (#22428)

Many functions now return a `TailUsageDetails` that adorns a `UsageDetails` with
a `JoinArity` that reflects the number of join point binders around the body
for which the `UsageDetails` was computed. `TailUsageDetails` is now returned by
`occAnalLamTail` as well as `occAnalUnfolding` and `occAnalRules`.

I adjusted `Note [Join points and unfoldings/rules]` and
`Note [Adjusting right-hand sides]` to account for the new machinery.
I also wrote a new `Note [Join arity prediction based on joinRhsArity]`
and refer to it when we combine `TailUsageDetails` for a recursive RHS.

I also renamed

  * `occAnalLam` to `occAnalLamTail`
  * `adjustRhsUsage` to `adjustTailUsage`
  * a few other less important functions

and properly documented the that each call of `occAnalLamTail` must pair up with
`adjustTailUsage`.

I removed `Note [Unfoldings and join points]` because it was redundant with
`Note [Occurrences in stable unfoldings]`.

While in town, I refactored `mkLoopBreakerNodes` so that it returns a condensed
`NodeDetails` called `SimpleNodeDetails`.

Fixes #22428.

The refactoring seems to have quite beneficial effect on ghc/alloc performance:

```
     CoOpt_Read(normal) ghc/alloc    784,778,420    768,091,176  -2.1% GOOD
         T12150(optasm) ghc/alloc     77,762,270     75,986,720  -2.3% GOOD
         T12425(optasm) ghc/alloc     85,740,186     84,641,712  -1.3% GOOD
         T13056(optasm) ghc/alloc    306,104,656    299,811,632  -2.1% GOOD
         T13253(normal) ghc/alloc    350,233,952    346,004,008  -1.2%
         T14683(normal) ghc/alloc  2,800,514,792  2,754,651,360  -1.6%
         T15304(normal) ghc/alloc  1,230,883,318  1,215,978,336  -1.2%
         T15630(normal) ghc/alloc    153,379,590    151,796,488  -1.0%
         T16577(normal) ghc/alloc  7,356,797,056  7,244,194,416  -1.5%
         T17516(normal) ghc/alloc  1,718,941,448  1,692,157,288  -1.6%
         T19695(normal) ghc/alloc  1,485,794,632  1,458,022,112  -1.9%
        T21839c(normal) ghc/alloc    437,562,314    431,295,896  -1.4% GOOD
        T21839r(normal) ghc/alloc    446,927,580    440,615,776  -1.4% GOOD

              geo. mean                                          -0.6%
              minimum                                            -2.4%
              maximum                                            -0.0%
```

Metric Decrease:
    CoOpt_Read
    T10421
    T12150
    T12425
    T13056
    T18698a
    T18698b
    T21839c
    T21839r
    T9961

- - - - -


7 changed files:

- compiler/GHC/Core/Opt/Arity.hs
- compiler/GHC/Core/Opt/OccurAnal.hs
- compiler/GHC/Data/Graph/Directed.hs
- compiler/GHC/Utils/Misc.hs
- + testsuite/tests/simplCore/should_compile/T22428.hs
- + testsuite/tests/simplCore/should_compile/T22428.stderr
- testsuite/tests/simplCore/should_compile/all.T


Changes:

=====================================
compiler/GHC/Core/Opt/Arity.hs
=====================================
@@ -132,6 +132,9 @@ joinRhsArity :: CoreExpr -> JoinArity
 -- Join points are supposed to have manifestly-visible
 -- lambdas at the top: no ticks, no casts, nothing
 -- Moreover, type lambdas count in JoinArity
+-- NB: For non-recursive bindings, the join arity of the binding may actually be
+-- less that the number of manifestly-visible lambdas.
+-- See Note [Join arity prediction based on joinRhsArity] in GHC.Core.Opt.OccurAnal
 joinRhsArity (Lam _ e) = 1 + joinRhsArity e
 joinRhsArity _         = 0
 


=====================================
compiler/GHC/Core/Opt/OccurAnal.hs
=====================================
@@ -59,7 +59,7 @@ import GHC.Builtin.Names( runRWKey )
 import GHC.Unit.Module( Module )
 
 import Data.List (mapAccumL, mapAccumR)
-import Data.List.NonEmpty (NonEmpty (..), nonEmpty)
+import Data.List.NonEmpty (NonEmpty (..))
 import qualified Data.List.NonEmpty as NE
 
 {-
@@ -510,6 +510,16 @@ of the file.
     at any of the definitions.  This is done by Simplify.simplRecBind,
     when it calls addLetIdInfo.
 
+Note [TailUsageDetails when forming Rec groups]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+The `TailUsageDetails` stored in the `nd_uds` field of a `NodeDetails` is
+computed by `occAnalLamTail` applied to the RHS, not `occAnalExpr`.
+That is because the binding might still become a *non-recursive join point* in
+the AcyclicSCC case of dependency analysis!
+Hence we do the delayed `adjustTailUsage` in `occAnalRec`/`tagRecBinders` to get
+a regular, adjusted UsageDetails.
+See Note [Join points and unfoldings/rules] for more details on the contract.
+
 Note [Stable unfoldings]
 ~~~~~~~~~~~~~~~~~~~~~~~~
 None of the above stuff about RULES applies to a stable unfolding
@@ -608,6 +618,65 @@ tail call with `n` arguments (counting both value and type arguments). Otherwise
 'occ_tail' will be 'NoTailCallInfo'. The tail call info flows bottom-up with the
 rest of 'OccInfo' until it goes on the binder.
 
+Note [Join arity prediction based on joinRhsArity]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+In general, the join arity from tail occurrences of a join point (O) may be
+higher or lower than the manifest join arity of the join body (M). E.g.,
+
+  -- M > O:
+  let f x y = x + y              -- M = 2
+  in if b then f 1 else f 2      -- O = 1
+  ==> { Contify for join arity 1 }
+  join f x = \y -> x + y
+  in if b then jump f 1 else jump f 2
+
+  -- M < O
+  let f = id                     -- M = 0
+  in if ... then f 12 else f 13  -- O = 1
+  ==> { Contify for join arity 1, eta-expand f }
+  join f x = id x
+  in if b then jump f 12 else jump f 13
+
+But for *recursive* let, it is crucial that both arities match up, consider
+
+  letrec f x y = if ... then f x else True
+  in f 42
+
+Here, M=2 but O=1. If we settled for a joinrec arity of 1, the recursive jump
+would not happen in a tail context! Contification is invalid here.
+So indeed it is crucial to demand that M=O.
+
+(Side note: Actually, we could be more specific: Let O1 be the join arity of
+occurrences from the letrec RHS and O2 the join arity from the let body. Then
+we need M=O1 and M<=O2 and could simply eta-expand the RHS to match O2 later.
+M=O is the specific case where we don't want to eta-expand. Neither the join
+points paper nor GHC does this at the moment.)
+
+We can capitalise on this observation and conclude that *if* f could become a
+joinrec (without eta-expansion), it will have join arity M.
+Now, M is just the result of 'joinRhsArity', a rather simple, local analysis.
+It is also the join arity inside the 'TailUsageDetails' returned by
+'occAnalLamTail', so we can predict join arity without doing any fixed-point
+iteration or really doing any deep traversal of let body or RHS at all.
+We check for M in the 'adjustTailUsage' call inside 'tagRecBinders'.
+
+All this is quite apparent if you look at the contification transformation in
+Fig. 5 of "Compiling without Continuations" (which does not account for
+eta-expansion at all, mind you). The letrec case looks like this
+
+  letrec f = /\as.\xs. L[us] in L'[es]
+    ... and a bunch of conditions establishing that f only occurs
+        in app heads of join arity (len as + len xs) inside us and es ...
+
+The syntactic form `/\as.\xs. L[us]` forces M=O iff `f` occurs in `us`. However,
+for non-recursive functions, this is the definition of contification from the
+paper:
+
+  let f = /\as.\xs.u in L[es]     ... conditions ...
+
+Note that u could be a lambda itself, as we have seen. No relationship between M
+and O to exploit here.
+
 Note [Join points and unfoldings/rules]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 Consider
@@ -618,8 +687,10 @@ Consider
 
 Before j is inlined, we'll have occurrences of j2 in
 both j's RHS and in its stable unfolding.  We want to discover
-j2 as a join point.  So we must do the adjustRhsUsage thing
-on j's RHS.  That's why we pass mb_join_arity to calcUnfolding.
+j2 as a join point. So 'occAnalUnfolding' returns an unadjusted
+'TailUsageDetails', like 'occAnalLamTail'. We adjust the usage details of the
+unfolding to the actual join arity using the same 'adjustTailArity' as for
+the RHS, see Note [Adjusting right-hand sides].
 
 Same with rules. Suppose we have:
 
@@ -636,14 +707,31 @@ up.  So provided the join-point arity of k matches the args of the
 rule we can allow the tail-call info from the RHS of the rule to
 propagate.
 
-* Wrinkle for Rec case. In the recursive case we don't know the
-  join-point arity in advance, when calling occAnalUnfolding and
-  occAnalRules.  (See makeNode.)  We don't want to pass Nothing,
-  because then a recursive joinrec might lose its join-poin-hood
-  when SpecConstr adds a RULE.  So we just make do with the
-  *current* join-poin-hood, stored in the Id.
+* Note that the join arity of the RHS and that of the unfolding or RULE might
+  mismatch:
+
+    let j x y = j2 (x+x)
+        {-# INLINE[2] j = \x. g #-}
+        {-# RULE forall x y z. j x y z = h 17 #-}
+    in j 1 2
 
-  In the non-recursive case things are simple: see occAnalNonRecBind
+  So it is crucial that we adjust each TailUsageDetails individually
+  with the actual join arity 2 here before we combine with `andUDs`.
+  Here, that means losing tail call info on `g` and `h`.
+
+* Wrinkle for Rec case: We store one TailUsageDetails in the node Details for
+  RHS, unfolding and RULE combined. Clearly, if they don't agree on their join
+  arity, we have to do some adjusting. We choose to adjust to the join arity
+  of the RHS, because that is likely the join arity that the join point will
+  have; see Note [Join arity prediction based on joinRhsArity].
+
+  If the guess is correct, then tail calls in the RHS are preserved; a necessary
+  condition for the whole binding becoming a joinrec.
+  The guess can only be incorrect in the 'AcyclicSCC' case when the binding
+  becomes a non-recursive join point with a different join arity. But then the
+  eventual call to 'adjustTailUsage' in 'tagRecBinders'/'occAnalRec' will
+  be with a different join arity and destroy unsound tail call info with
+  'markNonTail'.
 
 * Wrinkle for RULES.  Suppose the example was a bit different:
       let j :: Int -> Int
@@ -669,28 +757,21 @@ propagate.
   This appears to be very rare in practice. TODO Perhaps we should gather
   statistics to be sure.
 
-Note [Unfoldings and join points]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-We assume that anything in an unfolding occurs multiple times, since
-unfoldings are often copied (that's the whole point!). But we still
-need to track tail calls for the purpose of finding join points.
-
-
 ------------------------------------------------------------
 Note [Adjusting right-hand sides]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 There's a bit of a dance we need to do after analysing a lambda expression or
 a right-hand side. In particular, we need to
 
-  a) call 'markAllInsideLam' *unless* the binding is for a thunk, a one-shot
-     lambda, or a non-recursive join point; and
-  b) call 'markAllNonTail' *unless* the binding is for a join point, and
-     the RHS has the right arity; e.g.
+  a) call 'markAllNonTail' *unless* the binding is for a join point, and
+     the TailUsageDetails from the RHS has the right join arity; e.g.
         join j x y = case ... of
                        A -> j2 p
                        B -> j2 q
         in j a b
      Here we want the tail calls to j2 to be tail calls of the whole expression
+  b) call 'markAllInsideLam' *unless* the binding is for a thunk, a one-shot
+     lambda, or a non-recursive join point
 
 Some examples, with how the free occurrences in e (assumed not to be a value
 lambda) get marked:
@@ -707,26 +788,39 @@ lambda) get marked:
 There are a few other caveats; most importantly, if we're marking a binding as
 'AlwaysTailCalled', it's *going* to be a join point, so we treat it as one so
 that the effect cascades properly. Consequently, at the time the RHS is
-analysed, we won't know what adjustments to make; thus 'occAnalLamOrRhs' must
-return the unadjusted 'UsageDetails', to be adjusted by 'adjustRhsUsage' once
-join-point-hood has been decided.
-
-Thus the overall sequence taking place in 'occAnalNonRecBind' and
-'occAnalRecBind' is as follows:
-
-  1. Call 'occAnalLamOrRhs' to find usage information for the RHS.
-  2. Call 'tagNonRecBinder' or 'tagRecBinders', which decides whether to make
+analysed, we won't know what adjustments to make; thus 'occAnalLamTail' must
+return the unadjusted 'TailUsageDetails', to be adjusted by 'adjustTailUsage'
+once join-point-hood has been decided and eventual one-shot annotations have
+been added through 'markNonRecJoinOneShots'.
+
+It is not so simple to see that 'occAnalNonRecBind' and 'occAnalRecBind' indeed
+perform a similar sequence of steps. Thus, here is an interleaving of events
+of both functions, serving as a specification:
+
+  1. Call 'occAnalLamTail' to find usage information for the RHS.
+     Recursive case:     'makeNode'
+     Non-recursive case: 'occAnalNonRecBind'
+  2. (Analyse the binding's scope. Done in 'occAnalBind'/`occAnal Let{}`.
+      Same whether recursive or not.)
+  3. Call 'tagNonRecBinder' or 'tagRecBinders', which decides whether to make
      the binding a join point.
-  3. Call 'adjustRhsUsage' accordingly. (Done as part of 'tagRecBinders' when
-     recursive.)
-
-(In the recursive case, this logic is spread between 'makeNode' and
-'occAnalRec'.)
+     Cyclic  Recursive case:  'mkLoopBreakerNodes'
+     Acyclic Recursive case:  `occAnalRec AcyclicSCC{}`
+     Non-recursive case:      'occAnalNonRecBind'
+  4. Non-recursive join point: Call 'markNonRecJoinOneShots' so that e.g.,
+     FloatOut sees one-shot annotations on lambdas
+     Acyclic Recursive case:  `occAnalRec AcyclicSCC{}`  calls 'adjustNonRecRhs'
+     Non-recursive case:      'occAnalNonRecBind'        calls 'adjustNonRecRhs'
+  5. Call 'adjustTailUsage' accordingly.
+     Cyclic Recursive case:   'tagRecBinders'
+     Acyclic Recursive case:  'adjustNonRecRhs'
+     Non-recursive case:      'adjustNonRecRhs'
 -}
 
-
 data WithUsageDetails a = WithUsageDetails !UsageDetails !a
 
+data WithTailUsageDetails a = WithTailUsageDetails !TailUsageDetails !a
+
 ------------------------------------------------------------------
 --                 occAnalBind
 ------------------------------------------------------------------
@@ -750,19 +844,17 @@ occAnalNonRecBind !env lvl imp_rule_edges bndr rhs body_usage
   | isTyVar bndr      -- A type let; we don't gather usage info
   = WithUsageDetails body_usage [NonRec bndr rhs]
 
-  | not (bndr `usedIn` body_usage)    -- It's not mentioned
-  = WithUsageDetails body_usage []
+  | not (bndr `usedIn` body_usage)
+  = WithUsageDetails body_usage [] -- See Note [Dead code]
 
   | otherwise                   -- It's mentioned in the body
-  = WithUsageDetails (body_usage' `andUDs` rhs_usage) [NonRec final_bndr rhs']
+  = WithUsageDetails (body_usage' `andUDs` rhs_usage) [NonRec final_bndr final_rhs]
   where
-    (body_usage', tagged_bndr) = tagNonRecBinder lvl body_usage bndr
-    final_bndr = tagged_bndr `setIdUnfolding` unf'
-                             `setIdSpecialisation` mkRuleInfo rules'
-    rhs_usage = rhs_uds `andUDs` unf_uds `andUDs` rule_uds
+    WithUsageDetails body_usage' tagged_bndr = tagNonRecBinder lvl body_usage bndr
 
     -- Get the join info from the *new* decision
     -- See Note [Join points and unfoldings/rules]
+    -- => join arity O of Note [Join arity prediction based on joinRhsArity]
     mb_join_arity = willBeJoinId_maybe tagged_bndr
     is_join_point = isJust mb_join_arity
 
@@ -773,17 +865,28 @@ occAnalNonRecBind !env lvl imp_rule_edges bndr rhs body_usage
 
     -- See Note [Sources of one-shot information]
     rhs_env = env1 { occ_one_shots = argOneShots dmd }
-    (WithUsageDetails rhs_uds rhs') = occAnalRhs rhs_env NonRecursive mb_join_arity rhs
+    -- See Note [Join arity prediction based on joinRhsArity]
+    -- Match join arity O from mb_join_arity with manifest join arity M as
+    -- returned by of occAnalLamTail. It's totally OK for them to mismatch;
+    -- hence adjust the UDs from the RHS
+    WithUsageDetails adj_rhs_uds final_rhs
+      = adjustNonRecRhs mb_join_arity $ occAnalLamTail rhs_env rhs
+    rhs_usage = adj_rhs_uds `andUDs` adj_unf_uds `andUDs` adj_rule_uds
+    final_bndr = tagged_bndr `setIdSpecialisation` mkRuleInfo rules'
+                             `setIdUnfolding` unf2
 
     --------- Unfolding ---------
-    -- See Note [Unfoldings and join points]
+    -- See Note [Join points and unfoldings/rules]
     unf | isId bndr = idUnfolding bndr
         | otherwise = NoUnfolding
-    (WithUsageDetails unf_uds unf') = occAnalUnfolding rhs_env NonRecursive mb_join_arity unf
+    WithTailUsageDetails unf_uds unf1 = occAnalUnfolding rhs_env unf
+    unf2 = markNonRecUnfoldingOneShots mb_join_arity unf1
+    adj_unf_uds = adjustTailArity mb_join_arity unf_uds
 
     --------- Rules ---------
     -- See Note [Rules are extra RHSs] and Note [Rule dependency info]
-    rules_w_uds  = occAnalRules rhs_env mb_join_arity bndr
+    -- and Note [Join points and unfoldings/rules]
+    rules_w_uds  = occAnalRules rhs_env bndr
     rules'       = map fstOf3 rules_w_uds
     imp_rule_uds = impRulesScopeUsage (lookupImpRules imp_rule_edges bndr)
          -- imp_rule_uds: consider
@@ -794,8 +897,9 @@ occAnalNonRecBind !env lvl imp_rule_edges bndr rhs body_usage
          -- that g is (since the RULE might turn g into h), so
          -- we make g mention h.
 
-    rule_uds = foldr add_rule_uds imp_rule_uds rules_w_uds
-    add_rule_uds (_, l, r) uds = l `andUDs` r `andUDs` uds
+    adj_rule_uds = foldr add_rule_uds imp_rule_uds rules_w_uds
+    add_rule_uds (_, l, r) uds
+      = l `andUDs` adjustTailArity mb_join_arity r `andUDs` uds
 
     ----------
     occ = idOccInfo tagged_bndr
@@ -820,7 +924,7 @@ occAnalRecBind :: OccEnv -> TopLevelFlag -> ImpRuleEdges -> [(Var,CoreExpr)]
 occAnalRecBind !env lvl imp_rule_edges pairs body_usage
   = foldr (occAnalRec rhs_env lvl) (WithUsageDetails body_usage []) sccs
   where
-    sccs :: [SCC Details]
+    sccs :: [SCC NodeDetails]
     sccs = {-# SCC "occAnalBind.scc" #-}
            stronglyConnCompFromEdgedVerticesUniq nodes
 
@@ -832,48 +936,62 @@ occAnalRecBind !env lvl imp_rule_edges pairs body_usage
     bndr_set = mkVarSet bndrs
     rhs_env  = env `addInScope` bndrs
 
+adjustNonRecRhs :: Maybe JoinArity -> WithTailUsageDetails CoreExpr -> WithUsageDetails CoreExpr
+-- ^ This function concentrates shared logic between occAnalNonRecBind and the
+-- AcyclicSCC case of occAnalRec.
+--   * It applies 'markNonRecJoinOneShots' to the RHS
+--   * and returns the adjusted rhs UsageDetails combined with the body usage
+adjustNonRecRhs mb_join_arity (WithTailUsageDetails rhs_tuds rhs)
+  = WithUsageDetails rhs_uds' rhs'
+  where
+    --------- Marking (non-rec) join binders one-shot ---------
+    !rhs' | Just ja <- mb_join_arity = markNonRecJoinOneShots ja rhs
+          | otherwise                = rhs
+    --------- Adjusting right-hand side usage ---------
+    rhs_uds' = adjustTailUsage mb_join_arity rhs' rhs_tuds
+
+bindersOfSCC :: SCC NodeDetails -> [Var]
+bindersOfSCC (AcyclicSCC nd) = [nd_bndr nd]
+bindersOfSCC (CyclicSCC ds)  = map nd_bndr ds
 
 -----------------------------
 occAnalRec :: OccEnv -> TopLevelFlag
-           -> SCC Details
+           -> SCC NodeDetails
            -> WithUsageDetails [CoreBind]
            -> WithUsageDetails [CoreBind]
 
-        -- The NonRec case is just like a Let (NonRec ...) above
-occAnalRec !_ lvl (AcyclicSCC (ND { nd_bndr = bndr, nd_rhs = rhs
-                                  , nd_uds = rhs_uds }))
-           (WithUsageDetails body_uds binds)
-  | not (bndr `usedIn` body_uds)
-  = WithUsageDetails body_uds binds -- See Note [Dead code]
+-- Check for Note [Dead code]
+-- NB: Only look at body_uds, ignoring uses in the SCC
+occAnalRec !_ _ scc (WithUsageDetails body_uds binds)
+  | not (any (`usedIn` body_uds) (bindersOfSCC scc))
+  = WithUsageDetails body_uds binds
 
-  | otherwise                   -- It's mentioned in the body
-  = WithUsageDetails (body_uds' `andUDs` rhs_uds')
-                     (NonRec tagged_bndr rhs : binds)
+-- The NonRec case is just like a Let (NonRec ...) above
+occAnalRec !_ lvl
+           (AcyclicSCC (ND { nd_bndr = bndr, nd_rhs = wtuds }))
+           (WithUsageDetails body_uds binds)
+  = WithUsageDetails (body_uds' `andUDs` rhs_uds') (NonRec bndr' rhs' : binds)
   where
-    (body_uds', tagged_bndr) = tagNonRecBinder lvl body_uds bndr
-    rhs_uds'      = adjustRhsUsage mb_join_arity rhs rhs_uds
+    WithUsageDetails body_uds' tagged_bndr = tagNonRecBinder lvl body_uds bndr
     mb_join_arity = willBeJoinId_maybe tagged_bndr
+    WithUsageDetails rhs_uds' rhs' = adjustNonRecRhs mb_join_arity wtuds
+    !unf'  = markNonRecUnfoldingOneShots mb_join_arity (idUnfolding tagged_bndr)
+    !bndr' = tagged_bndr `setIdUnfolding` unf'
 
-        -- The Rec case is the interesting one
-        -- See Note [Recursive bindings: the grand plan]
-        -- See Note [Loop breaking]
+-- The Rec case is the interesting one
+-- See Note [Recursive bindings: the grand plan]
+-- See Note [Loop breaking]
 occAnalRec env lvl (CyclicSCC details_s) (WithUsageDetails body_uds binds)
-  | not (any (`usedIn` body_uds) bndrs) -- NB: look at body_uds, not total_uds
-  = WithUsageDetails body_uds binds     -- See Note [Dead code]
-
-  | otherwise   -- At this point we always build a single Rec
   = -- pprTrace "occAnalRec" (ppr loop_breaker_nodes)
     WithUsageDetails final_uds (Rec pairs : binds)
-
   where
-    bndrs      = map nd_bndr details_s
     all_simple = all nd_simple details_s
 
     ------------------------------
     -- Make the nodes for the loop-breaker analysis
     -- See Note [Choosing loop breakers] for loop_breaker_nodes
     final_uds :: UsageDetails
-    loop_breaker_nodes :: [LetrecNode]
+    loop_breaker_nodes :: [LoopBreakerNode]
     (WithUsageDetails final_uds loop_breaker_nodes) = mkLoopBreakerNodes env lvl body_uds details_s
 
     ------------------------------
@@ -1102,7 +1220,7 @@ type Binding = (Id,CoreExpr)
 loopBreakNodes :: Int
                -> VarSet        -- Binders whose dependencies may be "missing"
                                 -- See Note [Weak loop breakers]
-               -> [LetrecNode]
+               -> [LoopBreakerNode]
                -> [Binding]             -- Append these to the end
                -> [Binding]
 
@@ -1121,7 +1239,7 @@ loopBreakNodes depth weak_fvs nodes binds
           CyclicSCC nodes  -> reOrderNodes depth weak_fvs nodes binds
 
 ----------------------------------
-reOrderNodes :: Int -> VarSet -> [LetrecNode] -> [Binding] -> [Binding]
+reOrderNodes :: Int -> VarSet -> [LoopBreakerNode] -> [Binding] -> [Binding]
     -- Choose a loop breaker, mark it no-inline,
     -- and call loopBreakNodes on the rest
 reOrderNodes _ _ []     _     = panic "reOrderNodes"
@@ -1133,7 +1251,7 @@ reOrderNodes depth weak_fvs (node : nodes) binds
     (map (nodeBinding mk_loop_breaker) chosen_nodes ++ binds)
   where
     (chosen_nodes, unchosen) = chooseLoopBreaker approximate_lb
-                                                 (nd_score (node_payload node))
+                                                 (snd_score (node_payload node))
                                                  [node] [] nodes
 
     approximate_lb = depth >= 2
@@ -1142,8 +1260,8 @@ reOrderNodes depth weak_fvs (node : nodes) binds
         -- After two iterations (d=0, d=1) give up
         -- and approximate, returning to d=0
 
-nodeBinding :: (Id -> Id) -> LetrecNode -> Binding
-nodeBinding set_id_occ (node_payload -> ND { nd_bndr = bndr, nd_rhs = rhs})
+nodeBinding :: (Id -> Id) -> LoopBreakerNode -> Binding
+nodeBinding set_id_occ (node_payload -> SND { snd_bndr = bndr, snd_rhs = rhs})
   = (set_id_occ bndr, rhs)
 
 mk_loop_breaker :: Id -> Id
@@ -1163,13 +1281,13 @@ mk_non_loop_breaker weak_fvs bndr
     tail_info = tailCallInfo (idOccInfo bndr)
 
 ----------------------------------
-chooseLoopBreaker :: Bool             -- True <=> Too many iterations,
-                                      --          so approximate
-                  -> NodeScore            -- Best score so far
-                  -> [LetrecNode]       -- Nodes with this score
-                  -> [LetrecNode]       -- Nodes with higher scores
-                  -> [LetrecNode]       -- Unprocessed nodes
-                  -> ([LetrecNode], [LetrecNode])
+chooseLoopBreaker :: Bool                -- True <=> Too many iterations,
+                                         --          so approximate
+                  -> NodeScore           -- Best score so far
+                  -> [LoopBreakerNode]   -- Nodes with this score
+                  -> [LoopBreakerNode]   -- Nodes with higher scores
+                  -> [LoopBreakerNode]   -- Unprocessed nodes
+                  -> ([LoopBreakerNode], [LoopBreakerNode])
     -- This loop looks for the bind with the lowest score
     -- to pick as the loop  breaker.  The rest accumulate in
 chooseLoopBreaker _ _ loop_nodes acc []
@@ -1189,7 +1307,7 @@ chooseLoopBreaker approx_lb loop_sc loop_nodes acc (node : nodes)
   | otherwise              -- Worse score so don't pick it
   = chooseLoopBreaker approx_lb loop_sc loop_nodes (node : acc) nodes
   where
-    sc = nd_score (node_payload node)
+    sc = snd_score (node_payload node)
 
 {-
 Note [Complexity of loop breaking]
@@ -1322,16 +1440,21 @@ ToDo: try using the occurrence info for the inline'd binder.
 ************************************************************************
 -}
 
-type LetrecNode = Node Unique Details  -- Node comes from Digraph
-                                       -- The Unique key is gotten from the Id
-data Details
-  = ND { nd_bndr :: Id          -- Binder
+-- | Digraph node as constructed by 'makeNode' and consumed by 'occAnalRec'.
+-- The Unique key is gotten from the Id.
+type LetrecNode = Node Unique NodeDetails
 
-       , nd_rhs  :: CoreExpr    -- RHS, already occ-analysed
+-- | Node details as consumed by 'occAnalRec'.
+data NodeDetails
+  = ND { nd_bndr :: Id          -- Binder
 
-       , nd_uds  :: UsageDetails  -- Usage from RHS, and RULES, and stable unfoldings
-                                  -- ignoring phase (ie assuming all are active)
-                                  -- See Note [Forming Rec groups]
+       , nd_rhs  :: !(WithTailUsageDetails CoreExpr)
+         -- ^ RHS, already occ-analysed
+         -- With TailUsageDetails from RHS, and RULES, and stable unfoldings,
+         -- ignoring phase (ie assuming all are active).
+         -- NB: Unadjusted TailUsageDetails, as if this Node becomes a
+         -- non-recursive join point!
+         -- See Note [TailUsageDetails when forming Rec groups]
 
        , nd_inl  :: IdSet       -- Free variables of the stable unfolding and the RHS
                                 -- but excluding any RULES
@@ -1348,18 +1471,33 @@ data Details
        , nd_active_rule_fvs :: IdSet    -- Variables bound in this Rec group that are free
                                         -- in the RHS of an active rule for this bndr
                                         -- See Note [Rules and loop breakers]
-
-       , nd_score :: NodeScore
   }
 
-instance Outputable Details where
+instance Outputable NodeDetails where
    ppr nd = text "ND" <> braces
              (sep [ text "bndr =" <+> ppr (nd_bndr nd)
-                  , text "uds =" <+> ppr (nd_uds nd)
+                  , text "uds =" <+> ppr uds
                   , text "inl =" <+> ppr (nd_inl nd)
                   , text "simple =" <+> ppr (nd_simple nd)
                   , text "active_rule_fvs =" <+> ppr (nd_active_rule_fvs nd)
-                  , text "score =" <+> ppr (nd_score nd)
+             ])
+            where WithTailUsageDetails uds _ = nd_rhs nd
+
+-- | Digraph with simplified and completely occurrence analysed
+-- 'SimpleNodeDetails', retaining just the info we need for breaking loops.
+type LoopBreakerNode = Node Unique SimpleNodeDetails
+
+-- | Condensed variant of 'NodeDetails' needed during loop breaking.
+data SimpleNodeDetails
+  = SND { snd_bndr  :: IdWithOccInfo  -- OccInfo accurate
+        , snd_rhs   :: CoreExpr       -- properly occur-analysed
+        , snd_score :: NodeScore
+        }
+
+instance Outputable SimpleNodeDetails where
+   ppr nd = text "SND" <> braces
+             (sep [ text "bndr =" <+> ppr (snd_bndr nd)
+                  , text "score =" <+> ppr (snd_score nd)
              ])
 
 -- The NodeScore is compared lexicographically;
@@ -1387,52 +1525,59 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs)
     -- explained in Note [Deterministic SCC] in GHC.Data.Graph.Directed.
   where
     details = ND { nd_bndr            = bndr'
-                 , nd_rhs             = rhs'
-                 , nd_uds             = scope_uds
+                 , nd_rhs             = WithTailUsageDetails scope_uds rhs'
                  , nd_inl             = inl_fvs
                  , nd_simple          = null rules_w_uds && null imp_rule_info
                  , nd_weak_fvs        = weak_fvs
-                 , nd_active_rule_fvs = active_rule_fvs
-                 , nd_score           = pprPanic "makeNodeDetails" (ppr bndr) }
+                 , nd_active_rule_fvs = active_rule_fvs }
 
     bndr' = bndr `setIdUnfolding`      unf'
                  `setIdSpecialisation` mkRuleInfo rules'
 
-    inl_uds = rhs_uds `andUDs` unf_uds
-    scope_uds = inl_uds `andUDs` rule_uds
+    -- NB: Both adj_unf_uds and adj_rule_uds have been adjusted to match the
+    --     JoinArity rhs_ja of unadj_rhs_uds.
+    unadj_inl_uds   = unadj_rhs_uds `andUDs` adj_unf_uds
+    unadj_scope_uds = unadj_inl_uds `andUDs` adj_rule_uds
+    scope_uds       = TUD rhs_ja unadj_scope_uds
                    -- Note [Rules are extra RHSs]
                    -- Note [Rule dependency info]
-    scope_fvs = udFreeVars bndr_set scope_uds
+    scope_fvs = udFreeVars bndr_set unadj_scope_uds
     -- scope_fvs: all occurrences from this binder: RHS, unfolding,
     --            and RULES, both LHS and RHS thereof, active or inactive
 
-    inl_fvs  = udFreeVars bndr_set inl_uds
+    inl_fvs  = udFreeVars bndr_set unadj_inl_uds
     -- inl_fvs: vars that would become free if the function was inlined.
     -- We conservatively approximate that by thefree vars from the RHS
     -- and the unfolding together.
     -- See Note [inl_fvs]
 
-    mb_join_arity = isJoinId_maybe bndr
-    -- Get join point info from the *current* decision
-    -- We don't know what the new decision will be!
-    -- Using the old decision at least allows us to
-    -- preserve existing join point, even RULEs are added
-    -- See Note [Join points and unfoldings/rules]
 
     --------- Right hand side ---------
     -- Constructing the edges for the main Rec computation
     -- See Note [Forming Rec groups]
-    -- Do not use occAnalRhs because we don't yet know the final
-    -- answer for mb_join_arity; instead, do the occAnalLam call from
-    -- occAnalRhs, and postpone adjustRhsUsage until occAnalRec
-    rhs_env                         = rhsCtxt env
-    (WithUsageDetails rhs_uds rhs') = occAnalLam rhs_env rhs
+    -- and Note [TailUsageDetails when forming Rec groups]
+    -- Compared to occAnalNonRecBind, we can't yet adjust the RHS because
+    --   (a) we don't yet know the final joinpointhood. It might not become a
+    --       join point after all!
+    --   (b) we don't even know whether it stays a recursive RHS after the SCC
+    --       analysis we are about to seed! So we can't markAllInsideLam in
+    --       advance, because if it ends up as a non-recursive join point we'll
+    --       consider it as one-shot and don't need to markAllInsideLam.
+    -- Instead, do the occAnalLamTail call here and postpone adjustTailUsage
+    -- until occAnalRec. In effect, we pretend that the RHS becomes a
+    -- non-recursive join point and fix up later with adjustTailUsage.
+    rhs_env = rhsCtxt env
+    WithTailUsageDetails (TUD rhs_ja unadj_rhs_uds) rhs' = occAnalLamTail rhs_env rhs
+      -- corresponding call to adjustTailUsage in occAnalRec and tagRecBinders
 
     --------- Unfolding ---------
-    -- See Note [Unfoldings and join points]
+    -- See Note [Join points and unfoldings/rules]
     unf = realIdUnfolding bndr -- realIdUnfolding: Ignore loop-breaker-ness
                                -- here because that is what we are setting!
-    (WithUsageDetails unf_uds unf') = occAnalUnfolding rhs_env Recursive mb_join_arity unf
+    WithTailUsageDetails unf_tuds unf' = occAnalUnfolding rhs_env unf
+    adj_unf_uds = adjustTailArity (Just rhs_ja) unf_tuds
+      -- `rhs_ja` is `joinRhsArity rhs` and is the prediction for source M
+      -- of Note [Join arity prediction based on joinRhsArity]
 
     --------- IMP-RULES --------
     is_active     = occ_rule_act env :: Activation -> Bool
@@ -1441,11 +1586,15 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs)
     imp_rule_fvs  = impRulesActiveFvs is_active bndr_set imp_rule_info
 
     --------- All rules --------
+    -- See Note [Join points and unfoldings/rules]
+    -- `rhs_ja` is `joinRhsArity rhs'` and is the prediction for source M
+    -- of Note [Join arity prediction based on joinRhsArity]
     rules_w_uds :: [(CoreRule, UsageDetails, UsageDetails)]
-    rules_w_uds = occAnalRules rhs_env mb_join_arity bndr
+    rules_w_uds = [ (r,l,adjustTailArity (Just rhs_ja) rhs_tuds)
+                  | (r,l,rhs_tuds) <- occAnalRules rhs_env bndr ]
     rules'      = map fstOf3 rules_w_uds
 
-    rule_uds = foldr add_rule_uds imp_rule_uds rules_w_uds
+    adj_rule_uds = foldr add_rule_uds imp_rule_uds rules_w_uds
     add_rule_uds (_, l, r) uds = l `andUDs` r `andUDs` uds
 
     -------- active_rule_fvs ------------
@@ -1463,8 +1612,8 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs)
 
 mkLoopBreakerNodes :: OccEnv -> TopLevelFlag
                    -> UsageDetails   -- for BODY of let
-                   -> [Details]
-                   -> WithUsageDetails [LetrecNode] -- adjusted
+                   -> [NodeDetails]
+                   -> WithUsageDetails [LoopBreakerNode] -- with OccInfo up-to-date
 -- See Note [Choosing loop breakers]
 -- This function primarily creates the Nodes for the
 -- loop-breaker SCC analysis.  More specifically:
@@ -1477,10 +1626,10 @@ mkLoopBreakerNodes :: OccEnv -> TopLevelFlag
 mkLoopBreakerNodes !env lvl body_uds details_s
   = WithUsageDetails final_uds (zipWithEqual "mkLoopBreakerNodes" mk_lb_node details_s bndrs')
   where
-    (final_uds, bndrs') = tagRecBinders lvl body_uds details_s
+    WithUsageDetails final_uds bndrs' = tagRecBinders lvl body_uds details_s
 
     mk_lb_node nd@(ND { nd_bndr = old_bndr, nd_inl = inl_fvs }) new_bndr
-      = DigraphNode { node_payload      = new_nd
+      = DigraphNode { node_payload      = simple_nd
                     , node_key          = varUnique old_bndr
                     , node_dependencies = nonDetKeysUniqSet lb_deps }
               -- It's OK to use nonDetKeysUniqSet here as
@@ -1488,7 +1637,8 @@ mkLoopBreakerNodes !env lvl body_uds details_s
               -- in nondeterministic order as explained in
               -- Note [Deterministic SCC] in GHC.Data.Graph.Directed.
       where
-        new_nd = nd { nd_bndr = new_bndr, nd_score = score }
+        WithTailUsageDetails _ rhs = nd_rhs nd
+        simple_nd = SND { snd_bndr = new_bndr, snd_rhs = rhs, snd_score = score }
         score  = nodeScore env new_bndr lb_deps nd
         lb_deps = extendFvs_ rule_fv_env inl_fvs
         -- See Note [Loop breaker dependencies]
@@ -1524,10 +1674,10 @@ group { f1 = e1; ...; fn = en } are:
 nodeScore :: OccEnv
           -> Id        -- Binder with new occ-info
           -> VarSet    -- Loop-breaker dependencies
-          -> Details
+          -> NodeDetails
           -> NodeScore
 nodeScore !env new_bndr lb_deps
-          (ND { nd_bndr = old_bndr, nd_rhs = bind_rhs })
+          (ND { nd_bndr = old_bndr, nd_rhs = WithTailUsageDetails _ bind_rhs })
 
   | not (isId old_bndr)     -- A type or coercion variable is never a loop breaker
   = (100, 0, False)
@@ -1748,7 +1898,7 @@ lambda and casts, e.g.
 
 * Occurrence analyser: we just mark each binder in the lambda-group
   (here: x,y,z) with its occurrence info in the *body* of the
-  lambda-group.  See occAnalLam.
+  lambda-group.  See occAnalLamTail.
 
 * Simplifier.  The simplifier is careful when partially applying
   lambda-groups. See the call to zapLambdaBndrs in
@@ -1804,25 +1954,31 @@ zapLambdaBndrs fun arg_count
     zap_bndr b | isTyVar b = b
                | otherwise = zapLamIdInfo b
 
-occAnalLam :: OccEnv -> CoreExpr -> (WithUsageDetails CoreExpr)
--- See Note [Occurrence analysis for lambda binders]
+occAnalLamTail :: OccEnv -> CoreExpr -> WithTailUsageDetails CoreExpr
+-- ^ See Note [Occurrence analysis for lambda binders].
 -- It does the following:
 --   * Sets one-shot info on the lambda binder from the OccEnv, and
 --     removes that one-shot info from the OccEnv
 --   * Sets the OccEnv to OccVanilla when going under a value lambda
 --   * Tags each lambda with its occurrence information
 --   * Walks through casts
+--   * Package up the analysed lambda with its manifest join arity
+--
 -- This function does /not/ do
 --   markAllInsideLam or
 --   markAllNonTail
--- The caller does that, either in occAnal (Lam {}), or in adjustRhsUsage
+-- The caller does that, via adjustTailUsage (mostly calls go through
+-- adjustNonRecRhs). Every call to occAnalLamTail must ultimately call
+-- adjustTailUsage to discharge the assumed join arity.
+--
+-- In effect, the analysis result is for a non-recursive join point with
+-- manifest arity and adjustTailUsage does the fixup.
 -- See Note [Adjusting right-hand sides]
-
-occAnalLam env (Lam bndr expr)
+occAnalLamTail env (Lam bndr expr)
   | isTyVar bndr
-  = let env1 = addOneInScope env bndr
-        WithUsageDetails usage expr' = occAnalLam env1 expr
-    in WithUsageDetails usage (Lam bndr expr')
+  , let env1 = addOneInScope env bndr
+  , WithTailUsageDetails (TUD ja usage) expr' <- occAnalLamTail env1 expr
+  = WithTailUsageDetails (TUD (ja+1) usage) (Lam bndr expr')
        -- Important: Keep the 'env' unchanged so that with a RHS like
        --   \(@ x) -> K @x (f @x)
        -- we'll see that (K @x (f @x)) is in a OccRhs, and hence refrain
@@ -1840,14 +1996,14 @@ occAnalLam env (Lam bndr expr)
 
         env1 = env { occ_encl = OccVanilla, occ_one_shots = env_one_shots' }
         env2 = addOneInScope env1 bndr
-        (WithUsageDetails usage expr') = occAnalLam env2 expr
+        WithTailUsageDetails (TUD ja usage) expr' = occAnalLamTail env2 expr
         (usage', bndr2) = tagLamBinder usage bndr1
-    in WithUsageDetails usage' (Lam bndr2 expr')
+    in WithTailUsageDetails (TUD (ja+1) usage') (Lam bndr2 expr')
 
 -- For casts, keep going in the same lambda-group
 -- See Note [Occurrence analysis for lambda binders]
-occAnalLam env (Cast expr co)
-  = let  (WithUsageDetails usage expr') = occAnalLam env expr
+occAnalLamTail env (Cast expr co)
+  = let  WithTailUsageDetails (TUD ja usage) expr' = occAnalLamTail env expr
          -- usage1: see Note [Gather occurrences of coercion variables]
          usage1 = addManyOccs usage (coVarsOfCo co)
 
@@ -1857,15 +2013,16 @@ occAnalLam env (Cast expr co)
                     _ -> usage1
 
          -- usage3: you might think this was not necessary, because of
-         -- the markAllNonTail in adjustRhsUsage; but not so!  For a
-         -- join point, adjustRhsUsage doesn't do this; yet if there is
+         -- the markAllNonTail in adjustTailUsage; but not so!  For a
+         -- join point, adjustTailUsage doesn't do this; yet if there is
          -- a cast, we must!  Also: why markAllNonTail?  See
          -- GHC.Core.Lint: Note Note [Join points and casts]
          usage3 = markAllNonTail usage2
 
-    in WithUsageDetails usage3 (Cast expr' co)
+    in WithTailUsageDetails (TUD ja usage3) (Cast expr' co)
 
-occAnalLam env expr = occAnal env expr
+occAnalLamTail env expr = case occAnal env expr of
+  WithUsageDetails usage expr' -> WithTailUsageDetails (TUD 0 usage) expr'
 
 {- Note [Occ-anal and cast worker/wrapper]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1885,8 +2042,8 @@ RHS. So it'll get a Many occ-info.  (Maybe Cast w/w should create a stable
 unfolding, which would obviate this Note; but that seems a bit of a
 heavyweight solution.)
 
-We only need to this in occAnalLam, not occAnal, because the top leve
-of a right hand side is handled by occAnalLam.
+We only need to this in occAnalLamTail, not occAnal, because the top leve
+of a right hand side is handled by occAnalLamTail.
 -}
 
 
@@ -1896,57 +2053,23 @@ of a right hand side is handled by occAnalLam.
 *                                                                      *
 ********************************************************************* -}
 
-occAnalRhs :: OccEnv -> RecFlag -> Maybe JoinArity
-           -> CoreExpr   -- RHS
-           -> WithUsageDetails CoreExpr
-occAnalRhs !env is_rec mb_join_arity rhs
-  = let (WithUsageDetails usage rhs1) = occAnalLam env rhs
-           -- We call occAnalLam here, not occAnalExpr, so that it doesn't
-           -- do the markAllInsideLam and markNonTailCall stuff before
-           -- we've had a chance to help with join points; that comes next
-        rhs2      = markJoinOneShots is_rec mb_join_arity rhs1
-        rhs_usage = adjustRhsUsage mb_join_arity rhs2 usage
-    in WithUsageDetails rhs_usage rhs2
-
-
-
-markJoinOneShots :: RecFlag -> Maybe JoinArity -> CoreExpr -> CoreExpr
--- For a /non-recursive/ join point we can mark all
--- its join-lambda as one-shot; and it's a good idea to do so
-markJoinOneShots NonRecursive (Just join_arity) rhs
-  = go join_arity rhs
-  where
-    go 0 rhs         = rhs
-    go n (Lam b rhs) = Lam (if isId b then setOneShotLambda b else b)
-                           (go (n-1) rhs)
-    go _ rhs         = rhs  -- Not enough lambdas.  This can legitimately happen.
-                            -- e.g.    let j = case ... in j True
-                            -- This will become an arity-1 join point after the
-                            -- simplifier has eta-expanded it; but it may not have
-                            -- enough lambdas /yet/. (Lint checks that JoinIds do
-                            -- have enough lambdas.)
-markJoinOneShots _ _ rhs
-  = rhs
-
 occAnalUnfolding :: OccEnv
-                 -> RecFlag
-                 -> Maybe JoinArity   -- See Note [Join points and unfoldings/rules]
                  -> Unfolding
-                 -> WithUsageDetails Unfolding
+                 -> WithTailUsageDetails Unfolding
 -- Occurrence-analyse a stable unfolding;
--- discard a non-stable one altogether.
-occAnalUnfolding !env is_rec mb_join_arity unf
+-- discard a non-stable one altogether and return empty usage details.
+occAnalUnfolding !env unf
   = case unf of
       unf@(CoreUnfolding { uf_tmpl = rhs, uf_src = src })
         | isStableSource src ->
             let
-              (WithUsageDetails usage rhs') = occAnalRhs env is_rec mb_join_arity rhs
+              WithTailUsageDetails (TUD rhs_ja usage) rhs' = occAnalLamTail env rhs
 
               unf' | noBinderSwaps env = unf -- Note [Unfoldings and rules]
                    | otherwise         = unf { uf_tmpl = rhs' }
-            in WithUsageDetails (markAllMany usage) unf'
+            in WithTailUsageDetails (TUD rhs_ja (markAllMany usage)) unf'
               -- markAllMany: see Note [Occurrences in stable unfoldings]
-        | otherwise          -> WithUsageDetails emptyDetails unf
+        | otherwise          -> WithTailUsageDetails (TUD 0 emptyDetails) unf
               -- For non-Stable unfoldings we leave them undisturbed, but
               -- don't count their usage because the simplifier will discard them.
               -- We leave them undisturbed because nodeScore uses their size info
@@ -1955,29 +2078,26 @@ occAnalUnfolding !env is_rec mb_join_arity unf
               -- scope remain in scope; there is no cloning etc.
 
       unf@(DFunUnfolding { df_bndrs = bndrs, df_args = args })
-        -> WithUsageDetails final_usage (unf { df_args = args' })
+        -> WithTailUsageDetails (TUD 0 final_usage) (unf { df_args = args' })
         where
           env'            = env `addInScope` bndrs
           (WithUsageDetails usage args') = occAnalList env' args
-          final_usage     = markAllManyNonTail (delDetailsList usage bndrs)
-                            `addLamCoVarOccs` bndrs
-                            `delDetailsList` bndrs
+          final_usage     = usage `addLamCoVarOccs` bndrs `delDetailsList` bndrs
               -- delDetailsList; no need to use tagLamBinders because we
               -- never inline DFuns so the occ-info on binders doesn't matter
 
-      unf -> WithUsageDetails emptyDetails unf
+      unf -> WithTailUsageDetails (TUD 0 emptyDetails) unf
 
 occAnalRules :: OccEnv
-             -> Maybe JoinArity  -- See Note [Join points and unfoldings/rules]
              -> Id               -- Get rules from here
              -> [(CoreRule,      -- Each (non-built-in) rule
                   UsageDetails,  -- Usage details for LHS
-                  UsageDetails)] -- Usage details for RHS
-occAnalRules !env mb_join_arity bndr
+                  TailUsageDetails)] -- Usage details for RHS
+occAnalRules !env bndr
   = map occ_anal_rule (idCoreRules bndr)
   where
     occ_anal_rule rule@(Rule { ru_bndrs = bndrs, ru_args = args, ru_rhs = rhs })
-      = (rule', lhs_uds', rhs_uds')
+      = (rule', lhs_uds', TUD rhs_ja rhs_uds')
       where
         env' = env `addInScope` bndrs
         rule' | noBinderSwaps env = rule  -- Note [Unfoldings and rules]
@@ -1990,14 +2110,11 @@ occAnalRules !env mb_join_arity bndr
         (WithUsageDetails rhs_uds rhs') = occAnal env' rhs
                             -- Note [Rules are extra RHSs]
                             -- Note [Rule dependency info]
-        rhs_uds' = markAllNonTailIf (not exact_join) $
-                   markAllMany                             $
+        rhs_uds' = markAllMany $
                    rhs_uds `delDetailsList` bndrs
+        rhs_ja = length args -- See Note [Join points and unfoldings/rules]
 
-        exact_join = exactJoin mb_join_arity args
-                     -- See Note [Join points and unfoldings/rules]
-
-    occ_anal_rule other_rule = (other_rule, emptyDetails, emptyDetails)
+    occ_anal_rule other_rule = (other_rule, emptyDetails, TUD 0 emptyDetails)
 
 {- Note [Join point RHSs]
 ~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -2032,6 +2149,8 @@ Another way to think about it: if we inlined g as-is into multiple
 call sites, now there's be multiple calls to f.
 
 Bottom line: treat all occurrences in a stable unfolding as "Many".
+We still leave tail call information intact, though, as to not spoil
+potential join points.
 
 Note [Unfoldings and rules]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -2209,10 +2328,7 @@ occAnal env app@(App _ _)
   = occAnalApp env (collectArgsTicks tickishFloatable app)
 
 occAnal env expr@(Lam {})
-  = let (WithUsageDetails usage expr') = occAnalLam env expr
-        final_usage = markAllInsideLamIf (not (isOneShotFun expr')) $
-                      markAllNonTail usage
-    in WithUsageDetails final_usage expr'
+  = adjustNonRecRhs Nothing $ occAnalLamTail env expr -- mb_join_arity == Nothing <=> markAllManyNonTail
 
 occAnal env (Case scrut bndr ty alts)
   = let
@@ -2287,7 +2403,7 @@ occAnalApp !env (Var fun, args, ticks)
   --     This caused #18296
   | fun `hasKey` runRWKey
   , [t1, t2, arg]  <- args
-  , let (WithUsageDetails usage arg') = occAnalRhs env NonRecursive (Just 1) arg
+  , WithUsageDetails usage arg' <- adjustNonRecRhs (Just 1) $ occAnalLamTail env arg
   = WithUsageDetails usage (mkTicks ticks $ mkApps (Var fun) [t1, t2, arg'])
 
 occAnalApp env (Var fun_id, args, ticks)
@@ -2872,7 +2988,6 @@ lookupBndrSwap env@(OccEnv { occ_bs_env = bs_env })  bndr
     case lookupBndrSwap env bndr1 of
       (fun, fun_id) -> (mkCastMCo fun mco, fun_id) }
 
-
 {- Historical note [Proxy let-bindings]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 We used to do the binder-swap transformation by introducing
@@ -2998,6 +3113,19 @@ data UsageDetails
 instance Outputable UsageDetails where
   ppr ud = ppr (ud_env (flattenUsageDetails ud))
 
+-- | Captures the result of applying 'occAnalLamTail' to a function `\xyz.body`.
+-- The TailUsageDetails records
+--   * the number of lambdas (including type lambdas: a JoinArity)
+--   * UsageDetails for the `body`, unadjusted by `adjustTailUsage`.
+--     If the binding turns out to be a join point with the indicated join
+--     arity, this unadjusted usage details is just what we need; otherwise we
+--     need to discard tail calls. That's what `adjustTailUsage` does.
+data TailUsageDetails = TUD !JoinArity !UsageDetails
+
+instance Outputable TailUsageDetails where
+  ppr (TUD ja uds) = lambda <> ppr ja <> ppr uds
+
+
 -------------------
 -- UsageDetails API
 
@@ -3139,24 +3267,49 @@ flattenUsageDetails ud@(UD { ud_env = env })
 
 -------------------
 -- See Note [Adjusting right-hand sides]
-adjustRhsUsage :: Maybe JoinArity
-               -> CoreExpr       -- Rhs, AFTER occ anal
-               -> UsageDetails   -- From body of lambda
+adjustTailUsage :: Maybe JoinArity
+               -> CoreExpr           -- Rhs, AFTER occAnalLamTail
+               -> TailUsageDetails   -- From body of lambda
                -> UsageDetails
-adjustRhsUsage mb_join_arity rhs usage
+adjustTailUsage mb_join_arity rhs (TUD rhs_ja usage)
   = -- c.f. occAnal (Lam {})
     markAllInsideLamIf (not one_shot) $
     markAllNonTailIf (not exact_join) $
     usage
   where
     one_shot   = isOneShotFun rhs
-    exact_join = exactJoin mb_join_arity bndrs
-    (bndrs,_)  = collectBinders rhs
+    exact_join = mb_join_arity == Just rhs_ja
+
+adjustTailArity :: Maybe JoinArity -> TailUsageDetails -> UsageDetails
+adjustTailArity mb_rhs_ja (TUD ud_ja usage) =
+  markAllNonTailIf (mb_rhs_ja /= Just ud_ja) usage
+
+markNonRecJoinOneShots :: JoinArity -> CoreExpr -> CoreExpr
+-- For a /non-recursive/ join point we can mark all
+-- its join-lambda as one-shot; and it's a good idea to do so
+markNonRecJoinOneShots join_arity rhs
+  = go join_arity rhs
+  where
+    go 0 rhs         = rhs
+    go n (Lam b rhs) = Lam (if isId b then setOneShotLambda b else b)
+                           (go (n-1) rhs)
+    go _ rhs         = rhs  -- Not enough lambdas.  This can legitimately happen.
+                            -- e.g.    let j = case ... in j True
+                            -- This will become an arity-1 join point after the
+                            -- simplifier has eta-expanded it; but it may not have
+                            -- enough lambdas /yet/. (Lint checks that JoinIds do
+                            -- have enough lambdas.)
 
-exactJoin :: Maybe JoinArity -> [a] -> Bool
-exactJoin Nothing           _    = False
-exactJoin (Just join_arity) args = args `lengthIs` join_arity
-  -- Remember join_arity includes type binders
+markNonRecUnfoldingOneShots :: Maybe JoinArity -> Unfolding -> Unfolding
+-- ^ Apply 'markNonRecJoinOneShots' to a stable unfolding
+markNonRecUnfoldingOneShots mb_join_arity unf
+  | Just ja <- mb_join_arity
+  , CoreUnfolding{uf_src=src,uf_tmpl=tmpl} <- unf
+  , isStableSource src
+  , let !tmpl' = markNonRecJoinOneShots ja tmpl
+  = unf{uf_tmpl=tmpl'}
+  | otherwise
+  = unf
 
 type IdWithOccInfo = Id
 
@@ -3192,8 +3345,8 @@ tagLamBinder usage bndr
 tagNonRecBinder :: TopLevelFlag           -- At top level?
                 -> UsageDetails           -- Of scope
                 -> CoreBndr               -- Binder
-                -> (UsageDetails,         -- Details with binder removed
-                    IdWithOccInfo)        -- Tagged binder
+                -> WithUsageDetails       -- Details with binder removed
+                    IdWithOccInfo         -- Tagged binder
 
 tagNonRecBinder lvl usage binder
  = let
@@ -3205,37 +3358,34 @@ tagNonRecBinder lvl usage binder
      binder' = setBinderOcc occ' binder
      usage'  = usage `delDetails` binder
    in
-   usage' `seq` (usage', binder')
+   WithUsageDetails usage' binder'
 
 tagRecBinders :: TopLevelFlag           -- At top level?
               -> UsageDetails           -- Of body of let ONLY
-              -> [Details]
-              -> (UsageDetails,         -- Adjusted details for whole scope,
+              -> [NodeDetails]
+              -> WithUsageDetails       -- Adjusted details for whole scope,
                                         -- with binders removed
-                  [IdWithOccInfo])      -- Tagged binders
+                  [IdWithOccInfo]       -- Tagged binders
 -- Substantially more complicated than non-recursive case. Need to adjust RHS
 -- details *before* tagging binders (because the tags depend on the RHSes).
 tagRecBinders lvl body_uds details_s
  = let
      bndrs    = map nd_bndr details_s
-     rhs_udss = map nd_uds  details_s
-
-     -- 1. Determine join-point-hood of whole group, as determined by
-     --    the *unadjusted* usage details
-     unadj_uds     = foldr andUDs body_uds rhs_udss
 
-     -- This is only used in `mb_join_arity`, to adjust each `Details` in `details_s`, thus,
-     -- when `bndrs` is non-empty. So, we only write `maybe False` as `decideJoinPointHood`
-     -- takes a `NonEmpty CoreBndr`; the default value `False` won't affect program behavior.
-     will_be_joins = maybe False (decideJoinPointHood lvl unadj_uds) (nonEmpty bndrs)
+     -- 1. See Note [Join arity prediction based on joinRhsArity]
+     --    Determine possible join-point-hood of whole group, by testing for
+     --    manifest join arity M.
+     --    This (re-)asserts that makeNode had made tuds for that same arity M!
+     unadj_uds     = foldr (andUDs . test_manifest_arity) body_uds details_s
+     test_manifest_arity ND{nd_rhs=WithTailUsageDetails tuds rhs}
+       = adjustTailArity (Just (joinRhsArity rhs)) tuds
 
-     -- 2. Adjust usage details of each RHS, taking into account the
-     --    join-point-hood decision
-     rhs_udss' = [ adjustRhsUsage (mb_join_arity bndr) rhs rhs_uds
-                 | ND { nd_bndr = bndr, nd_uds = rhs_uds
-                      , nd_rhs = rhs } <- details_s ]
+     bndr_ne = expectNonEmpty "List of binders is never empty" bndrs
+     will_be_joins = decideJoinPointHood lvl unadj_uds bndr_ne
 
      mb_join_arity :: Id -> Maybe JoinArity
+     -- mb_join_arity: See Note [Join arity prediction based on joinRhsArity]
+     -- This is the source O
      mb_join_arity bndr
          -- Can't use willBeJoinId_maybe here because we haven't tagged
          -- the binder yet (the tag depends on these adjustments!)
@@ -3247,6 +3397,12 @@ tagRecBinders lvl body_uds details_s
        = assert (not will_be_joins) -- Should be AlwaysTailCalled if
          Nothing                   -- we are making join points!
 
+     -- 2. Adjust usage details of each RHS, taking into account the
+     --    join-point-hood decision
+     rhs_udss' = [ adjustTailUsage (mb_join_arity bndr) rhs rhs_tuds -- matching occAnalLamTail in makeNode
+                 | ND { nd_bndr = bndr, nd_rhs = WithTailUsageDetails rhs_tuds rhs }
+                     <- details_s ]
+
      -- 3. Compute final usage details from adjusted RHS details
      adj_uds   = foldr andUDs body_uds rhs_udss'
 
@@ -3257,7 +3413,7 @@ tagRecBinders lvl body_uds details_s
      -- 5. Drop the binders from the adjusted details and return
      usage'    = adj_uds `delDetailsList` bndrs
    in
-   (usage', bndrs')
+   WithUsageDetails usage' bndrs'
 
 setBinderOcc :: OccInfo -> CoreBndr -> CoreBndr
 setBinderOcc occ_info bndr
@@ -3271,12 +3427,13 @@ setBinderOcc occ_info bndr
 
   | otherwise = setIdOccInfo bndr occ_info
 
--- | Decide whether some bindings should be made into join points or not.
+-- | Decide whether some bindings should be made into join points or not, based
+-- on its occurrences. This is
 -- Returns `False` if they can't be join points. Note that it's an
 -- all-or-nothing decision, as if multiple binders are given, they're
 -- assumed to be mutually recursive.
 --
--- It must, however, be a final decision. If we say "True" for 'f',
+-- It must, however, be a final decision. If we say `True` for 'f',
 -- and then subsequently decide /not/ make 'f' into a join point, then
 -- the decision about another binding 'g' might be invalidated if (say)
 -- 'f' tail-calls 'g'.


=====================================
compiler/GHC/Data/Graph/Directed.hs
=====================================
@@ -4,6 +4,7 @@
 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
 {-# LANGUAGE ScopedTypeVariables #-}
 {-# LANGUAGE ViewPatterns #-}
+{-# LANGUAGE DeriveFunctor #-}
 
 module GHC.Data.Graph.Directed (
         Graph, graphFromEdgedVerticesOrd, graphFromEdgedVerticesUniq,
@@ -108,7 +109,7 @@ data Node key payload = DigraphNode {
       node_payload :: payload, -- ^ User data
       node_key :: key, -- ^ User defined node id
       node_dependencies :: [key] -- ^ Dependencies/successors of the node
-  }
+  } deriving Functor
 
 
 instance (Outputable a, Outputable b) => Outputable (Node a b) where


=====================================
compiler/GHC/Utils/Misc.hs
=====================================
@@ -35,7 +35,7 @@ module GHC.Utils.Misc (
         equalLength, compareLength, leLength, ltLength,
 
         isSingleton, only, expectOnly, GHC.Utils.Misc.singleton,
-        notNull, snocView,
+        notNull, expectNonEmpty, snocView,
 
         chunkList,
 
@@ -481,7 +481,6 @@ expectOnly _   (a:_) = a
 #endif
 expectOnly msg _     = panic ("expectOnly: " ++ msg)
 
-
 -- | Split a list into chunks of /n/ elements
 chunkList :: Int -> [a] -> [[a]]
 chunkList _ [] = []
@@ -500,6 +499,16 @@ changeLast []     _  = panic "changeLast"
 changeLast [_]    x  = [x]
 changeLast (x:xs) x' = x : changeLast xs x'
 
+-- | Like @expectJust msg . nonEmpty@; a better alternative to 'NE.fromList'.
+expectNonEmpty :: HasCallStack => String -> [a] -> NonEmpty a
+{-# INLINE expectNonEmpty #-}
+expectNonEmpty _   (x:xs) = x:|xs
+expectNonEmpty msg []     = expectNonEmptyPanic msg
+
+expectNonEmptyPanic :: String -> a
+expectNonEmptyPanic msg = panic ("expectNonEmpty: " ++ msg)
+{-# NOINLINE expectNonEmptyPanic #-}
+
 -- | Apply an effectful function to the last list element.
 mapLastM :: Functor f => (a -> f a) -> NonEmpty a -> f (NonEmpty a)
 mapLastM f (x:|[]) = NE.singleton <$> f x


=====================================
testsuite/tests/simplCore/should_compile/T22428.hs
=====================================
@@ -0,0 +1,9 @@
+module T22428 where
+
+f :: Integer -> Integer -> Integer
+f x y = go y
+  where
+    go :: Integer -> Integer
+    go 0 = x
+    go n = go (n-1)
+    {-# INLINE go #-}


=====================================
testsuite/tests/simplCore/should_compile/T22428.stderr
=====================================
@@ -0,0 +1,45 @@
+
+==================== Tidy Core ====================
+Result size of Tidy Core
+  = {terms: 32, types: 14, coercions: 0, joins: 1/1}
+
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
+T22428.f1 :: Integer
+[GblId,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+         WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 10}]
+T22428.f1 = GHC.Num.Integer.IS 1#
+
+-- RHS size: {terms: 28, types: 10, coercions: 0, joins: 1/1}
+f :: Integer -> Integer -> Integer
+[GblId,
+ Arity=2,
+ Str=<SL><1L>,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+         WorkFree=True, Expandable=True, Guidance=IF_ARGS [0 0] 156 0}]
+f = \ (x :: Integer) (y :: Integer) ->
+      joinrec {
+        go [InlPrag=INLINE (sat-args=1), Occ=LoopBreaker, Dmd=SC(S,L)]
+          :: Integer -> Integer
+        [LclId[JoinId(1)(Just [!])],
+         Arity=1,
+         Str=<1L>,
+         Unf=Unf{Src=StableUser, TopLvl=False, Value=True, ConLike=True,
+                 WorkFree=True, Expandable=True,
+                 Guidance=ALWAYS_IF(arity=1,unsat_ok=False,boring_ok=False)}]
+        go (ds :: Integer)
+          = case ds of wild {
+              GHC.Num.Integer.IS x1 ->
+                case x1 of {
+                  __DEFAULT -> jump go (GHC.Num.Integer.integerSub wild T22428.f1);
+                  0# -> x
+                };
+              GHC.Num.Integer.IP x1 ->
+                jump go (GHC.Num.Integer.integerSub wild T22428.f1);
+              GHC.Num.Integer.IN x1 ->
+                jump go (GHC.Num.Integer.integerSub wild T22428.f1)
+            }; } in
+      jump go y
+
+
+


=====================================
testsuite/tests/simplCore/should_compile/all.T
=====================================
@@ -459,6 +459,10 @@ test('T22494', [grep_errmsg(r'case') ], compile, ['-O -ddump-simpl -dsuppress-un
 test('T22491', normal, compile, ['-O2'])
 test('T21476', normal, compile, [''])
 test('T22272', normal, multimod_compile, ['T22272', '-O -fexpose-all-unfoldings -fno-omit-interface-pragmas -fno-ignore-interface-pragmas'])
+
+# go should become a join point
+test('T22428', [grep_errmsg(r'jump go') ], compile, ['-O -ddump-simpl -dsuppress-uniques -dno-typeable-binds -dsuppress-unfoldings'])
+
 test('T22459', normal, compile, [''])
 test('T22623', normal, multimod_compile, ['T22623', '-O -v0'])
 test('T22662', normal, compile, [''])



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/18f1be99f7ada0e489c37c2c695703b8d54e95d5

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/18f1be99f7ada0e489c37c2c695703b8d54e95d5
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20230112/f78fb642/attachment-0001.html>


More information about the ghc-commits mailing list