[Git][ghc/ghc][wip/arity-type-9.4] Fix arityType: -fpedantic-bottoms, join points, etc

Zubin (@wz1000) gitlab at gitlab.haskell.org
Wed Oct 26 13:28:22 UTC 2022



Zubin pushed to branch wip/arity-type-9.4 at Glasgow Haskell Compiler / GHC


Commits:
006662a8 by Simon Peyton Jones at 2022-10-26T18:58:09+05:30
Fix arityType: -fpedantic-bottoms, join points, etc

This MR fixes #21694 and #21755

* For #21694 the underlying problem was that we were calling arityType
  on an expression that had free join points.  This is a Bad Bad Idea.
  See Note [No free join points in arityType].

* I also made andArityType work correctly with -fpedantic-bottoms;
  see Note [Combining case branches: andWithTail].

* I realised that, now we have ae_sigs giving the ArityType for
  let-bound Ids, we don't need the (pre-dating) special code in
  arityType for join points. But instead we need to extend the env for
  Rec bindings, which weren't doing before.  More uniform now.  See
  Note [arityType for let-bindings].

  This meant we could get rid of ae_joins, and in fact get rid of
  EtaExpandArity altogether.  Simpler.

  And finally, it was the strange treatment of join-point Ids (involving
  a fake ABot type) that led to a serious bug: #21755.  Fixed by this
  refactoring

* Rewrote Note [Combining case branches: optimistic one-shot-ness]

Compile time improves slightly on average:

Metrics: compile_time/bytes allocated
---------------------------------------------------------------------------------------
CoOpt_Read(normal) ghc/alloc    803,788,056    747,832,680  -7.1% GOOD
    T18223(normal) ghc/alloc    928,207,320    959,424,016  +3.1%  BAD
         geo. mean                                          -0.3%
         minimum                                            -7.1%
         maximum                                            +3.1%

On Windows it's a bit better: geo mean is -0.6%, and three more
benchmarks trip their compile-time bytes-allocated threshold (they
were all close on the other build):

   T18698b(normal) ghc/alloc    235,619,776    233,219,008  -1.0% GOOD
     T6048(optasm) ghc/alloc    112,208,192    109,704,936  -2.2% GOOD
    T18140(normal) ghc/alloc     85,064,192     83,168,360  -2.2% GOOD

I had a quick look at T18223 but it is knee deep in coercions and
the size of everything looks similar before and after.  I decided
to accept that 3.4% increase in exchange for goodness elsewhere.

Metric Decrease:
    CoOpt_Read
    T18140
    T18698b
    T6048

Metric Increase:
    T18223

(cherry picked from commit 5e282da37e19a1ab24ae167daf32276a64ed2842)

- - - - -


11 changed files:

- compiler/GHC/Core.hs
- compiler/GHC/Core/Opt/Arity.hs
- compiler/GHC/Core/Opt/Simplify.hs
- compiler/GHC/Core/Opt/Simplify/Utils.hs
- + testsuite/tests/arityanal/should_compile/T21755.hs
- + testsuite/tests/arityanal/should_compile/T21755.stderr
- testsuite/tests/arityanal/should_compile/all.T
- + testsuite/tests/callarity/should_compile/T21694a.hs
- + testsuite/tests/callarity/should_compile/T21694a.stderr
- + testsuite/tests/simplCore/should_compile/T21694b.hs
- + testsuite/tests/simplCore/should_compile/T21694b.stderr


Changes:

=====================================
compiler/GHC/Core.hs
=====================================
@@ -739,6 +739,7 @@ Join points must follow these invariants:
 
          The arity of a join point isn't very important; but short of setting
          it to zero, it is helpful to have an invariant.  E.g. #17294.
+         See also Note [Do not eta-expand join points] in GHC.Core.Opt.Simplify.Utils.
 
   3. If the binding is recursive, then all other bindings in the recursive group
      must also be join points.


=====================================
compiler/GHC/Core/Opt/Arity.hs
=====================================
@@ -17,7 +17,7 @@ module GHC.Core.Opt.Arity
    , exprBotStrictness_maybe
 
    -- ** ArityType
-   , ArityType(..), mkBotArityType, mkTopArityType, expandableArityType
+   , ArityType(..), mkBotArityType, mkManifestArityType, expandableArityType
    , arityTypeArity, maxWithArity, idArityType
 
    -- ** Join points
@@ -53,7 +53,7 @@ import GHC.Types.Demand
 import GHC.Types.Var
 import GHC.Types.Var.Env
 import GHC.Types.Id
-import GHC.Types.Var.Set
+import GHC.Core.DataCon
 import GHC.Types.Basic
 import GHC.Types.Tickish
 
@@ -594,7 +594,8 @@ same fix.
 -- where the @at@ fields of @ALam@ are inductively subject to the same order.
 -- That is, @ALam os at1 < ALam os at2@ iff @at1 < at2 at .
 --
