[Git][ghc/ghc][wip/T23083] Simplifier: Eta expand arguments (#23083)

Sebastian Graf (@sgraf812) gitlab at gitlab.haskell.org
Tue Mar 7 13:29:39 UTC 2023



Sebastian Graf pushed to branch wip/T23083 at Glasgow Haskell Compiler / GHC


Commits:
d81cad21 by Sebastian Graf at 2023-03-07T14:28:40+01:00
Simplifier: Eta expand arguments (#23083)

Previously, we'd only eta expand let bindings and lambdas,
now we'll also eta expand arguments such as in T23083:
```hs
g f h = f (h `seq` (h $))
```
Unless `-fpedantic-bottoms` is set, we'll now transform to
```hs
g f h = f (\eta -> h eta)
```

Tweaking the Simplifier to eta-expand in args was a bit more painful than expected:

  * `tryEtaExpandRhs` and `findRhsArity` previously only worked on bindings.
    But eta expansion of non-recursive bindings is morally the same as eta
    expansion of arguments. And in fact the binder was never really looked at in
    the non-recursive case.
    I was able to make `findRhsArity` cater for both arguments and bindings, as
    well as have a new function `tryEtaExpandArg` that shares most of its code
    with that of `tryEtaExpandRhs`.

  * The Simplifier had a function `simplArg` that wasn't called in `rebuildCall`,
    which seems to be the main way to simplify args. Hence I consolidated the
    code path to call `simplArg`, too, renaming to `simplLazyArg`.

Fixes #23083.

- - - - -


7 changed files:

- compiler/GHC/Core/Opt/Arity.hs
- compiler/GHC/Core/Opt/Simplify/Iteration.hs
- compiler/GHC/Core/Opt/Simplify/Utils.hs
- compiler/GHC/Core/Opt/Stats.hs
- + testsuite/tests/simplCore/should_compile/T23083.hs
- + testsuite/tests/simplCore/should_compile/T23083.stderr
- testsuite/tests/simplCore/should_compile/all.T


Changes:

=====================================
compiler/GHC/Core/Opt/Arity.hs
=====================================
@@ -872,9 +872,16 @@ exprEtaExpandArity opts e
 *                                                                      *
 ********************************************************************* -}
 
-findRhsArity :: ArityOpts -> RecFlag -> Id -> CoreExpr
-             -> (Bool, SafeArityType)
--- This implements the fixpoint loop for arity analysis
+findRhsArity
+  :: ArityOpts
+  -> Maybe Id       -- ^ `Just bndr` when it's a recursive RHS bound by bndr
+  -> Bool           -- ^ Is it a join binding?
+  -> [OneShotInfo]  -- ^ The one-shot info from the use sites, perhaps from
+                    -- `idDemandOneShots` of the binder
+  -> CoreExpr       -- ^ The RHS (or argument expression)
+  -> Type           -- ^ Type of the CoreExpr
+  -> (Bool, SafeArityType)
+-- ^ This implements the fixpoint loop for arity analysis
 -- See Note [Arity analysis]
 --
 -- The Bool is True if the returned arity is greater than (exprArity rhs)
@@ -884,8 +891,8 @@ findRhsArity :: ArityOpts -> RecFlag -> Id -> CoreExpr
 -- Returns an SafeArityType that is guaranteed trimmed to typeArity of 'bndr'
 --         See Note [Arity trimming]
 
-findRhsArity opts is_rec bndr rhs
-  | isJoinId bndr
+findRhsArity opts mb_rec_bndr is_join use_one_shots rhs rhs_ty
+  | is_join
   = (False, join_arity_type)
     -- False: see Note [Do not eta-expand join points]
     -- But do return the correct arity and bottom-ness, because
@@ -900,28 +907,27 @@ findRhsArity opts is_rec bndr rhs
     old_arity = exprArity rhs
 
     init_env :: ArityEnv
-    init_env = findRhsArityEnv opts (isJoinId bndr)
+    init_env = findRhsArityEnv opts is_join
 
     -- Non-join-points only
-    non_join_arity_type = case is_rec of
-                             Recursive    -> go 0 botArityType
-                             NonRecursive -> step init_env
+    non_join_arity_type = case mb_rec_bndr of
+                             Just bndr    -> go 0 bndr botArityType
+                             Nothing      -> step init_env
     arity_increased = arityTypeArity non_join_arity_type > old_arity
 
     -- Join-points only
     -- See Note [Arity for non-recursive join bindings]
     -- and Note [Arity for recursive join bindings]
-    join_arity_type = case is_rec of
-                         Recursive    -> go 0 botArityType
-                         NonRecursive -> trimArityType ty_arity (cheapArityType rhs)
+    join_arity_type = case mb_rec_bndr of
+                         Just bndr    -> go 0 bndr botArityType
+                         Nothing      -> trimArityType ty_arity (cheapArityType rhs)
 
-    ty_arity     = typeArity (idType bndr)
-    id_one_shots = idDemandOneShots bndr
+    ty_arity     = typeArity rhs_ty
 
     step :: ArityEnv -> SafeArityType
     step env = trimArityType ty_arity $
                safeArityType $ -- See Note [Arity invariants for bindings], item (3)
-               arityType env rhs `combineWithDemandOneShots` id_one_shots
+               arityType env rhs `combineWithDemandOneShots` use_one_shots
        -- trimArityType: see Note [Trim arity inside the loop]
        -- combineWithDemandOneShots: take account of the demand on the
        -- binder.  Perhaps it is always called with 2 args
@@ -934,8 +940,8 @@ findRhsArity opts is_rec bndr rhs
     -- is assumed to be sound. In other words, arities should never
     -- decrease.  Result: the common case is that there is just one
     -- iteration
-    go :: Int -> SafeArityType -> SafeArityType
-    go !n cur_at@(AT lams div)
+    go :: Int -> Id -> SafeArityType -> SafeArityType
+    go !n bndr cur_at@(AT lams div)
       | not (isDeadEndDiv div)           -- the "stop right away" case
       , length lams <= old_arity = cur_at -- from above
       | next_at == cur_at        = cur_at
@@ -944,7 +950,7 @@ findRhsArity opts is_rec bndr rhs
       = warnPprTrace (debugIsOn && n > 2)
             "Exciting arity"
             (nest 2 (ppr bndr <+> ppr cur_at <+> ppr next_at $$ ppr rhs)) $
-        go (n+1) next_at
+        go (n+1) bndr next_at
       where
         next_at = step (extendSigEnv init_env bndr cur_at)
 
@@ -964,30 +970,6 @@ combineWithDemandOneShots at@(AT lams div) oss
     zip_lams ((ch,os1):lams) (os2:oss)
       = (ch, os1 `bestOneShot` os2) : zip_lams lams oss
 
-idDemandOneShots :: Id -> [OneShotInfo]
-idDemandOneShots bndr
-  = call_arity_one_shots `zip_lams` dmd_one_shots
-  where
-    call_arity_one_shots :: [OneShotInfo]
-    call_arity_one_shots
-      | call_arity == 0 = []
-      | otherwise       = NoOneShotInfo : replicate (call_arity-1) OneShotLam
-    -- Call Arity analysis says the function is always called
-    -- applied to this many arguments.  The first NoOneShotInfo is because
-    -- if Call Arity says "always applied to 3 args" then the one-shot info
-    -- we get is [NoOneShotInfo, OneShotLam, OneShotLam]
-    call_arity = idCallArity bndr
-
-    dmd_one_shots :: [OneShotInfo]
-    -- If the demand info is C(x,C(1,C(1,.))) then we know that an
-    -- application to one arg is also an application to three
-    dmd_one_shots = argOneShots (idDemandInfo bndr)
-
-    -- Take the *longer* list
-    zip_lams (lam1:lams1) (lam2:lams2) = (lam1 `bestOneShot` lam2) : zip_lams lams1 lams2
-    zip_lams []           lams2        = lams2
-    zip_lams lams1        []           = lams1
-
 {- Note [Arity analysis]
 ~~~~~~~~~~~~~~~~~~~~~~~~
 The motivating example for arity analysis is this:


=====================================
compiler/GHC/Core/Opt/Simplify/Iteration.hs
=====================================
@@ -1517,7 +1517,7 @@ rebuild env expr cont
       ApplyToVal { sc_arg = arg, sc_env = se, sc_dup = dup_flag
                  , sc_cont = cont, sc_hole_ty = fun_ty }
         -- See Note [Avoid redundant simplification]
-        -> do { (_, _, arg') <- simplArg env dup_flag fun_ty se arg
+        -> do { (_, _, arg') <- simplLazyArg env dup_flag fun_ty Nothing topDmd se arg
               ; rebuild env (App expr arg') cont }
 
 completeBindX :: SimplEnv
@@ -1633,7 +1633,6 @@ simplCast env body co0 cont0
                                    , sc_hole_ty = coercionLKind co }) }
                                         -- NB!  As the cast goes past, the
                                         -- type of the hole changes (#16312)
-
         -- (f |> co) e   ===>   (f (e |> co1)) |> co2
         -- where   co :: (s1->s2) ~ (t1->t2)
         --         co1 :: t1 ~ s1
@@ -1652,7 +1651,7 @@ simplCast env body co0 cont0
                       -- See Note [Avoiding exponential behaviour]
 
                    MCo co1 ->
-            do { (dup', arg_se', arg') <- simplArg env dup fun_ty arg_se arg
+            do { (dup', arg_se', arg') <- simplLazyArg env dup fun_ty Nothing topDmd arg_se arg
                     -- When we build the ApplyTo we can't mix the OutCoercion
                     -- 'co' with the InExpr 'arg', so we simplify
                     -- to make it all consistent.  It's a bit messy.
@@ -1678,17 +1677,21 @@ simplCast env body co0 cont0
           -- See Note [Representation polymorphism invariants] in GHC.Core
           -- test: typecheck/should_run/EtaExpandLevPoly
 
-simplArg :: SimplEnv -> DupFlag
-         -> OutType                 -- Type of the function applied to this arg
-         -> StaticEnv -> CoreExpr   -- Expression with its static envt
-         -> SimplM (DupFlag, StaticEnv, OutExpr)
-simplArg env dup_flag fun_ty arg_env arg
+simplLazyArg :: SimplEnv -> DupFlag
+             -> OutType                 -- Type of the function applied to this arg
+             -> Maybe ArgInfo
+             -> Demand                  -- Demand on the argument expr
+             -> StaticEnv -> CoreExpr   -- Expression with its static envt
+             -> SimplM (DupFlag, StaticEnv, OutExpr)
+simplLazyArg env dup_flag fun_ty mb_arg_info arg_dmd arg_env arg
   | isSimplified dup_flag
   = return (dup_flag, arg_env, arg)
   | otherwise
   = do { let arg_env' = arg_env `setInScopeFromE` env
-       ; arg' <- simplExprC arg_env' arg (mkBoringStop (funArgTy fun_ty))
-       ; return (Simplified, zapSubstEnv arg_env', arg') }
+       ; let arg_ty = funArgTy fun_ty
+       ; arg1 <- simplExprC arg_env' arg (mkLazyArgStop arg_ty mb_arg_info)
+       ; (_arity_type, arg2) <- tryEtaExpandArg env arg_dmd arg1 arg_ty
+       ; return (Simplified, zapSubstEnv arg_env', arg2) }
          -- Return a StaticEnv that includes the in-scope set from 'env',
          -- because arg' may well mention those variables (#20639)
 
@@ -2281,12 +2284,9 @@ rebuildCall env fun_info
         -- There is no benefit (unlike in a let-binding), and we'd
         -- have to be very careful about bogus strictness through
         -- floating a demanded let.
-  = do  { arg' <- simplExprC (arg_se `setInScopeFromE` env) arg
-                             (mkLazyArgStop arg_ty fun_info)
+  = do  { let (dmd:_) = ai_dmds fun_info
+        ; (_, _, arg') <- simplLazyArg env dup_flag fun_ty (Just fun_info) dmd arg_se arg
         ; rebuildCall env (addValArgTo fun_info  arg' fun_ty) cont }
-  where
-    arg_ty = funArgTy fun_ty
-
 
 ---------- No further useful info, revert to generic rebuild ------------
 rebuildCall env (ArgInfo { ai_fun = fun, ai_args = rev_args }) cont
@@ -3723,7 +3723,7 @@ mkDupableContWithDmds env dmds
     do  { let (dmd:cont_dmds) = dmds   -- Never fails
         ; (floats1, cont') <- mkDupableContWithDmds env cont_dmds cont
         ; let env' = env `setInScopeFromF` floats1
-        ; (_, se', arg') <- simplArg env' dup hole_ty se arg
+        ; (_, se', arg') <- simplLazyArg env' dup hole_ty Nothing dmd se arg
         ; (let_floats2, arg'') <- makeTrivial env NotTopLevel dmd (fsLit "karg") arg'
         ; let all_floats = floats1 `addLetFloats` let_floats2
         ; return ( all_floats


=====================================
compiler/GHC/Core/Opt/Simplify/Utils.hs
=====================================
@@ -9,7 +9,7 @@ The simplifier utilities
 module GHC.Core.Opt.Simplify.Utils (
         -- Rebuilding
         rebuildLam, mkCase, prepareAlts,
-        tryEtaExpandRhs, wantEtaExpansion,
+        tryEtaExpandRhs, tryEtaExpandArg, wantEtaExpansion,
 
         -- Inlining,
         preInlineUnconditionally, postInlineUnconditionally,
@@ -461,8 +461,9 @@ mkRhsStop :: OutType -> RecFlag -> Demand -> SimplCont
 -- See Note [RHS of lets] in GHC.Core.Unfold
 mkRhsStop ty is_rec bndr_dmd = Stop ty (RhsCtxt is_rec) (subDemandIfEvaluated bndr_dmd)
 
-mkLazyArgStop :: OutType -> ArgInfo -> SimplCont
-mkLazyArgStop ty fun_info = Stop ty (lazyArgContext fun_info) arg_sd
+mkLazyArgStop :: OutType -> Maybe ArgInfo -> SimplCont
+mkLazyArgStop ty Nothing         = mkBoringStop ty
+mkLazyArgStop ty (Just fun_info) = Stop ty (lazyArgContext fun_info) arg_sd
   where
     arg_sd = subDemandIfEvaluated (Partial.head (ai_dmds fun_info))
 
@@ -1738,7 +1739,7 @@ rebuildLam env bndrs@(bndr:_) body cont
       , seEtaExpand env
       , any isRuntimeVar bndrs  -- Only when there is at least one value lambda already
       , Just body_arity <- exprEtaExpandArity (seArityOpts env) body
-      = do { tick (EtaExpansion bndr)
+      = do { tick (EtaExpansion Nothing)
            ; let body' = etaExpandAT in_scope body_arity body
            ; traceSmpl "eta expand" (vcat [text "before" <+> ppr body
                                           , text "after" <+> ppr body'])
@@ -1859,15 +1860,39 @@ tryEtaExpandRhs :: SimplEnv -> BindContext -> OutId -> OutExpr
                 -> SimplM (ArityType, OutExpr)
 -- See Note [Eta-expanding at let bindings]
 tryEtaExpandRhs env bind_cxt bndr rhs
+  = tryEtaExpandArgOrRhs env mb_rec_bndr (isJoinBC bind_cxt)
+                         (idDemandOneShots bndr) rhs (idType bndr)
+  where
+    mb_rec_bndr = case bindContextRec bind_cxt of
+      Recursive    -> Just bndr
+      NonRecursive -> Nothing
+
+tryEtaExpandArg :: SimplEnv -> Demand -> OutExpr -> OutType
+                -> SimplM (ArityType, OutExpr)
+-- See Note [Eta-expanding at let bindings]
+tryEtaExpandArg env arg_dmd arg arg_ty
+  = tryEtaExpandArgOrRhs env Nothing False (argOneShots arg_dmd) arg arg_ty
+
+tryEtaExpandArgOrRhs
+  :: SimplEnv
+  -> Maybe OutId    -- ^ `Just bndr` when it's a recursive RHS bound by bndr
+  -> Bool           -- ^ Is it a join binding?
+  -> [OneShotInfo]  -- ^ The one-shot info from the use sites, perhaps from
+                    -- `idDemandOneShots` of the binder
+  -> OutExpr        -- ^ The RHS (or argument expression)
+  -> OutType        -- ^ Type of the CoreExpr
+  -> SimplM (ArityType, OutExpr)
+-- See Note [Eta-expanding at let bindings]
+tryEtaExpandArgOrRhs env mb_rec_bndr is_join use_one_shots rhs rhs_ty
   | do_eta_expand           -- If the current manifest arity isn't enough
                             --    (never true for join points)
   , seEtaExpand env         -- and eta-expansion is on
   , wantEtaExpansion rhs
   = -- Do eta-expansion.
-    assertPpr( not (isJoinBC bind_cxt) ) (ppr bndr) $
+    assertPpr( not is_join ) (ppr mb_rec_bndr) $
        -- assert: this never happens for join points; see GHC.Core.Opt.Arity
        --         Note [Do not eta-expand join points]
-    do { tick (EtaExpansion bndr)
+    do { tick (EtaExpansion mb_rec_bndr)
        ; return (arity_type, etaExpandAT in_scope arity_type rhs) }
 
   | otherwise
@@ -1876,8 +1901,7 @@ tryEtaExpandRhs env bind_cxt bndr rhs
   where
     in_scope   = getInScope env
     arity_opts = seArityOpts env
-    is_rec     = bindContextRec bind_cxt
-    (do_eta_expand, arity_type) = findRhsArity arity_opts is_rec bndr rhs
+    (do_eta_expand, arity_type) = findRhsArity arity_opts mb_rec_bndr is_join use_one_shots rhs rhs_ty
 
 wantEtaExpansion :: CoreExpr -> Bool
 -- Mostly True; but False of PAPs which will immediately eta-reduce again
@@ -1890,6 +1914,30 @@ wantEtaExpansion (Var {})               = False
 wantEtaExpansion (Lit {})               = False
 wantEtaExpansion _                      = True
 
+idDemandOneShots :: Id -> [OneShotInfo]
+idDemandOneShots bndr
+  = call_arity_one_shots `zip_lams` dmd_one_shots
+  where
+    call_arity_one_shots :: [OneShotInfo]
+    call_arity_one_shots
+      | call_arity == 0 = []
+      | otherwise       = NoOneShotInfo : replicate (call_arity-1) OneShotLam
+    -- Call Arity analysis says the function is always called
+    -- applied to this many arguments.  The first NoOneShotInfo is because
+    -- if Call Arity says "always applied to 3 args" then the one-shot info
+    -- we get is [NoOneShotInfo, OneShotLam, OneShotLam]
+    call_arity = idCallArity bndr
+
+    dmd_one_shots :: [OneShotInfo]
+    -- If the demand info is C(x,C(1,C(1,.))) then we know that an
+    -- application to one arg is also an application to three
+    dmd_one_shots = argOneShots (idDemandInfo bndr)
+
+    -- Take the *longer* list
+    zip_lams (lam1:lams1) (lam2:lams2) = (lam1 `bestOneShot` lam2) : zip_lams lams1 lams2
+    zip_lams []           lams2        = lams2
+    zip_lams lams1        []           = lams1
+
 {-
 Note [Eta-expanding at let bindings]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


=====================================
compiler/GHC/Core/Opt/Stats.hs
=====================================
@@ -226,7 +226,7 @@ data Tick  -- See Note [Which transformations are innocuous]
   | RuleFired                   FastString      -- Rule name
 
   | LetFloatFromLet
-  | EtaExpansion                Id      -- LHS binder
+  | EtaExpansion                (Maybe Id)      -- LHS binder, if recursive
   | EtaReduction                Id      -- Binder on outer lambda
   | BetaReduction               Id      -- Lambda binder
 


=====================================
testsuite/tests/simplCore/should_compile/T23083.hs
=====================================
@@ -0,0 +1,6 @@
+{-# OPTIONS_GHC -O2 -fforce-recomp #-}
+
+module T23083 where
+
+g :: ((Integer -> Integer) -> Integer) -> (Integer -> Integer) -> Integer
+g f h = f (h `seq` (h $))


=====================================
testsuite/tests/simplCore/should_compile/T23083.stderr
=====================================
@@ -0,0 +1,36 @@
+
+==================== Tidy Core ====================
+Result size of Tidy Core = {terms: 21, types: 17, coercions: 0, joins: 0/0}
+
+-- RHS size: {terms: 6, types: 6, coercions: 0, joins: 0/0}
+g :: ((Integer -> Integer) -> Integer) -> (Integer -> Integer) -> Integer
+[GblId, Arity=2, Str=<1C(1,L)><LC(S,L)>, Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True, WorkFree=True, Expandable=True, Guidance=IF_ARGS [60 60] 50 0}]
+g = \ (f :: (Integer -> Integer) -> Integer) (h :: Integer -> Integer) -> f (\ (eta :: Integer) -> h eta)
+
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
+T23083.$trModule4 :: GHC.Prim.Addr#
+[GblId, Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True, WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 20 0}]
+T23083.$trModule4 = "main"#
+
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
+T23083.$trModule3 :: GHC.Types.TrName
+[GblId, Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True, WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 10}]
+T23083.$trModule3 = GHC.Types.TrNameS T23083.$trModule4
+
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
+T23083.$trModule2 :: GHC.Prim.Addr#
+[GblId, Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True, WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 30 0}]
+T23083.$trModule2 = "T23083"#
+
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
+T23083.$trModule1 :: GHC.Types.TrName
+[GblId, Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True, WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 10}]
+T23083.$trModule1 = GHC.Types.TrNameS T23083.$trModule2
+
+-- RHS size: {terms: 3, types: 0, coercions: 0, joins: 0/0}
+T23083.$trModule :: GHC.Types.Module
+[GblId, Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True, WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 10}]
+T23083.$trModule = GHC.Types.Module T23083.$trModule3 T23083.$trModule1
+
+
+


=====================================
testsuite/tests/simplCore/should_compile/all.T
=====================================
@@ -476,3 +476,4 @@ test('T23012', normal, compile, ['-O'])
 
 test('RewriteHigherOrderPatterns', normal, compile, ['-O -ddump-rule-rewrites -dsuppress-all -dsuppress-uniques'])
 test('T23024', normal, multimod_compile, ['T23024', '-O -v0'])
+test('T23083', [ grep_errmsg(r'f.*eta') ], compile, ['-O -ddump-simpl -dsuppress-uniques -dppr-cols=99999'])



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/d81cad21a5a05b06d3a2413fc7400eba2548cbd6

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/d81cad21a5a05b06d3a2413fc7400eba2548cbd6
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20230307/59b52763/attachment-0001.html>


More information about the ghc-commits mailing list