[Git][ghc/ghc][wip/T17521] Try using PUSH_TAGGED in interpreter

Jaro Reinders (@Noughtmare) gitlab at gitlab.haskell.org
Mon Aug 21 16:15:16 UTC 2023

Jaro Reinders pushed to branch wip/T17521 at Glasgow Haskell Compiler / GHC

b655a3ec by Jaro Reinders at 2023-08-21T18:15:08+02:00
Try using PUSH_TAGGED in interpreter

- - - - -

6 changed files:

- compiler/GHC/ByteCode/Asm.hs
- compiler/GHC/ByteCode/Instr.hs
- compiler/GHC/StgToByteCode.hs
- rts/Disassembler.c
- rts/Interpreter.c
- rts/include/rts/Bytecodes.h


@@ -406,6 +406,9 @@ assembleI platform i = case i of
   PUSH_BCO proto           -> do let ul_bco = assembleBCO platform proto
                                  p <- ioptr (liftM BCOPtrBCO ul_bco)
                                  emit bci_PUSH_G [Op p]
+  PUSH_TAGGED nm dcon      -> do p <- ptr (BCOPtrName nm)
+                                 itbl_no <- lit [BCONPtrItbl (getName dcon)]
+                                 emit bci_PUSH_TAGGED [Op p, Op itbl_no]
   PUSH_ALTS proto pk
                            -> do let ul_bco = assembleBCO platform proto
                                  p <- ioptr (liftM BCOPtrBCO ul_bco)

@@ -88,6 +88,9 @@ data BCInstr
    | PUSH_PRIMOP  PrimOp
    | PUSH_BCO     (ProtoBCO Name)
+   -- Push a tagged ptr
+   | PUSH_TAGGED Name DataCon
    -- Push an alt continuation
    | PUSH_ALTS          (ProtoBCO Name) ArgRep
    | PUSH_ALTS_TUPLE    (ProtoBCO Name) -- continuation
@@ -294,6 +297,7 @@ instance Outputable BCInstr where
    ppr (PUSH_UBX32 lit)      = text "PUSH_UBX32" <+> ppr lit
    ppr (PUSH_UBX lit nw)     = text "PUSH_UBX" <+> parens (ppr nw) <+> ppr lit
    ppr (PUSH_ADDR nm)        = text "PUSH_ADDR" <+> ppr nm
+   ppr (PUSH_TAGGED nm tg)   = text "PUSH_TAGGED" <+> ppr nm <+> ppr tg
    ppr PUSH_APPLY_N          = text "PUSH_APPLY_N"
    ppr PUSH_APPLY_V          = text "PUSH_APPLY_V"
    ppr PUSH_APPLY_F          = text "PUSH_APPLY_F"
@@ -390,6 +394,7 @@ bciStackUse PUSH32_W{}            = 1  -- takes exactly 1 word
 bciStackUse PUSH_G{}              = 1
 bciStackUse PUSH_PRIMOP{}         = 1
 bciStackUse PUSH_BCO{}            = 1
+bciStackUse PUSH_TAGGED{}         = 1
 bciStackUse (PUSH_ALTS bco _)     = 2 {- profiling only, restore CCCS -} +
                                     3 + protoBCOStackUse bco
 bciStackUse (PUSH_ALTS_TUPLE bco info _) =

@@ -43,7 +43,6 @@ import GHC.Types.Literal
 import GHC.Builtin.PrimOps
 import GHC.Builtin.PrimOps.Ids (primOpId)
 import GHC.Core.Type
-import GHC.Core.TyCo.Compare (eqType)
 import GHC.Types.RepType
 import GHC.Core.DataCon
 import GHC.Core.TyCon
@@ -58,7 +57,7 @@ import GHC.Data.FastString
 import GHC.Utils.Panic
 import GHC.Utils.Panic.Plain
 import GHC.Utils.Exception (evaluate)
