[Git][ghc/ghc][wip/andreask/infer_exprs] 2 commits: Tag inference: Fix #21954 by retaining tagsigs of vars in function position.

Andreas Klebinger (@AndreasK) gitlab at gitlab.haskell.org
Sat Aug 13 11:02:21 UTC 2022



Andreas Klebinger pushed to branch wip/andreask/infer_exprs at Glasgow Haskell Compiler / GHC


Commits:
18eaf69f by Andreas Klebinger at 2022-08-13T13:01:46+02:00
Tag inference: Fix #21954 by retaining tagsigs of vars in function position.

For an expression like:

    case x of y
      Con z -> z

If we also retain the tag sig for z we can generate code to immediately return
it rather than calling out to stg_ap_0_fast.

- - - - -
ac810f99 by Andreas Klebinger at 2022-08-13T13:01:53+02:00
Stg.InferTags.Rewrite - Avoid some thunks.

- - - - -


4 changed files:

- compiler/GHC/Stg/InferTags/Rewrite.hs
- testsuite/tests/simplStg/should_compile/all.T
- + testsuite/tests/simplStg/should_compile/inferTags002.hs
- + testsuite/tests/simplStg/should_compile/inferTags002.stderr


Changes:

=====================================
compiler/GHC/Stg/InferTags/Rewrite.hs
=====================================
@@ -128,7 +128,7 @@ getMap :: RM (UniqFM Id TagSig)
 getMap = RM $ ((\(fst,_,_,_) -> fst) <$> get)
 
 setMap :: (UniqFM Id TagSig) -> RM ()
-setMap m = RM $ do
+setMap !m = RM $ do
     (_,us,mod,lcls) <- get
     put (m, us,mod,lcls)
 
@@ -139,7 +139,7 @@ getFVs :: RM IdSet
 getFVs = RM $ ((\(_,_,_,lcls) -> lcls) <$> get)
 
 setFVs :: IdSet -> RM ()
-setFVs fvs = RM $ do
+setFVs !fvs = RM $ do
     (tag_map,us,mod,_lcls) <- get
     put (tag_map, us,mod,fvs)
 
@@ -195,9 +195,9 @@ withBinders NotTopLevel sigs cont = do
 withClosureLcls :: DIdSet -> RM a -> RM a
 withClosureLcls fvs act = do
     old_fvs <- getFVs
-    let fvs' = nonDetStrictFoldDVarSet (flip extendVarSet) old_fvs fvs
+    let !fvs' = nonDetStrictFoldDVarSet (flip extendVarSet) old_fvs fvs
     setFVs fvs'
-    r <- act
+    !r <- act
     setFVs old_fvs
     return r
 
@@ -206,9 +206,9 @@ withClosureLcls fvs act = do
 withLcl :: Id -> RM a -> RM a
 withLcl fv act = do
     old_fvs <- getFVs
-    let fvs' = extendVarSet old_fvs fv
+    let !fvs' = extendVarSet old_fvs fv
     setFVs fvs'
-    r <- act
+    !r <- act
     setFVs old_fvs
     return r
 
@@ -222,7 +222,7 @@ isTagged v = do
             | otherwise -> do -- Local binding
                 !s <- getMap
                 let !sig = lookupWithDefaultUFM s (pprPanic "unknown Id:" (ppr v)) v
-                return $ case sig of
+                return $! case sig of
                     TagSig info ->
                         case info of
                             TagDunno -> False
@@ -234,7 +234,7 @@ isTagged v = do
             , isNullaryRepDataCon con
             -> return True
             | Just lf_info <- idLFInfo_maybe v
-            -> return $
+            -> return $!
                 -- Can we treat the thing as tagged based on it's LFInfo?
                 case lf_info of
                     -- Function, applied not entered.
