[Git][ghc/ghc][wip/T22277] Denest NonRecs in SpecConstr for more specialisation (#22277)

Sebastian Graf (@sgraf812) gitlab at gitlab.haskell.org
Tue Oct 11 12:52:50 UTC 2022



Sebastian Graf pushed to branch wip/T22277 at Glasgow Haskell Compiler / GHC


Commits:
0beddf90 by Sebastian Graf at 2022-10-11T14:52:43+02:00
Denest NonRecs in SpecConstr for more specialisation (#22277)

See Note [Denesting non-recursive let bindings].

Fixes #22277. It is also related to #14951 and #14844 in that it
fixes a very specific case of looking through a non-recursive let binding in
SpecConstr.

- - - - -


6 changed files:

- compiler/GHC/Core.hs
- compiler/GHC/Core/Opt/SpecConstr.hs
- + testsuite/tests/simplCore/should_compile/T14951.hs
- + testsuite/tests/simplCore/should_compile/T22277.hs
- + testsuite/tests/simplCore/should_compile/T22277.stderr
- testsuite/tests/simplCore/should_compile/all.T


Changes:

=====================================
compiler/GHC/Core.hs
=====================================
@@ -41,7 +41,7 @@ module GHC.Core (
         isId, cmpAltCon, cmpAlt, ltAlt,
 
         -- ** Simple 'Expr' access functions and predicates
-        bindersOf, bindersOfBinds, rhssOfBind, rhssOfAlts,
+        bindersOf, bindersOfBinds, rhssOfBind, rhssOfAlts, collectLets,
         collectBinders, collectTyBinders, collectTyAndValBinders,
         collectNBinders, collectNValBinders_maybe,
         collectArgs, stripNArgs, collectArgsTicks, flattenBinds,
@@ -1940,6 +1940,15 @@ flattenBinds (NonRec b r : binds) = (b,r) : flattenBinds binds
 flattenBinds (Rec prs1   : binds) = prs1 ++ flattenBinds binds
 flattenBinds []                   = []
 
+-- | We often want to strip off leading 'Let's before getting down to
+-- business. The inverse of 'mkLets'.
+collectLets :: Expr b -> ([Bind b], Expr b)
+collectLets expr
+  = go [] expr
+  where
+    go bs (Let b e) = go (b:bs) e
+    go bs e         = (reverse bs, e)
+
 -- | We often want to strip off leading lambdas before getting down to
 -- business. Variants are 'collectTyBinders', 'collectValBinders',
 -- and 'collectTyAndValBinders'
@@ -1957,7 +1966,7 @@ collectBinders expr
   = go [] expr
   where
     go bs (Lam b e) = go (b:bs) e
-    go bs e          = (reverse bs, e)
+    go bs e         = (reverse bs, e)
 
 collectTyBinders expr
   = go [] expr


=====================================
compiler/GHC/Core/Opt/SpecConstr.hs
=====================================
@@ -32,7 +32,7 @@ import GHC.Core
 import GHC.Core.Subst
 import GHC.Core.Utils
 import GHC.Core.Unfold
-import GHC.Core.FVs     ( exprsFreeVarsList, exprFreeVars )
+import GHC.Core.FVs     ( exprsFreeVarsList, exprFreeVars, exprsFreeVars, exprSomeFreeVarsList )
 import GHC.Core.Opt.Monad
 import GHC.Core.Opt.WorkWrap.Utils
 import GHC.Core.DataCon
@@ -52,6 +52,7 @@ import GHC.Unit.Module.ModGuts
 import GHC.Types.Literal ( litIsLifted )
 import GHC.Types.Id
 import GHC.Types.Id.Info ( IdDetails(..) )
+import GHC.Types.Var ( setIdDetails )
 import GHC.Types.Var.Env
 import GHC.Types.Var.Set
 import GHC.Types.Name
@@ -80,10 +81,11 @@ import GHC.Exts( SpecConstrAnnotation(..) )
 import GHC.Serialized   ( deserializeWithData )
 
 import Control.Monad    ( zipWithM )
-import Data.List (nubBy, sortBy, partition, dropWhileEnd, mapAccumL )
+import Data.List ( nubBy, sortBy, partition, dropWhileEnd, mapAccumL )
 import Data.Maybe( mapMaybe )
 import Data.Ord( comparing )
 import Data.Tuple
+import Data.Bifunctor ( first )
 
 {-
 -----------------------------------------------------
@@ -773,10 +775,21 @@ specConstrProgram guts
        ; return (guts { mg_binds = binds' }) }
 
 scTopBinds :: ScEnv -> [InBind] -> UniqSM (ScUsage, [OutBind])
-scTopBinds _env []     = return (nullUsage, [])
-scTopBinds env  (b:bs) = do { (usg, b', bs') <- scBind TopLevel env b $
-                                                (\env -> scTopBinds env bs)
-                            ; return (usg, b' ++ bs') }
+scTopBinds env bs = do
+  (usg, bs, ()) <- scBinds TopLevel env bs (\_env -> return (nullUsage, ()))
+  return (usg, bs)
+
+scBinds :: TopLevelFlag -> ScEnv -> [InBind]
+       -> (ScEnv -> UniqSM (ScUsage, a))   -- Specialise the scope of the bindings
+       -> UniqSM (ScUsage, [OutBind], a)
+scBinds _lvl env []     k = do
+  (usg, a) <- k env
+  return (usg, [], a)
+scBinds lvl  env (b:bs) k = do
+  (usg, b', (bs', a)) <- scBind lvl env b $ \env -> do
+    (usg, bs', a) <- scBinds lvl env bs k
+    return (usg, (bs',a))
+  return (usg, b' ++ bs', a)
 
 {-
 ************************************************************************
@@ -1018,6 +1031,9 @@ extendScInScope env qvars
 extendScSubst :: ScEnv -> Var -> OutExpr -> ScEnv
 extendScSubst env var expr = env { sc_subst = extendSubst (sc_subst env) var expr }
 
+extendScSubstPre :: ScEnv -> Var -> InExpr -> ScEnv
+extendScSubstPre env var expr = extendScSubst env var (substExpr (sc_subst env) expr)
+
 extendScSubstList :: ScEnv -> [(Var,OutExpr)] -> ScEnv
 extendScSubstList env prs = env { sc_subst = extendSubstList (sc_subst env) prs }
 
@@ -1330,6 +1346,13 @@ creates specialised versions of functions.
 scBind :: TopLevelFlag -> ScEnv -> InBind
        -> (ScEnv -> UniqSM (ScUsage, a))   -- Specialise the scope of the binding
        -> UniqSM (ScUsage, [OutBind], a)
+scBind top_lvl env (NonRec bndr rhs) do_body
+  | Just (app, binds) <- denest_nonrec_let (getSubstInScope (sc_subst env)) bndr rhs
+    -- See Note [Denesting non-recursive let bindings]
+  , let env' | isTopLevel top_lvl = extendScInScope env (bindersOfBinds binds)
+             | otherwise          = env
+    -- At top level, scBinds assumes that we've already put all binders into scope; see initScEnv
+  = scBinds top_lvl env' binds (\env -> do_body $ extendScSubstPre env bndr app)
 scBind top_lvl env (NonRec bndr rhs) do_body
   | isTyVar bndr         -- Type-lets may be created by doBeta
   = do { (final_usage, body') <- do_body (extendScSubst env bndr rhs)
@@ -1424,8 +1447,88 @@ scBind top_lvl env (Rec prs) do_body
 
     rhs_env2 = extendHowBound rhs_env1 bndrs' RecFun
 
-{- Note [Specialising local let bindings]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+-- | Implements Note [Denesting non-recursive let bindings].
+--
+-- The call `denest_nonrec_let in_scope f (\xs -> let binds in g ys)` returns
+-- `Just (\xs -> g' ys, binds')`, where `g'` and `binds'` were stripped of their
+-- join-point-ness (if `f` was not a join point itself).
+-- The function returns `Nothing` if the code does not match.
+--
+-- The `InScopeSet` makes sure that `binds` do not shadow existing bindings
+-- that are used in ..., in which case this function will return `Nothing`, too.
+denest_nonrec_let :: InScopeSet -> InId -> InExpr -> Maybe (InExpr, [InBind])
+denest_nonrec_let in_scope f rhs
+  | (xs,          body) <- collectBinders rhs
+  , (binds@(_:_), call) <- collectLets body
+  , (Var g, args)       <- collectArgs call
+  , let bndrs = bindersOfBinds binds
+  , (g', binds') <- need_zap_join_point_hood f g binds `orElse` (g, binds)
+  -- expensive tests last:
+  , bndrs `dont_shadow` in_scope     -- floating binds out may not shadow bindings already in scope
+  , args  `exprs_dont_mention` bndrs -- args may not mention binds
+  , binds `binds_dont_mention` xs    -- binds may not mention xs
+  = Just (mkLams xs $ mkApps (Var g') args, binds')
+  | otherwise
+  = Nothing
+  where
+    dont_shadow :: [Var] -> InScopeSet -> Bool
+    dont_shadow bndrs in_scope =
+      disjointVarSet (getInScopeVars in_scope) (mkVarSet bndrs)
+
+    exprs_dont_mention :: [CoreExpr] -> [Var] -> Bool
+    exprs_dont_mention exprs vs =
+      disjointVarSet (exprsFreeVars exprs) (mkVarSet vs)
+
+    binds_dont_mention :: [CoreBind] -> [Var] -> Bool
+    binds_dont_mention binds vs =
+      let some_var = head (bindersOfBinds binds)
+          vs_set   = mkVarSet vs
+      in null $ exprSomeFreeVarsList (`elemVarSet` vs_set) (mkLets binds (Var some_var))
+
+    need_zap_join_point_hood :: Id -> Id -> [CoreBind] -> Maybe (Id, [CoreBind])
+    need_zap_join_point_hood f g binds
+      | isJoinId f       = Nothing -- `f` and `g` share tail context
+      | not (isJoinId g) = Nothing -- `g` and thus `binds` never were joinpoints to begin with
+      | otherwise        = Just (mark_non_join g, map (map_binders mark_non_join) binds)
+
+    map_binders :: (b -> b) -> Bind b -> Bind b
+    map_binders f (NonRec b rhs) = NonRec (f b) rhs
+    map_binders f (Rec prs)      = Rec (map (first f) prs)
+
+    mark_non_join :: Id -> Id
+    mark_non_join id = case idDetails id of
+      JoinId _ Nothing          -> id `setIdDetails` VanillaId
+      JoinId _ (Just cbv_marks) -> id `setIdDetails` WorkerLikeId cbv_marks
+      _                         -> id
+
+{- Note [Denesting non-recursive let bindings]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Suppose we see (local or at top-level)
+
+  f xs = let binds in g as;
+  rest
+
+where `xs` don't occur in `binds` and `as` do not mention `binds`.
+It might be interesting to specialise `f` and `g` for call patterns in `rest`,
+but it is difficult to do it in this nested form, because
+
+  1. We only get to see `ScrutOcc`s on `g`, in its RHS
+  2. The interesting call patterns in `rest` apply only to `f`
+  3. Specialising `f` and `g` for those call patterns duplicates `binds` twice:
+     We keep one copy of `bind` in the original `f`, one copy of `bind` in `$sf`
+     and another specialised copy `$sbind` (containing `$sg`) in `$sf`.
+
+So for SpecConstr, we float out `binds` (removing potential join-point-ness)
+
+  binds;
+  rest[f:=\xs -> g as]
+
+Because now all call patterns of `f` directly apply to `g` and might match up
+with one of the `ScrutOcc`s in its RHS, while only needing a single duplicate of
+`bind`.
+
+Note [Specialising local let bindings]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 It is not uncommon to find this
 
    let $j = \x. <blah> in ...$j True...$j True...


=====================================
testsuite/tests/simplCore/should_compile/T14951.hs
=====================================
@@ -0,0 +1,24 @@
+-- {-# OPTIONS_GHC -Wincomplete-patterns -fforce-recomp #-}
+-- {-# OPTIONS_GHC -O2 -fforce-recomp #-}
+-- {-# LANGUAGE PatternSynonyms #-}
+-- {-# LANGUAGE BangPatterns #-}
+-- {-# LANGUAGE MagicHash, UnboxedTuples #-}
+
+module T14844Example (topLvl) where
+
+topLvl large = (bar1, bar2, foo)
+  where
+    foo :: Integer -> (a -> b -> Bool) -> (a,b) -> Bool
+    foo 0 _ _ = False
+    foo s f t = l s' t
+       where
+         l 0 t = False
+         l 1 t = case t of (x,y) -> f x y
+         l n (x,y) = l (n-1) (x,y)
+         s' = large s
+
+    bar1 :: Integer -> (a -> b -> Bool) -> a -> b -> Bool
+    bar1 s f x y = foo s f (x,y)
+
+    bar2 :: Integer ->  (a -> b -> Bool) -> a -> b -> Bool
+    bar2 s f x y = foo (s + 1) f (x,y)


=====================================
testsuite/tests/simplCore/should_compile/T22277.hs
=====================================
@@ -0,0 +1,16 @@
+{-# OPTIONS_GHC -O2 -fforce-recomp #-}
+
+module T22277 where
+
+entry :: Int -> Int
+entry n = case n of
+  0 -> f n (13,24)
+  _ -> f n (n,n)
+  where
+    f :: Int -> (Int,Int) -> Int
+    f m x = g m x
+      where
+        exit m = (length $ reverse $ reverse $ reverse $ reverse $ [0..m]) + n
+        g n p | even n    = exit n
+              | n > 43    = g (n-1) p
+              | otherwise = fst p


=====================================
testsuite/tests/simplCore/should_compile/T22277.stderr
=====================================
@@ -0,0 +1,132 @@
+[1 of 1] Compiling T22277           ( T22277.hs, T22277.o )
+
+==================== Tidy Core ====================
+Result size of Tidy Core
+  = {terms: 110, types: 49, coercions: 0, joins: 3/4}
+
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
+T22277.$trModule4 :: GHC.Prim.Addr#
+[GblId,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+         WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 20 0}]
+T22277.$trModule4 = "main"#
+
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
+T22277.$trModule3 :: GHC.Types.TrName
+[GblId,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+         WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 10}]
+T22277.$trModule3 = GHC.Types.TrNameS T22277.$trModule4
+
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
+T22277.$trModule2 :: GHC.Prim.Addr#
+[GblId,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+         WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 30 0}]
+T22277.$trModule2 = "T22277"#
+
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
+T22277.$trModule1 :: GHC.Types.TrName
+[GblId,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+         WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 10}]
+T22277.$trModule1 = GHC.Types.TrNameS T22277.$trModule2
+
+-- RHS size: {terms: 3, types: 0, coercions: 0, joins: 0/0}
+T22277.$trModule :: GHC.Types.Module
+[GblId,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+         WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 10}]
+T22277.$trModule
+  = GHC.Types.Module T22277.$trModule3 T22277.$trModule1
+
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
+T22277.entry2 :: Int
+[GblId,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+         WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 10}]
+T22277.entry2 = GHC.Types.I# 13#
+
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
+T22277.entry1 :: Int
+[GblId,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+         WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 10}]
+T22277.entry1 = GHC.Types.I# 24#
+
+-- RHS size: {terms: 89, types: 40, coercions: 0, joins: 3/4}
+entry :: Int -> Int
+[GblId,
+ Arity=1,
+ Str=<1P(SL)>,
+ Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
+         WorkFree=True, Expandable=True, Guidance=IF_ARGS [20] 403 0}]
+entry
+  = \ (n :: Int) ->
+      case n of wild { GHC.Types.I# ds ->
+      join {
+        $w$sexit [InlPrag=[2], Dmd=LC(S,!P(L))] :: GHC.Prim.Int# -> Int
+        [LclId[JoinId(1)(Nothing)], Arity=1, Str=<L>]
+        $w$sexit (ww [OS=OneShot] :: GHC.Prim.Int#)
+          = join {
+              $j [Dmd=1C(1,!P(L))] :: [Int] -> Int
+              [LclId[JoinId(1)(Just [!])], Arity=1, Str=<1L>, Unf=OtherCon []]
+              $j (arg [OS=OneShot] :: [Int])
+                = case GHC.List.$wlenAcc
+                         @Int
+                         (GHC.List.reverse1
+                            @Int
+                            (GHC.List.reverse1
+                               @Int
+                               (GHC.List.reverse1
+                                  @Int
+                                  (GHC.List.reverse1 @Int arg (GHC.Types.[] @Int))
+                                  (GHC.Types.[] @Int))
+                               (GHC.Types.[] @Int))
+                            (GHC.Types.[] @Int))
+                         0#
+                  of ww1
+                  { __DEFAULT ->
+                  GHC.Types.I# (GHC.Prim.+# ww1 ds)
+                  } } in
+            case GHC.Prim.># 0# ww of {
+              __DEFAULT ->
+                letrec {
+                  go3 [Occ=LoopBreaker, Dmd=SC(S,L)] :: GHC.Prim.Int# -> [Int]
+                  [LclId, Arity=1, Str=<L>, Unf=OtherCon []]
+                  go3
+                    = \ (x :: GHC.Prim.Int#) ->
+                        GHC.Types.:
+                          @Int
+                          (GHC.Types.I# x)
+                          (case GHC.Prim.==# x ww of {
+                             __DEFAULT -> go3 (GHC.Prim.+# x 1#);
+                             1# -> GHC.Types.[] @Int
+                           }); } in
+                jump $j (go3 0#);
+              1# -> jump $j (GHC.Types.[] @Int)
+            } } in
+      joinrec {
+        $s$wg [Occ=LoopBreaker, Dmd=SC(S,C(1,C(1,!P(L))))]
+          :: Int -> Int -> GHC.Prim.Int# -> Int
+        [LclId[JoinId(3)(Nothing)],
+         Arity=3,
+         Str=<ML><A><L>,
+         Unf=OtherCon []]
+        $s$wg (sc :: Int) (sc1 :: Int) (sc2 :: GHC.Prim.Int#)
+          = case GHC.Prim.remInt# sc2 2# of {
+              __DEFAULT ->
+                case GHC.Prim.># sc2 43# of {
+                  __DEFAULT -> sc;
+                  1# -> jump $s$wg sc sc1 (GHC.Prim.-# sc2 1#)
+                };
+              0# -> jump $w$sexit sc2
+            }; } in
+      case ds of ds1 {
+        __DEFAULT -> jump $s$wg wild wild ds1;
+        0# -> jump $s$wg T22277.entry2 T22277.entry1 0#
+      }
+      }
+
+
+


=====================================
testsuite/tests/simplCore/should_compile/all.T
=====================================
@@ -273,6 +273,9 @@ test('T14152a', [extra_files(['T14152.hs']), pre_cmd('cp T14152.hs T14152a.hs'),
                 compile, ['-fno-exitification -ddump-simpl'])
 test('T13990', normal, compile, ['-dcore-lint -O'])
 test('T14650', normal, compile, ['-O2'])
+
+# SpecConstr should specialise `l` here:
+test('T14951', [expect_broken(14591), grep_errmsg(r'\$sl') ], compile, ['-O2 -dsuppress-uniques -ddump-simpl'])
 test('T14959', normal, compile, ['-O'])
 test('T14978',
      normal,
@@ -434,3 +437,5 @@ test('T21286',  normal, multimod_compile, ['T21286', '-O -ddump-rule-firings'])
 test('T21851', [grep_errmsg(r'case.*w\$sf') ], multimod_compile, ['T21851', '-O -dno-typeable-binds -dsuppress-uniques'])
 # One module, T22097.hs, has OPTIONS_GHC -ddump-simpl
 test('T22097', [grep_errmsg(r'case.*wgoEven') ], multimod_compile, ['T22097', '-O -dno-typeable-binds -dsuppress-uniques'])
+# SpecConstr should be able to specialise `go` for the pair
+test('T22277', [grep_errmsg(r'\$s\$wgo') ], compile, ['-O2 -ddump-simpl -dsuppress-uniques'])



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/0beddf90d6ba63a84e6a00093d164f260c8d50c9

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/0beddf90d6ba63a84e6a00093d164f260c8d50c9
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20221011/95d16a9e/attachment-0001.html>


More information about the ghc-commits mailing list