-import GHC.StgToCmm.Closure ( NonVoid(..), fromNonVoid, nonVoidIds, argPrimRep )
+import GHC.StgToCmm.Closure ( NonVoid(..), fromNonVoid, nonVoidIds, argPrimRep, idPrimRep)
 import GHC.StgToCmm.Layout
 import GHC.Runtime.Heap.Layout hiding (WordOff, ByteOff, wordsToBytes)
 import GHC.Data.Bitmap
@@ -93,6 +92,8 @@ import Data.Either ( partitionEithers )
 import GHC.Stg.Syntax
 import qualified Data.IntSet as IntSet
 import GHC.CoreToIface
+import GHC.Types.Var.Env (IdEnv, mkVarEnv, lookupVarEnv)
+import GHC.StgToCmm.Types (LambdaFormInfo(LFCon))
 -- -----------------------------------------------------------------------------
 -- Generating byte code for a complete module
@@ -118,9 +119,10 @@ byteCodeGen hsc_env this_mod binds tycs mb_modBreaks
             flattenBind (StgRec bs)     = bs
         stringPtrs <- allocateTopStrings interp strings
+        let flattened_binds = concatMap flattenBind (reverse lifted_binds)
         (BcM_State{..}, proto_bcos) <-
-           runBc hsc_env this_mod mb_modBreaks $ do
-             let flattened_binds = concatMap flattenBind (reverse lifted_binds)
+           runBc hsc_env this_mod mb_modBreaks (mkVarEnv (getDcs flattened_binds)) $ do
              mapM schemeTopBind flattened_binds
         when (notNull ffis)
@@ -150,6 +152,11 @@ byteCodeGen hsc_env this_mod binds tycs mb_modBreaks
         interp  = hscInterp hsc_env
         profile = targetProfile dflags
+getDcs :: [(Id, CgStgRhs)] -> [(Id, DataCon)]
+getDcs ((id, StgRhsCon _ dc _ _ _ _) : xs) = (id, dc) : getDcs xs
+getDcs (_ : xs) = getDcs xs
+getDcs [] = []
 -- | see Note [Generating code for top-level string literal bindings]
   :: Interp
@@ -1861,10 +1868,19 @@ pushAtom d p (StgVarArg var)
             -- see Note [Generating code for top-level string literal bindings]
-            | isUnliftedType (idType var) -> do
-              massert (idType var `eqType` addrPrimTy)
+            | idPrimRep var == AddrRep -> do
               return (unitOL (PUSH_ADDR (getName var)), szb)
+            | idPrimRep var == BoxedRep (Just Unlifted) -> do
+              mayDc <- lookupDc var
+              case mayDc of
+                Nothing ->
+                  case idLFInfo_maybe var of
+                    Nothing -> pprPanic "pushAtom: unlifted external id without LFInfo" (ppr var)
+                    Just (LFCon dc) -> return (unitOL (PUSH_TAGGED (getName var) dc), szb)
+                    Just{} -> pprPanic "pushAtom: expected LFCon" (ppr var)
+                Just dc -> return (unitOL (PUSH_TAGGED (getName var) dc), szb)
             | otherwise -> do
               return (unitOL (PUSH_G (getName var)), szb)
@@ -2230,6 +2246,7 @@ data BcM_State
                                          -- Should be free()d when it is GCd
         , modBreaks   :: Maybe ModBreaks -- info about breakpoints
         , breakInfo   :: IntMap CgBreakInfo
+        , bcm_dcs     :: IdEnv DataCon
 newtype BcM r = BcM (BcM_State -> IO (BcM_State, r)) deriving (Functor)
@@ -2239,11 +2256,11 @@ ioToBc io = BcM $ \st -> do
   x <- io
   return (st, x)
-runBc :: HscEnv -> Module -> Maybe ModBreaks
+runBc :: HscEnv -> Module -> Maybe ModBreaks -> IdEnv DataCon
       -> BcM r
       -> IO (BcM_State, r)
