[Git][ghc/ghc][wip/T22428] Fix contification with stable unfoldings (#22428)
Sebastian Graf (@sgraf812)
gitlab at gitlab.haskell.org
Tue Nov 29 12:39:47 UTC 2022
Sebastian Graf pushed to branch wip/T22428 at Glasgow Haskell Compiler / GHC
Commits:
71da9da7 by Sebastian Graf at 2022-11-29T13:39:34+01:00
Fix contification with stable unfoldings (#22428)
Many functions now return a `TailUsageDetails` that adorns a `UsageDetails` with
a `JoinArity` that reflects the number of join point binders around the body
for which the `UsageDetails` was computed. `TailUsageDetails` is now returned by
`occAnalLamTail` as well as `occAnalUnfolding` and `occAnalRules`.
I adjusted `Note [Join points and unfoldings/rules]` and
`Note [Adjusting right-hand sides]` to account for the new machinery.
I also wrote a new `Note [Join arity prediction based on joinRhsArity]`
and refer to it when we combine `TailUsageDetails` for a recursive RHS.
I also renamed
* `occAnalLam` to `occAnalLamTail`
* `adjustRhsUsage` to `adjustTailUsage`
* a few other less important functions
and properly documented the that each call of `occAnalLamTail` must pair up with
`adjustTailUsage`.
I removed `Note [Unfoldings and join points]` because it was redundant with
`Note [Occurrences in stable unfoldings]`.
While in town, I refactored `mkLoopBreakerNodes` so that it returns a condensed
`NodeDetails` called `SimpleNodeDetails`.
Fixes #22428.
- - - - -
7 changed files:
- compiler/GHC/Core/Opt/Arity.hs
- compiler/GHC/Core/Opt/OccurAnal.hs
- compiler/GHC/Data/Graph/Directed.hs
- compiler/GHC/Utils/Misc.hs
- + testsuite/tests/simplCore/should_compile/T22428.hs
- + testsuite/tests/simplCore/should_compile/T22428.stderr
- testsuite/tests/simplCore/should_compile/all.T
Changes:
=====================================
compiler/GHC/Core/Opt/Arity.hs
=====================================
@@ -133,6 +133,9 @@ joinRhsArity :: CoreExpr -> JoinArity
-- Join points are supposed to have manifestly-visible
-- lambdas at the top: no ticks, no casts, nothing
-- Moreover, type lambdas count in JoinArity
+-- NB: For non-recursive bindings, the join arity of the binding may actually be
+-- less that the number of manifestly-visible lambdas.
+-- See Note [Join arity prediction based on joinRhsArity] in GHC.Core.Opt.OccurAnal
joinRhsArity (Lam _ e) = 1 + joinRhsArity e
joinRhsArity _ = 0
=====================================
compiler/GHC/Core/Opt/OccurAnal.hs
=====================================
@@ -59,7 +59,7 @@ import GHC.Builtin.Names( runRWKey )
import GHC.Unit.Module( Module )
import Data.List (mapAccumL, mapAccumR)
-import Data.List.NonEmpty (NonEmpty (..), nonEmpty)
+import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as NE
{-
@@ -510,6 +510,14 @@ of the file.
at any of the definitions. This is done by Simplify.simplRecBind,
when it calls addLetIdInfo.
+Note that the UsageDetails stored in the nd_uds field of a NodeDetails is
+computed by `occAnalLamTail` applied to the RHS, not `occAnalExpr`. Specifically
+`occAnalLamTail` does not do `markAllInsideLam` or `markAllNonTail`, as if the
+binding was a *non-recursive join point*. Which it indeed might become in the
+AcyclicSCC case of dependency analysis!
+Hence we do the delayed 'adjustTailUsage' in 'occAnalRec'/'tagRecBinders'.
+See Note [Join points and unfoldings/rules] for more details on the contract.
+
Note [Stable unfoldings]
~~~~~~~~~~~~~~~~~~~~~~~~
None of the above stuff about RULES applies to a stable unfolding
@@ -608,6 +616,53 @@ tail call with `n` arguments (counting both value and type arguments). Otherwise
'occ_tail' will be 'NoTailCallInfo'. The tail call info flows bottom-up with the
rest of 'OccInfo' until it goes on the binder.
+Note [Join arity prediction based on joinRhsArity]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Recall the contification transformation in Fig. 5 of
+"Compiling without Continuations". The letrec case looks like this
+
+ letrec f = /\as.\xs. L[us] in L'[es]
+ ... and a bunch of conditions establishing that f only occurs
+ in app heads of join arity (len as + len xs) inside us and es ...
+
+So for a recursive group, the possible join arity of each binding is determined
+by two sources:
+
+ O. The arity of AlwaysTailCalled occurrences in the let body and its RHSs
+ after stripping off all manifest lambdas (because a lambda body is not a
+ tail context). This is governed by 'decideJoinPointHood'.
+ M. The manifest join arity of the RHSs as reported by 'joinRhsArity'.
+
+If (O) and (M) do not agree, we may not turn the letrec into a join point.
+What goes wrong if we ignore (M)? Consider
+
+ letrec f x y = if ... then f x else True
+ in f 42
+
+According to (O), we'd turn `f` into a join point of arity 1. But that is wrong,
+because then the recursive jump in its RHS is not in tail position!
+
+We can use (M) and conclude that *if* f becomes a join point, then it will have
+join arity 2. Generally, we can predict the join arity of a recursive binding
+by the manifest join arity of its RHS, which is exactly the join arity returned
+when analysing the RHS with 'occAnalLamTail'.
+
+However, for non-recursive functions, this prediction can be wrong.
+Note again the definition of contification from the paper:
+
+ let f = /\as.\xs.u in L[es] ... conditions ...
+
+Note the occurrence of u, not L[us]. u might indeed be a lambda itself, e.g.,
+
+ let f x y = rhs
+ in if b then f 1 else f 2
+
+And based on (O) we'll turn it into a join point of arity 1, because, unlike for
+letrec, f may not occur in rhs. So we adjust the result returned by
+'occAnalLamTail', spoiling all tail calls, see Note [Adjusting right-hand sides].
+It's also possible for (M) to yield a lower prediction than (O); then the
+Simplifier will eta-expand accordingly.
+
Note [Join points and unfoldings/rules]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Consider
@@ -618,8 +673,10 @@ Consider
Before j is inlined, we'll have occurrences of j2 in
both j's RHS and in its stable unfolding. We want to discover
-j2 as a join point. So we must do the adjustRhsUsage thing
-on j's RHS. That's why we pass mb_join_arity to calcUnfolding.
+j2 as a join point. So 'occAnalUnfolding' returns an unadjusted
+'TailUsageDetails', like 'occAnalLamTail'. We adjust the usage details of the
+unfolding to the actual join arity using the same 'adjustTailArity' as for
+the RHS, see Note [Adjusting right-hand sides].
Same with rules. Suppose we have:
@@ -636,14 +693,31 @@ up. So provided the join-point arity of k matches the args of the
rule we can allow the tail-call info from the RHS of the rule to
propagate.
-* Wrinkle for Rec case. In the recursive case we don't know the
- join-point arity in advance, when calling occAnalUnfolding and
- occAnalRules. (See makeNode.) We don't want to pass Nothing,
- because then a recursive joinrec might lose its join-poin-hood
- when SpecConstr adds a RULE. So we just make do with the
- *current* join-poin-hood, stored in the Id.
+* Note that the join arity of the RHS and that of the unfolding or RULE might
+ mismatch:
+
+ let j x y = j2 (x+x)
+ {-# INLINE[2] j = \x. g #-}
+ {-# RULE forall x y z. j x y z = h 17 #-}
+ in j 1 2
+
+ So it is crucial that we adjust each TailUsageDetails individually
+ with the actual join arity 2 here before we combine with `andUDs`.
+ Here, that means losing tail call info on `g` and `h`.
- In the non-recursive case things are simple: see occAnalNonRecBind
+* Wrinkle for Rec case: We store one TailUsageDetails in the node Details for
+ RHS, unfolding and RULE combined. Clearly, if they don't agree on their join
+ arity, we have to do some adjusting. We choose to adjust to the join arity
+ of the RHS, because that is likely the join arity that the join point will
+ have; see Note [Join arity prediction based on joinRhsArity].
+
+ If the guess is correct, then tail calls in the RHS are preserved; a necessary
+ condition for the whole binding becoming a join rec.
+ The guess can only be incorrect in the 'AcyclicSCC' case when the binding
+ becomes a non-recursive join point with a different join arity. But then the
+ eventual call to 'adjustTailUsage' in 'tagRecBinders'/'occAnalRec' will
+ be with a different join arity and destroy unsound tail call info with
+ 'markNonTail'.
* Wrinkle for RULES. Suppose the example was a bit different:
let j :: Int -> Int
@@ -669,28 +743,21 @@ propagate.
This appears to be very rare in practice. TODO Perhaps we should gather
statistics to be sure.
-Note [Unfoldings and join points]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-We assume that anything in an unfolding occurs multiple times, since
-unfoldings are often copied (that's the whole point!). But we still
-need to track tail calls for the purpose of finding join points.
-
-
------------------------------------------------------------
Note [Adjusting right-hand sides]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
There's a bit of a dance we need to do after analysing a lambda expression or
a right-hand side. In particular, we need to
- a) call 'markAllInsideLam' *unless* the binding is for a thunk, a one-shot
- lambda, or a non-recursive join point; and
- b) call 'markAllNonTail' *unless* the binding is for a join point, and
- the RHS has the right arity; e.g.
+ a) call 'markAllNonTail' *unless* the binding is for a join point, and
+ the TailUsageDetails from the RHS has the right join arity; e.g.
join j x y = case ... of
A -> j2 p
B -> j2 q
in j a b
Here we want the tail calls to j2 to be tail calls of the whole expression
+ b) call 'markAllInsideLam' *unless* the binding is for a thunk, a one-shot
+ lambda, or a non-recursive join point
Some examples, with how the free occurrences in e (assumed not to be a value
lambda) get marked:
@@ -707,26 +774,30 @@ lambda) get marked:
There are a few other caveats; most importantly, if we're marking a binding as
'AlwaysTailCalled', it's *going* to be a join point, so we treat it as one so
that the effect cascades properly. Consequently, at the time the RHS is
-analysed, we won't know what adjustments to make; thus 'occAnalLamOrRhs' must
-return the unadjusted 'UsageDetails', to be adjusted by 'adjustRhsUsage' once
-join-point-hood has been decided.
+analysed, we won't know what adjustments to make; thus 'occAnalLamTail' must
+return the unadjusted 'UsageDetails', to be adjusted by 'adjustTailUsage' once
+join-point-hood has been decided and eventual one-shot annotations have been
+added through 'markNonRecJoinOneShots'.
Thus the overall sequence taking place in 'occAnalNonRecBind' and
'occAnalRecBind' is as follows:
- 1. Call 'occAnalLamOrRhs' to find usage information for the RHS.
+ 1. Call 'occAnalLamTail' to find usage information for the RHS.
2. Call 'tagNonRecBinder' or 'tagRecBinders', which decides whether to make
the binding a join point.
- 3. Call 'adjustRhsUsage' accordingly. (Done as part of 'tagRecBinders' when
+ 3. Call 'markNonRecJoinOneShots' so that we recognise every non-recursive join
+ point as one-shot
+ 4. Call 'adjustTailUsage' accordingly. (Done as part of 'tagRecBinders' when
recursive.)
-(In the recursive case, this logic is spread between 'makeNode' and
-'occAnalRec'.)
+In the recursive case, this logic is spread: Step 1 is done in 'makeNode' and
+steps (2)-(4) are done in the 'AcyclicSCC' case of 'occAnalRec'.
-}
-
data WithUsageDetails a = WithUsageDetails !UsageDetails !a
+data WithTailUsageDetails a = WithTailUsageDetails !TailUsageDetails !a
+
------------------------------------------------------------------
-- occAnalBind
------------------------------------------------------------------
@@ -754,16 +825,14 @@ occAnalNonRecBind !env lvl imp_rule_edges bndr rhs body_usage
= WithUsageDetails body_usage []
| otherwise -- It's mentioned in the body
- = WithUsageDetails (body_usage' `andUDs` rhs_usage) [NonRec final_bndr rhs']
+ = WithUsageDetails final_uds [NonRec final_bndr rhs']
where
- (body_usage', tagged_bndr) = tagNonRecBinder lvl body_usage bndr
- final_bndr = tagged_bndr `setIdUnfolding` unf'
- `setIdSpecialisation` mkRuleInfo rules'
- rhs_usage = rhs_uds `andUDs` unf_uds `andUDs` rule_uds
+ final_uds = body_rhs_uds' `andUDs` adj_unf_uds `andUDs` adj_rule_uds
+ final_bndr = bndr' `setIdSpecialisation` mkRuleInfo rules'
-- Get the join info from the *new* decision
-- See Note [Join points and unfoldings/rules]
- mb_join_arity = willBeJoinId_maybe tagged_bndr
+ mb_join_arity = willBeJoinId_maybe bndr'
is_join_point = isJust mb_join_arity
--------- Right hand side ---------
@@ -773,17 +842,21 @@ occAnalNonRecBind !env lvl imp_rule_edges bndr rhs body_usage
-- See Note [Sources of one-shot information]
rhs_env = env1 { occ_one_shots = argOneShots dmd }
- (WithUsageDetails rhs_uds rhs') = occAnalRhs rhs_env NonRecursive mb_join_arity rhs
+ WithTailUsageDetails rhs_uds rhs1 = occAnalLamTail rhs_env rhs
+ WithUsageDetails body_rhs_uds' (bndr', rhs')
+ = tagNonRecBind lvl bndr rhs1 unf1 body_usage rhs_uds
--------- Unfolding ---------
- -- See Note [Unfoldings and join points]
+ -- See Note [Join points and unfoldings/rules]
unf | isId bndr = idUnfolding bndr
| otherwise = NoUnfolding
- (WithUsageDetails unf_uds unf') = occAnalUnfolding rhs_env NonRecursive mb_join_arity unf
+ WithTailUsageDetails unf_uds unf1 = occAnalUnfolding rhs_env unf
+ adj_unf_uds = adjustTailArity mb_join_arity unf_uds
--------- Rules ---------
-- See Note [Rules are extra RHSs] and Note [Rule dependency info]
- rules_w_uds = occAnalRules rhs_env mb_join_arity bndr
+ -- and Note [Join points and unfoldings/rules]
+ rules_w_uds = occAnalRules rhs_env bndr
rules' = map fstOf3 rules_w_uds
imp_rule_uds = impRulesScopeUsage (lookupImpRules imp_rule_edges bndr)
-- imp_rule_uds: consider
@@ -794,11 +867,12 @@ occAnalNonRecBind !env lvl imp_rule_edges bndr rhs body_usage
-- that g is (since the RULE might turn g into h), so
-- we make g mention h.
- rule_uds = foldr add_rule_uds imp_rule_uds rules_w_uds
- add_rule_uds (_, l, r) uds = l `andUDs` r `andUDs` uds
+ adj_rule_uds = foldr add_rule_uds imp_rule_uds rules_w_uds
+ add_rule_uds (_, l, r) uds
+ = l `andUDs` adjustTailArity mb_join_arity r `andUDs` uds
----------
- occ = idOccInfo tagged_bndr
+ occ = idOccInfo bndr'
certainly_inline -- See Note [Cascading inlines]
= case occ of
OneOcc { occ_in_lam = NotInsideLam, occ_n_br = 1 }
@@ -820,7 +894,7 @@ occAnalRecBind :: OccEnv -> TopLevelFlag -> ImpRuleEdges -> [(Var,CoreExpr)]
occAnalRecBind !env lvl imp_rule_edges pairs body_usage
= foldr (occAnalRec rhs_env lvl) (WithUsageDetails body_usage []) sccs
where
- sccs :: [SCC Details]
+ sccs :: [SCC NodeDetails]
sccs = {-# SCC "occAnalBind.scc" #-}
stronglyConnCompFromEdgedVerticesUniq nodes
@@ -832,10 +906,38 @@ occAnalRecBind !env lvl imp_rule_edges pairs body_usage
bndr_set = mkVarSet bndrs
rhs_env = env `addInScope` bndrs
+tagNonRecBind :: TopLevelFlag -> Id -> CoreExpr -> Unfolding -> UsageDetails -> TailUsageDetails -> WithUsageDetails (IdWithOccInfo, CoreExpr)
+-- ^ This function concentrates shared logic between occAnalNonRecBind and the
+-- AcyclicSCC case of occAnalRec.
+-- * It decides join point hood and tags binder with the result
+-- * then applies 'markNonRecJoinOneShots' on RHS and unfolding and
+-- * and finally returns the adjusted rhs UsageDetails combined with the body
+-- usage
+tagNonRecBind lvl bndr rhs unf body_uds rhs_uds
+ = WithUsageDetails (body_uds' `andUDs` rhs_uds') (final_bndr, rhs')
+ where
+ --------- Decide join-point-hood ---------
+ WithUsageDetails body_uds' tagged_bndr = tagNonRecBinder lvl body_uds bndr
+ mb_join_arity = willBeJoinId_maybe tagged_bndr
+ --------- Marking (non-rec) join binders one-shot ---------
+ !rhs'
+ | Just ja <- mb_join_arity = markNonRecJoinOneShots ja rhs
+ | otherwise = rhs
+ !final_bndr
+ | Just ja <- mb_join_arity
+ , CoreUnfolding{uf_src=src,uf_tmpl=tmpl} <- unf
+ , isStableSource src
+ , let !tmpl' = markNonRecJoinOneShots ja tmpl
+ = tagged_bndr `setIdUnfolding` unf{uf_tmpl=tmpl'}
+ | otherwise
+ = tagged_bndr
+ --------- Adjusting right-hand side usage ---------
+ rhs_uds' = adjustTailUsage mb_join_arity rhs' rhs_uds
+ -- corresponding call to occAnalLamTail is in makeNode/occAnalNonRecBind
-----------------------------
occAnalRec :: OccEnv -> TopLevelFlag
- -> SCC Details
+ -> SCC NodeDetails
-> WithUsageDetails [CoreBind]
-> WithUsageDetails [CoreBind]
@@ -847,12 +949,10 @@ occAnalRec !_ lvl (AcyclicSCC (ND { nd_bndr = bndr, nd_rhs = rhs
= WithUsageDetails body_uds binds -- See Note [Dead code]
| otherwise -- It's mentioned in the body
- = WithUsageDetails (body_uds' `andUDs` rhs_uds')
- (NonRec tagged_bndr rhs : binds)
+ = WithUsageDetails final_uds (NonRec bndr' rhs' : binds)
where
- (body_uds', tagged_bndr) = tagNonRecBinder lvl body_uds bndr
- rhs_uds' = adjustRhsUsage mb_join_arity rhs rhs_uds
- mb_join_arity = willBeJoinId_maybe tagged_bndr
+ WithUsageDetails final_uds (bndr', rhs')
+ = tagNonRecBind lvl bndr rhs (idUnfolding bndr) body_uds rhs_uds
-- The Rec case is the interesting one
-- See Note [Recursive bindings: the grand plan]
@@ -873,7 +973,7 @@ occAnalRec env lvl (CyclicSCC details_s) (WithUsageDetails body_uds binds)
-- Make the nodes for the loop-breaker analysis
-- See Note [Choosing loop breakers] for loop_breaker_nodes
final_uds :: UsageDetails
- loop_breaker_nodes :: [LetrecNode]
+ loop_breaker_nodes :: [LoopBreakerNode]
(WithUsageDetails final_uds loop_breaker_nodes) = mkLoopBreakerNodes env lvl body_uds details_s
------------------------------
@@ -1102,7 +1202,7 @@ type Binding = (Id,CoreExpr)
loopBreakNodes :: Int
-> VarSet -- Binders whose dependencies may be "missing"
-- See Note [Weak loop breakers]
- -> [LetrecNode]
+ -> [LoopBreakerNode]
-> [Binding] -- Append these to the end
-> [Binding]
@@ -1121,7 +1221,7 @@ loopBreakNodes depth weak_fvs nodes binds
CyclicSCC nodes -> reOrderNodes depth weak_fvs nodes binds
----------------------------------
-reOrderNodes :: Int -> VarSet -> [LetrecNode] -> [Binding] -> [Binding]
+reOrderNodes :: Int -> VarSet -> [LoopBreakerNode] -> [Binding] -> [Binding]
-- Choose a loop breaker, mark it no-inline,
-- and call loopBreakNodes on the rest
reOrderNodes _ _ [] _ = panic "reOrderNodes"
@@ -1133,7 +1233,7 @@ reOrderNodes depth weak_fvs (node : nodes) binds
(map (nodeBinding mk_loop_breaker) chosen_nodes ++ binds)
where
(chosen_nodes, unchosen) = chooseLoopBreaker approximate_lb
- (nd_score (node_payload node))
+ (snd_score (node_payload node))
[node] [] nodes
approximate_lb = depth >= 2
@@ -1142,8 +1242,8 @@ reOrderNodes depth weak_fvs (node : nodes) binds
-- After two iterations (d=0, d=1) give up
-- and approximate, returning to d=0
-nodeBinding :: (Id -> Id) -> LetrecNode -> Binding
-nodeBinding set_id_occ (node_payload -> ND { nd_bndr = bndr, nd_rhs = rhs})
+nodeBinding :: (Id -> Id) -> LoopBreakerNode -> Binding
+nodeBinding set_id_occ (node_payload -> SND { snd_bndr = bndr, snd_rhs = rhs})
= (set_id_occ bndr, rhs)
mk_loop_breaker :: Id -> Id
@@ -1163,13 +1263,13 @@ mk_non_loop_breaker weak_fvs bndr
tail_info = tailCallInfo (idOccInfo bndr)
----------------------------------
-chooseLoopBreaker :: Bool -- True <=> Too many iterations,
- -- so approximate
- -> NodeScore -- Best score so far
- -> [LetrecNode] -- Nodes with this score
- -> [LetrecNode] -- Nodes with higher scores
- -> [LetrecNode] -- Unprocessed nodes
- -> ([LetrecNode], [LetrecNode])
+chooseLoopBreaker :: Bool -- True <=> Too many iterations,
+ -- so approximate
+ -> NodeScore -- Best score so far
+ -> [LoopBreakerNode] -- Nodes with this score
+ -> [LoopBreakerNode] -- Nodes with higher scores
+ -> [LoopBreakerNode] -- Unprocessed nodes
+ -> ([LoopBreakerNode], [LoopBreakerNode])
-- This loop looks for the bind with the lowest score
-- to pick as the loop breaker. The rest accumulate in
chooseLoopBreaker _ _ loop_nodes acc []
@@ -1189,7 +1289,7 @@ chooseLoopBreaker approx_lb loop_sc loop_nodes acc (node : nodes)
| otherwise -- Worse score so don't pick it
= chooseLoopBreaker approx_lb loop_sc loop_nodes (node : acc) nodes
where
- sc = nd_score (node_payload node)
+ sc = snd_score (node_payload node)
{-
Note [Complexity of loop breaking]
@@ -1322,16 +1422,21 @@ ToDo: try using the occurrence info for the inline'd binder.
************************************************************************
-}
-type LetrecNode = Node Unique Details -- Node comes from Digraph
- -- The Unique key is gotten from the Id
-data Details
+-- | Digraph node as constructed by 'makeNode' and consumed by 'occAnalRec'.
+-- The Unique key is gotten from the Id.
+type LetrecNode = Node Unique NodeDetails
+
+-- | Node details as consumed by 'occAnalRec'.
+data NodeDetails
= ND { nd_bndr :: Id -- Binder
, nd_rhs :: CoreExpr -- RHS, already occ-analysed
- , nd_uds :: UsageDetails -- Usage from RHS, and RULES, and stable unfoldings
- -- ignoring phase (ie assuming all are active)
- -- See Note [Forming Rec groups]
+ , nd_uds :: TailUsageDetails -- Usage from RHS, and RULES, and stable unfoldings
+ -- ignoring phase (ie assuming all are active)
+ -- NB: Unadjusted TailUsageDetails, as if
+ -- this Node becomes a non-recursive join point!
+ -- See Note [Forming Rec groups]
, nd_inl :: IdSet -- Free variables of the stable unfolding and the RHS
-- but excluding any RULES
@@ -1348,18 +1453,32 @@ data Details
, nd_active_rule_fvs :: IdSet -- Variables bound in this Rec group that are free
-- in the RHS of an active rule for this bndr
-- See Note [Rules and loop breakers]
-
- , nd_score :: NodeScore
}
-instance Outputable Details where
+instance Outputable NodeDetails where
ppr nd = text "ND" <> braces
(sep [ text "bndr =" <+> ppr (nd_bndr nd)
, text "uds =" <+> ppr (nd_uds nd)
, text "inl =" <+> ppr (nd_inl nd)
, text "simple =" <+> ppr (nd_simple nd)
, text "active_rule_fvs =" <+> ppr (nd_active_rule_fvs nd)
- , text "score =" <+> ppr (nd_score nd)
+ ])
+
+-- | Digraph with simplified and completely occurrence analysed
+-- 'SimpleNodeDetails', retaining just the info we need for breaking loops.
+type LoopBreakerNode = Node Unique SimpleNodeDetails
+
+-- | Condensed variant of 'NodeDetails' needed during loop breaking.
+data SimpleNodeDetails
+ = SND { snd_bndr :: IdWithOccInfo -- OccInfo accurate
+ , snd_rhs :: CoreExpr -- properly occur-analysed
+ , snd_score :: NodeScore
+ }
+
+instance Outputable SimpleNodeDetails where
+ ppr nd = text "SND" <> braces
+ (sep [ text "bndr =" <+> ppr (snd_bndr nd)
+ , text "score =" <+> ppr (snd_score nd)
])
-- The NodeScore is compared lexicographically;
@@ -1392,47 +1511,54 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs)
, nd_inl = inl_fvs
, nd_simple = null rules_w_uds && null imp_rule_info
, nd_weak_fvs = weak_fvs
- , nd_active_rule_fvs = active_rule_fvs
- , nd_score = pprPanic "makeNodeDetails" (ppr bndr) }
+ , nd_active_rule_fvs = active_rule_fvs }
bndr' = bndr `setIdUnfolding` unf'
`setIdSpecialisation` mkRuleInfo rules'
- inl_uds = rhs_uds `andUDs` unf_uds
- scope_uds = inl_uds `andUDs` rule_uds
+ -- NB: Both adj_unf_uds and adj_rule_uds have been adjusted to match the
+ -- JoinArity rhs_ja of unadj_rhs_uds.
+ unadj_inl_uds = unadj_rhs_uds `andUDs` adj_unf_uds
+ unadj_scope_uds = unadj_inl_uds `andUDs` adj_rule_uds
+ scope_uds = TUD rhs_ja unadj_scope_uds
-- Note [Rules are extra RHSs]
-- Note [Rule dependency info]
- scope_fvs = udFreeVars bndr_set scope_uds
+ scope_fvs = udFreeVars bndr_set unadj_scope_uds
-- scope_fvs: all occurrences from this binder: RHS, unfolding,
-- and RULES, both LHS and RHS thereof, active or inactive
- inl_fvs = udFreeVars bndr_set inl_uds
+ inl_fvs = udFreeVars bndr_set unadj_inl_uds
-- inl_fvs: vars that would become free if the function was inlined.
-- We conservatively approximate that by thefree vars from the RHS
-- and the unfolding together.
-- See Note [inl_fvs]
- mb_join_arity = isJoinId_maybe bndr
- -- Get join point info from the *current* decision
- -- We don't know what the new decision will be!
- -- Using the old decision at least allows us to
- -- preserve existing join point, even RULEs are added
- -- See Note [Join points and unfoldings/rules]
--------- Right hand side ---------
-- Constructing the edges for the main Rec computation
-- See Note [Forming Rec groups]
- -- Do not use occAnalRhs because we don't yet know the final
- -- answer for mb_join_arity; instead, do the occAnalLam call from
- -- occAnalRhs, and postpone adjustRhsUsage until occAnalRec
- rhs_env = rhsCtxt env
- (WithUsageDetails rhs_uds rhs') = occAnalLam rhs_env rhs
+ -- Compared to occAnalLam, we can't yet adjust the RHS because
+ -- (a) we don't yet know the final answer for pred_join_arity. It might be
+ -- Nothing!
+ -- (b) we don't even know whether it stays a recursive RHS after the SCC
+ -- analysis we are about to seed! So we can't markAllInsideLam in
+ -- advance, because if it ends up as a non-recursive join point we'll
+ -- consider it as one-shot and don't need to markAllInsideLam.
+ -- Instead, do the occAnalLamTail call from occAnalLam, and postpone
+ -- adjustTailUsage until occAnalRec. In effect, we pretend that the RHS
+ -- becomes a non-recursive join point and fix up later with adjustTailUsage.
+ rhs_env = rhsCtxt env
+ WithTailUsageDetails (TUD rhs_ja unadj_rhs_uds) rhs' = occAnalLamTail rhs_env rhs
+ -- corresponding call to adjustTailUsage in occAnalRec and tagRecBinders
--------- Unfolding ---------
- -- See Note [Unfoldings and join points]
+ -- See Note [Join points and unfoldings/rules]
unf = realIdUnfolding bndr -- realIdUnfolding: Ignore loop-breaker-ness
-- here because that is what we are setting!
- (WithUsageDetails unf_uds unf') = occAnalUnfolding rhs_env Recursive mb_join_arity unf
+ WithTailUsageDetails unf_tuds unf' = occAnalUnfolding rhs_env unf
+ adj_unf_uds = adjustTailArity (Just rhs_ja) unf_tuds
+ -- `rhs_ja` is `joinRhsArity rhs` and is the prediction for source (M)
+ -- of Note [Join arity prediction based on joinRhsArity]
--------- IMP-RULES --------
is_active = occ_rule_act env :: Activation -> Bool
@@ -1441,11 +1567,15 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs)
imp_rule_fvs = impRulesActiveFvs is_active bndr_set imp_rule_info
--------- All rules --------
+ -- See Note [Join points and unfoldings/rules]
+ -- `rhs_ja` is `joinRhsArity rhs'` and is the prediction for source (M)
+ -- of Note [Join arity prediction based on joinRhsArity]
rules_w_uds :: [(CoreRule, UsageDetails, UsageDetails)]
- rules_w_uds = occAnalRules rhs_env mb_join_arity bndr
+ rules_w_uds = [ (r,l,adjustTailArity (Just rhs_ja) rhs_tuds)
+ | (r,l,rhs_tuds) <- occAnalRules rhs_env bndr ]
rules' = map fstOf3 rules_w_uds
- rule_uds = foldr add_rule_uds imp_rule_uds rules_w_uds
+ adj_rule_uds = foldr add_rule_uds imp_rule_uds rules_w_uds
add_rule_uds (_, l, r) uds = l `andUDs` r `andUDs` uds
-------- active_rule_fvs ------------
@@ -1463,8 +1593,8 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs)
mkLoopBreakerNodes :: OccEnv -> TopLevelFlag
-> UsageDetails -- for BODY of let
- -> [Details]
- -> WithUsageDetails [LetrecNode] -- adjusted
+ -> [NodeDetails]
+ -> WithUsageDetails [LoopBreakerNode] -- with OccInfo up-to-date
-- See Note [Choosing loop breakers]
-- This function primarily creates the Nodes for the
-- loop-breaker SCC analysis. More specifically:
@@ -1477,10 +1607,10 @@ mkLoopBreakerNodes :: OccEnv -> TopLevelFlag
mkLoopBreakerNodes !env lvl body_uds details_s
= WithUsageDetails final_uds (zipWithEqual "mkLoopBreakerNodes" mk_lb_node details_s bndrs')
where
- (final_uds, bndrs') = tagRecBinders lvl body_uds details_s
+ WithUsageDetails final_uds bndrs' = tagRecBinders lvl body_uds details_s
mk_lb_node nd@(ND { nd_bndr = old_bndr, nd_inl = inl_fvs }) new_bndr
- = DigraphNode { node_payload = new_nd
+ = DigraphNode { node_payload = simple_nd
, node_key = varUnique old_bndr
, node_dependencies = nonDetKeysUniqSet lb_deps }
-- It's OK to use nonDetKeysUniqSet here as
@@ -1488,7 +1618,7 @@ mkLoopBreakerNodes !env lvl body_uds details_s
-- in nondeterministic order as explained in
-- Note [Deterministic SCC] in GHC.Data.Graph.Directed.
where
- new_nd = nd { nd_bndr = new_bndr, nd_score = score }
+ simple_nd = SND { snd_bndr = new_bndr, snd_rhs = nd_rhs nd, snd_score = score }
score = nodeScore env new_bndr lb_deps nd
lb_deps = extendFvs_ rule_fv_env inl_fvs
-- See Note [Loop breaker dependencies]
@@ -1524,7 +1654,7 @@ group { f1 = e1; ...; fn = en } are:
nodeScore :: OccEnv
-> Id -- Binder with new occ-info
-> VarSet -- Loop-breaker dependencies
- -> Details
+ -> NodeDetails
-> NodeScore
nodeScore !env new_bndr lb_deps
(ND { nd_bndr = old_bndr, nd_rhs = bind_rhs })
@@ -1748,7 +1878,7 @@ lambda and casts, e.g.
* Occurrence analyser: we just mark each binder in the lambda-group
(here: x,y,z) with its occurrence info in the *body* of the
- lambda-group. See occAnalLam.
+ lambda-group. See occAnalLamTail.
* Simplifier. The simplifier is careful when partially applying
lambda-groups. See the call to zapLambdaBndrs in
@@ -1804,7 +1934,7 @@ zapLambdaBndrs fun arg_count
zap_bndr b | isTyVar b = b
| otherwise = zapLamIdInfo b
-occAnalLam :: OccEnv -> CoreExpr -> (WithUsageDetails CoreExpr)
+occAnalLamTail :: OccEnv -> CoreExpr -> WithTailUsageDetails CoreExpr
-- See Note [Occurrence analysis for lambda binders]
-- It does the following:
-- * Sets one-shot info on the lambda binder from the OccEnv, and
@@ -1815,13 +1945,17 @@ occAnalLam :: OccEnv -> CoreExpr -> (WithUsageDetails CoreExpr)
-- This function does /not/ do
-- markAllInsideLam or
-- markAllNonTail
--- The caller does that, either in occAnal (Lam {}), or in adjustRhsUsage
+-- The caller does that, either in occAnalLam, or calling adjustTailUsage directly.
+-- Every call site links to its respective adjustTailUsage call and vice versa.
+--
+-- In effect, the analysis result is for a non-recursive join point with
+-- manifest arity and adjustTailUsage does the fixup.
-- See Note [Adjusting right-hand sides]
-occAnalLam env (Lam bndr expr)
+occAnalLamTail env (Lam bndr expr)
| isTyVar bndr
- = let (WithUsageDetails usage expr') = occAnalLam env expr
- in WithUsageDetails usage (Lam bndr expr')
+ , WithTailUsageDetails tusage expr' <- occAnalLamTail env expr
+ = WithTailUsageDetails tusage (Lam bndr expr')
-- Important: Keep the 'env' unchanged so that with a RHS like
-- \(@ x) -> K @x (f @x)
-- we'll see that (K @x (f @x)) is in a OccRhs, and hence refrain
@@ -1839,14 +1973,14 @@ occAnalLam env (Lam bndr expr)
env1 = env { occ_encl = OccVanilla, occ_one_shots = env_one_shots' }
env2 = addOneInScope env1 bndr
- (WithUsageDetails usage expr') = occAnalLam env2 expr
+ WithTailUsageDetails (TUD ja usage) expr' = occAnalLamTail env2 expr
(usage', bndr2) = tagLamBinder usage bndr1
- in WithUsageDetails usage' (Lam bndr2 expr')
+ in WithTailUsageDetails (TUD (ja+1) usage') (Lam bndr2 expr')
-- For casts, keep going in the same lambda-group
-- See Note [Occurrence analysis for lambda binders]
-occAnalLam env (Cast expr co)
- = let (WithUsageDetails usage expr') = occAnalLam env expr
+occAnalLamTail env (Cast expr co)
+ = let WithTailUsageDetails (TUD ja usage) expr' = occAnalLamTail env expr
-- usage1: see Note [Gather occurrences of coercion variables]
usage1 = addManyOccs usage (coVarsOfCo co)
@@ -1856,15 +1990,16 @@ occAnalLam env (Cast expr co)
_ -> usage1
-- usage3: you might think this was not necessary, because of
- -- the markAllNonTail in adjustRhsUsage; but not so! For a
- -- join point, adjustRhsUsage doesn't do this; yet if there is
+ -- the markAllNonTail in adjustTailUsage; but not so! For a
+ -- join point, adjustTailUsage doesn't do this; yet if there is
-- a cast, we must! Also: why markAllNonTail? See
-- GHC.Core.Lint: Note Note [Join points and casts]
usage3 = markAllNonTail usage2
- in WithUsageDetails usage3 (Cast expr' co)
+ in WithTailUsageDetails (TUD ja usage3) (Cast expr' co)
-occAnalLam env expr = occAnal env expr
+occAnalLamTail env expr = case occAnal env expr of
+ WithUsageDetails usage expr' -> WithTailUsageDetails (TUD 0 usage) expr'
{- Note [Occ-anal and cast worker/wrapper]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1884,8 +2019,8 @@ RHS. So it'll get a Many occ-info. (Maybe Cast w/w should create a stable
unfolding, which would obviate this Note; but that seems a bit of a
heavyweight solution.)
-We only need to this in occAnalLam, not occAnal, because the top leve
-of a right hand side is handled by occAnalLam.
+We only need to this in occAnalLamTail, not occAnal, because the top leve
+of a right hand side is handled by occAnalLamTail.
-}
@@ -1895,57 +2030,23 @@ of a right hand side is handled by occAnalLam.
* *
********************************************************************* -}
-occAnalRhs :: OccEnv -> RecFlag -> Maybe JoinArity
- -> CoreExpr -- RHS
- -> WithUsageDetails CoreExpr
-occAnalRhs !env is_rec mb_join_arity rhs
- = let (WithUsageDetails usage rhs1) = occAnalLam env rhs
- -- We call occAnalLam here, not occAnalExpr, so that it doesn't
- -- do the markAllInsideLam and markNonTailCall stuff before
- -- we've had a chance to help with join points; that comes next
- rhs2 = markJoinOneShots is_rec mb_join_arity rhs1
- rhs_usage = adjustRhsUsage mb_join_arity rhs2 usage
- in WithUsageDetails rhs_usage rhs2
-
-
-
-markJoinOneShots :: RecFlag -> Maybe JoinArity -> CoreExpr -> CoreExpr
--- For a /non-recursive/ join point we can mark all
--- its join-lambda as one-shot; and it's a good idea to do so
-markJoinOneShots NonRecursive (Just join_arity) rhs
- = go join_arity rhs
- where
- go 0 rhs = rhs
- go n (Lam b rhs) = Lam (if isId b then setOneShotLambda b else b)
- (go (n-1) rhs)
- go _ rhs = rhs -- Not enough lambdas. This can legitimately happen.
- -- e.g. let j = case ... in j True
- -- This will become an arity-1 join point after the
- -- simplifier has eta-expanded it; but it may not have
- -- enough lambdas /yet/. (Lint checks that JoinIds do
- -- have enough lambdas.)
-markJoinOneShots _ _ rhs
- = rhs
-
occAnalUnfolding :: OccEnv
- -> RecFlag
- -> Maybe JoinArity -- See Note [Join points and unfoldings/rules]
-> Unfolding
- -> WithUsageDetails Unfolding
+ -> WithTailUsageDetails Unfolding
-- Occurrence-analyse a stable unfolding;
--- discard a non-stable one altogether.
-occAnalUnfolding !env is_rec mb_join_arity unf
+-- discard a non-stable one altogether and return empty usage details.
+occAnalUnfolding !env unf
= case unf of
unf@(CoreUnfolding { uf_tmpl = rhs, uf_src = src })
| isStableSource src ->
let
- (WithUsageDetails usage rhs') = occAnalRhs env is_rec mb_join_arity rhs
+ WithTailUsageDetails (TUD rhs_ja usage) rhs' = occAnalLamTail env rhs
unf' | noBinderSwaps env = unf -- Note [Unfoldings and rules]
| otherwise = unf { uf_tmpl = rhs' }
- in WithUsageDetails (markAllMany usage) unf'
+ in WithTailUsageDetails (TUD rhs_ja (markAllMany usage)) unf'
-- markAllMany: see Note [Occurrences in stable unfoldings]
- | otherwise -> WithUsageDetails emptyDetails unf
+ | otherwise -> WithTailUsageDetails (TUD 0 emptyDetails) unf
-- For non-Stable unfoldings we leave them undisturbed, but
-- don't count their usage because the simplifier will discard them.
-- We leave them undisturbed because nodeScore uses their size info
@@ -1954,29 +2055,26 @@ occAnalUnfolding !env is_rec mb_join_arity unf
-- scope remain in scope; there is no cloning etc.
unf@(DFunUnfolding { df_bndrs = bndrs, df_args = args })
- -> WithUsageDetails final_usage (unf { df_args = args' })
+ -> WithTailUsageDetails (TUD 0 final_usage) (unf { df_args = args' })
where
env' = env `addInScope` bndrs
(WithUsageDetails usage args') = occAnalList env' args
- final_usage = markAllManyNonTail (delDetailsList usage bndrs)
- `addLamCoVarOccs` bndrs
- `delDetailsList` bndrs
+ final_usage = usage `addLamCoVarOccs` bndrs `delDetailsList` bndrs
-- delDetailsList; no need to use tagLamBinders because we
-- never inline DFuns so the occ-info on binders doesn't matter
- unf -> WithUsageDetails emptyDetails unf
+ unf -> WithTailUsageDetails (TUD 0 emptyDetails) unf
occAnalRules :: OccEnv
- -> Maybe JoinArity -- See Note [Join points and unfoldings/rules]
-> Id -- Get rules from here
-> [(CoreRule, -- Each (non-built-in) rule
UsageDetails, -- Usage details for LHS
- UsageDetails)] -- Usage details for RHS
-occAnalRules !env mb_join_arity bndr
+ TailUsageDetails)] -- Usage details for RHS
+occAnalRules !env bndr
= map occ_anal_rule (idCoreRules bndr)
where
occ_anal_rule rule@(Rule { ru_bndrs = bndrs, ru_args = args, ru_rhs = rhs })
- = (rule', lhs_uds', rhs_uds')
+ = (rule', lhs_uds', TUD rhs_ja rhs_uds')
where
env' = env `addInScope` bndrs
rule' | noBinderSwaps env = rule -- Note [Unfoldings and rules]
@@ -1989,14 +2087,11 @@ occAnalRules !env mb_join_arity bndr
(WithUsageDetails rhs_uds rhs') = occAnal env' rhs
-- Note [Rules are extra RHSs]
-- Note [Rule dependency info]
- rhs_uds' = markAllNonTailIf (not exact_join) $
- markAllMany $
+ rhs_uds' = markAllMany $
rhs_uds `delDetailsList` bndrs
+ rhs_ja = length args -- See Note [Join points and unfoldings/rules]
- exact_join = exactJoin mb_join_arity args
- -- See Note [Join points and unfoldings/rules]
-
- occ_anal_rule other_rule = (other_rule, emptyDetails, emptyDetails)
+ occ_anal_rule other_rule = (other_rule, emptyDetails, TUD 0 emptyDetails)
{- Note [Join point RHSs]
~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -2031,6 +2126,8 @@ Another way to think about it: if we inlined g as-is into multiple
call sites, now there's be multiple calls to f.
Bottom line: treat all occurrences in a stable unfolding as "Many".
+We still leave tail call information in tact, though, as to not spoil
+potential join points.
Note [Unfoldings and rules]
~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -2208,10 +2305,9 @@ occAnal env app@(App _ _)
= occAnalApp env (collectArgsTicks tickishFloatable app)
occAnal env expr@(Lam {})
- = let (WithUsageDetails usage expr') = occAnalLam env expr
- final_usage = markAllInsideLamIf (not (isOneShotFun expr')) $
- markAllNonTail usage
- in WithUsageDetails final_usage expr'
+ = WithUsageDetails (adjustTailUsage Nothing expr' tusg) expr' -- mb_join_arity == Nothing <=> markAllManyNonTail
+ where
+ WithTailUsageDetails tusg expr' = occAnalLamTail env expr
occAnal env (Case scrut bndr ty alts)
= let
@@ -2286,8 +2382,10 @@ occAnalApp !env (Var fun, args, ticks)
-- This caused #18296
| fun `hasKey` runRWKey
, [t1, t2, arg] <- args
- , let (WithUsageDetails usage arg') = occAnalRhs env NonRecursive (Just 1) arg
- = WithUsageDetails usage (mkTicks ticks $ mkApps (Var fun) [t1, t2, arg'])
+ , WithTailUsageDetails tusg arg1 <- occAnalLamTail env arg
+ , let !arg2 = markNonRecJoinOneShots 1 arg1
+ , let !usage = adjustTailUsage (Just 1) arg2 tusg
+ = WithUsageDetails usage (mkTicks ticks $ mkApps (Var fun) [t1, t2, arg2])
occAnalApp env (Var fun_id, args, ticks)
= WithUsageDetails all_uds (mkTicks ticks app')
@@ -2865,7 +2963,6 @@ lookupBndrSwap env@(OccEnv { occ_bs_env = bs_env }) bndr
case lookupBndrSwap env bndr1 of
(fun, fun_id) -> (mkCastMCo fun mco, fun_id) }
-
{- Historical note [Proxy let-bindings]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
We used to do the binder-swap transformation by introducing
@@ -2991,6 +3088,19 @@ data UsageDetails
instance Outputable UsageDetails where
ppr ud = ppr (ud_env (flattenUsageDetails ud))
+-- | Captures the result of applying 'occAnalLamTail' to a function `\xyz.body`.
+-- The TailUsageDetails records
+-- * the number of lambdas (including type lambdas: a JoinArity)
+-- * UsageDetails for the `body`, unadjusted by `adjustTailUsage`.
+-- If the binding turns out to be a join point with the indicated join
+-- arity, this unadjusted usage details is just what we need; otherwise we
+-- need to discard tail calls. That's what `adjustTailUsage` does.
+data TailUsageDetails = TUD !JoinArity !UsageDetails
+
+instance Outputable TailUsageDetails where
+ ppr (TUD ja uds) = lambda <> ppr ja <> ppr uds
+
+
-------------------
-- UsageDetails API
@@ -3132,24 +3242,38 @@ flattenUsageDetails ud@(UD { ud_env = env })
-------------------
-- See Note [Adjusting right-hand sides]
-adjustRhsUsage :: Maybe JoinArity
- -> CoreExpr -- Rhs, AFTER occ anal
- -> UsageDetails -- From body of lambda
+adjustTailUsage :: Maybe JoinArity
+ -> CoreExpr -- Rhs, AFTER occAnalLamTail
+ -> TailUsageDetails -- From body of lambda
-> UsageDetails
-adjustRhsUsage mb_join_arity rhs usage
+adjustTailUsage mb_join_arity rhs (TUD rhs_ja usage)
= -- c.f. occAnal (Lam {})
markAllInsideLamIf (not one_shot) $
markAllNonTailIf (not exact_join) $
usage
where
one_shot = isOneShotFun rhs
- exact_join = exactJoin mb_join_arity bndrs
- (bndrs,_) = collectBinders rhs
+ exact_join = mb_join_arity == Just rhs_ja
-exactJoin :: Maybe JoinArity -> [a] -> Bool
-exactJoin Nothing _ = False
-exactJoin (Just join_arity) args = args `lengthIs` join_arity
- -- Remember join_arity includes type binders
+adjustTailArity :: Maybe JoinArity -> TailUsageDetails -> UsageDetails
+adjustTailArity mb_rhs_ja (TUD ud_ja usage) =
+ markAllNonTailIf (mb_rhs_ja /= Just ud_ja) usage
+
+markNonRecJoinOneShots :: JoinArity -> CoreExpr -> CoreExpr
+-- For a /non-recursive/ join point we can mark all
+-- its join-lambda as one-shot; and it's a good idea to do so
+markNonRecJoinOneShots join_arity rhs
+ = go join_arity rhs
+ where
+ go 0 rhs = rhs
+ go n (Lam b rhs) = Lam (if isId b then setOneShotLambda b else b)
+ (go (n-1) rhs)
+ go _ rhs = rhs -- Not enough lambdas. This can legitimately happen.
+ -- e.g. let j = case ... in j True
+ -- This will become an arity-1 join point after the
+ -- simplifier has eta-expanded it; but it may not have
+ -- enough lambdas /yet/. (Lint checks that JoinIds do
+ -- have enough lambdas.)
type IdWithOccInfo = Id
@@ -3185,8 +3309,8 @@ tagLamBinder usage bndr
tagNonRecBinder :: TopLevelFlag -- At top level?
-> UsageDetails -- Of scope
-> CoreBndr -- Binder
- -> (UsageDetails, -- Details with binder removed
- IdWithOccInfo) -- Tagged binder
+ -> WithUsageDetails -- Details with binder removed
+ IdWithOccInfo -- Tagged binder
tagNonRecBinder lvl usage binder
= let
@@ -3198,37 +3322,36 @@ tagNonRecBinder lvl usage binder
binder' = setBinderOcc occ' binder
usage' = usage `delDetails` binder
in
- usage' `seq` (usage', binder')
+ WithUsageDetails usage' binder'
tagRecBinders :: TopLevelFlag -- At top level?
-> UsageDetails -- Of body of let ONLY
- -> [Details]
- -> (UsageDetails, -- Adjusted details for whole scope,
+ -> [NodeDetails]
+ -> WithUsageDetails -- Adjusted details for whole scope,
-- with binders removed
- [IdWithOccInfo]) -- Tagged binders
+ [IdWithOccInfo] -- Tagged binders
-- Substantially more complicated than non-recursive case. Need to adjust RHS
-- details *before* tagging binders (because the tags depend on the RHSes).
tagRecBinders lvl body_uds details_s
= let
bndrs = map nd_bndr details_s
- rhs_udss = map nd_uds details_s
-- 1. Determine join-point-hood of whole group, as determined by
- -- the *unadjusted* usage details
- unadj_uds = foldr andUDs body_uds rhs_udss
-
- -- This is only used in `mb_join_arity`, to adjust each `Details` in `details_s`, thus,
- -- when `bndrs` is non-empty. So, we only write `maybe False` as `decideJoinPointHood`
- -- takes a `NonEmpty CoreBndr`; the default value `False` won't affect program behavior.
- will_be_joins = maybe False (decideJoinPointHood lvl unadj_uds) (nonEmpty bndrs)
-
- -- 2. Adjust usage details of each RHS, taking into account the
- -- join-point-hood decision
- rhs_udss' = [ adjustRhsUsage (mb_join_arity bndr) rhs rhs_uds
- | ND { nd_bndr = bndr, nd_uds = rhs_uds
- , nd_rhs = rhs } <- details_s ]
+ -- the *unadjusted* usage details, as if they all ended up as join
+ -- points. Hence 'assumeCorrectTailArity'.
+ unadj_uds = foldr (andUDs . test_manifest_arity) body_uds details_s
+ test_manifest_arity ND{nd_rhs=rhs,nd_uds=uds}
+ = adjustTailArity (Just (joinRhsArity rhs)) uds
+ -- joinRhsArity: See Note [Join arity prediction based on joinRhsArity]
+ -- This is the place we test (again) for source (M); makeNode had
+ -- better made uds for that same join arity!
+
+ bndr_ne = expectNonEmpty "List of binders is never empty" bndrs
+ will_be_joins = decideJoinPointHood lvl unadj_uds bndr_ne
mb_join_arity :: Id -> Maybe JoinArity
+ -- mb_join_arity: See Note [Join arity prediction based on joinRhsArity]
+ -- This is the source (O)
mb_join_arity bndr
-- Can't use willBeJoinId_maybe here because we haven't tagged
-- the binder yet (the tag depends on these adjustments!)
@@ -3240,6 +3363,12 @@ tagRecBinders lvl body_uds details_s
= assert (not will_be_joins) -- Should be AlwaysTailCalled if
Nothing -- we are making join points!
+ -- 2. Adjust usage details of each RHS, taking into account the
+ -- join-point-hood decision
+ rhs_udss' = [ adjustTailUsage (mb_join_arity bndr) rhs rhs_uds -- matching occAnalLamTail in makeNode
+ | ND { nd_bndr = bndr, nd_uds = rhs_uds
+ , nd_rhs = rhs } <- details_s ]
+
-- 3. Compute final usage details from adjusted RHS details
adj_uds = foldr andUDs body_uds rhs_udss'
@@ -3250,7 +3379,7 @@ tagRecBinders lvl body_uds details_s
-- 5. Drop the binders from the adjusted details and return
usage' = adj_uds `delDetailsList` bndrs
in
- (usage', bndrs')
+ WithUsageDetails usage' bndrs'
setBinderOcc :: OccInfo -> CoreBndr -> CoreBndr
setBinderOcc occ_info bndr
@@ -3264,12 +3393,13 @@ setBinderOcc occ_info bndr
| otherwise = setIdOccInfo bndr occ_info
--- | Decide whether some bindings should be made into join points or not.
+-- | Decide whether some bindings should be made into join points or not, based
+-- on its occurrences. This is
-- Returns `False` if they can't be join points. Note that it's an
-- all-or-nothing decision, as if multiple binders are given, they're
-- assumed to be mutually recursive.
--
--- It must, however, be a final decision. If we say "True" for 'f',
+-- It must, however, be a final decision. If we say `True` for 'f',
-- and then subsequently decide /not/ make 'f' into a join point, then
-- the decision about another binding 'g' might be invalidated if (say)
-- 'f' tail-calls 'g'.
=====================================
compiler/GHC/Data/Graph/Directed.hs
=====================================
@@ -4,6 +4,7 @@
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}
+{-# LANGUAGE DeriveFunctor #-}
module GHC.Data.Graph.Directed (
Graph, graphFromEdgedVerticesOrd, graphFromEdgedVerticesUniq,
@@ -107,7 +108,7 @@ data Node key payload = DigraphNode {
node_payload :: payload, -- ^ User data
node_key :: key, -- ^ User defined node id
node_dependencies :: [key] -- ^ Dependencies/successors of the node
- }
+ } deriving Functor
instance (Outputable a, Outputable b) => Outputable (Node a b) where
=====================================
compiler/GHC/Utils/Misc.hs
=====================================
@@ -39,7 +39,7 @@ module GHC.Utils.Misc (
equalLength, compareLength, leLength, ltLength,
isSingleton, only, expectOnly, GHC.Utils.Misc.singleton,
- notNull, snocView,
+ notNull, expectNonEmpty, snocView,
chunkList,
@@ -504,7 +504,6 @@ expectOnly _ (a:_) = a
#endif
expectOnly msg _ = panic ("expectOnly: " ++ msg)
-
-- | Split a list into chunks of /n/ elements
chunkList :: Int -> [a] -> [[a]]
chunkList _ [] = []
@@ -523,6 +522,16 @@ changeLast [] _ = panic "changeLast"
changeLast [_] x = [x]
changeLast (x:xs) x' = x : changeLast xs x'
+-- | Like @expectJust msg . nonEmpty@; a better alternative to 'NE.fromList'.
+expectNonEmpty :: HasCallStack => String -> [a] -> NonEmpty a
+{-# INLINE expectNonEmpty #-}
+expectNonEmpty _ (x:xs) = x:|xs
+expectNonEmpty msg [] = expectNonEmptyPanic msg
+
+expectNonEmptyPanic :: String -> a
+expectNonEmptyPanic msg = panic ("expectNonEmpty: " ++ msg)
+{-# NOINLINE expectNonEmptyPanic #-}
+
-- | Apply an effectful function to the last list element.
mapLastM :: Functor f => (a -> f a) -> NonEmpty a -> f (NonEmpty a)
mapLastM f (x:|[]) = NE.singleton <$> f x
=====================================
testsuite/tests/simplCore/should_compile/T22428.hs
=====================================
@@ -0,0 +1,9 @@
+module T22428 where
+
+f :: Integer -> Integer -> Integer
+f x y = go y
+ where
+ go :: Integer -> Integer
+ go 0 = x
+ go n = go (n-1)
+ {-# INLINE go #-}
=====================================
testsuite/tests/simplCore/should_compile/T22428.stderr
=====================================
@@ -0,0 +1,45 @@
+
+==================== Tidy Core ====================
+Result size of Tidy Core
+ = {terms: 32, types: 14, coercions: 0, joins: 1/1}
+
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
+T22428.f1 :: Integer
+[GblId,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+ WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 10}]
+T22428.f1 = GHC.Num.Integer.IS 1#
+
+-- RHS size: {terms: 28, types: 10, coercions: 0, joins: 1/1}
+f :: Integer -> Integer -> Integer
+[GblId,
+ Arity=2,
+ Str=<SL><1L>,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+ WorkFree=True, Expandable=True, Guidance=IF_ARGS [0 0] 156 0}]
+f = \ (x :: Integer) (y :: Integer) ->
+ joinrec {
+ go [InlPrag=INLINE (sat-args=1), Occ=LoopBreaker, Dmd=SC(S,L)]
+ :: Integer -> Integer
+ [LclId[JoinId(1)(Just [!])],
+ Arity=1,
+ Str=<1L>,
+ Unf=Unf{Src=StableUser, TopLvl=False, Value=True, ConLike=True,
+ WorkFree=True, Expandable=True,
+ Guidance=ALWAYS_IF(arity=1,unsat_ok=False,boring_ok=False)}]
+ go (ds :: Integer)
+ = case ds of wild {
+ GHC.Num.Integer.IS x1 ->
+ case x1 of {
+ __DEFAULT -> jump go (GHC.Num.Integer.integerSub wild T22428.f1);
+ 0# -> x
+ };
+ GHC.Num.Integer.IP x1 ->
+ jump go (GHC.Num.Integer.integerSub wild T22428.f1);
+ GHC.Num.Integer.IN x1 ->
+ jump go (GHC.Num.Integer.integerSub wild T22428.f1)
+ }; } in
+ jump go y
+
+
+
=====================================
testsuite/tests/simplCore/should_compile/all.T
=====================================
@@ -450,5 +450,9 @@ test('T22375', normal, compile, ['-O -ddump-simpl -dsuppress-uniques -dno-typeab
# One module, T21851_2.hs, has OPTIONS_GHC -ddump-simpl
# Expecting to see $s$wwombat
test('T21851_2', [grep_errmsg(r'wwombat') ], multimod_compile, ['T21851_2', '-O -dno-typeable-binds -dsuppress-uniques'])
+
# Should not inline m, so there shouldn't be a single YES
test('T22317', [grep_errmsg(r'ANSWER = YES') ], compile, ['-O -dinline-check m -ddebug-output'])
+
+# go should become a join point
+test('T22428', [grep_errmsg(r'jump go') ], compile, ['-O -ddump-simpl -dsuppress-uniques -dno-typeable-binds -dsuppress-unfoldings'])
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/71da9da7da0c2e1d95111b46ec8fdf93ed763527
--
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/71da9da7da0c2e1d95111b46ec8fdf93ed763527
You're receiving this email because of your account on gitlab.haskell.org.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20221129/cbd4c7c6/attachment-0001.html>
More information about the ghc-commits
mailing list