[Git][ghc/ghc][wip/T25096] Fix nasty bug in occurrence analyser

Simon Peyton Jones (@simonpj) gitlab at gitlab.haskell.org
Sun Jul 21 22:39:38 UTC 2024



Simon Peyton Jones pushed to branch wip/T25096 at Glasgow Haskell Compiler / GHC


Commits:
58a09b76 by Simon Peyton Jones at 2024-07-21T23:39:28+01:00
Fix nasty bug in occurrence analyser

As #25096 showed, the occurrence analyser was getting one-shot info
flat out wrong.

This commit fixes the bug and actually makes the code a bit tidier too.

- - - - -


5 changed files:

- compiler/GHC/Core/Opt/OccurAnal.hs
- compiler/GHC/Utils/Outputable.hs
- + testsuite/tests/simplCore/should_run/T25096.hs
- + testsuite/tests/simplCore/should_run/T25096.stdout
- testsuite/tests/simplCore/should_run/all.T


Changes:

=====================================
compiler/GHC/Core/Opt/OccurAnal.hs
=====================================
@@ -1035,8 +1035,6 @@ occAnalNonRecRhs !env lvl imp_rule_edges mb_join bndr rhs
   | otherwise
   = (adj_rhs_uds : adj_unf_uds : adj_rule_uds, final_bndr_with_rules, final_rhs )
   where
-    is_join_point = isJoinPoint mb_join
-
     --------- Right hand side ---------
     -- For join points, set occ_encl to OccVanilla, via setTailCtxt.  If we have
     --    join j = Just (f x) in ...
@@ -1044,12 +1042,9 @@ occAnalNonRecRhs !env lvl imp_rule_edges mb_join bndr rhs
     --    let y = f x in join j = Just y in ...
     -- That's that OccRhs would do; but there's no point because
     -- j will never be scrutinised.
-    env1 | is_join_point = setTailCtxt env
-         | otherwise     = setNonTailCtxt rhs_ctxt env  -- Zap occ_join_points
+    rhs_env  = mkRhsOccEnv env NonRecursive rhs_ctxt mb_join bndr rhs
     rhs_ctxt = mkNonRecRhsCtxt lvl bndr unf
 
-    -- See Note [Sources of one-shot information]
-    rhs_env = addOneShotsFromDmd bndr env1
     -- See Note [Join arity prediction based on joinRhsArity]
     -- Match join arity O from mb_join_arity with manifest join arity M as
     -- returned by of occAnalLamTail. It's totally OK for them to mismatch;
@@ -1059,16 +1054,15 @@ occAnalNonRecRhs !env lvl imp_rule_edges mb_join bndr rhs
     final_bndr_with_rules
       | noBinderSwaps env = bndr -- See Note [Unfoldings and rules]
       | otherwise         = bndr `setIdSpecialisation` mkRuleInfo rules'
-                                 `setIdUnfolding` unf2
+                                 `setIdUnfolding` unf1
     final_bndr_no_rules
       | noBinderSwaps env = bndr -- See Note [Unfoldings and rules]
-      | otherwise         = bndr `setIdUnfolding` unf2
+      | otherwise         = bndr `setIdUnfolding` unf1
 
     --------- Unfolding ---------
     -- See Note [Join points and unfoldings/rules]
     unf = idUnfolding bndr
     WTUD unf_tuds unf1 = occAnalUnfolding rhs_env unf
-    unf2 = markNonRecUnfoldingOneShots mb_join unf1
     adj_unf_uds = adjustTailArity mb_join unf_tuds
 
     --------- Rules ---------
@@ -1143,10 +1137,8 @@ occAnalRec !_ lvl
   | isDeadOcc occ  -- Check for dead code: see Note [Dead code]
   = WUD body_uds binds
   | otherwise