--- Why the strange Top element? See Note [Combining case branches].
+-- Why the strange Top element?
+--   See Note [Combining case branches: optimistic one-shot-ness]
 --
 -- We rely on this lattice structure for fixed-point iteration in
 -- 'findRhsArity'. For the semantics of 'ArityType', see Note [ArityType].
@@ -641,11 +642,16 @@ mkBotArityType oss = AT oss botDiv
 botArityType :: ArityType
 botArityType = mkBotArityType []
 
-mkTopArityType :: [OneShotInfo] -> ArityType
-mkTopArityType oss = AT oss topDiv
+mkManifestArityType :: [Var] -> CoreExpr -> ArityType
+mkManifestArityType bndrs body
+  = AT oss div
+  where
+    oss = [idOneShotInfo bndr | bndr <- bndrs, isId bndr]
+    div | exprIsDeadEnd body = botDiv
+        | otherwise          = topDiv
 
 topArityType :: ArityType
-topArityType = mkTopArityType []
+topArityType = AT [] topDiv
 
 -- | The number of value args for the arity type
 arityTypeArity :: ArityType -> Arity
@@ -685,7 +691,7 @@ takeWhileOneShot (AT oss div)
 exprEtaExpandArity :: DynFlags -> CoreExpr -> ArityType
 -- exprEtaExpandArity is used when eta expanding
 --      e  ==>  \xy -> e x y
-exprEtaExpandArity dflags e = arityType (etaExpandArityEnv dflags) e
+exprEtaExpandArity dflags e = arityType (findRhsArityEnv dflags) e
 
 getBotArity :: ArityType -> Maybe Arity
 -- Arity of a divergent function
@@ -825,6 +831,7 @@ floatIn cheap at
   | otherwise                      = takeWhileOneShot at
 
 arityApp :: ArityType -> Bool -> ArityType
+
 -- Processing (fun arg) where at is the ArityType of fun,
 -- Knock off an argument and behave like 'let'
 arityApp (AT (_:oss) div) cheap = floatIn cheap (AT oss div)
@@ -834,16 +841,30 @@ arityApp at               _     = at
 -- See the haddocks on 'ArityType' for the lattice.
 --
 -- Used for branches of a @case at .