-runBc hsc_env this_mod modBreaks (BcM m)
-   = m (BcM_State hsc_env this_mod 0 [] modBreaks IntMap.empty)
+runBc hsc_env this_mod modBreaks dcs (BcM m)
+   = m (BcM_State hsc_env this_mod 0 [] modBreaks IntMap.empty dcs)
 thenBc :: BcM a -> (a -> BcM b) -> BcM b
 thenBc (BcM expr) cont = BcM $ \st0 -> do
@@ -2317,3 +2334,6 @@ getCurrentModBreaks = BcM $ \st -> return (st, modBreaks st)
 tickFS :: FastString
 tickFS = fsLit "ticked"
+lookupDc :: Id -> BcM (Maybe DataCon)
+lookupDc id = BcM $ \st -> pure (st, lookupVarEnv (bcm_dcs st) id)

@@ -134,6 +134,13 @@ disInstr ( StgBCO *bco, int pc )
          debugBelch("PUSH_G   " ); printPtr( ptrs[instrs[pc]] );
          debugBelch("\n" );
          pc += 1; break;
+      case bci_PUSH_TAGGED:
+         debugBelch("PUSH_TAGGED  " );
+         printPtr( ptrs[BCO_NEXT] );
+         debugBelch(" ");
+         printPtr( (StgPtr)literals[BCO_NEXT] );
+         debugBelch("\n");
+         break;
       case bci_PUSH_ALTS_P:
          debugBelch("PUSH_ALTS_P  " ); printPtr( ptrs[instrs[pc]] );

@@ -290,6 +290,9 @@ StgClosure * copyPAP  (Capability *cap, StgPAP *oldpap)
 STATIC_INLINE StgClosure *tagConstr(StgClosure *con) {
     return TAG_CLOSURE(stg_min(TAG_MASK, 1 + GET_TAG(con)), con);
+STATIC_INLINE StgClosure *tagPtr(StgClosure *p, StgInfoTable *itbl) {
+    return TAG_CLOSURE(stg_min(TAG_MASK, 1 + itbl->srt), p);
 static StgWord app_ptrs_itbl[] = {
@@ -1296,11 +1299,26 @@ run_BCO:
         case bci_PUSH_G: {
             W_ o1 = BCO_GET_LARGE_ARG;
+            IF_DEBUG(interpreter,
+                     debugBelch("PUSH_G %ld\n", o1);
+                );
             SpW(-1) = BCO_PTR(o1);
             goto nextInsn;
+        case bci_PUSH_TAGGED: {
+            W_ o1 = BCO_GET_LARGE_ARG;
+            W_ o_itbl = BCO_GET_LARGE_ARG;
+            IF_DEBUG(interpreter,
+                     debugBelch("PUSH_TAGGED %ld %ld\n", o1, o_itbl);
+                );
+            StgInfoTable* itbl = INFO_PTR_TO_STRUCT((StgInfoTable *)BCO_LIT(o_itbl));
+            SpW(-1) = (W_)tagPtr((StgClosure *)BCO_PTR(o1), itbl);
+            Sp_subW(1);
+            goto nextInsn;
+        }
         case bci_PUSH_ALTS_P: {
             W_ o_bco  = BCO_GET_LARGE_ARG;
@@ -1677,6 +1695,9 @@ run_BCO:
             W_ i;
             W_ o_itbl         = BCO_GET_LARGE_ARG;
             W_ n_words        = BCO_GET_LARGE_ARG;
+            IF_DEBUG(interpreter,
+                     debugBelch("PACK %ld %ld\n", o_itbl, n_words);
+                );
             StgInfoTable* itbl = INFO_PTR_TO_STRUCT((StgInfoTable *)BCO_LIT(o_itbl));
             int request        = CONSTR_sizeW( itbl->layout.payload.ptrs,
                                                itbl->layout.payload.nptrs );

@@ -112,6 +112,8 @@
 #define bci_PRIMCALL                    87
+#define bci_PUSH_TAGGED                 88
 /* If you need to go past 255 then you will run into the flags */
 /* If you need to go below 0x0100 then you will run into the instructions */

View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/b655a3ecac16dfd474e396437753171a59fa693f
