[Git][ghc/ghc][wip/T25096] Fix nasty bug in occurrence analyser
Simon Peyton Jones (@simonpj)
gitlab at gitlab.haskell.org
Sun Jul 21 22:39:38 UTC 2024
Simon Peyton Jones pushed to branch wip/T25096 at Glasgow Haskell Compiler / GHC
Commits:
58a09b76 by Simon Peyton Jones at 2024-07-21T23:39:28+01:00
Fix nasty bug in occurrence analyser
As #25096 showed, the occurrence analyser was getting one-shot info
flat out wrong.
This commit fixes the bug and actually makes the code a bit tidier too.
- - - - -
5 changed files:
- compiler/GHC/Core/Opt/OccurAnal.hs
- compiler/GHC/Utils/Outputable.hs
- + testsuite/tests/simplCore/should_run/T25096.hs
- + testsuite/tests/simplCore/should_run/T25096.stdout
- testsuite/tests/simplCore/should_run/all.T
Changes:
=====================================
compiler/GHC/Core/Opt/OccurAnal.hs
=====================================
@@ -1035,8 +1035,6 @@ occAnalNonRecRhs !env lvl imp_rule_edges mb_join bndr rhs
| otherwise
= (adj_rhs_uds : adj_unf_uds : adj_rule_uds, final_bndr_with_rules, final_rhs )
where
- is_join_point = isJoinPoint mb_join
-
--------- Right hand side ---------
-- For join points, set occ_encl to OccVanilla, via setTailCtxt. If we have
-- join j = Just (f x) in ...
@@ -1044,12 +1042,9 @@ occAnalNonRecRhs !env lvl imp_rule_edges mb_join bndr rhs
-- let y = f x in join j = Just y in ...
-- That's that OccRhs would do; but there's no point because
-- j will never be scrutinised.
- env1 | is_join_point = setTailCtxt env
- | otherwise = setNonTailCtxt rhs_ctxt env -- Zap occ_join_points
+ rhs_env = mkRhsOccEnv env NonRecursive rhs_ctxt mb_join bndr rhs
rhs_ctxt = mkNonRecRhsCtxt lvl bndr unf
- -- See Note [Sources of one-shot information]
- rhs_env = addOneShotsFromDmd bndr env1
-- See Note [Join arity prediction based on joinRhsArity]
-- Match join arity O from mb_join_arity with manifest join arity M as
-- returned by of occAnalLamTail. It's totally OK for them to mismatch;
@@ -1059,16 +1054,15 @@ occAnalNonRecRhs !env lvl imp_rule_edges mb_join bndr rhs
final_bndr_with_rules
| noBinderSwaps env = bndr -- See Note [Unfoldings and rules]
| otherwise = bndr `setIdSpecialisation` mkRuleInfo rules'
- `setIdUnfolding` unf2
+ `setIdUnfolding` unf1
final_bndr_no_rules
| noBinderSwaps env = bndr -- See Note [Unfoldings and rules]
- | otherwise = bndr `setIdUnfolding` unf2
+ | otherwise = bndr `setIdUnfolding` unf1
--------- Unfolding ---------
-- See Note [Join points and unfoldings/rules]
unf = idUnfolding bndr
WTUD unf_tuds unf1 = occAnalUnfolding rhs_env unf
- unf2 = markNonRecUnfoldingOneShots mb_join unf1
adj_unf_uds = adjustTailArity mb_join unf_tuds
--------- Rules ---------
@@ -1143,10 +1137,8 @@ occAnalRec !_ lvl
| isDeadOcc occ -- Check for dead code: see Note [Dead code]
= WUD body_uds binds
| otherwise
- = let (tagged_bndr, mb_join) = tagNonRecBinder lvl occ bndr
+ = let (bndr', mb_join) = tagNonRecBinder lvl occ bndr
!(WUD rhs_uds' rhs') = adjustNonRecRhs mb_join wtuds
- !unf' = markNonRecUnfoldingOneShots mb_join (idUnfolding tagged_bndr)
- !bndr' = tagged_bndr `setIdUnfolding` unf'
in WUD (body_uds `andUDs` rhs_uds')
(NonRec bndr' rhs' : binds)
where
@@ -1751,10 +1743,9 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs)
-- Instead, do the occAnalLamTail call here and postpone adjustTailUsage
-- until occAnalRec. In effect, we pretend that the RHS becomes a
-- non-recursive join point and fix up later with adjustTailUsage.
- rhs_env | isJoinId bndr = setTailCtxt env
- | otherwise = setNonTailCtxt OccRhs env
- -- If bndr isn't an /existing/ join point, it's safe to zap the
- -- occ_join_points, because they can't occur in RHS.
+ rhs_env = mkRhsOccEnv env Recursive OccRhs (idJoinPointHood bndr) bndr rhs
+ -- If bndr isn't an /existing/ join point (so idJoinPointHood = NotJoinPoint),
+ -- it's safe to zap the occ_join_points, because they can't occur in RHS.
WTUD (TUD rhs_ja unadj_rhs_uds) rhs' = occAnalLamTail rhs_env rhs
-- The corresponding call to adjustTailUsage is in occAnalRec and tagRecBinders
@@ -2309,20 +2300,8 @@ occAnalRule env rule@(Rule { ru_bndrs = bndrs, ru_args = args, ru_rhs = rhs })
occAnalRule _ other_rule = (other_rule, emptyDetails, TUD 0 emptyDetails)
-{- Note [Join point RHSs]
-~~~~~~~~~~~~~~~~~~~~~~~~~
-Consider
- x = e
- join j = Just x
-
-We want to inline x into j right away, so we don't want to give
-the join point a RhsCtxt (#14137). It's not a huge deal, because
-the FloatIn pass knows to float into join point RHSs; and the simplifier
-does not float things out of join point RHSs. But it's a simple, cheap
-thing to do. See #14137.
-
-Note [Occurrences in stable unfoldings]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+{- Note [Occurrences in stable unfoldings]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Consider
f p = BIG
{-# INLINE g #-}
@@ -2598,7 +2577,7 @@ occAnalArgs !env fun args !one_shots
| otherwise
= case one_shots of
[] -> (env_args, []) -- Fast path; one_shots is often empty
- (os : one_shots') -> (addOneShots os env_args, one_shots')
+ (os : one_shots') -> (setOneShots os env_args, one_shots')
{-
Applications are dealt with specially because we want
@@ -2910,42 +2889,125 @@ setScrutCtxt !env alts
-- non-default alternative. That in turn influences
-- pre/postInlineUnconditionally. Grep for "occ_int_cxt"!
+{- Note [The OccEnv for a right hand side]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+How do we create the OccEnv for a RHS (in mkRhsOccEnv)?
+
+For a non-join point binding, x = rhs
+
+ * occ_encl: set to OccRhs; but see `mkNonRecRhsCtxt` for wrinkles
+
+ * occ_join_points: zap them!
+
+ * occ_one_shots: initialise from the idDemandInfo;
+ see Note [Sources of one-shot information]
+
+For a join point binding, j x = rhs
+
+ * occ_encl: Consider
+ x = e
+ join j = Just x
+ We want to inline x into j right away, so we don't want to give the join point
+ a OccRhs (#14137); we want OccVanilla. It's not a huge deal, because the
+ FloatIn pass knows to float into join point RHSs; and the simplifier does not
+ float things out of join point RHSs. But it's a simple, cheap thing to do.
+
+ * occ_join_points: no need to zap.
+
+ * occ_one_shots: we start with one-shot-info from the context, which indeed
+ applies to the /body/ of the join point, after walking past the binders.
+ So we add to the front a OneShotInfo for each value-binder of the join
+ point: see `extendOneShotsForJoinPoint`. (Failing to account for the join-point
+ binders caused #25096.)
+
+ For the join point binders themselves, of a /non-recursive/ join point,
+ we make the binder a OneShotLam. Again see `extendOneShotsForJoinPoint`.
+
+ These one-shot infos then get attached to the binder by `occAnalLamTail`.
+-}
+
setNonTailCtxt :: OccEncl -> OccEnv -> OccEnv
setNonTailCtxt ctxt !env
= env { occ_encl = ctxt
, occ_one_shots = []
- , occ_join_points = zapped_jp_env }
- where
- -- zapped_jp_env is basically just emptyVarEnv (hence zapped). See (W3) of
- -- Note [Occurrence analysis for join points] Zapping improves efficiency,
- -- slightly, if you accidentally introduce a bug, in which you zap [jx :-> uds] and
- -- then find an occurrence of jx anyway, you might lose those uds, and
- -- that might mean we don't record all occurrencs, and that means we
- -- duplicate a redex.... a very nasty bug (which I encountered!). Hence
- -- this DEBUG code which doesn't remove jx from the envt; it just gives it
- -- emptyDetails, which in turn causes a panic in mkOneOcc. That will catch
- -- this bug before it does any damage.
-#ifdef DEBUG
- zapped_jp_env = mapVarEnv (\ _ -> emptyVarEnv) (occ_join_points env)
-#else
- zapped_jp_env = emptyVarEnv
-#endif
+ , occ_join_points = zapJoinPointInfo (occ_join_points env) }
setTailCtxt :: OccEnv -> OccEnv
-setTailCtxt !env
- = env { occ_encl = OccVanilla }
+setTailCtxt !env = env { occ_encl = OccVanilla }
-- Preserve occ_one_shots, occ_join points
-- Do not use OccRhs for the RHS of a join point (which is a tail ctxt):
- -- see Note [Join point RHSs]
-addOneShots :: OneShots -> OccEnv -> OccEnv
-addOneShots os !env
+mkRhsOccEnv :: OccEnv -> RecFlag -> OccEncl -> JoinPointHood -> Id -> CoreExpr -> OccEnv
+-- See Note [The OccEnv for a right hand side]
+-- For a join point:
+-- - Keep occ_one_shots, occ_joinPoints from the context
+-- - But push enough OneShotInfo onto occ_one_shots to account
+-- for the join-point value binders
+-- - Set occ_encl to OccVanilla
+-- For non-join points
+-- - Zap occ_one_shots and occ_join_points
+-- - Set occ_encl to specified OccEncl
+mkRhsOccEnv env@(OccEnv { occ_one_shots = ctxt_one_shots, occ_join_points = ctxt_join_points })
+ is_rec encl jp_hood bndr rhs
+ | JoinPoint join_arity <- jp_hood
+ = env { occ_encl = OccVanilla
+ , occ_one_shots = extendOneShotsForJoinPoint is_rec join_arity rhs ctxt_one_shots
+ , occ_join_points = ctxt_join_points }
+
+ | otherwise
+ = env { occ_encl = encl
+ , occ_one_shots = argOneShots (idDemandInfo bndr)
+ -- argOneShots: see Note [Sources of one-shot information]
+ , occ_join_points = zapJoinPointInfo ctxt_join_points }
+
+zapJoinPointInfo :: JoinPointInfo -> JoinPointInfo
+-- (zapJoinPointInfo jp_info) basically just returns emptyVarEnv (hence zapped).
+-- See (W3) of Note [Occurrence analysis for join points]
+--
+-- Zapping improves efficiency, slightly, if you accidentally introduce a bug,
+-- in which you zap [jx :-> uds] and then find an occurrence of jx anyway, you
+-- might lose those uds, and that might mean we don't record all occurrencs, and
+-- that means we duplicate a redex.... a very nasty bug (which I encountered!).
+-- Hence this DEBUG code which doesn't remove jx from the envt; it just gives it
+-- emptyDetails, which in turn causes a panic in mkOneOcc. That will catch this
+-- bug before it does any damage.
+#ifdef DEBUG
+zapJoinPointInfo jp_info = mapVarEnv (\ _ -> emptyVarEnv) jp_info
+#else
+zapJoinPointInfo jp_info = emptyVarEnv
+#endif
+
+extendOneShotsForJoinPoint
+ :: RecFlag -> JoinArity -> CoreExpr
+ -> [OneShotInfo] -> [OneShotInfo]
+-- Push enough OneShortInfos on the front of ctxt_one_shots
+-- to account for the value lambdas of the join point
+extendOneShotsForJoinPoint is_rec join_arity rhs ctxt_one_shots
+ = go join_arity rhs
+ where
+ -- For a /non-recursive/ join point we can mark all
+ -- its join-lambda as one-shot; and it's a good idea to do so
+ -- But not so for recursive ones
+ os = case is_rec of
+ NonRecursive -> OneShotLam
+ Recursive -> NoOneShotInfo
+
+ go 0 _ = ctxt_one_shots
+ go n (Lam b rhs)
+ | isId b = os : go (n-1) rhs
+ | otherwise = go (n-1) rhs
+ go _ _ = [] -- Not enough lambdas. This can legitimately happen.
+ -- e.g. let j = case ... in j True
+ -- This will become an arity-1 join point after the
+ -- simplifier has eta-expanded it; but it may not have
+ -- enough lambdas /yet/. (Lint checks that JoinIds do
+ -- have enough lambdas.)
+
+setOneShots :: OneShots -> OccEnv -> OccEnv
+setOneShots os !env
| null os = env -- Fast path for common case
| otherwise = env { occ_one_shots = os }
-addOneShotsFromDmd :: Id -> OccEnv -> OccEnv
-addOneShotsFromDmd bndr = addOneShots (argOneShots (idDemandInfo bndr))
-
isRhsEnv :: OccEnv -> Bool
isRhsEnv (OccEnv { occ_encl = cxt }) = case cxt of
OccRhs -> True
@@ -3732,17 +3794,10 @@ adjustNonRecRhs :: JoinPointHood
-> WithUsageDetails CoreExpr
-- ^ This function concentrates shared logic between occAnalNonRecBind and the
-- AcyclicSCC case of occAnalRec.
--- * It applies 'markNonRecJoinOneShots' to the RHS
--- * and returns the adjusted rhs UsageDetails combined with the body usage
+-- It returns the adjusted rhs UsageDetails combined with the body usage
adjustNonRecRhs mb_join_arity rhs_wuds@(WTUD _ rhs)
- = WUD rhs_uds' rhs'
- where
- --------- Marking (non-rec) join binders one-shot ---------
- !rhs' | JoinPoint ja <- mb_join_arity = markNonRecJoinOneShots ja rhs
- | otherwise = rhs
+ = WUD (adjustTailUsage mb_join_arity rhs_wuds) rhs
- --------- Adjusting right-hand side usage ---------
- rhs_uds' = adjustTailUsage mb_join_arity rhs_wuds
adjustTailUsage :: JoinPointHood
-> WithTailUsageDetails CoreExpr -- Rhs usage, AFTER occAnalLamTail
@@ -3760,33 +3815,6 @@ adjustTailArity :: JoinPointHood -> TailUsageDetails -> UsageDetails
adjustTailArity mb_rhs_ja (TUD ja usage)
= markAllNonTailIf (mb_rhs_ja /= JoinPoint ja) usage
-markNonRecJoinOneShots :: JoinArity -> CoreExpr -> CoreExpr
--- For a /non-recursive/ join point we can mark all
--- its join-lambda as one-shot; and it's a good idea to do so
-markNonRecJoinOneShots join_arity rhs
- = go join_arity rhs
- where
- go 0 rhs = rhs
- go n (Lam b rhs) = Lam (if isId b then setOneShotLambda b else b)
- (go (n-1) rhs)
- go _ rhs = rhs -- Not enough lambdas. This can legitimately happen.
- -- e.g. let j = case ... in j True
- -- This will become an arity-1 join point after the
- -- simplifier has eta-expanded it; but it may not have
- -- enough lambdas /yet/. (Lint checks that JoinIds do
- -- have enough lambdas.)
-
-markNonRecUnfoldingOneShots :: JoinPointHood -> Unfolding -> Unfolding
--- ^ Apply 'markNonRecJoinOneShots' to a stable unfolding
-markNonRecUnfoldingOneShots mb_join_arity unf
- | JoinPoint ja <- mb_join_arity
- , CoreUnfolding{uf_src=src,uf_tmpl=tmpl} <- unf
- , isStableSource src
- , let !tmpl' = markNonRecJoinOneShots ja tmpl
- = unf{uf_tmpl=tmpl'}
- | otherwise
- = unf
-
type IdWithOccInfo = Id
tagLamBinders :: UsageDetails -- Of scope
=====================================
compiler/GHC/Utils/Outputable.hs
=====================================
@@ -1261,7 +1261,7 @@ data BindingSite
data JoinPointHood
= JoinPoint {-# UNPACK #-} !Int -- The JoinArity (but an Int here because
- | NotJoinPoint -- synonym JoinArity is defined in Types.Basic
+ | NotJoinPoint -- synonym JoinArity is defined in Types.Basic)
deriving( Eq )
isJoinPoint :: JoinPointHood -> Bool
=====================================
testsuite/tests/simplCore/should_run/T25096.hs
=====================================
@@ -0,0 +1,20 @@
+module Main where
+
+import System.IO.Unsafe
+import Control.Monad
+
+main :: IO ()
+main = do
+ foo "test" 10
+
+foo :: String -> Int -> IO ()
+foo x n = go n
+ where
+ oops = unsafePerformIO (putStrLn "Once" >> pure (cycle x))
+
+ go 0 = return ()
+ go n = do
+ -- `oops` should be shared between loop iterations
+ let p = take n oops
+ let !_ = unsafePerformIO (putStrLn p >> pure ())
+ go (n-1)
=====================================
testsuite/tests/simplCore/should_run/T25096.stdout
=====================================
@@ -0,0 +1,11 @@
+Once
+testtestte
+testtestt
+testtest
+testtes
+testte
+testt
+test
+tes
+te
+t
=====================================
testsuite/tests/simplCore/should_run/all.T
=====================================
@@ -115,3 +115,4 @@ test('T23134', normal, compile_and_run, ['-O0 -fcatch-nonexhaustive-cases'])
test('T23289', normal, compile_and_run, [''])
test('T23056', [only_ways(['ghci-opt'])], ghci_script, ['T23056.script'])
test('T24725', normal, compile_and_run, ['-O -dcore-lint'])
+test('T25096', normal, compile_and_run, ['-O -dcore-lint'])
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/58a09b76211ba5944222a91b6eb4cd1074951456
--
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/58a09b76211ba5944222a91b6eb4cd1074951456
You're receiving this email because of your account on gitlab.haskell.org.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240721/857ec90f/attachment-0001.html>
More information about the ghc-commits
mailing list