-andArityType :: ArityType -> ArityType -> ArityType
-andArityType (AT (os1:oss1) div1) (AT (os2:oss2) div2)
-  | AT oss' div' <- andArityType (AT oss1 div1) (AT oss2 div2)
-  = AT ((os1 `bestOneShot` os2) : oss') div' -- See Note [Combining case branches]
-andArityType at1@(AT []         div1) at2
-  | isDeadEndDiv div1 = at2                  -- Note [ABot branches: max arity wins]
-  | otherwise         = at1                  -- See Note [Combining case branches]
-andArityType at1                  at2@(AT []         div2)
-  | isDeadEndDiv div2 = at1                  -- Note [ABot branches: max arity wins]
-  | otherwise         = at2                  -- See Note [Combining case branches]
+andArityType :: ArityEnv -> ArityType -> ArityType -> ArityType
+andArityType env (AT (lam1:lams1) div1) (AT (lam2:lams2) div2)
+  | AT lams' div' <- andArityType env (AT lams1 div1) (AT lams2 div2)
+  = AT ((lam1 `and_lam` lam2) : lams') div'
+  where
+    (os1) `and_lam` (os2)
+      = ( os1 `bestOneShot` os2)
+        -- bestOneShot: see Note [Combining case branches: optimistic one-shot-ness]
+
+andArityType env (AT [] div1) at2 = andWithTail env div1 at2
+andArityType env at1 (AT [] div2) = andWithTail env div2 at1
+
+andWithTail :: ArityEnv -> Divergence -> ArityType -> ArityType
+andWithTail env div1 at2@(AT lams2 _)
+  | isDeadEndDiv div1     -- case x of { T -> error; F -> \y.e }
+  = at2        -- Note [ABot branches: max arity wins]
+
+  | pedanticBottoms env  -- Note [Combining case branches: andWithTail]
+  = AT [] topDiv
+
+  | otherwise  -- case x of { T -> plusInt <expensive>; F -> \y.e }
+  = AT lams2 topDiv    -- We know div1 = topDiv
+    -- See Note [Combining case branches: andWithTail]
+
 
 {- Note [ABot branches: max arity wins]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -854,8 +875,60 @@ Consider   case x of
 Remember: \o1..on.⊥ means "if you apply to n args, it'll definitely diverge".
 So we need \??.⊥ for the whole thing, the /max/ of both arities.
 
-Note [Combining case branches]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Note [Combining case branches: optimistic one-shot-ness]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+When combining the ArityTypes for two case branches (with andArityType)
+and both ArityTypes have ATLamInfo, then we just combine their
+expensive-ness and one-shot info.  The tricky point is when we have
+     case x of True -> \x{one-shot). blah1
+               Fale -> \y.           blah2
+
+Since one-shot-ness is about the /consumer/ not the /producer/, we
+optimistically assume that if either branch is one-shot, we combine
+the best of the two branches, on the (slightly dodgy) basis that if we
+know one branch is one-shot, then they all must be.  Surprisingly,
+this means that the one-shot arity type is effectively the top element
+of the lattice.
+
+Hence the call to `bestOneShot` in `andArityType`.
+
+Note [Combining case branches: andWithTail]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+When combining the ArityTypes for two case branches (with andArityType)
+and one side or the other has run out of ATLamInfo; then we get
+into `andWithTail`.
+
+* If one branch is guaranteed bottom (isDeadEndDiv), we just take
+  the other; see Note [ABot branches: max arity wins]
+
+* Otherwise, if pedantic-bottoms is on, we just have to return
+  AT [] topDiv.  E.g. if we have
+    f x z = case x of True  -> \y. blah
+                      False -> z
+  then we can't eta-expand, because that would change the behaviour
+  of (f False bottom().
+
+* But if pedantic-bottoms is not on, we allow ourselves to push
+  `z` under a lambda (much as we allow ourselves to put the `case x`
+  under a lambda).  However we know nothing about the expensiveness
+  or one-shot-ness of `z`, so we'd better assume it looks like
+  (Expensive, NoOneShotInfo) all the way. Remembering
+  Note [Combining case branches: optimistic one-shot-ness],
+  we just add work to ever ATLamInfo, keeping the one-shot-ness.
+
+Here's an example:
+  go = \x. let z = go e0
+               go2 = \x. case x of
+                           True  -> z
+                           False -> \s(one-shot). e1
+           in go2 x
+We *really* want to respect the one-shot annotation provided by the
+user and eta-expand go and go2.
+When combining the branches of the case we have
+     T `andAT` \1.T
+and we want to get \1.T.
+But if the inner lambda wasn't one-shot (\?.T) we don't want to do this.
+(We need a usage analysis to justify that.)
 
 Unless we can conclude that **all** branches are safe to eta-expand then we
 must pessimisticaly conclude that we can't eta-expand. See #21694 for where this
@@ -934,22 +1007,21 @@ data ArityEnv
   = AE
   { ae_mode   :: !AnalysisMode
   -- ^ The analysis mode. See 'AnalysisMode'.
-  , ae_joins  :: !IdSet
-  -- ^ In-scope join points. See Note [Eta-expansion and join points]
-  --   INVARIANT: Disjoint with the domain of 'am_sigs' (if present).
   }
 
 -- | The @ArityEnv@ used by 'exprBotStrictness_maybe'. Pedantic about bottoms
 -- and no application is ever considered cheap.
 botStrictnessArityEnv :: ArityEnv
-botStrictnessArityEnv = AE { ae_mode = BotStrictness, ae_joins = emptyVarSet }
+botStrictnessArityEnv = AE { ae_mode = BotStrictness }
 
+{-
 -- | The @ArityEnv@ used by 'exprEtaExpandArity'.
 etaExpandArityEnv :: DynFlags -> ArityEnv
 etaExpandArityEnv dflags
   = AE { ae_mode  = EtaExpandArity { am_ped_bot = gopt Opt_PedanticBottoms dflags
                                    , am_dicts_cheap = gopt Opt_DictsCheap dflags }
        , ae_joins = emptyVarSet }
+-}
 
 -- | The @ArityEnv@ used by 'findRhsArity'.
 findRhsArityEnv :: DynFlags -> ArityEnv
@@ -957,7 +1029,11 @@ findRhsArityEnv dflags
   = AE { ae_mode  = FindRhsArity { am_ped_bot = gopt Opt_PedanticBottoms dflags
                                  , am_dicts_cheap = gopt Opt_DictsCheap dflags
                                  , am_sigs = emptyVarEnv }
-       , ae_joins = emptyVarSet }
+       }
+
+isFindRhsArity :: ArityEnv -> Bool
+isFindRhsArity (AE { ae_mode = FindRhsArity {} }) = True
+isFindRhsArity _                                  = False
 
 -- First some internal functions in snake_case for deleting in certain VarEnvs
 -- of the ArityType. Don't call these; call delInScope* instead!
@@ -976,32 +1052,17 @@ del_sig_env_list :: [Id] -> ArityEnv -> ArityEnv -- internal!
 del_sig_env_list ids = modifySigEnv (\sigs -> delVarEnvList sigs ids)
 {-# INLINE del_sig_env_list #-}
 
-del_join_env :: JoinId -> ArityEnv -> ArityEnv -- internal!
-del_join_env id env@(AE { ae_joins = joins })
-  = env { ae_joins = delVarSet joins id }
-{-# INLINE del_join_env #-}
-
-del_join_env_list :: [JoinId] -> ArityEnv -> ArityEnv -- internal!
-del_join_env_list ids env@(AE { ae_joins = joins })
-  = env { ae_joins = delVarSetList joins ids }
-{-# INLINE del_join_env_list #-}
-
 -- end of internal deletion functions
 
-extendJoinEnv :: ArityEnv -> [JoinId] -> ArityEnv
-extendJoinEnv env@(AE { ae_joins = joins }) join_ids
-  = del_sig_env_list join_ids
-  $ env { ae_joins = joins `extendVarSetList` join_ids }
-
 extendSigEnv :: ArityEnv -> Id -> ArityType -> ArityEnv
 extendSigEnv env id ar_ty
-  = del_join_env id (modifySigEnv (\sigs -> extendVarEnv sigs id ar_ty) env)
+  = modifySigEnv (\sigs -> extendVarEnv sigs id ar_ty) env
 
 delInScope :: ArityEnv -> Id -> ArityEnv
-delInScope env id = del_join_env id $ del_sig_env id env
+delInScope env id = del_sig_env id env
 
 delInScopeList :: ArityEnv -> [Id] -> ArityEnv
-delInScopeList env ids = del_join_env_list ids $ del_sig_env_list ids env
+delInScopeList env ids = del_sig_env_list ids env
 
 lookupSigEnv :: ArityEnv -> Id -> Maybe ArityType
 lookupSigEnv AE{ ae_mode = mode } id = case mode of
@@ -1046,8 +1107,11 @@ myIsCheapApp :: IdEnv ArityType -> CheapAppFun
 myIsCheapApp sigs fn n_val_args = case lookupVarEnv sigs fn of
   -- Nothing means not a local function, fall back to regular
   -- 'GHC.Core.Utils.isCheapApp'
-  Nothing         -> isCheapApp fn n_val_args
-  -- @Just at@ means local function with @at@ as current ArityType.
+  Nothing -> isCheapApp fn n_val_args
+
+  -- `Just at` means local function with `at` as current SafeArityType.
+  -- NB the SafeArityType bit: that means we can ignore the cost flags
+  --    in 'lams', and just consider the length
   -- Roughly approximate what 'isCheapApp' is doing.
   Just (AT oss div)
     | isDeadEndDiv div -> True -- See Note [isCheapApp: bottoming functions] in GHC.Core.Utils
@@ -1055,7 +1119,10 @@ myIsCheapApp sigs fn n_val_args = case lookupVarEnv sigs fn of
     | otherwise -> False
 
 ----------------
-arityType :: ArityEnv -> CoreExpr -> ArityType
+arityType :: HasDebugCallStack => ArityEnv -> CoreExpr -> ArityType
+-- Precondition: all the free join points of the expression
+--               are bound by the ArityEnv
+-- See Note [No free join points in arityType]
 
 arityType env (Cast e co)
   = minWithArity (arityType env e) co_arity -- See Note [Arity trimming]
@@ -1067,12 +1134,13 @@ arityType env (Cast e co)
     -- #5441 is a nice demo
 
 arityType env (Var v)
-  | v `elemVarSet` ae_joins env
-  = botArityType  -- See Note [Eta-expansion and join points]
   | Just at <- lookupSigEnv env v -- Local binding
   = at
   | otherwise
-  = idArityType v
+  = assertPpr  (not (isFindRhsArity env && isJoinId v)) (ppr v) $
+    -- All join-point should be in the ae_sigs
+    -- See Note [No free join points in arityType]
+    idArityType v
 
         -- Lambdas; increase arity
 arityType env (Lam x e)
@@ -1109,50 +1177,104 @@ arityType env (Case scrut bndr _ alts)
   where
     env' = delInScope env bndr
     arity_type_alt (Alt _con bndrs rhs) = arityType (delInScopeList env' bndrs) rhs
-    alts_type = foldr1 andArityType (map arity_type_alt alts)
-
-arityType env (Let (NonRec j rhs) body)
-  | Just join_arity <- isJoinId_maybe j
-  , (_, rhs_body)   <- collectNBinders join_arity rhs
-  = -- See Note [Eta-expansion and join points]
-    andArityType (arityType env rhs_body)
-                 (arityType env' body)
+    alts_type = foldr1 (andArityType env) (map arity_type_alt alts)
+
+arityType env (Let (NonRec b r) e)
+  = -- See Note [arityType for let-bindings]
+    floatIn cheap_rhs (arityType env' e)
   where
-     env' = extendJoinEnv env [j]
+    cheap_rhs = myExprIsCheap env r (Just (idType b))
+    env'      = extendSigEnv env b (arityType env r)
 
 arityType env (Let (Rec pairs) body)
   | ((j,_):_) <- pairs
   , isJoinId j
-  = -- See Note [Eta-expansion and join points]
-    foldr (andArityType . do_one) (arityType env' body) pairs
+  = -- See Note [arityType for join bindings]
+    foldr (andArityType env . do_one) (arityType rec_env body) pairs
   where
-    env' = extendJoinEnv env (map fst pairs)
+    rec_env = foldl add_bot env pairs
+    add_bot env (j,_) = extendSigEnv env j botArityType
+
+    do_one :: (JoinId, CoreExpr) -> ArityType
     do_one (j,rhs)
       | Just arity <- isJoinId_maybe j
-      = arityType env' $ snd $ collectNBinders arity rhs
+      = arityType rec_env $ snd $ collectNBinders arity rhs
       | otherwise
       = pprPanic "arityType:joinrec" (ppr pairs)
 
-arityType env (Let (NonRec b r) e)
-  = floatIn cheap_rhs (arityType env' e)
-  where
-    cheap_rhs = myExprIsCheap env r (Just (idType b))
-    env'      = extendSigEnv env b (arityType env r)
-
 arityType env (Let (Rec prs) e)
   = floatIn (all is_cheap prs) (arityType env' e)
   where
-    env'           = delInScopeList env (map fst prs)
     is_cheap (b,e) = myExprIsCheap env' e (Just (idType b))
+    env'            = foldl extend_rec env prs
+    extend_rec :: ArityEnv -> (Id,CoreExpr) -> ArityEnv
+    extend_rec env (b,e) = extendSigEnv env b  $
+                           mkManifestArityType bndrs body
+                         where
+                           (bndrs, body) = collectBinders e
+      -- We can't call arityType on the RHS, because it might mention
+      -- join points bound in this very letrec, and we don't want to
+      -- do a fixpoint calculation here.  So we make do with the
+      -- manifest arity
 
 arityType env (Tick t e)
   | not (tickishIsCode t)     = arityType env e
 
 arityType _ _ = topArityType
 
-{- Note [Eta-expansion and join points]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-Consider this (#18328)
+
+{- Note [No free join points in arityType]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Suppose we call arityType on this expression (EX1)
+   \x . case x of True  -> \y. e
+                  False -> $j 3
+where $j is a join point.  It really makes no sense to talk of the arity
+of this expression, because it has a free join point.  In particular, we
+can't eta-expand the expression because we'd have do the same thing to the
+binding of $j, and we can't see that binding.
+
+If we had (EX2)
+   \x. join $j y = blah
+       case x of True  -> \y. e
+                 False -> $j 3
+then it would make perfect sense: we can determine $j's ArityType, and
+propagate it to the usage site as usual.
+
+But how can we get (EX1)?  It doesn't make much sense, because $j can't
+be a join point under the \x anyway.  So we make it a precondition of
+arityType that the argument has no free join-point Ids.  (This is checked
+with an assesrt in the Var case of arityType.)
+
+BUT the invariant risks being invalidated by one very narrow special case: runRW#
+   join $j y = blah
+   runRW# (\s. case x of True  -> \y. e
+                         False -> $j x)
+
+We have special magic in OccurAnal, and Simplify to allow continuations to
+move into the body of a runRW# call.
+
+So we are careful never to attempt to eta-expand the (\s.blah) in the
+argument to runRW#, at least not when there is a literal lambda there,
+so that OccurAnal has seen it and allowed join points bound outside.
+See Note [No eta-expansion in runRW#] in GHC.Core.Opt.Simplify.Iteration.
+
+Note [arityType for let-bindings]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+For non-recursive let-bindings, we just get the arityType of the RHS,
+and extend the environment.  That works nicely for things like this
+(#18793):
+  go = \ ds. case ds_a2CF of {
+               []     -> id
+               : y ys -> case y of { GHC.Types.I# x ->
+                         let acc = go ys in
+                         case x ># 42# of {
+                            __DEFAULT -> acc
+                            1# -> \x1. acc (negate x2)
+
+Here we want to get a good arity for `acc`, based on the ArityType
+of `go`.
+
+All this is particularly important for join points. Consider this (#18328)
 
   f x = join j y = case y of
                       True -> \a. blah
@@ -1165,42 +1287,64 @@ Consider this (#18328)
 and suppose the join point is too big to inline.  Now, what is the
 arity of f?  If we inlined the join point, we'd definitely say "arity
 2" because we are prepared to push case-scrutinisation inside a
-lambda.  But currently the join point totally messes all that up,
-because (thought of as a vanilla let-binding) the arity pinned on 'j'
-is just 1.
+lambda. It's important that we extend the envt with j's ArityType,
+so that we can use that information in the A/C branch of the case.
+
+For /recursive/ bindings it's more difficult, to call arityType,
+because we don't have an ArityType to put in the envt for the
+recursively bound Ids.  So for non-join-point bindings we satisfy
+ourselves with mkManifestArityType.  Typically we'll have eta-expanded
+the binding (based on an earlier fixpoint calculation in
+findRhsArity), so the manifest arity is good.
+
+But for /recursive join points/ things are not so good.
+See Note [Arity type for recursive join bindings]
+
+See Note [Arity type for recursive join bindings]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider
+  f x = joinrec j 0 = \ a b c -> (a,x,b)
+                j n = j (n-1)
+        in j 20
 
-Why don't we eta-expand j?  Because of
-Note [Do not eta-expand join points] in GHC.Core.Opt.Simplify.Utils
+Obviously `f` should get arity 4.  But the manifest arity of `j`
+is 1.  Remember, we don't eta-expand join points; see
+GHC.Core.Opt.Simplify.Utils Note [Do not eta-expand join points].
+And the ArityInfo on `j` will be just 1 too; see GHC.Core
+Note [Invariants on join points], item (2b).  So using
+Note [ArityType for let-bindings] won't work well.
 
-Even if we don't eta-expand j, why is its arity only 1?
-See invariant 2b in Note [Invariants on join points] in GHC.Core.
+We could do a fixpoint iteration, but that's a heavy hammer
+to use in arityType.  So we take advantage of it being a join
+point:
 
-So we do this:
+* Extend the ArityEnv to bind each of the recursive binders
+  (all join points) to `botArityType`.  This means that any
+  jump to the join point will return botArityType, which is
+  unit for `andArityType`:
+      botAritType `andArityType` at = at
+  So it's almost as if those "jump" branches didn't exist.
 
-* Treat the RHS of a join-point binding, /after/ stripping off
-  join-arity lambda-binders, as very like the body of the let.
-  More precisely, do andArityType with the arityType from the
-  body of the let.
+* In this extended env, find the ArityType of each of the RHS, after
+  stripping off the join-point binders.
 
-* Dually, when we come to a /call/ of a join point, just no-op
-  by returning ABot, the bottom element of ArityType,
-  which so that: bot `andArityType` x = x
+* Use andArityType to combine all these RHS ArityTypes.
 
-* This works if the join point is bound in the expression we are
-  taking the arityType of.  But if it's bound further out, it makes
-  no sense to say that (say) the arityType of (j False) is ABot.
-  Bad things happen.  So we keep track of the in-scope join-point Ids
-  in ae_join.
+* Find the ArityType of the body, also in this strange extended
+  environment
 
-This will make f, above, have arity 2. Then, we'll eta-expand it thus:
+* And combine that into the result with andArityType.
 
-  f x eta = (join j y = ... in case x of ...) eta
+In our example, the jump (j 20) will yield Bot, as will the jump
+(j (n-1)). We'll 'and' those the ArityType of (\abc. blah).  Good!
 
-and the Simplify will automatically push that application of eta into
-the join points.
+In effect we are treating the RHSs as alternative bodies (like
+in a case), and ignoring all jumps.  In this way we don't need
+to take a fixpoint.  Tricky!
 
-An alternative (roughly equivalent) idea would be to carry an
-environment mapping let-bound Ids to their ArityType.
+NB: we treat /non-recursive/ join points in the same way, but
+actually it works fine to treat them uniformly with normal
+let-bindings, and that takes less code.
 -}
 
 idArityType :: Id -> ArityType


=====================================
compiler/GHC/Core/Opt/Simplify.hs
=====================================
@@ -2144,19 +2144,32 @@ rebuildCall env (ArgInfo { ai_fun = fun_id, ai_args = rev_args })
             (ApplyToVal { sc_arg = arg, sc_env = arg_se
                         , sc_cont = cont, sc_hole_ty = fun_ty })
   | fun_id `hasKey` runRWKey
-  , not (contIsStop cont)  -- Don't fiddle around if the continuation is boring
   , [ TyArg {}, TyArg {} ] <- rev_args
-  = do { s <- newId (fsLit "s") Many realWorldStatePrimTy
-       ; let (m,_,_) = splitFunTy fun_ty
-             env'  = (arg_se `setInScopeFromE` env) `addNewInScopeIds` [s]
+  -- Do this even if (contIsStop cont)
+  -- See Note [No eta-expansion in runRW#]
+  = do { let arg_env = arg_se `setInScopeFromE` env
              ty'   = contResultType cont
-             cont' = ApplyToVal { sc_dup = Simplified, sc_arg = Var s
-                                , sc_env = env', sc_cont = cont
-                                , sc_hole_ty = mkVisFunTy m realWorldStatePrimTy ty' }
-                     -- cont' applies to s, then K
-       ; body' <- simplExprC env' arg cont'
-       ; let arg'  = Lam s body'
-             rr'   = getRuntimeRep ty'
+
+       -- If the argument is a literal lambda already, take a short cut
+       -- This isn't just efficiency; if we don't do this we get a beta-redex
+       -- every time, so the simplifier keeps doing more iterations.
+       ; arg' <- case arg of
+           Lam s body -> do { (env', s') <- simplBinder arg_env s
+                            ; body' <- simplExprC env' body cont
+                            ; return (Lam s' body') }
+                            -- Important: do not try to eta-expand this lambda
+                            -- See Note [No eta-expansion in runRW#]
+           _ -> do { s' <- newId (fsLit "s") Many realWorldStatePrimTy
+                   ; let (m,_,_) = splitFunTy fun_ty
+                         env'  = arg_env `addNewInScopeIds` [s']
+                         cont' = ApplyToVal { sc_dup = Simplified, sc_arg = Var s'
+                                            , sc_env = env', sc_cont = cont
+                                            , sc_hole_ty = mkVisFunTy m realWorldStatePrimTy ty' }
+                                -- cont' applies to s', then K
+                   ; body' <- simplExprC env' arg cont'
+                   ; return (Lam s' body') }
+
+       ; let rr'   = getRuntimeRep ty'
              call' = mkApps (Var fun_id) [mkTyArg rr', mkTyArg ty', arg']
        ; return (emptyFloats env, call') }
 
@@ -2263,6 +2276,19 @@ to get the effect that finding (error "foo") in a strict arg position will
 discard the entire application and replace it with (error "foo").  Getting
 all this at once is TOO HARD!
 
+Note [No eta-expansion in runRW#]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+When we see `runRW# (\s. blah)` we must not attempt to eta-expand that
+lambda.  Why not?  Because
+* `blah` can mention join points bound outside the runRW#
+* eta-expansion uses arityType, and
+* `arityType` cannot cope with free join Ids:
+
+So the simplifier spots the literal lambda, and simplifies inside it.
+It's a very special lambda, because it is the one the OccAnal spots and
+allows join points bound /outside/ to be called /inside/.
+
+See Note [No free join points in arityType] in GHC.Core.Opt.Arity
 
 ************************************************************************
 *                                                                      *


=====================================
compiler/GHC/Core/Opt/Simplify/Utils.hs
=====================================
@@ -1732,9 +1732,7 @@ tryEtaExpandRhs :: SimplEnv -> OutId -> OutExpr
 tryEtaExpandRhs env bndr rhs
   | Just join_arity <- isJoinId_maybe bndr
   = do { let (join_bndrs, join_body) = collectNBinders join_arity rhs
-             oss   = [idOneShotInfo id | id <- join_bndrs, isId id]
-             arity_type | exprIsDeadEnd join_body = mkBotArityType oss
-                        | otherwise               = mkTopArityType oss
+             arity_type = mkManifestArityType join_bndrs join_body
        ; return (arity_type, rhs) }
          -- Note [Do not eta-expand join points]
          -- But do return the correct arity and bottom-ness, because


=====================================
testsuite/tests/arityanal/should_compile/T21755.hs
=====================================
@@ -0,0 +1,11 @@
+module T21755 where
+
+mySum :: [Int] -> Int
+mySum [] = 0
+mySum (x:xs) = x + mySum xs
+
+f :: Int -> (Int -> Int) -> Int -> Int
+f k z =
+    if even (mySum [0..k])
+      then \n -> n + 1
+      else \n -> z n


=====================================
testsuite/tests/arityanal/should_compile/T21755.stderr
=====================================
@@ -0,0 +1 @@
+ 
\ No newline at end of file


=====================================
testsuite/tests/arityanal/should_compile/all.T
=====================================
@@ -21,3 +21,4 @@ test('Arity16', [ only_ways(['optasm']), grep_errmsg('Arity=') ], compile, ['-dn
 test('T18793', [ only_ways(['optasm']), grep_errmsg('Arity=') ], compile, ['-dno-typeable-binds -ddump-simpl -dppr-cols=99999 -dsuppress-uniques'])
 test('T18870', [ only_ways(['optasm']) ], compile, ['-ddebug-output'])
 test('T18937', [ only_ways(['optasm']) ], compile, ['-ddebug-output'])
+test('T21755',  [ grep_errmsg(r'Arity=') ], compile, ['-O -dno-typeable-binds -fno-worker-wrapper'])


=====================================
testsuite/tests/callarity/should_compile/T21694a.hs
=====================================
@@ -0,0 +1,27 @@
+module Main (main) where
+
+import GHC.Exts
+import Control.DeepSeq
+import System.Exit
+
+-- If we eta expand the `False` branch will return
+-- a lambda \eta -> z instead of z.
+-- This behaves differently if the z argument is a bottom.
+-- We used to assume that a oneshot annotation would mean
+-- we could eta-expand on *all* branches. But this is clearly
+-- not sound in this case. So we test for this here.
+{-# NOINLINE f #-}
+f :: Bool -> (Int -> Int) -> Int -> Int
+f b z =
+    case b of
+        True -> oneShot $ \n -> n + 1
+        False -> z
+
+
+
+main :: IO Int
+main = do
+    return $! force $! f False (error "Urkh! But expected!")
+    return 0
+
+


=====================================
testsuite/tests/callarity/should_compile/T21694a.stderr
=====================================
@@ -0,0 +1,3 @@
+T21694a: Urkh! But expected!
+CallStack (from HasCallStack):
+  error, called at T21694a.hs:23:33 in main:Main


=====================================
testsuite/tests/simplCore/should_compile/T21694b.hs
=====================================
@@ -0,0 +1,6 @@
+module T21694 where
+
+-- f should get arity 4
+f x = let j 0 = \ a b c -> (a,x,b)
+          j n = j (n-1 :: Int)
+      in j 20


=====================================
testsuite/tests/simplCore/should_compile/T21694b.stderr
=====================================
@@ -0,0 +1,115 @@
+
+==================== Tidy Core ====================
+Result size of Tidy Core
+  = {terms: 44, types: 40, coercions: 0, joins: 2/2}
+
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
+T21694.f1 :: Int
+[GblId,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+         WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 10}]
+T21694.f1 = GHC.Types.I# 20#
+
+-- RHS size: {terms: 26, types: 22, coercions: 0, joins: 2/2}
+f :: forall {p1} {a} {c} {p2}. p1 -> a -> c -> p2 -> (a, p1, c)
+[GblId,
+ Arity=4,
+ Str=<L><L><L><A>,
+ Cpr=1,
+ Unf=Unf{Src=InlineStable, TopLvl=True, Value=True, ConLike=True,
+         WorkFree=True, Expandable=True,
+         Guidance=ALWAYS_IF(arity=4,unsat_ok=True,boring_ok=False)
+         Tmpl= \ (@p_ax8)
+                 (@a_aL5)
+                 (@c_aL6)
+                 (@p1_aL7)
+                 (x_agu [Occ=OnceL1] :: p_ax8)
+                 (eta_B0 [Occ=OnceL1] :: a_aL5)
+                 (eta1_B1 [Occ=OnceL1] :: c_aL6)
+                 _ [Occ=Dead] ->
+                 joinrec {
+                   j_sLX [InlPrag=[2], Occ=T[1]] :: Int -> (a_aL5, p_ax8, c_aL6)
+                   [LclId[JoinId(1)(Just [!])],
+                    Arity=1,
+                    Str=<S!P(SL)>,
+                    Unf=Unf{Src=InlineStable, TopLvl=False, Value=True, ConLike=True,
+                            WorkFree=True, Expandable=True,
+                            Guidance=ALWAYS_IF(arity=1,unsat_ok=True,boring_ok=False)
+                            Tmpl= \ (ds_sM1 [Occ=Once1!] :: Int) ->
+                                    case ds_sM1 of { GHC.Types.I# ww_sM3 [Occ=Once1] ->
+                                    jump $wj_sM6 ww_sM3
+                                    }}]
+                   j_sLX (ds_sM1 [Occ=Once1!] :: Int)
+                     = case ds_sM1 of { GHC.Types.I# ww_sM3 [Occ=Once1] ->
+                       jump $wj_sM6 ww_sM3
+                       };
+                   $wj_sM6 [InlPrag=[2], Occ=LoopBreakerT[1]]
+                     :: GHC.Prim.Int# -> (a_aL5, p_ax8, c_aL6)
+                   [LclId[JoinId(1)(Nothing)], Arity=1, Str=<SL>, Unf=OtherCon []]
+                   $wj_sM6 (ww_sM3 [Occ=Once1!] :: GHC.Prim.Int#)
+                     = case ww_sM3 of ds_X2 [Occ=Once1] {
+                         __DEFAULT -> jump j_sLX (GHC.Types.I# (GHC.Prim.-# ds_X2 1#));
+                         0# -> (eta_B0, x_agu, eta1_B1)
+                       }; } in
+                 jump j_sLX T21694.f1}]
+f = \ (@p_ax8)
+      (@a_aL5)
+      (@c_aL6)
+      (@p1_aL7)
+      (x_agu :: p_ax8)
+      (eta_B0 :: a_aL5)
+      (eta1_B1 :: c_aL6)
+      _ [Occ=Dead] ->
+      join {
+        exit_X3 [Dmd=S!P(L,L,L)] :: (a_aL5, p_ax8, c_aL6)
+        [LclId[JoinId(0)(Nothing)]]
+        exit_X3 = (eta_B0, x_agu, eta1_B1) } in
+      joinrec {
+        $wj_sM6 [InlPrag=[2], Occ=LoopBreaker, Dmd=SCS(!P(L,L,L))]
+          :: GHC.Prim.Int# -> (a_aL5, p_ax8, c_aL6)
+        [LclId[JoinId(1)(Nothing)], Arity=1, Str=<1L>, Unf=OtherCon []]
+        $wj_sM6 (ww_sM3 :: GHC.Prim.Int#)
+          = case ww_sM3 of ds_X2 {
+              __DEFAULT -> jump $wj_sM6 (GHC.Prim.-# ds_X2 1#);
+              0# -> jump exit_X3
+            }; } in
+      jump $wj_sM6 20#
+
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
+T21694.$trModule4 :: GHC.Prim.Addr#
+[GblId,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+         WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 20 0}]
+T21694.$trModule4 = "main"#
+
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
+T21694.$trModule3 :: GHC.Types.TrName
+[GblId,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+         WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 10}]
+T21694.$trModule3 = GHC.Types.TrNameS T21694.$trModule4
+
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
+T21694.$trModule2 :: GHC.Prim.Addr#
+[GblId,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+         WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 30 0}]
+T21694.$trModule2 = "T21694"#
+
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
+T21694.$trModule1 :: GHC.Types.TrName
+[GblId,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+         WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 10}]
+T21694.$trModule1 = GHC.Types.TrNameS T21694.$trModule2
+
+-- RHS size: {terms: 3, types: 0, coercions: 0, joins: 0/0}
+T21694.$trModule :: GHC.Types.Module
+[GblId,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+         WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 10}]
+T21694.$trModule
+  = GHC.Types.Module T21694.$trModule3 T21694.$trModule1
+
+
+



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/006662a8b2bc867c3a8b0d7b3a7aaab2c7cd4b3b

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/006662a8b2bc867c3a8b0d7b3a7aaab2c7cd4b3b
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20221026/c6474780/attachment-0001.html>


More information about the ghc-commits mailing list