[Git][ghc/ghc][wip/T3781] Make UnfoldingGuidance account for free variables

Jaro Reinders (@Noughtmare) gitlab at gitlab.haskell.org
Tue Aug 8 21:32:40 UTC 2023



Jaro Reinders pushed to branch wip/T3781 at Glasgow Haskell Compiler / GHC


Commits:
35ab806d by Simon Peyton Jones at 2023-08-08T23:32:13+02:00
Make UnfoldingGuidance account for free variables

For over a decade (#3781) I've wondered whether GHC's
UnfoldingGuidance should take account of free variables.  For example:

   f x = let g y = case x of ....
         in
         ...(case x of ....g e....)...

If we inline the call to g in the body of the case x, we can save
re-evaluating (and case analysis of) x. This very similar to the
"discounts" we give for arguments.  All we need is to do it for
free variables.

It turned out to be rather easy.

* Beef up GHC.Core.Unfold.sizeExpr to accumulate discounts for
  free variables.

* Beef up GHC.Core.Unfold.callSiteInline to compute those
  discounts, based on info in the InScopeSet

Fixes #3781
Also fixes #19643 (which reminded me of this issue)

- - - - -


6 changed files:

- compiler/GHC/Core.hs
- compiler/GHC/Core/Opt/Simplify/Inline.hs
- compiler/GHC/Core/Opt/Simplify/Iteration.hs
- compiler/GHC/Core/Ppr.hs
- compiler/GHC/Core/Seq.hs
- compiler/GHC/Core/Unfold.hs


Changes:

=====================================
compiler/GHC/Core.hs
=====================================
@@ -94,7 +94,7 @@ module GHC.Core (
 import GHC.Prelude
 import GHC.Platform
 
-import GHC.Types.Var.Env( InScopeSet )
+import GHC.Types.Var.Env( InScopeSet, IdEnv )
 import GHC.Types.Var
 import GHC.Core.Type
 import GHC.Core.Coercion
@@ -1385,6 +1385,7 @@ data UnfoldingGuidance
       ug_args ::  [Int],  -- Discount if the argument is evaluated.
                           -- (i.e., a simplification will definitely
                           -- be possible).  One elt of the list per *value* arg.
+      ug_fvs :: IdEnv Int,   -- Discount for free variables
 
       ug_size :: Int,     -- The "size" of the unfolding.
 


=====================================
compiler/GHC/Core/Opt/Simplify/Inline.hs
=====================================
@@ -31,6 +31,8 @@ import GHC.Utils.Outputable
 import GHC.Types.Name
 
 import Data.List (isPrefixOf)
+import GHC.Types.Var.Env
+import GHC.Types.Unique.FM (nonDetStrictFoldUFM_Directly)
 
 {-
 ************************************************************************
@@ -89,6 +91,7 @@ StrictAnal.addStrictnessInfoToTopId
 
 callSiteInline :: Logger
                -> UnfoldingOpts
+               -> InScopeSet
                -> Int                   -- Case depth
                -> Id                    -- The Id
                -> Bool                  -- True <=> unfolding is active
@@ -96,7 +99,8 @@ callSiteInline :: Logger
                -> [ArgSummary]          -- One for each value arg; True if it is interesting
                -> CallCtxt              -- True <=> continuation is interesting
                -> Maybe CoreExpr        -- Unfolding, if any
-callSiteInline logger opts !case_depth id active_unfolding lone_variable arg_infos cont_info
+callSiteInline logger opts in_scope !case_depth id
+               active_unfolding lone_variable arg_infos cont_info
   = case idUnfolding id of
       -- idUnfolding checks for loop-breakers, returning NoUnfolding
       -- Things with an INLINE pragma may have an unfolding *and*
@@ -104,7 +108,7 @@ callSiteInline logger opts !case_depth id active_unfolding lone_variable arg_inf
         CoreUnfolding { uf_tmpl = unf_template
                       , uf_cache = unf_cache
                       , uf_guidance = guidance }
-          | active_unfolding -> tryUnfolding logger opts case_depth id lone_variable
+          | active_unfolding -> tryUnfolding logger opts in_scope case_depth id lone_variable
                                     arg_infos cont_info unf_template
                                     unf_cache guidance
           | otherwise -> traceInline logger opts id "Inactive unfolding:" (ppr id) Nothing
@@ -227,10 +231,12 @@ needed on a per-module basis.
 
 -}
 
-tryUnfolding :: Logger -> UnfoldingOpts -> Int -> Id -> Bool -> [ArgSummary] -> CallCtxt
+tryUnfolding :: Logger -> UnfoldingOpts -> InScopeSet -> Int -> Id
+             -> Bool -> [ArgSummary] -> CallCtxt
              -> CoreExpr -> UnfoldingCache -> UnfoldingGuidance
              -> Maybe CoreExpr
-tryUnfolding logger opts !case_depth id lone_variable arg_infos
+tryUnfolding logger opts in_scope !case_depth id
+             lone_variable arg_infos
              cont_info unf_template unf_cache guidance
  = case guidance of
      UnfNever -> traceInline logger opts id str (text "UnfNever") Nothing
@@ -245,7 +251,8 @@ tryUnfolding logger opts !case_depth id lone_variable arg_infos
           some_benefit = calc_some_benefit uf_arity
           enough_args  = (n_val_args >= uf_arity) || (unsat_ok && n_val_args > 0)
 
-     UnfIfGoodArgs { ug_args = arg_discounts, ug_res = res_discount, ug_size = size }
+     UnfIfGoodArgs { ug_args = arg_discounts, ug_fvs = fv_discounts
+                   , ug_res = res_discount, ug_size = size }
         | unfoldingVeryAggressive opts
         -> traceInline logger opts id str (mk_doc some_benefit extra_doc True) (Just unf_template)
         | is_wf && some_benefit && small_enough
@@ -261,7 +268,8 @@ tryUnfolding logger opts !case_depth id lone_variable arg_infos
                         | otherwise       = (size * (case_depth - depth_treshold)) `div` depth_scaling
           adjusted_size = size + depth_penalty - discount
           small_enough = adjusted_size <= unfoldingUseThreshold opts
-          discount = computeDiscount arg_discounts res_discount arg_infos cont_info
+          discount = computeDiscount in_scope arg_discounts fv_discounts res_discount
+                                     arg_infos cont_info
 
           extra_doc = vcat [ text "case depth =" <+> int case_depth
                            , text "depth based penalty =" <+> int depth_penalty
@@ -510,9 +518,12 @@ which Roman did.
 
 -}
 
-computeDiscount :: [Int] -> Int -> [ArgSummary] -> CallCtxt
-                -> Int
-computeDiscount arg_discounts res_discount arg_infos cont_info
+computeDiscount :: InScopeSet
+                -> [Int]      -- Argument discounts
+                -> VarEnv Int -- Free-variable discounts
+                -> Int -> [ArgSummary] -> CallCtxt -> Int
+computeDiscount in_scope arg_discounts fv_discounts res_discount
+                arg_infos cont_info
 
   = 10          -- Discount of 10 because the result replaces the call
                 -- so we count 10 for the function itself
@@ -521,10 +532,17 @@ computeDiscount arg_discounts res_discount arg_infos cont_info
                -- Discount of 10 for each arg supplied,
                -- because the result replaces the call
 
-    + total_arg_discount + res_discount'
+    + total_arg_discount + fv_discount + res_discount'
   where
     actual_arg_discounts = zipWith mk_arg_discount arg_discounts arg_infos
     total_arg_discount   = sum actual_arg_discounts
+    fv_discount = nonDetStrictFoldUFM_Directly add_fv 0 fv_discounts
+    add_fv uniq disc tot_disc
+      | Just v <- lookupInScope_Directly in_scope uniq
+      , hasCoreUnfolding (idUnfolding v)
+      = disc + tot_disc
+      | otherwise
+      = disc
 
     mk_arg_discount _        TrivArg    = 0
     mk_arg_discount _        NonTrivArg = 10


=====================================
compiler/GHC/Core/Opt/Simplify/Iteration.hs
=====================================
@@ -2300,7 +2300,7 @@ rebuildCall env (ArgInfo { ai_fun = fun, ai_args = rev_args }) cont
 -----------------------------------
 tryInlining :: SimplEnv -> Logger -> OutId -> SimplCont -> SimplM (Maybe OutExpr)
 tryInlining env logger var cont
-  | Just expr <- callSiteInline logger uf_opts case_depth var active_unf
+  | Just expr <- callSiteInline logger uf_opts in_scope case_depth var active_unf
                                 lone_variable arg_infos interesting_cont
   = do { dump_inline expr cont
        ; return (Just expr) }
@@ -2311,6 +2311,7 @@ tryInlining env logger var cont
   where
     uf_opts    = seUnfoldingOpts env
     case_depth = seCaseDepth env
+    in_scope   = seInScope env
     (lone_variable, arg_infos, call_cont) = contArgs cont
     interesting_cont = interestingCallContext env call_cont
     active_unf       = activeUnfolding (seMode env) var


=====================================
compiler/GHC/Core/Ppr.hs
=====================================
@@ -35,6 +35,7 @@ import GHC.Types.Fixity (LexicalFixity(..))
 import GHC.Types.Literal( pprLiteral )
 import GHC.Types.Name( pprInfixName, pprPrefixName )
 import GHC.Types.Var
+import GHC.Types.Var.Env( isEmptyVarEnv )
 import GHC.Types.Id
 import GHC.Types.Id.Info
 import GHC.Types.Demand
@@ -611,11 +612,13 @@ instance Outputable UnfoldingGuidance where
         parens (text "arity="     <> int arity    <> comma <>
                 text "unsat_ok="  <> ppr unsat_ok <> comma <>
                 text "boring_ok=" <> ppr boring_ok)
-    ppr (UnfIfGoodArgs { ug_args = cs, ug_size = size, ug_res = discount })
-      = hsep [ text "IF_ARGS",
-               brackets (hsep (map int cs)),
-               int size,
-               int discount ]
+    ppr (UnfIfGoodArgs { ug_args = cs, ug_fvs = fvs
+                       , ug_size = size, ug_res = discount })
+      = hsep [ text "IF_ARGS"
+             , brackets (hsep (map int cs))
+             , if isEmptyVarEnv fvs then empty else ppr fvs
+             , int size
+             , int discount ]
 
 instance Outputable Unfolding where
   ppr NoUnfolding                = text "No unfolding"


=====================================
compiler/GHC/Core/Seq.hs
=====================================
@@ -113,5 +113,6 @@ seqUnfolding (CoreUnfolding { uf_tmpl = e, uf_is_top = top,
 seqUnfolding _ = ()
 
 seqGuidance :: UnfoldingGuidance -> ()
-seqGuidance (UnfIfGoodArgs ns n b) = n `seq` sum ns `seq` b `seq` ()
-seqGuidance _                      = ()
+seqGuidance (UnfIfGoodArgs ns ds n b) = n `seq` sum ns `seq` ds `seq` b `seq` ()
+                                        -- We use strict maps so I think `seq` ds will do
+seqGuidance _                         = ()


=====================================
compiler/GHC/Core/Unfold.hs
=====================================
@@ -48,11 +48,12 @@ import GHC.Types.RepType ( isZeroBitTy )
 import GHC.Types.Basic  ( Arity, RecFlag )
 import GHC.Core.Type
 import GHC.Builtin.Names
-import GHC.Data.Bag
 import GHC.Utils.Misc
 import GHC.Utils.Outputable
 import GHC.Types.ForeignCall
 import GHC.Types.Tickish
+import GHC.Types.Var.Env
+import GHC.Types.Var.Set
 
 import qualified Data.ByteString as BS
 
@@ -247,7 +248,7 @@ calcUnfoldingGuidance opts is_top_bottoming (Tick t expr)
 calcUnfoldingGuidance opts is_top_bottoming expr
   = case sizeExpr opts bOMB_OUT_SIZE val_bndrs body of
       TooBig -> UnfNever
-      SizeIs size cased_bndrs scrut_discount
+      SizeIs size id_discounts scrut_discount
         | uncondInline expr n_val_bndrs size
         -> UnfWhen { ug_unsat_ok = unSaturatedOk
                    , ug_boring_ok =  boringCxtOk
@@ -257,10 +258,11 @@ calcUnfoldingGuidance opts is_top_bottoming expr
         -> UnfNever   -- See Note [Do not inline top-level bottoming functions]
 
         | otherwise
-        -> UnfIfGoodArgs { ug_args  = map (mk_discount cased_bndrs) val_bndrs
+        -> UnfIfGoodArgs { ug_args  = map (lookupDiscount id_discounts) val_bndrs
+                         , ug_fvs   = mapVarEnv getDiscount $
+                                      id_discounts `delVarEnvList` val_bndrs
                          , ug_size  = size
                          , ug_res   = scrut_discount }
-
   where
     (bndrs, body) = collectBinders expr
     bOMB_OUT_SIZE = unfoldingCreationThreshold opts
@@ -268,17 +270,7 @@ calcUnfoldingGuidance opts is_top_bottoming expr
     val_bndrs   = filter isId bndrs
     n_val_bndrs = length val_bndrs
 
-    mk_discount :: Bag (Id,Int) -> Id -> Int
-    mk_discount cbs bndr = foldl' combine 0 cbs
-           where
-             combine acc (bndr', disc)
-               | bndr == bndr' = acc `plus_disc` disc
-               | otherwise     = acc
 
-             plus_disc :: Int -> Int -> Int
-             plus_disc | isFunTy (idType bndr) = max
-                       | otherwise             = (+)
-             -- See Note [Function and non-function discounts]
 
 {- Note [Inline unsafeCoerce]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -434,46 +426,58 @@ sizeExpr :: UnfoldingOpts
 -- Forcing bOMB_OUT_SIZE early prevents repeated
 -- unboxing of the Int argument.
 sizeExpr opts !bOMB_OUT_SIZE top_args expr
-  = size_up expr
+  = size_up (mkVarSet top_args, emptyVarSet) expr
   where
-    size_up (Cast e _) = size_up e
-    size_up (Tick _ e) = size_up e
-    size_up (Type _)   = sizeZero           -- Types cost nothing
-    size_up (Coercion _) = sizeZero
-    size_up (Lit lit)  = sizeN (litSize lit)
-    size_up (Var f) | isZeroBitId f = sizeZero
+    size_up :: IgnoreSet -> CoreExpr -> ExprSize
+      -- The IdSet are the Ids that we *don't* want to collect
+      -- discount information for; namely, the Ids bound locally in
+      -- the expression
+
+    size_up ig (Cast e _)   = size_up ig e
+    size_up ig (Tick _ e)   = size_up ig e
+    size_up _  (Type _)     = sizeZero           -- Types cost nothing
+    size_up _  (Coercion _) = sizeZero
+    size_up _  (Lit lit)    = sizeN (litSize lit)
+    size_up ig (Var f) | isZeroBitId f = sizeZero
                       -- Make sure we get constructor discounts even
                       -- on nullary constructors
-                    | otherwise       = size_up_call f [] 0
+                    | otherwise       = size_up_call ig f [] 0
 
-    size_up (App fun arg)
-      | isTyCoArg arg = size_up fun
-      | otherwise     = size_up arg  `addSizeNSD`
-                        size_up_app fun [arg] (if isZeroBitExpr arg then 1 else 0)
+    size_up ig (App fun arg)
+      | isTyCoArg arg = size_up ig fun
+      | otherwise     = size_up ig arg  `addSizeNSD`
+                        size_up_app ig fun [arg] (if isZeroBitExpr arg then 1 else 0)
 
-    size_up (Lam b e)
-      | isId b && not (isZeroBitId b) = lamScrutDiscount opts (size_up e `addSizeN` 10)
-      | otherwise = size_up e
+    size_up ig (Lam b e)
+      | isId b && not (isZeroBitId b) = lamScrutDiscount opts (size_up ig' e `addSizeN` 10)
+      | otherwise = size_up ig' e
+      where
+        ig' = ig `extendIgnore` b
 
-    size_up (Let (NonRec binder rhs) body)
-      = size_up_rhs (binder, rhs) `addSizeNSD`
-        size_up body              `addSizeN`
+    size_up ig (Let (NonRec binder rhs) body)
+      = size_up_rhs ig (binder, rhs) `addSizeNSD`
+        size_up ig' body              `addSizeN`
         size_up_alloc binder
+      where
+       ig' = ig `extendIgnore` binder
 
-    size_up (Let (Rec pairs) body)
-      = foldr (addSizeNSD . size_up_rhs)
-              (size_up body `addSizeN` sum (map (size_up_alloc . fst) pairs))
+    size_up ig (Let (Rec pairs) body)
+      = foldr (addSizeNSD . size_up_rhs ig')
+              (size_up ig' body `addSizeN` sum (map (size_up_alloc . fst) pairs))
               pairs
+      where
+        ig' = ig `extendIgnoreList` map fst pairs
 
-    size_up (Case e _ _ alts)
+    size_up ig (Case e bndr _ alts)
         | null alts
-        = size_up e    -- case e of {} never returns, so take size of scrutinee
+        = size_up (ig `extendIgnore` bndr) e
+          -- case e of {} never returns, so take size of scrutinee
 
-    size_up (Case e _ _ alts)
+    size_up ig (Case e bndr _ alts)
         -- Now alts is non-empty
-        | Just v <- is_top_arg e -- We are scrutinising an argument variable
+        | Just v <- is_var e -- We are scrutinising a variable
         = let
-            alt_sizes = map size_up_alt alts
+            alt_sizes = map (size_up_alt ig') alts
 
                   -- alts_size tries to compute a good discount for
                   -- the case when we are scrutinising an argument variable
@@ -481,17 +485,17 @@ sizeExpr opts !bOMB_OUT_SIZE top_args expr
                           -- Size of all alternatives
                       (SizeIs max _        _)
                           -- Size of biggest alternative
-                  = SizeIs tot (unitBag (v, 20 + tot - max)
-                      `unionBags` tot_disc) tot_scrut
+                  = SizeIs tot (unitDisc ig v (CaseDisc (20 + tot - max))
+                                `addDiscs` tot_disc) tot_scrut
                           -- If the variable is known, we produce a
                           -- discount that will take us back to 'max',
                           -- the size of the largest alternative The
-                          -- 1+ is a little discount for reduced
+                          -- 20+ is a little discount for reduced
                           -- allocation in the caller
                           --
                           -- Notice though, that we return tot_disc,
-                          -- the total discount from all branches.  I
-                          -- think that's right.
+                          -- the total discount from all branches.
+                          -- I think that's right.
 
             alts_size tot_size _ = tot_size
           in
@@ -501,14 +505,18 @@ sizeExpr opts !bOMB_OUT_SIZE top_args expr
                 -- that may eliminate allocation in the caller
                 -- And it eliminates the case itself
         where
-          is_top_arg (Var v) | v `elem` top_args = Just v
-          is_top_arg (Cast e _) = is_top_arg e
-          is_top_arg _ = Nothing
+          ig' = ig `extendIgnore` bndr
 
+          is_var (Var v)    = Just v
+          is_var (Cast e _) = is_var e
+          is_var _          = Nothing
 
-    size_up (Case e _ _ alts) = size_up e  `addSizeNSD`
-                                foldr (addAltSize . size_up_alt) case_size alts
+
+    size_up ig (Case e bndr _ alts) = size_up ig' e  `addSizeNSD`
+                                      foldr (addAltSize . size_up_alt ig') case_size alts
       where
+          ig' = ig `extendIgnore` bndr
+
           case_size
            | is_inline_scrut e, lengthAtMost alts 1 = sizeN (-10)
            | otherwise = sizeZero
@@ -544,42 +552,44 @@ sizeExpr opts !bOMB_OUT_SIZE top_args expr
               | otherwise
                 = False
 
-    size_up_rhs (bndr, rhs)
+    size_up_rhs ig (bndr, rhs)
       | JoinPoint join_arity <- idJoinPointHood bndr
         -- Skip arguments to join point
-      , (_bndrs, body) <- collectNBinders join_arity rhs
-      = size_up body
+      , (bndrs, body) <- collectNBinders join_arity rhs
+      = size_up (ig `extendIgnoreList` bndrs) body
       | otherwise
-      = size_up rhs
+      = size_up ig rhs
 
     ------------
     -- size_up_app is used when there's ONE OR MORE value args
-    size_up_app (App fun arg) args voids
-        | isTyCoArg arg                  = size_up_app fun args voids
-        | isZeroBitExpr arg              = size_up_app fun (arg:args) (voids + 1)
-        | otherwise                      = size_up arg  `addSizeNSD`
-                                           size_up_app fun (arg:args) voids
-    size_up_app (Var fun)     args voids = size_up_call fun args voids
-    size_up_app (Tick _ expr) args voids = size_up_app expr args voids
-    size_up_app (Cast expr _) args voids = size_up_app expr args voids
-    size_up_app other         args voids = size_up other `addSizeN`
-                                           callSize (length args) voids
+    size_up_app ig (App fun arg) args voids
+        | isTyCoArg arg                  = size_up_app ig fun args voids
+        | isZeroBitExpr arg              = size_up_app ig fun (arg:args) (voids + 1)
+        | otherwise                      = size_up ig arg  `addSizeNSD`
+                                           size_up_app ig fun (arg:args) voids
+    size_up_app ig (Var fun)     args voids = size_up_call ig fun args voids
+    size_up_app ig (Tick _ expr) args voids = size_up_app ig expr args voids
+    size_up_app ig (Cast expr _) args voids = size_up_app ig expr args voids
+    size_up_app ig other         args voids = size_up ig other `addSizeN`
+                                              callSize (length args) voids
        -- if the lhs is not an App or a Var, or an invisible thing like a
        -- Tick or Cast, then we should charge for a complete call plus the
        -- size of the lhs itself.
 
     ------------
-    size_up_call :: Id -> [CoreExpr] -> Int -> ExprSize
-    size_up_call fun val_args voids
+    size_up_call :: IgnoreSet -> Id -> [CoreExpr] -> Int -> ExprSize
+    size_up_call ig fun val_args voids
        = case idDetails fun of
            FCallId _        -> sizeN (callSize (length val_args) voids)
            DataConWorkId dc -> conSize    dc (length val_args)
            PrimOpId op _    -> primOpSize op (length val_args)
-           ClassOpId {}     -> classOpSize opts top_args val_args
-           _                -> funSize opts top_args fun (length val_args) voids
+           ClassOpId {}     -> classOpSize opts ig val_args
+           _                -> funSize opts ig fun (length val_args) voids
 
     ------------
-    size_up_alt (Alt _con _bndrs rhs) = size_up rhs `addSizeN` 10
+    size_up_alt ig (Alt _con bndrs rhs) = size_up ig' rhs `addSizeN` 10
+      where
+        ig' = ig `extendIgnoreList` bndrs
         -- Don't charge for args, so that wrappers look cheap
         -- (See comments about wrappers with Case)
         --
@@ -603,21 +613,23 @@ sizeExpr opts !bOMB_OUT_SIZE top_args expr
     addSizeN TooBig          _  = TooBig
     addSizeN (SizeIs n xs d) m  = mkSizeIs bOMB_OUT_SIZE (n + m) xs d
 
-        -- addAltSize is used to add the sizes of case alternatives
-    addAltSize TooBig            _      = TooBig
+    -- addAltSize is used to add the sizes of case alternatives
+    -- The /second/ argument is expected to be the bigger one; force it first
     addAltSize _                 TooBig = TooBig
+    addAltSize TooBig            _      = TooBig
     addAltSize (SizeIs n1 xs d1) (SizeIs n2 ys d2)
         = mkSizeIs bOMB_OUT_SIZE (n1 + n2)
-                                 (xs `unionBags` ys)
+                                 (xs `addDiscs` ys)
                                  (d1 + d2) -- Note [addAltSize result discounts]
 
-        -- This variant ignores the result discount from its LEFT argument
-        -- It's used when the second argument isn't part of the result
-    addSizeNSD TooBig            _      = TooBig
+    -- This variant ignores the result discount from its FIRST argument
+    -- It's used when the first argument isn't part of the result
+    -- The second argument is also expected to be bigger: force it first
     addSizeNSD _                 TooBig = TooBig
+    addSizeNSD TooBig            _      = TooBig
     addSizeNSD (SizeIs n1 xs _) (SizeIs n2 ys d2)
         = mkSizeIs bOMB_OUT_SIZE (n1 + n2)
-                                 (xs `unionBags` ys)
+                                 (xs `addDiscs` ys)
                                  d2  -- Ignore d1
 
     -- don't count expressions such as State# RealWorld
@@ -641,11 +653,11 @@ litSize _other = 0    -- Must match size of nullary constructors
                       -- Key point: if  x |-> 4, then x must inline unconditionally
                       --            (eg via case binding)
 
-classOpSize :: UnfoldingOpts -> [Id] -> [CoreExpr] -> ExprSize
+classOpSize :: UnfoldingOpts -> IgnoreSet -> [CoreExpr] -> ExprSize
 -- See Note [Conlike is interesting]
 classOpSize _ _ []
   = sizeZero
-classOpSize opts top_args (arg1 : other_args)
+classOpSize opts ig (arg1 : other_args)
   = SizeIs size arg_discount 0
   where
     size = 20 + (10 * length other_args)
@@ -653,9 +665,8 @@ classOpSize opts top_args (arg1 : other_args)
     -- give it a discount, to encourage the inlining of this function
     -- The actual discount is rather arbitrarily chosen
     arg_discount = case arg1 of
-                     Var dict | dict `elem` top_args
-                              -> unitBag (dict, unfoldingDictDiscount opts)
-                     _other   -> emptyBag
+                     Var dict -> unitDisc ig dict (CaseDisc (unfoldingDictDiscount opts))
+                     _other   -> emptyIdDiscounts
 
 -- | The size of a function call
 callSize
@@ -678,10 +689,10 @@ jumpSize n_val_args voids = 2 * (1 + n_val_args - voids)
   -- spectral/puzzle. TODO Perhaps adjusting the default threshold would be a
   -- better solution?
 
-funSize :: UnfoldingOpts -> [Id] -> Id -> Int -> Int -> ExprSize
+funSize :: UnfoldingOpts -> IgnoreSet -> Id -> Int -> Int -> ExprSize
 -- Size for functions that are not constructors or primops
 -- Note [Function applications]
-funSize opts top_args fun n_val_args voids
+funSize opts ig fun n_val_args voids
   | fun `hasKey` buildIdKey   = buildSize
   | fun `hasKey` augmentIdKey = augmentSize
   | otherwise = SizeIs size arg_discount res_discount
@@ -695,9 +706,9 @@ funSize opts top_args fun n_val_args voids
 
         --                  DISCOUNTS
         --  See Note [Function and non-function discounts]
-    arg_discount | some_val_args && fun `elem` top_args
-                 = unitBag (fun, unfoldingFunAppDiscount opts)
-                 | otherwise = emptyBag
+    arg_discount | some_val_args
+                 = unitDisc ig fun (AppDisc (unfoldingFunAppDiscount opts))
+                 | otherwise = emptyIdDiscounts
         -- If the function is an argument and is applied
         -- to some values, give it an arg-discount
 
@@ -708,13 +719,13 @@ funSize opts top_args fun n_val_args voids
 
 conSize :: DataCon -> Int -> ExprSize
 conSize dc n_val_args
-  | n_val_args == 0 = SizeIs 0 emptyBag 10    -- Like variables
+  | n_val_args == 0 = SizeIs 0 emptyIdDiscounts 10    -- Like variables
 
 -- See Note [Unboxed tuple size and result discount]
-  | isUnboxedTupleDataCon dc = SizeIs 0 emptyBag 10
+  | isUnboxedTupleDataCon dc = SizeIs 0 emptyIdDiscounts 10
 
 -- See Note [Constructor size and result discount]
-  | otherwise = SizeIs 10 emptyBag 10
+  | otherwise = SizeIs 10 emptyIdDiscounts 10
 
 {- Note [Constructor size and result discount]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -792,8 +803,8 @@ the function to a constructor application, so we *want* a big discount
 if the argument is scrutinised by many case expressions.
 
 Conclusion:
-  - For functions, take the max of the discounts
-  - For data values, take the sum of the discounts
+  - For functions,   take the max of the discounts (AppDisc)
+  - For data values, take the sum of the discounts (CaseDisc)
 
 
 Note [Literal integer size]
@@ -818,7 +829,7 @@ primOpSize op n_val_args
 
 
 buildSize :: ExprSize
-buildSize = SizeIs 0 emptyBag 40
+buildSize = SizeIs 0 emptyIdDiscounts 40
         -- We really want to inline applications of build
         -- build t (\cn -> e) should cost only the cost of e (because build will be inlined later)
         -- Indeed, we should add a result_discount because build is
@@ -827,14 +838,14 @@ buildSize = SizeIs 0 emptyBag 40
         -- The "4" is rather arbitrary.
 
 augmentSize :: ExprSize
-augmentSize = SizeIs 0 emptyBag 40
+augmentSize = SizeIs 0 emptyIdDiscounts 40
         -- Ditto (augment t (\cn -> e) ys) should cost only the cost of
         -- e plus ys. The -2 accounts for the \cn
 
 -- When we return a lambda, give a discount if it's used (applied)
 lamScrutDiscount :: UnfoldingOpts -> ExprSize -> ExprSize
 lamScrutDiscount opts (SizeIs n vs _) = SizeIs n vs (unfoldingFunAppDiscount opts)
-lamScrutDiscount _      TooBig          = TooBig
+lamScrutDiscount _      TooBig        = TooBig
 
 {-
 Note [addAltSize result discounts]
@@ -909,13 +920,71 @@ Code for manipulating sizes
 data ExprSize
     = TooBig
     | SizeIs { _es_size_is  :: {-# UNPACK #-} !Int -- ^ Size found
-             , _es_args     :: !(Bag (Id,Int))
+
+             , _es_args     :: !IdDiscounts
                -- ^ Arguments cased herein, and discount for each such
+
              , _es_discount :: {-# UNPACK #-} !Int
                -- ^ Size to subtract if result is scrutinised by a case
                -- expression
              }
 
+type IgnoreSet
+  = ( IdSet     -- Lambda-bound binders for this unfolding
+    , IdSet )   -- Locally-bound binders, within this unfolding
+
+type IdDiscounts = IdEnv IdDiscount
+
+
+lookupDiscount :: IdDiscounts -> Id -> Int
+lookupDiscount discounts bndr
+  = case lookupVarEnv discounts bndr of
+       Just d  -> getDiscount d
+       Nothing -> 0
+
+emptyIdDiscounts :: IdDiscounts
+emptyIdDiscounts = emptyVarEnv
+
+extendIgnore :: IgnoreSet -> Id -> IgnoreSet
+extendIgnore (tops,locals) v = (tops, locals `extendVarSet` v)
+
+extendIgnoreList :: IgnoreSet -> [Id] -> IgnoreSet
+extendIgnoreList (tops,locals) vs = (tops, locals `extendVarSetList` vs)
+
+unitDisc :: IgnoreSet -> Id -> IdDiscount -> IdDiscounts
+-- Record a discount for the use of an Id
+-- But not if it is
+--   (a) a GlobalId
+--   (b) bound locally within the function body we are analysing
+--   (c) an AppDisc for a free variable
+--   (d) has no unfolding
+unitDisc (top_args, ignore_these) v disc
+  | isLocalId v
+  , not (v `elemVarSet` ignore_these)
+  , case disc of { AppDisc _ -> v `elemVarSet` top_args
+                 ; CaseDisc {} -> True }
+  , not (hasCoreUnfolding (idUnfolding v))
+  = unitVarEnv v disc
+  | otherwise
+  = emptyIdDiscounts
+
+addDiscs :: IdDiscounts -> IdDiscounts -> IdDiscounts
+addDiscs = plusVarEnv_C addIdDiscount
+
+data IdDiscount
+  = CaseDisc {-# UNPACK #-} !Int
+  | AppDisc  {-# UNPACK #-} !Int  -- See Note [Function and non-function discounts]
+
+getDiscount :: IdDiscount -> Int
+getDiscount (CaseDisc n) = n
+getDiscount (AppDisc n)  = n
+
+addIdDiscount :: IdDiscount -> IdDiscount -> IdDiscount
+addIdDiscount (CaseDisc n1) (CaseDisc n2) = CaseDisc (n1+n2)
+addIdDiscount (CaseDisc n1) (AppDisc n2)  = AppDisc  (n1 `max` n2)
+addIdDiscount (AppDisc n1)  (CaseDisc n2) = AppDisc  (n1 `max` n2)
+addIdDiscount (AppDisc n1)  (AppDisc n2)  = AppDisc  (n1 `max` n2)
+
 instance Outputable ExprSize where
   ppr TooBig         = text "TooBig"
   ppr (SizeIs a _ c) = brackets (int a <+> int c)
@@ -925,7 +994,7 @@ instance Outputable ExprSize where
 --      tup = (a_1, ..., a_99)
 --      x = case tup of ...
 --
-mkSizeIs :: Int -> Int -> Bag (Id, Int) -> Int -> ExprSize
+mkSizeIs :: Int -> Int -> IdDiscounts -> Int -> ExprSize
 mkSizeIs max n xs d | (n - d) > max = TooBig
                     | otherwise     = SizeIs n xs d
 
@@ -938,5 +1007,5 @@ maxSize s1@(SizeIs n1 _ _) s2@(SizeIs n2 _ _) | n1 > n2   = s1
 sizeZero :: ExprSize
 sizeN :: Int -> ExprSize
 
-sizeZero = SizeIs 0 emptyBag 0
-sizeN n  = SizeIs n emptyBag 0
+sizeZero = SizeIs 0 emptyIdDiscounts 0
+sizeN n  = SizeIs n emptyIdDiscounts 0



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/35ab806dd4085651e12284027f3c3269ecb944be

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/35ab806dd4085651e12284027f3c3269ecb944be
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20230808/867d75e4/attachment-0001.html>


More information about the ghc-commits mailing list