-  = let (tagged_bndr, mb_join) = tagNonRecBinder lvl occ bndr
+  = let (bndr', mb_join) = tagNonRecBinder lvl occ bndr
         !(WUD rhs_uds' rhs') = adjustNonRecRhs mb_join wtuds
-        !unf'  = markNonRecUnfoldingOneShots mb_join (idUnfolding tagged_bndr)
-        !bndr' = tagged_bndr `setIdUnfolding` unf'
     in WUD (body_uds `andUDs` rhs_uds')
            (NonRec bndr' rhs' : binds)
   where
@@ -1751,10 +1743,9 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs)
     -- Instead, do the occAnalLamTail call here and postpone adjustTailUsage
     -- until occAnalRec. In effect, we pretend that the RHS becomes a
     -- non-recursive join point and fix up later with adjustTailUsage.
-    rhs_env | isJoinId bndr = setTailCtxt env
-            | otherwise     = setNonTailCtxt OccRhs env
-            -- If bndr isn't an /existing/ join point, it's safe to zap the
-            -- occ_join_points, because they can't occur in RHS.
+    rhs_env = mkRhsOccEnv env Recursive OccRhs (idJoinPointHood bndr) bndr rhs
+            -- If bndr isn't an /existing/ join point (so idJoinPointHood = NotJoinPoint),
+            -- it's safe to zap the occ_join_points, because they can't occur in RHS.
     WTUD (TUD rhs_ja unadj_rhs_uds) rhs' = occAnalLamTail rhs_env rhs
       -- The corresponding call to adjustTailUsage is in occAnalRec and tagRecBinders
 
@@ -2309,20 +2300,8 @@ occAnalRule env rule@(Rule { ru_bndrs = bndrs, ru_args = args, ru_rhs = rhs })
 
 occAnalRule _ other_rule = (other_rule, emptyDetails, TUD 0 emptyDetails)
 
-{- Note [Join point RHSs]
-~~~~~~~~~~~~~~~~~~~~~~~~~
-Consider
-   x = e
-   join j = Just x
-
-We want to inline x into j right away, so we don't want to give
-the join point a RhsCtxt (#14137).  It's not a huge deal, because
-the FloatIn pass knows to float into join point RHSs; and the simplifier
-does not float things out of join point RHSs.  But it's a simple, cheap
-thing to do.  See #14137.
-
-Note [Occurrences in stable unfoldings]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+{- Note [Occurrences in stable unfoldings]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 Consider
     f p = BIG
     {-# INLINE g #-}
@@ -2598,7 +2577,7 @@ occAnalArgs !env fun args !one_shots
             | otherwise
             = case one_shots of
                 []                -> (env_args, []) -- Fast path; one_shots is often empty
-                (os : one_shots') -> (addOneShots os env_args, one_shots')
+                (os : one_shots') -> (setOneShots os env_args, one_shots')
 
 {-
 Applications are dealt with specially because we want
@@ -2910,42 +2889,125 @@ setScrutCtxt !env alts
      -- non-default alternative.  That in turn influences
      -- pre/postInlineUnconditionally.  Grep for "occ_int_cxt"!
 
+{- Note [The OccEnv for a right hand side]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+How do we create the OccEnv for a RHS (in mkRhsOccEnv)?
+
+For a non-join point binding, x = rhs
+
+  * occ_encl: set to OccRhs; but see `mkNonRecRhsCtxt` for wrinkles
+
+  * occ_join_points: zap them!
+
+  * occ_one_shots: initialise from the idDemandInfo;
+    see Note [Sources of one-shot information]
+
+For a join point binding,  j x = rhs
+
+  * occ_encl: Consider
+       x = e
+       join j = Just x
+    We want to inline x into j right away, so we don't want to give the join point
+    a OccRhs (#14137); we want OccVanilla.  It's not a huge deal, because the
+    FloatIn pass knows to float into join point RHSs; and the simplifier does not
+    float things out of join point RHSs.  But it's a simple, cheap thing to do.
+
+  * occ_join_points: no need to zap.
+
+  * occ_one_shots: we start with one-shot-info from the context, which indeed
+    applies to the /body/ of the join point, after walking past the binders.
+    So we add to the front a OneShotInfo for each value-binder of the join
+    point: see `extendOneShotsForJoinPoint`. (Failing to account for the join-point
+    binders caused #25096.)
+
+    For the join point binders themselves, of a /non-recursive/ join point,
+    we make the binder a OneShotLam.  Again see `extendOneShotsForJoinPoint`.
+
+    These one-shot infos then get attached to the binder by `occAnalLamTail`.
+-}
+
 setNonTailCtxt :: OccEncl -> OccEnv -> OccEnv
 setNonTailCtxt ctxt !env
   = env { occ_encl        = ctxt
         , occ_one_shots   = []
-        , occ_join_points = zapped_jp_env }
-  where
-    -- zapped_jp_env is basically just emptyVarEnv (hence zapped).  See (W3) of
-    -- Note [Occurrence analysis for join points] Zapping improves efficiency,
-    -- slightly, if you accidentally introduce a bug, in which you zap [jx :-> uds] and
-    -- then find an occurrence of jx anyway, you might lose those uds, and
-    -- that might mean we don't record all occurrencs, and that means we
-    -- duplicate a redex....  a very nasty bug (which I encountered!).  Hence
-    -- this DEBUG code which doesn't remove jx from the envt; it just gives it
-    -- emptyDetails, which in turn causes a panic in mkOneOcc. That will catch
-    -- this bug before it does any damage.
-#ifdef DEBUG
-    zapped_jp_env = mapVarEnv (\ _ -> emptyVarEnv) (occ_join_points env)
-#else
-    zapped_jp_env = emptyVarEnv
-#endif
+        , occ_join_points = zapJoinPointInfo (occ_join_points env) }
 
 setTailCtxt :: OccEnv -> OccEnv
-setTailCtxt !env
-  = env { occ_encl = OccVanilla }
+setTailCtxt !env = env { occ_encl = OccVanilla }
     -- Preserve occ_one_shots, occ_join points
     -- Do not use OccRhs for the RHS of a join point (which is a tail ctxt):
-    --    see Note [Join point RHSs]
 
-addOneShots :: OneShots -> OccEnv -> OccEnv
-addOneShots os !env
+mkRhsOccEnv :: OccEnv -> RecFlag -> OccEncl -> JoinPointHood -> Id -> CoreExpr -> OccEnv
+-- See Note [The OccEnv for a right hand side]
+-- For a join point:
+--   - Keep occ_one_shots, occ_joinPoints from the context
+--   - But push enough OneShotInfo onto occ_one_shots to account
+--     for the join-point value binders
+--   - Set occ_encl to OccVanilla
+-- For non-join points
+--   - Zap occ_one_shots and occ_join_points
+--   - Set occ_encl to specified OccEncl
+mkRhsOccEnv env@(OccEnv { occ_one_shots = ctxt_one_shots, occ_join_points = ctxt_join_points })
+            is_rec encl jp_hood bndr rhs
+  | JoinPoint join_arity <- jp_hood
+  = env { occ_encl        = OccVanilla
+        , occ_one_shots   = extendOneShotsForJoinPoint is_rec join_arity rhs ctxt_one_shots
+        , occ_join_points = ctxt_join_points }
+
+  | otherwise
+  = env { occ_encl        = encl
+        , occ_one_shots   = argOneShots (idDemandInfo bndr)
+                            -- argOneShots: see Note [Sources of one-shot information]
+        , occ_join_points = zapJoinPointInfo ctxt_join_points }
+
+zapJoinPointInfo :: JoinPointInfo -> JoinPointInfo
+-- (zapJoinPointInfo jp_info) basically just returns emptyVarEnv (hence zapped).
+-- See (W3) of Note [Occurrence analysis for join points]
+--
+-- Zapping improves efficiency, slightly, if you accidentally introduce a bug,
+-- in which you zap [jx :-> uds] and then find an occurrence of jx anyway, you
+-- might lose those uds, and that might mean we don't record all occurrencs, and
+-- that means we duplicate a redex....  a very nasty bug (which I encountered!).
+-- Hence this DEBUG code which doesn't remove jx from the envt; it just gives it
+-- emptyDetails, which in turn causes a panic in mkOneOcc. That will catch this
+-- bug before it does any damage.
+#ifdef DEBUG
+zapJoinPointInfo jp_info = mapVarEnv (\ _ -> emptyVarEnv) jp_info
+#else
+zapJoinPointInfo jp_info = emptyVarEnv
+#endif
+
+extendOneShotsForJoinPoint
+  :: RecFlag -> JoinArity -> CoreExpr
+  -> [OneShotInfo] -> [OneShotInfo]
+-- Push enough OneShortInfos on the front of ctxt_one_shots
+-- to account for the value lambdas of the join point
+extendOneShotsForJoinPoint is_rec join_arity rhs ctxt_one_shots
+  = go join_arity rhs
+  where
+    -- For a /non-recursive/ join point we can mark all
+    -- its join-lambda as one-shot; and it's a good idea to do so
+    -- But not so for recursive ones
+    os = case is_rec of
+           NonRecursive -> OneShotLam
+           Recursive    -> NoOneShotInfo
+
+    go 0 _        = ctxt_one_shots
+    go n (Lam b rhs)
+      | isId b    = os : go (n-1) rhs
+      | otherwise =      go (n-1) rhs
+    go _ _        = []  -- Not enough lambdas.  This can legitimately happen.
+                        -- e.g.    let j = case ... in j True
+                        -- This will become an arity-1 join point after the
+                        -- simplifier has eta-expanded it; but it may not have
+                        -- enough lambdas /yet/. (Lint checks that JoinIds do
+                        -- have enough lambdas.)
+
+setOneShots :: OneShots -> OccEnv -> OccEnv
+setOneShots os !env
   | null os   = env  -- Fast path for common case
   | otherwise = env { occ_one_shots = os }
 
-addOneShotsFromDmd :: Id -> OccEnv -> OccEnv
-addOneShotsFromDmd bndr = addOneShots (argOneShots (idDemandInfo bndr))
-
 isRhsEnv :: OccEnv -> Bool
 isRhsEnv (OccEnv { occ_encl = cxt }) = case cxt of
                                           OccRhs -> True
@@ -3732,17 +3794,10 @@ adjustNonRecRhs :: JoinPointHood
                 -> WithUsageDetails CoreExpr
 -- ^ This function concentrates shared logic between occAnalNonRecBind and the
 -- AcyclicSCC case of occAnalRec.
---   * It applies 'markNonRecJoinOneShots' to the RHS
---   * and returns the adjusted rhs UsageDetails combined with the body usage
+-- It returns the adjusted rhs UsageDetails combined with the body usage
 adjustNonRecRhs mb_join_arity rhs_wuds@(WTUD _ rhs)
-  = WUD rhs_uds' rhs'
-  where
-    --------- Marking (non-rec) join binders one-shot ---------
-    !rhs' | JoinPoint ja <- mb_join_arity = markNonRecJoinOneShots ja rhs
-          | otherwise                     = rhs
+  = WUD (adjustTailUsage mb_join_arity rhs_wuds) rhs
 
-    --------- Adjusting right-hand side usage ---------
-    rhs_uds' = adjustTailUsage mb_join_arity rhs_wuds
 
 adjustTailUsage :: JoinPointHood
                 -> WithTailUsageDetails CoreExpr    -- Rhs usage, AFTER occAnalLamTail
@@ -3760,33 +3815,6 @@ adjustTailArity :: JoinPointHood -> TailUsageDetails -> UsageDetails
 adjustTailArity mb_rhs_ja (TUD ja usage)
   = markAllNonTailIf (mb_rhs_ja /= JoinPoint ja) usage
 
-markNonRecJoinOneShots :: JoinArity -> CoreExpr -> CoreExpr
--- For a /non-recursive/ join point we can mark all
--- its join-lambda as one-shot; and it's a good idea to do so
-markNonRecJoinOneShots join_arity rhs
-  = go join_arity rhs
-  where
-    go 0 rhs         = rhs
-    go n (Lam b rhs) = Lam (if isId b then setOneShotLambda b else b)
-                           (go (n-1) rhs)
-    go _ rhs         = rhs  -- Not enough lambdas.  This can legitimately happen.
-                            -- e.g.    let j = case ... in j True
-                            -- This will become an arity-1 join point after the
-                            -- simplifier has eta-expanded it; but it may not have
-                            -- enough lambdas /yet/. (Lint checks that JoinIds do
-                            -- have enough lambdas.)
-
-markNonRecUnfoldingOneShots :: JoinPointHood -> Unfolding -> Unfolding
--- ^ Apply 'markNonRecJoinOneShots' to a stable unfolding
-markNonRecUnfoldingOneShots mb_join_arity unf
-  | JoinPoint ja <- mb_join_arity
-  , CoreUnfolding{uf_src=src,uf_tmpl=tmpl} <- unf
-  , isStableSource src
-  , let !tmpl' = markNonRecJoinOneShots ja tmpl
-  = unf{uf_tmpl=tmpl'}
-  | otherwise
-  = unf
-
 type IdWithOccInfo = Id
 
 tagLamBinders :: UsageDetails        -- Of scope


=====================================
compiler/GHC/Utils/Outputable.hs
=====================================
@@ -1261,7 +1261,7 @@ data BindingSite
 
 data JoinPointHood
   = JoinPoint {-# UNPACK #-} !Int   -- The JoinArity (but an Int here because
-  | NotJoinPoint                    -- synonym JoinArity is defined in Types.Basic
+  | NotJoinPoint                    -- synonym JoinArity is defined in Types.Basic)
   deriving( Eq )
 
 isJoinPoint :: JoinPointHood -> Bool


=====================================
testsuite/tests/simplCore/should_run/T25096.hs
=====================================
@@ -0,0 +1,20 @@
+module Main where
+
+import System.IO.Unsafe
+import Control.Monad
+
+main :: IO ()
+main = do
+  foo "test" 10
+
+foo :: String -> Int -> IO ()
+foo x n = go n
+  where
+    oops = unsafePerformIO (putStrLn "Once" >> pure (cycle x))
+
+    go 0 = return ()
+    go n = do
+      -- `oops` should be shared between loop iterations
+      let p  = take n oops
+      let !_ = unsafePerformIO (putStrLn p >> pure ())
+      go (n-1)


=====================================
testsuite/tests/simplCore/should_run/T25096.stdout
=====================================
@@ -0,0 +1,11 @@
+Once
+testtestte
+testtestt
+testtest
+testtes
+testte
+testt
+test
+tes
+te
+t


=====================================
testsuite/tests/simplCore/should_run/all.T
=====================================
@@ -115,3 +115,4 @@ test('T23134', normal, compile_and_run, ['-O0 -fcatch-nonexhaustive-cases'])
 test('T23289', normal, compile_and_run, [''])
 test('T23056', [only_ways(['ghci-opt'])], ghci_script, ['T23056.script'])
 test('T24725', normal, compile_and_run, ['-O -dcore-lint'])
+test('T25096', normal, compile_and_run, ['-O -dcore-lint'])



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/58a09b76211ba5944222a91b6eb4cd1074951456

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/58a09b76211ba5944222a91b6eb4cd1074951456
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240721/857ec90f/attachment-0001.html>


More information about the ghc-commits mailing list