[Git][ghc/ghc][wip/T23113] WorkWrap: Rethink threshold arity for join points (#23113)

Sebastian Graf (@sgraf812) gitlab at gitlab.haskell.org
Tue Aug 8 18:14:12 UTC 2023



Sebastian Graf pushed to branch wip/T23113 at Glasgow Haskell Compiler / GHC


Commits:
58bfd214 by Sebastian Graf at 2023-08-08T20:12:50+02:00
WorkWrap: Rethink threshold arity for join points (#23113)

... and document our ponderings in `Note [Threshold arity for join points]`.

Fixes #23113

- - - - -


4 changed files:

- compiler/GHC/Core/Opt/DmdAnal.hs
- compiler/GHC/Core/Opt/WorkWrap.hs
- compiler/GHC/Core/Opt/WorkWrap/Utils.hs
- compiler/GHC/Types/Demand.hs


Changes:

=====================================
compiler/GHC/Core/Opt/DmdAnal.hs
=====================================
@@ -1071,30 +1071,35 @@ dmdAnalRhsSig
 -- Process the RHS of the binding, add the strictness signature
 -- to the Id, and augment the environment with the signature as well.
 -- See Note [NOINLINE and strictness]
-dmdAnalRhsSig top_lvl rec_flag env let_dmd id rhs
-  = -- pprTrace "dmdAnalRhsSig" (ppr id $$ ppr let_dmd $$ ppr rhs_dmds $$ ppr sig $$ ppr weak_fvs) $
+dmdAnalRhsSig top_lvl rec_flag env let_sd id rhs
+  = -- pprTrace "dmdAnalRhsSig" (ppr id $$ ppr let_sd $$ ppr rhs_dmds $$ ppr sig $$ ppr weak_fv) $
     (final_env, weak_fvs, final_id, final_rhs)
   where
-    threshold_arity = thresholdArity id rhs
-
-    rhs_dmd = mkCalledOnceDmds threshold_arity body_dmd
-
-    body_dmd
+    ww_arity = wwArity id rhs
+    threshold_sd = mkCalledOnceDmds ww_arity body_sd
+    body_sd
       | isJoinId id
       -- See Note [Demand analysis for join points]
       -- See Note [Invariants on join points] invariant 2b, in GHC.Core
-      --     threshold_arity matches the join arity of the join point
+      --     ww_arity matches the join arity of the join point
       -- See Note [Unboxed demand on function bodies returning small products]
-      = unboxedWhenSmall env rec_flag (resultType_maybe id) let_dmd
+      = unboxedWhenSmall env rec_flag (resultType_maybe id) let_sd
       | otherwise
       -- See Note [Unboxed demand on function bodies returning small products]
       = unboxedWhenSmall env rec_flag (resultType_maybe id) topSubDmd
 
-    WithDmdType rhs_dmd_ty rhs' = dmdAnal env rhs_dmd rhs
+    WithDmdType rhs_dmd_ty rhs' = dmdAnal env threshold_sd rhs
     DmdType rhs_env rhs_dmds = rhs_dmd_ty
-    (final_rhs_dmds, final_rhs) = finaliseArgBoxities env id threshold_arity
+    (final_rhs_dmds, final_rhs) = finaliseArgBoxities env id ww_arity
                                                       rhs_dmds (de_div rhs_env) rhs'
 
+    -- See Note [Demand signatures are computed for a threshold arity based on idArity]
+    -- The key here is that *we know* it is OK to unleash the signature with
+    -- threshold_arity incoming arguments and 'mkDmdSigForArity' encodes this
+    -- information directly in the signature.
+    -- For join points, threshold_arity might be larger than ww_arity.
+    threshold_arity = -- pprTrace "threshold" (ppr id $$ ppr threshold_sd) $
+                      calledOnceArity threshold_sd
     sig = mkDmdSigForArity threshold_arity (DmdType sig_env final_rhs_dmds)
 
     opts       = ae_opts env
@@ -1127,13 +1132,6 @@ splitWeakDmds :: DmdEnv -> (DmdEnv, WeakDmds)
 splitWeakDmds (DE fvs div) = (DE sig_fvs div, weak_fvs)
   where (!weak_fvs, !sig_fvs) = partitionVarEnv isWeakDmd fvs
 
-thresholdArity :: Id -> CoreExpr -> Arity
--- See Note [Demand signatures are computed for a threshold arity based on idArity]
-thresholdArity fn rhs
-  = case idJoinPointHood fn of
-      JoinPoint join_arity -> count isId $ fst $ collectNBinders join_arity rhs
-      NotJoinPoint         -> idArity fn
-
 -- | The result type after applying 'idArity' many arguments. Returns 'Nothing'
 -- when the type doesn't have exactly 'idArity' many arrows.
 resultType_maybe :: Id -> Maybe Type
@@ -1283,13 +1281,10 @@ The threshold we use is
   idArity is /at least/ the number of manifest lambdas, but might be higher for
   PAPs and trivial RHS (see Note [Demand analysis for trivial right-hand sides]).
 
-* Join points: the value-binder subset of the JoinArity.  This can
-  be less than the number of visible lambdas; e.g.
-     join j x = \y. blah
-     in ...(jump j 2)....(jump j 3)....
-  We know that j will never be applied to more than 1 arg (its join
-  arity, and we don't eta-expand join points, so here a threshold
-  of 1 is the best we can do.
+* Join points: Most often the value-binder subset of the JoinArity, which is
+  often smaller or equal to idArity. But it can also be larger than that because
+  we may consider how the join body is used;
+  see Note [Threshold arity of join points].
 
 Note that the idArity of a function varies independently of its cardinality
 properties (cf. Note [idArity varies independently of dmdTypeDepth]), so we
@@ -1369,6 +1364,48 @@ say that f's arity is no greater than 2, because it'd be false to say
 that f does no work when applied to 3 args.  Lint checks this constraint,
 in `GHC.Core.Lint.lintLetBind`.
 
+See also Note [Threshold arity of join points] for how the threshold arity of
+join points is special.
+
+Note [Threshold arity of join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+The threshold arity in the demand signature of a join point might be
+
+  * Less than `idArity`:
+      join j' x = \pqr. blah in ...(jump j' 1)... (jump j' 2)...
+    Here idArity is 4, but join-arity is 1. Easy.
+  * More than `idArity`:
+      f g = g 42   :: <C(1,L)>
+      h x = f (join j y = (+) y in ... j 13 ...)
+    Note that f's demand on its arg is put on the join expr and hence its RHS.
+    How this is achieved is described in Note [Demand analysis for join points].
+    In this Note, we refer to it as the known-context assumption.
+
+The latter example is interesting, because analysis ends up with a demand /type/
+of <1!L><1!L> for the RHS of the `j`, based on the arity 2 signature of `(+)`,
+but we can't unbox both arguments lest we'd eta expand and thus would be
+destroying joinpointhood:
+  f ( join j y z = case y of I# y# -> case z of I# z# -> $wj y# z# )
+    (      $wj y# z# = y# +# z#                                    )
+    ( in ... j 13 ...                                              )
+This is ill-formed because the jump to `j` is with arity 1.
+
+So `finaliseArgBoxities` will instead drop boxity info of the second arg,
+keeping only the boxity on the first arg. Result: Signature <1!L><1L>.
+Worker/wrapper then ignores any excess argument demands for join points.
+(This is OK, as every call is still with 2 incoming arguments, as can be
+asserted by reconstructing the threshold demand on `j`.)
+This produces the following W/W split
+  join   j y = case y of I# y# -> $wj y#
+       $wj y# = let y = I# y# in (+) y
+  in ... j 13 ...
+It is likely that the wrapper (+) inlines, thus we get
+  join   j y = case y of I# y# -> $wj y#
+       $wj y# = \z -> case z of I# z# -> y# +# z#
+  in ... $wj 13# ...
+Which still saves allocating the closure for 13 at the call site (but in turn
+needs to allocate a closure for the lambda).
+
 Note [Demand analysis for trivial right-hand sides]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 Consider
@@ -1922,8 +1959,11 @@ positiveTopBudget (MkB n _) = n >= 0
 finaliseArgBoxities :: AnalEnv -> Id -> Arity
                     -> [Demand] -> Divergence
                     -> CoreExpr -> ([Demand], CoreExpr)
-finaliseArgBoxities env fn threshold_arity rhs_dmds div rhs
-
+-- POSTCONDITION:
+-- The demand info in 'rhs_dmds' goes untouched into the first component of the
+-- result pair (including its length).
+-- It might lose some or all of its boxity info, though.
+finaliseArgBoxities env fn ww_arity rhs_dmds div rhs
   -- Check for an OPAQUE function: see Note [OPAQUE pragma]
   -- In that case, trim off all boxity info from argument demands
   -- and demand info on lambda binders
@@ -1933,24 +1973,38 @@ finaliseArgBoxities env fn threshold_arity rhs_dmds div rhs
   = (trimmed_rhs_dmds, set_lam_dmds trimmed_rhs_dmds rhs)
 
   -- Check that we have enough visible binders to match the
-  -- threshold arity; if not, we won't do worker/wrapper
+  -- WW arity; if not, we won't do worker/wrapper
   -- This happens if we have simply  {f = g} or a PAP {f = h 13}
   -- we simply want to give f the same demand signature as g
   -- How can such bindings arise?  Perhaps from {-# NOLINE[2] f #-},
   -- or if the call to `f` is currently not-applied (map f xs).
   -- It's a bit of a corner case.  Anyway for now we pass on the
   -- unadulterated demands from the RHS, without any boxity trimming.
-  | threshold_arity > count isId bndrs
+  | ww_arity > count isId bndrs
   = (rhs_dmds, rhs)
 
   -- The normal case
-  | otherwise -- NB: threshold_arity might be less than
-              -- manifest arity for join points
+  | otherwise
   = -- pprTrace "finaliseArgBoxities" (
     --   vcat [text "function:" <+> ppr fn
     --        , text "max" <+> ppr max_wkr_args
     --        , text "dmds before:" <+> ppr (map idDemandInfo (filter isId bndrs))
     --        , text "dmds after: " <+>  ppr arg_dmds' ]) $
+
+    -- Let us check for some pre-conditions to keep our sanity:
+    assertPpr (  length arg_triples >= ww_arity
+                   -- As per the PAP case above
+              && length rhs_dmds >= length arg_triples)
+                   -- The rhs_dmds come from the lambda case; hence there should
+                   -- be at least as many rhs_dmds as there are lambda binders
+                   -- (hence arg_triples)
+              (text "finaliseArgBoxities: more arg_triples than ww_arity") $
+    warnPprTrace (not (isJoinId fn) && length rhs_dmds > ww_arity)
+                 "finaliseArgBoxities: excess rhs_dmds"
+                 (ppr fn <+> ppr (length bndrs) <+> ppr ww_arity <+> ppr rhs_dmds) $
+                  -- It is far from clear that it's OK to ignore excess rhs_dmds
+                  -- here rather than zap all boxity. Hence we warn to collect
+                  -- some examples. See Note [Threshold arity of join points]
     (arg_dmds', set_lam_dmds arg_dmds' rhs)
     -- set_lam_dmds: we must attach the final boxities to the lambda-binders
     -- of the function, both because that's kosher, and because CPR analysis
@@ -1963,16 +2017,21 @@ finaliseArgBoxities env fn threshold_arity rhs_dmds div rhs
                       -- This is the budget initialisation step of
                       -- Note [Worker argument budget]
 
-    -- This is the key line, which uses almost-circular programming
-    -- The remaining budget from one layer becomes the initial
-    -- budget for the next layer down.  See Note [Worker argument budget]
-    (remaining_budget, arg_dmds') = go_args (MkB max_wkr_args remaining_budget) arg_triples
-
     arg_triples :: [(Type, StrictnessMark, Demand)]
-    arg_triples = take threshold_arity $
+    arg_triples = take ww_arity $
                   [ (idType bndr, NotMarkedStrict, get_dmd bndr)
                   | bndr <- bndrs, isRuntimeVar bndr ]
 
+    nonboxy_dmds = map trimBoxity $ drop ww_arity rhs_dmds
+    arg_dmds' = boxy_dmds ++ nonboxy_dmds
+      -- NB: length of arg_dmds' is the same as rhs_dmds, as per pre-conditions
+      -- above
+
+    -- This is the key line, which uses almost-circular programming
+    -- The remaining budget from one layer becomes the initial
+    -- budget for the next layer down.  See Note [Worker argument budget]
+    (remaining_budget, boxy_dmds) = go_args (MkB max_wkr_args remaining_budget) arg_triples
+
     get_dmd :: Id -> Demand
     get_dmd bndr
       | is_bot_fn = unboxDeeplyDmd dmd -- See Note [Boxity for bottoming functions],


=====================================
compiler/GHC/Core/Opt/WorkWrap.hs
=====================================
@@ -759,11 +759,8 @@ by LitRubbish (see Note [Drop absent bindings]) so there is no great harm.
 ---------------------
 splitFun :: WwOpts -> Id -> CoreExpr -> UniqSM [(Id, CoreExpr)]
 splitFun ww_opts fn_id rhs
-  | Just (arg_vars, body) <- collectNValBinders_maybe (length wrap_dmds) rhs
-  = warnPprTrace (not (wrap_dmds `lengthIs` (arityInfo fn_info)))
-                 "splitFun"
-                 (ppr fn_id <+> (ppr wrap_dmds $$ ppr cpr)) $
-    do { mb_stuff <- mkWwBodies ww_opts fn_id arg_vars (exprType body) wrap_dmds cpr
+  | Just (arg_vars, body) <- collectNValBinders_maybe (wwArity fn_id rhs) rhs
+  = do { mb_stuff <- mkWwBodies ww_opts fn_id arg_vars (exprType body) wrap_dmds cpr
        ; case mb_stuff of
             Nothing -> -- No useful wrapper; leave the binding alone
                        return [(fn_id, rhs)]


=====================================
compiler/GHC/Core/Opt/WorkWrap/Utils.hs
=====================================
@@ -11,7 +11,7 @@ module GHC.Core.Opt.WorkWrap.Utils
    ( WwOpts(..), mkWwBodies, mkWWstr, mkWWstr_one
    , needsVoidWorkerArg
    , DataConPatContext(..)
-   , UnboxingDecision(..), canUnboxArg
+   , UnboxingDecision(..), canUnboxArg, wwArity
    , findTypeShape, IsRecDataConResult(..), isRecDataCon
    , mkAbsentFiller
    , isWorkerSmallEnough, dubiousDataConInstArgTys
@@ -207,7 +207,10 @@ mkWwBodies :: WwOpts
 -- and its unfolding(s) alike.
 --
 mkWwBodies opts fun_id arg_vars res_ty demands res_cpr
-  = do  { massertPpr (filter isId arg_vars `equalLength` demands)
+  = do  { massertPpr (isJoinId fun_id || (filter isId arg_vars `equalLength` demands))
+                        -- Threshold arity should match manifest arity here,
+                        -- UNLESS it's a join point
+                        -- See Note [Threshold arity of join points]
                      (text "wrong wrapper arity" $$ ppr fun_id $$ ppr arg_vars $$ ppr res_ty $$ ppr demands)
 
         -- Clone and prepare arg_vars of the original fun RHS
@@ -289,6 +292,13 @@ isWorkerSmallEnough max_worker_args old_n_args vars
     -- Also if the function took 82 arguments before (old_n_args), it's fine if
     -- it takes <= 82 arguments afterwards.
 
+wwArity :: Id -> CoreExpr -> Arity
+-- The arity for which we want to produce a boxity signature
+wwArity fn rhs
+  = case idJoinPointHood fn of
+      JoinPoint join_arity -> count isId $ fst $ collectNBinders join_arity rhs
+      NotJoinPoint         -> idArity fn
+
 {-
 Note [Always do CPR w/w]
 ~~~~~~~~~~~~~~~~~~~~~~~~


=====================================
compiler/GHC/Types/Demand.hs
=====================================
@@ -36,7 +36,7 @@ module GHC.Types.Demand (
     lazyApply1Dmd, lazyApply2Dmd, strictOnceApply1Dmd, strictManyApply1Dmd,
     -- ** Other @Demand@ operations
     oneifyCard, oneifyDmd, strictifyDmd, strictifyDictDmd, lazifyDmd,
-    peelCallDmd, peelManyCalls, mkCalledOnceDmd, mkCalledOnceDmds,
+    peelCallDmd, peelManyCalls, mkCalledOnceDmd, mkCalledOnceDmds, calledOnceArity,
     mkWorkerDemand, subDemandIfEvaluated,
     -- ** Extracting one-shot information
     argOneShots, argsOneShots, saturatedByOneShots,
@@ -1036,6 +1036,16 @@ peelManyCalls k sd = go k C_11 sd
     go _ _  _                          = (topCard, topSubDmd)
 {-# INLINE peelManyCalls #-} -- so that the pair cancels away in a `fst _` context
 
+calledOnceArity :: SubDemand -> Arity
+calledOnceArity sd = go 0 sd
+  where
+    go n (Call m sd) | isUsedOnce m = go (n+1) sd
+      -- NB: /Not/ viewCall, because we'd go infinitely deep on a Poly without
+      -- knowing the type arity (the upper bound for the threshold).
+      -- Besides, we only really are interested in C_11 or C_01 Calls for
+      -- which we'll never use Poly anyway (cf. 'CardNonOnce').
+    go n _                          = n
+
 -- | Extract the 'SubDemand' of a 'Demand'.
 -- PRECONDITION: The SubDemand must be used in a context where the expression
 -- denoted by the Demand is under evaluation.
@@ -2135,6 +2145,7 @@ immediately specifying the incoming demand it was produced under. Despite StrSig
 being a newtype wrapper around DmdType, it actually encodes two things:
 
   * The threshold (i.e., minimum arity) to unleash the signature
+    See Note [Demand signatures are computed for a threshold arity based on idArity]
   * A demand type that is sound to unleash when the minimum arity requirement is
     met.
 
@@ -2155,9 +2166,20 @@ newtype DmdSig
 -- | Turns a 'DmdType' computed for the particular 'Arity' into a 'DmdSig'
 -- unleashable at that arity. See Note [Understanding DmdType and DmdSig].
 mkDmdSigForArity :: Arity -> DmdType -> DmdSig
-mkDmdSigForArity arity dmd_ty@(DmdType fvs args)
-  | arity < dmdTypeDepth dmd_ty = DmdSig $ DmdType fvs (take arity args)
-  | otherwise                   = DmdSig (etaExpandDmdType arity dmd_ty)
+mkDmdSigForArity threshold_arity dmd_ty@(DmdType fvs args)
+  | threshold_arity < dmdTypeDepth dmd_ty
+  = assertPpr (de_div fvs == topDiv) (ppr (de_div fvs)) $
+      -- Consider
+      --   f = \x -> case ... of ... -> \y -> bot
+      -- with threshold_arity=1. Note that dmd_ty=<B><B>b can't happen:
+      -- Since threshold arity was 1, we'd lazify the result of the Lam case
+      -- `\y -> bot`, turning <B>b into <L>, resulting in <L><L> as the
+      -- demand type passed into this function.
+      -- Hence we may simply drop the excess arguments and retain the old div,
+      -- which is asserted to be 'topDiv'.
+    DmdSig $ DmdType fvs (take threshold_arity args)
+  | otherwise
+  = DmdSig (etaExpandDmdType threshold_arity dmd_ty)
 
 mkClosedDmdSig :: [Demand] -> Divergence -> DmdSig
 mkClosedDmdSig ds div = mkDmdSigForArity (length ds) (DmdType (mkEmptyDmdEnv div) ds)



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/58bfd21476d5aaeba30c52c004f21ed33afa613d

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/58bfd21476d5aaeba30c52c004f21ed33afa613d
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20230808/6ab4c501/attachment-0001.html>


More information about the ghc-commits mailing list