[Git][ghc/ghc][wip/T22404] Wibbles

Simon Peyton Jones (@simonpj) gitlab at gitlab.haskell.org
Thu Jul 13 23:14:08 UTC 2023



Simon Peyton Jones pushed to branch wip/T22404 at Glasgow Haskell Compiler / GHC


Commits:
f29e8062 by Simon Peyton Jones at 2023-07-14T00:13:55+01:00
Wibbles

- - - - -


1 changed file:

- compiler/GHC/Core/Opt/OccurAnal.hs


Changes:

=====================================
compiler/GHC/Core/Opt/OccurAnal.hs
=====================================
@@ -945,7 +945,7 @@ occAnalBind
 
 occAnalBind env lvl ire (Rec pairs) thing_inside combine
   = addInScope env (map fst pairs) $ \env ->
-    let WUD body_uds body' = thing_inside env
+    let WUD body_uds body'  = thing_inside env
         WUD bind_uds binds' = occAnalRecBind env lvl ire pairs body_uds
     in WUD bind_uds (combine binds' body')
 
@@ -958,17 +958,15 @@ occAnalBind !env lvl ire (NonRec bndr rhs) thing_inside combine
   -- Analyse the RHS and /then/ the body
   | NotTopLevel <- lvl
   , mb_join@(Just {}) <- isJoinId_maybe bndr
-  , not (isStableUnfolding (realIdUnfolding bndr))
-  , not (idHasRules bndr)
   = let -- Analyse the rhs first, generating rhs_uds
-        WUD rhs_uds rhs' = adjustNonRecRhs mb_join $
-                           occAnalLamTail (setTailCtxt env) rhs
+        (rhs_uds_s, bndr', rhs') = occAnalNonRecRhs env ire mb_join bndr rhs
+        rhs_uds = foldr1 orUDs rhs_uds_s   -- Note orUDs
 
         -- Now analyse the body, adding the join point
         -- into the environment with addJoinPoint
         WUD body_uds (tagged_bndr, body)