@@ -336,7 +336,7 @@ rewriteRhs (_id, _tagSig) (StgRhsCon ccs con cn ticks args) = {-# SCC rewriteRhs
 rewriteRhs _binding (StgRhsClosure fvs ccs flag args body) = do
     withBinders NotTopLevel args $
         withClosureLcls fvs $
-            StgRhsClosure fvs ccs flag (map fst args) <$> rewriteExpr False body
+            StgRhsClosure fvs ccs flag (map fst args) <$> rewriteExpr body
         -- return (closure)
 
 fvArgs :: [StgArg] -> RM DVarSet
@@ -345,40 +345,36 @@ fvArgs args = do
     -- pprTraceM "fvArgs" (text "args:" <> ppr args $$ text "lcls:" <> pprVarSet (fv_lcls) (braces . fsep . map ppr) )
     return $ mkDVarSet [ v | StgVarArg v <- args, elemVarSet v fv_lcls]
 
-type IsScrut = Bool
-
 rewriteArgs :: [StgArg] -> RM [StgArg]
 rewriteArgs = mapM rewriteArg
 rewriteArg :: StgArg -> RM StgArg
 rewriteArg (StgVarArg v) = StgVarArg <$!> rewriteId v
 rewriteArg  (lit at StgLitArg{}) = return lit
 
--- Attach a tagSig if it's tagged
 rewriteId :: Id -> RM Id
 rewriteId v = do
-    is_tagged <- isTagged v
+    !is_tagged <- isTagged v
     if is_tagged then return $! setIdTagSig v (TagSig TagProper)
                  else return v
 
-rewriteExpr :: IsScrut -> InferStgExpr -> RM TgStgExpr
-rewriteExpr _ (e at StgCase {})          = rewriteCase e
-rewriteExpr _ (e at StgLet {})           = rewriteLet e
-rewriteExpr _ (e at StgLetNoEscape {})   = rewriteLetNoEscape e
-rewriteExpr isScrut (StgTick t e)     = StgTick t <$!> rewriteExpr isScrut e
-rewriteExpr _ e@(StgConApp {})        = rewriteConApp e
-
-rewriteExpr isScrut e@(StgApp {})     = rewriteApp isScrut e
-rewriteExpr _ (StgLit lit)           = return $! (StgLit lit)
-rewriteExpr _ (StgOpApp op@(StgPrimOp DataToTagOp)  args res_ty) = do
+rewriteExpr :: InferStgExpr -> RM TgStgExpr
+rewriteExpr (e at StgCase {})          = rewriteCase e
+rewriteExpr (e at StgLet {})           = rewriteLet e
+rewriteExpr (e at StgLetNoEscape {})   = rewriteLetNoEscape e
+rewriteExpr (StgTick t e)     = StgTick t <$!> rewriteExpr e
+rewriteExpr e@(StgConApp {})        = rewriteConApp e
+rewriteExpr e@(StgApp {})     = rewriteApp e
+rewriteExpr (StgLit lit)           = return $! (StgLit lit)
+rewriteExpr (StgOpApp op@(StgPrimOp DataToTagOp) args res_ty) = do
         (StgOpApp op) <$!> rewriteArgs args <*> pure res_ty
-rewriteExpr _ (StgOpApp op args res_ty) = return $! (StgOpApp op args res_ty)
+rewriteExpr (StgOpApp op args res_ty) = return $! (StgOpApp op args res_ty)
 
 
 rewriteCase :: InferStgExpr -> RM TgStgExpr
 rewriteCase (StgCase scrut bndr alt_type alts) =
     withBinder NotTopLevel bndr $
         pure StgCase <*>
-            rewriteExpr True scrut <*>
+            rewriteExpr scrut <*>
             pure (fst bndr) <*>
             pure alt_type <*>
             mapM rewriteAlt alts
@@ -388,7 +384,7 @@ rewriteCase _ = panic "Impossible: nodeCase"
 rewriteAlt :: InferStgAlt -> RM TgStgAlt
 rewriteAlt alt at GenStgAlt{alt_con=_, alt_bndrs=bndrs, alt_rhs=rhs} =
     withBinders NotTopLevel bndrs $ do
-        !rhs' <- rewriteExpr False rhs
+        !rhs' <- rewriteExpr rhs
         return $! alt {alt_bndrs = map fst bndrs, alt_rhs = rhs'}
 
 rewriteLet :: InferStgExpr -> RM TgStgExpr
@@ -396,7 +392,7 @@ rewriteLet (StgLet xt bind expr) = do
     (!bind') <- rewriteBinds NotTopLevel bind
     withBind NotTopLevel bind $ do
         -- pprTraceM "withBindLet" (ppr $ bindersOfX bind)
-        !expr' <- rewriteExpr False expr
+        !expr' <- rewriteExpr expr
         return $! (StgLet xt bind' expr')
 rewriteLet _ = panic "Impossible"
 
@@ -404,7 +400,7 @@ rewriteLetNoEscape :: InferStgExpr -> RM TgStgExpr
 rewriteLetNoEscape (StgLetNoEscape xt bind expr) = do
     (!bind') <- rewriteBinds NotTopLevel bind
     withBind NotTopLevel bind $ do
-        !expr' <- rewriteExpr False expr
+        !expr' <- rewriteExpr expr
         return $! (StgLetNoEscape xt bind' expr')
 rewriteLetNoEscape _ = panic "Impossible"
 
@@ -424,19 +420,12 @@ rewriteConApp (StgConApp con cn args tys) = do
 
 rewriteConApp _ = panic "Impossible"
 
--- Special case: Expressions like `case x of { ... }`
-rewriteApp :: IsScrut -> InferStgExpr -> RM TgStgExpr
-rewriteApp True (StgApp f []) = do
-    -- pprTraceM "rewriteAppScrut" (ppr f)
-    f_tagged <- isTagged f
-    -- isTagged looks at more than the result of our analysis.
-    -- So always update here if useful.
-    let f' = if f_tagged
-                -- TODO: We might consisder using a subst env instead of setting the sig only for select places.
-                then setIdTagSig f (TagSig TagProper)
-                else f
+-- Special case: Atomic binders, usually in a case context like `case f of ...`.
+rewriteApp :: InferStgExpr -> RM TgStgExpr
+rewriteApp (StgApp f []) = do
+    f' <- rewriteId f
     return $! StgApp f' []
-rewriteApp _ (StgApp f args)
+rewriteApp (StgApp f args)
     -- pprTrace "rewriteAppOther" (ppr f <+> ppr args) False
     -- = undefined
     | Just marks <- idCbvMarks_maybe f
@@ -457,8 +446,8 @@ rewriteApp _ (StgApp f args)
             cbvArgIds = [x | StgVarArg x <- map fstOf3 cbvArgInfo] :: [Id]
         mkSeqs args cbvArgIds (\cbv_args -> StgApp f cbv_args)
 
-rewriteApp _ (StgApp f args) = return $ StgApp f args
-rewriteApp _ _ = panic "Impossible"
+rewriteApp (StgApp f args) = return $ StgApp f args
+rewriteApp _ = panic "Impossible"
 
 -- `mkSeq` x x' e generates `case x of x' -> e`
 -- We could also substitute x' for x in e but that's so rarely beneficial


=====================================
testsuite/tests/simplStg/should_compile/all.T
=====================================
@@ -11,3 +11,4 @@ setTestOpts(f)
 
 test('T13588', [ grep_errmsg('case') ] , compile, ['-dverbose-stg2stg -fno-worker-wrapper'])
 test('T19717', normal, compile, ['-ddump-stg-final -dsuppress-uniques -dno-typeable-binds'])
+test('inferTags002', [ only_ways(['optasm']), grep_errmsg('(call stg\_ap\_0)', [1])], compile, ['-ddump-cmm -dsuppress-uniques -dno-typeable-binds -O'])


=====================================
testsuite/tests/simplStg/should_compile/inferTags002.hs
=====================================
@@ -0,0 +1,7 @@
+module M where
+
+data T a = MkT !Bool !a
+
+-- The rhs of the case alternative should not result in a call std_ap_0_fast.
+f x = case x of
+    MkT y z -> z


=====================================
testsuite/tests/simplStg/should_compile/inferTags002.stderr
=====================================
@@ -0,0 +1,171 @@
+
+==================== Output Cmm ====================
+[M.$WMkT_entry() { //  [R3, R2]
+         { info_tbls: [(cym,
+                        label: block_cym_info
+                        rep: StackRep [False]
+                        srt: Nothing),
+                       (cyp,
+                        label: M.$WMkT_info
+                        rep: HeapRep static { Fun {arity: 2 fun_type: ArgSpec 15} }
+                        srt: Nothing),
+                       (cys,
+                        label: block_cys_info
+                        rep: StackRep [False]
+                        srt: Nothing)]
+           stack_info: arg_space: 8
+         }
+     {offset
+       cyp: // global
+           if ((Sp + -16) < SpLim) (likely: False) goto cyv; else goto cyw;
+       cyv: // global
+           R1 = M.$WMkT_closure;
+           call (stg_gc_fun)(R3, R2, R1) args: 8, res: 0, upd: 8;
+       cyw: // global
+           I64[Sp - 16] = cym;
+           R1 = R2;
+           P64[Sp - 8] = R3;
+           Sp = Sp - 16;
+           if (R1 & 7 != 0) goto cym; else goto cyn;
+       cyn: // global
+           call (I64[R1])(R1) returns to cym, args: 8, res: 8, upd: 8;
+       cym: // global
+           I64[Sp] = cys;
+           _sy8::P64 = R1;
+           R1 = P64[Sp + 8];
+           P64[Sp + 8] = _sy8::P64;
+           call stg_ap_0_fast(R1) returns to cys, args: 8, res: 8, upd: 8;
+       cys: // global
+           Hp = Hp + 24;
+           if (Hp > HpLim) (likely: False) goto cyA; else goto cyz;
+       cyA: // global
+           HpAlloc = 24;
+           call stg_gc_unpt_r1(R1) returns to cys, args: 8, res: 8, upd: 8;
+       cyz: // global
+           I64[Hp - 16] = M.MkT_con_info;
+           P64[Hp - 8] = P64[Sp + 8];
+           P64[Hp] = R1;
+           R1 = Hp - 15;
+           Sp = Sp + 16;
+           call (P64[Sp])(R1) args: 8, res: 0, upd: 8;
+     }
+ },
+ section ""data" . M.$WMkT_closure" {
+     M.$WMkT_closure:
+         const M.$WMkT_info;
+ }]
+
+
+
+==================== Output Cmm ====================
+[M.f_entry() { //  [R2]
+         { info_tbls: [(cyK,
+                        label: block_cyK_info
+                        rep: StackRep []
+                        srt: Nothing),
+                       (cyN,
+                        label: M.f_info
+                        rep: HeapRep static { Fun {arity: 1 fun_type: ArgSpec 5} }
+                        srt: Nothing)]
+           stack_info: arg_space: 8
+         }
+     {offset
+       cyN: // global
+           if ((Sp + -8) < SpLim) (likely: False) goto cyO; else goto cyP;
+       cyO: // global
+           R1 = M.f_closure;
+           call (stg_gc_fun)(R2, R1) args: 8, res: 0, upd: 8;
+       cyP: // global
+           I64[Sp - 8] = cyK;
+           R1 = R2;
+           Sp = Sp - 8;
+           if (R1 & 7 != 0) goto cyK; else goto cyL;
+       cyL: // global
+           call (I64[R1])(R1) returns to cyK, args: 8, res: 8, upd: 8;
+       cyK: // global
+           R1 = P64[R1 + 15];
+           Sp = Sp + 8;
+           call (P64[Sp])(R1) args: 8, res: 0, upd: 8;
+     }
+ },
+ section ""data" . M.f_closure" {
+     M.f_closure:
+         const M.f_info;
+ }]
+
+
+
+==================== Output Cmm ====================
+[M.MkT_entry() { //  [R3, R2]
+         { info_tbls: [(cz1,
+                        label: block_cz1_info
+                        rep: StackRep [False]
+                        srt: Nothing),
+                       (cz4,
+                        label: M.MkT_info
+                        rep: HeapRep static { Fun {arity: 2 fun_type: ArgSpec 15} }
+                        srt: Nothing),
+                       (cz7,
+                        label: block_cz7_info
+                        rep: StackRep [False]
+                        srt: Nothing)]
+           stack_info: arg_space: 8
+         }
+     {offset
+       cz4: // global
+           if ((Sp + -16) < SpLim) (likely: False) goto cza; else goto czb;
+       cza: // global
+           R1 = M.MkT_closure;
+           call (stg_gc_fun)(R3, R2, R1) args: 8, res: 0, upd: 8;
+       czb: // global
+           I64[Sp - 16] = cz1;
+           R1 = R2;
+           P64[Sp - 8] = R3;
+           Sp = Sp - 16;
+           if (R1 & 7 != 0) goto cz1; else goto cz2;
+       cz2: // global
+           call (I64[R1])(R1) returns to cz1, args: 8, res: 8, upd: 8;
+       cz1: // global
+           I64[Sp] = cz7;
+           _tyf::P64 = R1;
+           R1 = P64[Sp + 8];
+           P64[Sp + 8] = _tyf::P64;
+           call stg_ap_0_fast(R1) returns to cz7, args: 8, res: 8, upd: 8;
+       cz7: // global
+           Hp = Hp + 24;
+           if (Hp > HpLim) (likely: False) goto czf; else goto cze;
+       czf: // global
+           HpAlloc = 24;
+           call stg_gc_unpt_r1(R1) returns to cz7, args: 8, res: 8, upd: 8;
+       cze: // global
+           I64[Hp - 16] = M.MkT_con_info;
+           P64[Hp - 8] = P64[Sp + 8];
+           P64[Hp] = R1;
+           R1 = Hp - 15;
+           Sp = Sp + 16;
+           call (P64[Sp])(R1) args: 8, res: 0, upd: 8;
+     }
+ },
+ section ""data" . M.MkT_closure" {
+     M.MkT_closure:
+         const M.MkT_info;
+ }]
+
+
+
+==================== Output Cmm ====================
+[M.MkT_con_entry() { //  []
+         { info_tbls: [(czl,
+                        label: M.MkT_con_info
+                        rep: HeapRep 2 ptrs { Con {tag: 0 descr:"main:M.MkT"} }
+                        srt: Nothing)]
+           stack_info: arg_space: 8
+         }
+     {offset
+       czl: // global
+           R1 = R1 + 1;
+           call (P64[Sp])(R1) args: 8, res: 0, upd: 8;
+     }
+ }]
+
+



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/d04a6a5586785f52ed48fd5545e553e008f2165c...ac810f994792c79acef386136aafe4c3a0f1e1a1

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/d04a6a5586785f52ed48fd5545e553e008f2165c...ac810f994792c79acef386136aafe4c3a0f1e1a1
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20220813/795f8148/attachment-0001.html>


More information about the ghc-commits mailing list