-           = occAnalNonRecBody env lvl bndr $ \env ->
-             thing_inside (addJoinPoint env bndr rhs_uds)
+           = occAnalNonRecBody env NotTopLevel bndr' $ \env ->
+             thing_inside (addJoinPoint env bndr' rhs_uds)
     in
     if isDeadBinder tagged_bndr     -- Drop dead code; see Note [Dead code]
     then WUD body_uds body
@@ -978,13 +976,18 @@ occAnalBind !env lvl ire (NonRec bndr rhs) thing_inside combine
   -- The normal case, including newly-discovered join points
   -- Analyse the body and /then/ the RHS
   | otherwise
-  = let WUD body_uds (tagged_bndr, body) = occAnalNonRecBody env lvl bndr thing_inside
-        WUD bind_uds binds               = occAnalNonRecRhs  env ire tagged_bndr rhs
-    in
-    if isDeadBinder tagged_bndr      -- Drop dead code; see Note [Dead code]
+  = let
+        WUD body_uds (tagged_bndr, body) = occAnalNonRecBody env lvl bndr thing_inside
+    in if isDeadBinder tagged_bndr      -- Drop dead code; see Note [Dead code]
     then WUD body_uds body
-    else WUD (bind_uds `andUDs` body_uds)      -- Note `andUDs`
-             (combine binds body)
+    else let
+        -- Get the join info from the *new* decision
+        -- See Note [Join points and unfoldings/rules]
+        -- => join arity O of Note [Join arity prediction based on joinRhsArity]
+        mb_join = willBeJoinId_maybe tagged_bndr
+        (rhs_uds_s, final_bndr, rhs') = occAnalNonRecRhs env ire mb_join tagged_bndr rhs
+    in WUD (foldr andUDs body_uds rhs_uds_s)      -- Note `andUDs`
+           (combine [NonRec final_bndr rhs'] body)
 
 -----------------
 occAnalNonRecBody :: OccEnv -> TopLevelFlag -> Id
@@ -997,17 +1000,14 @@ occAnalNonRecBody env lvl bndr thing_inside
     in WUD inner_uds (tagged_bndr, res)
 
 -----------------
-occAnalNonRecRhs :: OccEnv -> ImpRuleEdges -> Id -> CoreExpr
-                 -> WithUsageDetails [CoreBind]
-occAnalNonRecRhs !env imp_rule_edges tagged_bndr rhs
-  = WUD (adj_rhs_uds `andUDs` adj_unf_uds `andUDs` adj_rule_uds)
-        [NonRec final_bndr final_rhs]
+occAnalNonRecRhs :: OccEnv -> ImpRuleEdges -> Maybe JoinArity
+                 -> Id -> CoreExpr
+                 -> ([UsageDetails], Id, CoreExpr)
+occAnalNonRecRhs !env imp_rule_edges mb_join bndr rhs
+  = (adj_rhs_uds : adj_unf_uds : adj_rule_uds,
+     final_bndr, final_rhs )
   where
-    -- Get the join info from the *new* decision
-    -- See Note [Join points and unfoldings/rules]
-    -- => join arity O of Note [Join arity prediction based on joinRhsArity]
-    mb_join_arity = willBeJoinId_maybe tagged_bndr
-    is_join_point = isJust mb_join_arity
+    is_join_point = isJust mb_join
 
     --------- Right hand side ---------
     env1 | is_join_point = setTailCtxt env
@@ -1021,24 +1021,26 @@ occAnalNonRecRhs !env imp_rule_edges tagged_bndr rhs
     -- Match join arity O from mb_join_arity with manifest join arity M as
     -- returned by of occAnalLamTail. It's totally OK for them to mismatch;
     -- hence adjust the UDs from the RHS
-    WUD adj_rhs_uds final_rhs = adjustNonRecRhs mb_join_arity $
+    WUD adj_rhs_uds final_rhs = adjustNonRecRhs mb_join $
                                 occAnalLamTail rhs_env rhs
-    final_bndr = tagged_bndr `setIdSpecialisation` mkRuleInfo rules'
-                             `setIdUnfolding` unf2
+    final_bndr = bndr `setIdSpecialisation` mkRuleInfo rules'
+                      `setIdUnfolding` unf2
 
     --------- Unfolding ---------
     -- See Note [Join points and unfoldings/rules]
-    unf = idUnfolding tagged_bndr
+    unf = idUnfolding bndr
     WTUD unf_tuds unf1 = occAnalUnfolding rhs_env unf
-    unf2 = markNonRecUnfoldingOneShots mb_join_arity unf1
-    adj_unf_uds = adjustTailArity mb_join_arity unf_tuds
+    unf2 = markNonRecUnfoldingOneShots mb_join unf1
+    adj_unf_uds = adjustTailArity mb_join unf_tuds
 
     --------- Rules ---------
     -- See Note [Rules are extra RHSs] and Note [Rule dependency info]
     -- and Note [Join points and unfoldings/rules]
-    rules_w_uds  = occAnalRules rhs_env tagged_bndr
+    rules_w_uds  = occAnalRules rhs_env bndr
     rules'       = map fstOf3 rules_w_uds
-    imp_rule_uds = impRulesScopeUsage (lookupImpRules imp_rule_edges tagged_bndr)
+    imp_rule_infos = lookupImpRules imp_rule_edges bndr
+    imp_rule_uds | null imp_rule_infos = []  -- Very common case
+                 | otherwise           = [impRulesScopeUsage imp_rule_infos]
          -- imp_rule_uds: consider
          --     h = ...
          --     g = ...
@@ -1048,19 +1050,19 @@ occAnalNonRecRhs !env imp_rule_edges tagged_bndr rhs
          -- we make g mention h.
 
     adj_rule_uds = foldr add_rule_uds imp_rule_uds rules_w_uds
-    add_rule_uds (_, l, r) uds
-      = l `andUDs` adjustTailArity mb_join_arity r `andUDs` uds
+    add_rule_uds (_, l, r) uds_s
+      = (l `andUDs` adjustTailArity mb_join r) : uds_s
 
     ----------
-    occ = idOccInfo tagged_bndr
     certainly_inline -- See Note [Cascading inlines]
-      = case occ of
+      = -- certainly_inline is only used for non-join points,so idOccInfo is valid
+        case idOccInfo bndr of
           OneOcc { occ_in_lam = NotInsideLam, occ_n_br = 1 }
             -> active && not_stable
           _ -> False
 
-    dmd        = idDemandInfo tagged_bndr
-    active     = isAlwaysActive (idInlineActivation tagged_bndr)
+    dmd        = idDemandInfo bndr
+    active     = isAlwaysActive (idInlineActivation bndr)
     not_stable = not (isStableUnfolding unf)
 
 -----------------
@@ -1128,7 +1130,7 @@ occAnalRec env lvl (CyclicSCC details_s) (WUD body_uds binds)
     -- See Note [Choosing loop breakers] for loop_breaker_nodes
     final_uds :: UsageDetails
     loop_breaker_nodes :: [LoopBreakerNode]
-    (WUD final_uds loop_breaker_nodes) = mkLoopBreakerNodes env lvl body_uds details_s
+    WUD final_uds loop_breaker_nodes = mkLoopBreakerNodes env lvl body_uds details_s
 
     ------------------------------
     weak_fvs :: VarSet
@@ -2946,7 +2948,7 @@ mkZeroedForm (UD { ud_env = rhs_occs })
   = emptyDetails { ud_env = mapMaybeUFM do_one rhs_occs }
   where
     do_one :: LocalOcc -> Maybe LocalOcc
-    do_one ManyOccL         = Nothing
+    do_one (ManyOccL {})    = Nothing
     do_one occ@(OneOccL {}) = Just (occ { lo_n_br = 0 })
 
 --------------------
@@ -3377,7 +3379,16 @@ data LocalOcc
                           -- Combining (AlwaysTailCalled 2) and (AlwaysTailCalled 3)
                           -- gives NoTailCallInfo
               , lo_int_cxt :: !InterestingCxt }
-    | ManyOccL
+    | ManyOccL !TailCallInfo
+
+instance Outputable LocalOcc where
+  ppr (OneOccL { lo_n_br = n, lo_tail = tci })
+    = text "OneOccL" <> braces (ppr n <> comma <> ppr tci)
+  ppr (ManyOccL tci) = text "ManyOccL" <> braces (ppr tci)
+
+localTailCallInfo :: LocalOcc -> TailCallInfo
+localTailCallInfo (OneOccL  { lo_tail = tci }) = tci
+localTailCallInfo (ManyOccL tci)               = tci
 
 type ZappedSet = OccInfoEnv -- Values are ignored
 
@@ -3393,6 +3404,7 @@ instance Outputable UsageDetails where
   ppr ud = text "UD" <+> (braces $ fsep $ punctuate comma $
            [ ppr uq <+> text ":->" <+> ppr (mkOccInfoByUnique ud uq local_occ)
            | (uq, local_occ) <- nonDetStrictFoldVarEnv_Directly do_one [] (ud_env ud) ])
+           $$ nest 2 (text "ud_z_tail" <+> ppr (ud_z_tail ud))
     where
       do_one :: Unique -> LocalOcc -> [(Unique,LocalOcc)] -> [(Unique,LocalOcc)]
       do_one uniq occ occs = (uniq, occ) : occs
@@ -3442,7 +3454,8 @@ mkOneOcc !env id int_cxt arity
 
 addManyOccId :: UsageDetails -> Id -> UsageDetails
 -- Add the non-committal (id :-> noOccInfo) to the usage details
-addManyOccId ud id = ud { ud_env = extendVarEnv (ud_env ud) id ManyOccL }
+addManyOccId ud id = ud { ud_env = extendVarEnv (ud_env ud) id
+                                       (ManyOccL NoTailCallInfo) }
 
 -- Add several occurrences, assumed not to be tail calls
 addManyOcc :: Var -> UsageDetails -> UsageDetails
@@ -3508,8 +3521,8 @@ lookupLocalDetails uds id = lookupVarEnv (ud_env uds) id
 lookupTailCallInfo :: UsageDetails -> Id -> TailCallInfo
 lookupTailCallInfo uds id
   | not (id `elemVarEnv` ud_z_tail uds)
-  , Just (OneOccL { lo_tail = tail_info }) <- lookupLocalDetails uds id
-  = tail_info
+  , Just occ <- lookupLocalDetails uds id
+  = localTailCallInfo occ
   | otherwise
   = NoTailCallInfo
 
@@ -3555,21 +3568,23 @@ mkOccInfoByUnique (UD { ud_z_many    = z_many
       OneOccL { lo_n_br = n_br, lo_int_cxt = int_cxt
               , lo_tail = tail_info }
           | uniq `elemVarEnvByKey`z_many
-          -> ManyOccs { occ_tail = tail_info' }  -- Hack alert
+          -> ManyOccs { occ_tail = mk_tail_info tail_info }
           | otherwise
           -> OneOcc { occ_in_lam  = in_lam
                     , occ_n_br    = n_br
                     , occ_int_cxt = int_cxt
-                    , occ_tail    = tail_info' }
+                    , occ_tail    = mk_tail_info tail_info }
          where
-           tail_info' | uniq `elemVarEnvByKey` z_tail = NoTailCallInfo
-                      | otherwise                     = tail_info
-
            in_lam | uniq `elemVarEnvByKey` z_in_lam = IsInsideLam
                   | otherwise                       = NotInsideLam
 
-      ManyOccL -> ManyOccs { occ_tail = NoTailCallInfo }
-                  -- I think this is redundant; remove from ManyOccs
+      ManyOccL tail_info -> ManyOccs { occ_tail = mk_tail_info tail_info }
+  where
+    mk_tail_info ti
+        | uniq `elemVarEnvByKey` z_tail = NoTailCallInfo
+        | otherwise                     = ti
+
+
 
 -------------------
 -- See Note [Adjusting right-hand sides]
@@ -3755,7 +3770,8 @@ decideJoinPointHood TopLevel _ _
 decideJoinPointHood NotTopLevel usage bndrs
   | isJoinId bndr1
   = warnPprTrace lost_join_point
-                 "OccurAnal failed to rediscover join point(s)" (ppr bndrs)
+                 "OccurAnal failed to rediscover join point(s)"
+                 lost_join_doc
     all_ok
 --   = assertPpr (not lost_join_point) (ppr bndrs)
 --    True
@@ -3764,21 +3780,6 @@ decideJoinPointHood NotTopLevel usage bndrs
   = all_ok
   where
     bndr1 = NE.head bndrs
-    lost_join_point
-      | isNothing (lookupLocalDetails usage bndr1) = False  -- Dead
-      | all_ok                                     = False
-      | otherwise
-      = pprTrace "djph"
-          (let arity = case lookupTailCallInfo usage bndr1 of
-                         AlwaysTailCalled ar -> ar
-                         NoTailCallInfo -> 0
-           in vcat [ text "bndr1:" <+> ppr bndr1
-                , text "occ:" <+> ppr (lookupDetails usage bndr1)
-                , text "arity:" <+> ppr arity
-                , text "rules:" <+> ppr (idCoreRules bndr1)
-                , text "ok_unf:" <+> ppr (ok_unfolding arity (realIdUnfolding bndr1))
-                , text "ok_type:" <+> ppr (isValidJoinPointType arity (idType bndr1)) ]) $
-        True
 
     -- See Note [Invariants on join points]; invariants cited by number below.
     -- Invariant 2 is always satisfiable by the simplifier by eta expansion.
@@ -3817,6 +3818,24 @@ decideJoinPointHood NotTopLevel usage bndrs
     ok_unfolding _ _
       = True
 
+    lost_join_point :: Bool
+    lost_join_point
+      | isNothing (lookupLocalDetails usage bndr1) = False  -- Dead
+      | all_ok                                     = False
+      | otherwise                                  = True
+
+    lost_join_doc
+      = vcat [ text "bndrs:" <+> ppr bndrs
+             , text "occ:" <+> ppr (lookupDetails usage bndr1)
+             , text "arity:" <+> ppr arity
+             , text "rules:" <+> ppr (idCoreRules bndr1)
+             , text "ok_unf:" <+> ppr (ok_unfolding arity (realIdUnfolding bndr1))
+             , text "ok_type:" <+> ppr (isValidJoinPointType arity (idType bndr1)) ]
+      where
+        arity = case lookupTailCallInfo usage bndr1 of
+                         AlwaysTailCalled ar -> ar
+                         NoTailCallInfo -> 0
+
 willBeJoinId_maybe :: CoreBndr -> Maybe JoinArity
 willBeJoinId_maybe bndr
   | isId bndr
@@ -3867,28 +3886,21 @@ markNonTail :: OccInfo -> OccInfo
 markNonTail IAmDead = IAmDead
 markNonTail occ     = occ { occ_tail = NoTailCallInfo }
 
-andOccInfo, orOccInfo :: LocalOcc -> LocalOcc -> LocalOcc
-
-andOccInfo (OneOccL { lo_n_br = nbr1, lo_int_cxt = int_cxt1, lo_tail = ar1 })
-           (OneOccL { lo_n_br = nbr2, lo_int_cxt = int_cxt2, lo_tail = ar2 })
-  | AlwaysTailCalled n1 <- ar1
-  , AlwaysTailCalled n2 <- ar2
-  , n1 == n2
-  = -- Hack alert
-    OneOccL { lo_n_br    = nbr1 + nbr2
-            , lo_int_cxt = int_cxt1 `mappend` int_cxt2
-            , lo_tail    = AlwaysTailCalled n1 }
-andOccInfo _ _ = ManyOccL
+andOccInfo :: LocalOcc -> LocalOcc -> LocalOcc
+andOccInfo occ1 occ2 = ManyOccL (tci1 `andTailCallInfo` tci2)
+  where
+    !tci1 = localTailCallInfo occ1
+    !tci2 = localTailCallInfo occ2
 
--- (orOccInfo orig new) is used
+orOccInfo :: LocalOcc -> LocalOcc -> LocalOcc
+-- (orOccInfo occ1 occ2) is used
 -- when combining occurrence info from branches of a case
-
-orOccInfo (OneOccL { lo_n_br = nbr1, lo_int_cxt = int_cxt1, lo_tail = ar1 })
-          (OneOccL { lo_n_br = nbr2, lo_int_cxt = int_cxt2, lo_tail = ar2 })
+orOccInfo (OneOccL { lo_n_br = nbr1, lo_int_cxt = int_cxt1, lo_tail = tci1 })
+          (OneOccL { lo_n_br = nbr2, lo_int_cxt = int_cxt2, lo_tail = tci2 })
   = OneOccL { lo_n_br    = nbr1 + nbr2
             , lo_int_cxt = int_cxt1 `mappend` int_cxt2
-            , lo_tail    = ar1 `andTailCallInfo` ar2 }
-orOccInfo _ _  = ManyOccL
+            , lo_tail    = tci1 `andTailCallInfo` tci2 }
+orOccInfo occ1 occ2 = andOccInfo occ1 occ2
 
 andTailCallInfo :: TailCallInfo -> TailCallInfo -> TailCallInfo
 andTailCallInfo info@(AlwaysTailCalled arity1) (AlwaysTailCalled arity2)



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/f29e8062316124752f47a22e500c1ffa35e9610f

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/f29e8062316124752f47a22e500c1ffa35e9610f
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20230713/1174e3b6/attachment-0001.html>


More information about the ghc-commits mailing list