[Git][ghc/ghc][wip/supersven/aarch64-jump-tables] AArch64: Implement switch/jump tables (#19912)

Sven Tennie (@supersven) gitlab at gitlab.haskell.org
Sun Sep 8 16:34:36 UTC 2024



Sven Tennie pushed to branch wip/supersven/aarch64-jump-tables at Glasgow Haskell Compiler / GHC


Commits:
c6721691 by Sven Tennie at 2024-09-08T18:34:04+02:00
AArch64: Implement switch/jump tables (#19912)

This improves the performance of Cmm switch statements (compared to a
chain of if statements.)

- - - - -


3 changed files:

- compiler/GHC/CmmToAsm/AArch64/CodeGen.hs
- compiler/GHC/CmmToAsm/AArch64/Instr.hs
- compiler/GHC/CmmToAsm/AArch64/Ppr.hs


Changes:

=====================================
compiler/GHC/CmmToAsm/AArch64/CodeGen.hs
=====================================
@@ -23,7 +23,7 @@ import GHC.Cmm.DebugBlock
 import GHC.CmmToAsm.Monad
    ( NatM, getNewRegNat
    , getPicBaseMaybeNat, getPlatform, getConfig
-   , getDebugBlock, getFileId
+   , getDebugBlock, getFileId, getNewLabelNat
    )
 -- import GHC.CmmToAsm.Instr
 import GHC.CmmToAsm.PIC
@@ -50,7 +50,7 @@ import GHC.Types.Unique.Supply
 import GHC.Data.OrdList
 import GHC.Utils.Outputable
 
-import Control.Monad    ( mapAndUnzipM, foldM )
+import Control.Monad    ( mapAndUnzipM )
 import GHC.Float
 
 import GHC.Types.Basic
@@ -209,43 +209,79 @@ annExpr e instr {- debugIsOn -} = ANN (text . show $ e) instr
 -- -----------------------------------------------------------------------------
 -- Generating a table-branch
 
--- TODO jump tables would be a lot faster, but we'll use bare bones for now.
--- this is usually done by sticking the jump table ids into an instruction
--- and then have the @generateJumpTableForInstr@ callback produce the jump
--- table as a static.
+-- | Generate jump to jump table target
 --
--- See Ticket 19912
---
--- data SwitchTargets =
---    SwitchTargets
---        Bool                       -- Signed values
---        (Integer, Integer)         -- Range
---        (Maybe Label)              -- Default value
---        (M.Map Integer Label)      -- The branches
---
--- Non Jumptable plan:
--- xE <- expr
+-- The index into the jump table is calulated by evaluating @expr at . The
+-- corresponding table entry contains the relative address to jump to (relative
+-- to the jump table's first entry / the table's own label).
+genSwitch :: NCGConfig -> CmmExpr -> SwitchTargets -> NatM InstrBlock
+genSwitch config expr targets = do
+  (reg, fmt1, e_code) <- getSomeReg indexExpr
+  let fmt = II64
+  targetReg <- getNewRegNat fmt
+  lbl <- getNewLabelNat
+  dynRef <- cmmMakeDynamicReference config DataReference lbl
+  (tableReg, fmt2, t_code) <- getSomeReg dynRef
+  let code =
+        toOL
+          [ COMMENT (text "indexExpr" <+> (text . show) indexExpr),
+            COMMENT (text "dynRef" <+> (text . show) dynRef)
+          ]
+          `appOL` e_code
+          `appOL` t_code
+          `appOL` toOL
+            [ COMMENT (ftext "Jump table for switch"),
+              -- index to offset into the table (relative to tableReg)
+              annExpr expr (LSL (OpReg (formatToWidth fmt1) reg) (OpReg (formatToWidth fmt1) reg) (OpImm (ImmInt 3))),
+              -- calculate table entry address
+              ADD (OpReg W64 targetReg) (OpReg (formatToWidth fmt1) reg) (OpReg (formatToWidth fmt2) tableReg),
+              -- load table entry (relative offset from tableReg (first entry) to target label)
+              LDR II64 (OpReg W64 targetReg) (OpAddr (AddrRegImm targetReg (ImmInt 0))),
+              -- calculate absolute address of the target label
+              ADD (OpReg W64 targetReg) (OpReg W64 targetReg) (OpReg W64 tableReg),
+              -- prepare jump to target label
+              J_TBL ids (Just lbl) targetReg
+            ]
+  return code
+  where
+    -- See Note [Sub-word subtlety during jump-table indexing] in
+    -- GHC.CmmToAsm.X86.CodeGen for why we must first offset, then widen.
+    indexExpr0 = cmmOffset platform expr offset
+    -- We widen to a native-width register to sanitize the high bits
+    indexExpr =
+      CmmMachOp
+        (MO_UU_Conv expr_w (platformWordWidth platform))
+        [indexExpr0]
+    expr_w = cmmExprWidth platform expr
+    (offset, ids) = switchTargetsToTable targets
+    platform = ncgPlatform config
+
+-- | Generate jump table data (if required)
 --
-genSwitch :: CmmExpr -> SwitchTargets -> NatM InstrBlock
-genSwitch expr targets = do -- pprPanic "genSwitch" (ppr expr)
-  (reg, format, code) <- getSomeReg expr
-  let w = formatToWidth format
-  let mkbranch acc (key, bid) = do
-        (keyReg, _format, code) <- getSomeReg (CmmLit (CmmInt key w))
-        return $ code `appOL`
-                 toOL [ CMP (OpReg w reg) (OpReg w keyReg)
-                      , BCOND EQ (TBlock bid)
-                      ] `appOL` acc
-      def_code = case switchTargetsDefault targets of
-        Just bid -> unitOL (B (TBlock bid))
-        Nothing  -> nilOL
-
-  switch_code <- foldM mkbranch nilOL (switchTargetsCases targets)
-  return $ code `appOL` switch_code `appOL` def_code
-
--- We don't do jump tables for now, see Ticket 19912
-generateJumpTableForInstr :: NCGConfig -> Instr
-  -> Maybe (NatCmmDecl RawCmmStatics Instr)
+-- The idea is to emit one table entry per case. The entry is the relative
+-- address of the block to jump to (relative to the table's first entry /
+-- table's own label.) The calculation itself is done by the linker.
+generateJumpTableForInstr ::
+  NCGConfig ->
+  Instr ->
+  Maybe (NatCmmDecl RawCmmStatics Instr)
+generateJumpTableForInstr config (J_TBL ids (Just lbl) _) =
+  let jumpTable =
+        map jumpTableEntryRel ids
+        where
+          jumpTableEntryRel Nothing =
+            CmmStaticLit (CmmInt 0 (ncgWordWidth config))
+          jumpTableEntryRel (Just blockid) =
+            CmmStaticLit
+              ( CmmLabelDiffOff
+                  blockLabel
+                  lbl
+                  0
+                  (ncgWordWidth config)
+              )
+            where
+              blockLabel = blockLbl blockid
+   in Just (CmmData (Section ReadOnlyData lbl) (CmmStaticsRaw lbl jumpTable))
 generateJumpTableForInstr _ _ = Nothing
 
 -- -----------------------------------------------------------------------------
@@ -266,6 +302,7 @@ stmtToInstrs :: CmmNode e x -- ^ Cmm Statement
 stmtToInstrs stmt = do
   -- traceM $ "-- -------------------------- stmtToInstrs -------------------------- --\n"
   --     ++ showSDocUnsafe (ppr stmt)
+  config <- getConfig
   platform <- getPlatform
   case stmt of
     CmmUnsafeForeignCall target result_regs args
@@ -294,7 +331,7 @@ stmtToInstrs stmt = do
       CmmCondBranch arg true false _prediction ->
           genCondBranch true false arg
 
-      CmmSwitch arg ids -> genSwitch arg ids
+      CmmSwitch arg ids -> genSwitch config arg ids
 
       CmmCall { cml_target = arg } -> genJump arg
 
@@ -339,12 +376,6 @@ getRegisterReg platform (CmmGlobal reg@(GlobalRegUse mid _))
         -- ones which map to a real machine register on this
         -- platform.  Hence if it's not mapped to a registers something
         -- went wrong earlier in the pipeline.
--- | Convert a BlockId to some CmmStatic data
--- TODO: Add JumpTable Logic, see Ticket 19912
--- jumpTableEntry :: NCGConfig -> Maybe BlockId -> CmmStatic
--- jumpTableEntry config Nothing   = CmmStaticLit (CmmInt 0 (ncgWordWidth config))
--- jumpTableEntry _ (Just blockid) = CmmStaticLit (CmmLabel blockLabel)
---     where blockLabel = blockLbl blockid
 
 -- -----------------------------------------------------------------------------
 -- General things for putting together code sequences


=====================================
compiler/GHC/CmmToAsm/AArch64/Instr.hs
=====================================
@@ -27,7 +27,7 @@ import GHC.Types.Unique.Supply
 
 import GHC.Utils.Panic
 
-import Data.Maybe (fromMaybe)
+import Data.Maybe (fromMaybe, catMaybes)
 
 import GHC.Stack
 
@@ -118,6 +118,7 @@ regUsageOfInstr platform instr = case instr of
   ORR dst src1 src2        -> usage (regOp src1 ++ regOp src2, regOp dst)
   -- 4. Branch Instructions ----------------------------------------------------
   J t                      -> usage (regTarget t, [])
+  J_TBL _ _ t              -> usage ([t], [])
   B t                      -> usage (regTarget t, [])
   BCOND _ t                -> usage (regTarget t, [])
   BL t ps                  -> usage (regTarget t ++ ps, callerSavedRegisters)
@@ -264,10 +265,11 @@ patchRegsOfInstr instr env = case instr of
     ORR o1 o2 o3   -> ORR  (patchOp o1) (patchOp o2) (patchOp o3)
 
     -- 4. Branch Instructions --------------------------------------------------
-    J t            -> J (patchTarget t)
-    B t            -> B (patchTarget t)
-    BL t rs        -> BL (patchTarget t) rs
-    BCOND c t      -> BCOND c (patchTarget t)
+    J t               -> J (patchTarget t)
+    J_TBL ids mbLbl t -> J_TBL ids mbLbl (env t)
+    B t               -> B (patchTarget t)
+    BL t rs           -> BL (patchTarget t) rs
+    BCOND c t         -> BCOND c (patchTarget t)
 
     -- 5. Atomic Instructions --------------------------------------------------
     -- 6. Conditional Instructions ---------------------------------------------
@@ -319,6 +321,7 @@ isJumpishInstr instr = case instr of
     CBZ{} -> True
     CBNZ{} -> True
     J{} -> True
+    J_TBL{} -> True
     B{} -> True
     BL{} -> True
     BCOND{} -> True
@@ -332,6 +335,7 @@ jumpDestsOfInstr (ANN _ i) = jumpDestsOfInstr i
 jumpDestsOfInstr (CBZ _ t) = [ id | TBlock id <- [t]]
 jumpDestsOfInstr (CBNZ _ t) = [ id | TBlock id <- [t]]
 jumpDestsOfInstr (J t) = [id | TBlock id <- [t]]
+jumpDestsOfInstr (J_TBL ids _mbLbl _r) = catMaybes ids
 jumpDestsOfInstr (B t) = [id | TBlock id <- [t]]
 jumpDestsOfInstr (BL t _) = [ id | TBlock id <- [t]]
 jumpDestsOfInstr (BCOND _ t) = [ id | TBlock id <- [t]]
@@ -340,6 +344,11 @@ jumpDestsOfInstr _ = []
 canFallthroughTo :: Instr -> BlockId -> Bool
 canFallthroughTo (ANN _ i) bid = canFallthroughTo i bid
 canFallthroughTo (J (TBlock target)) bid = bid == target
+canFallthroughTo (J_TBL targets _ _) bid = all isTargetBid targets
+  where
+    isTargetBid target = case target of
+      Nothing -> True
+      Just target -> target == bid
 canFallthroughTo (B (TBlock target)) bid = bid == target
 canFallthroughTo _ _ = False
 
@@ -353,6 +362,7 @@ patchJumpInstr instr patchF
         CBZ r (TBlock bid) -> CBZ r (TBlock (patchF bid))
         CBNZ r (TBlock bid) -> CBNZ r (TBlock (patchF bid))
         J (TBlock bid) -> J (TBlock (patchF bid))
+        J_TBL ids mbLbl r -> J_TBL (map (fmap patchF) ids) mbLbl r
         B (TBlock bid) -> B (TBlock (patchF bid))
         BL (TBlock bid) ps -> BL (TBlock (patchF bid)) ps
         BCOND c (TBlock bid) -> BCOND c (TBlock (patchF bid))
@@ -516,6 +526,7 @@ allocMoreStack platform slots proc@(CmmProc info lbl live (ListGraph code)) = do
 
       insert_dealloc insn r = case insn of
         J _ -> dealloc ++ (insn : r)
+        J_TBL {} -> dealloc ++ (insn : r)
         ANN _ (J _) -> dealloc ++ (insn : r)
         _other | jumpDestsOfInstr insn /= []
             -> patchJumpInstr insn retarget : r
@@ -644,6 +655,7 @@ data Instr
     | CBNZ Operand Target -- if op /= 0, then branch.
     -- Branching.
     | J Target            -- like B, but only generated from genJump. Used to distinguish genJumps from others.
+    | J_TBL [Maybe BlockId] (Maybe CLabel) Reg -- A jump instruction with data for switch/jump tables
     | B Target            -- unconditional branching b/br. (To a blockid, label or register)
     | BL Target [Reg] -- branch and link (e.g. set x30 to next pc, and branch)
     | BCOND Cond Target   -- branch with condition. b.<cond>
@@ -730,6 +742,7 @@ instrCon i =
       CBZ{} -> "CBZ"
       CBNZ{} -> "CBNZ"
       J{} -> "J"
+      J_TBL {} -> "J_TBL"
       B{} -> "B"
       BL{} -> "BL"
       BCOND{} -> "BCOND"


=====================================
compiler/GHC/CmmToAsm/AArch64/Ppr.hs
=====================================
@@ -426,6 +426,7 @@ pprInstr platform instr = case instr of
 
   -- 4. Branch Instructions ----------------------------------------------------
   J t            -> pprInstr platform (B t)
+  J_TBL _ _ r    -> pprInstr platform (B (TReg r))
   B (TBlock bid) -> line $ text "\tb" <+> pprAsmLabel platform (mkLocalBlockLabel (getUnique bid))
   B (TLabel lbl) -> line $ text "\tb" <+> pprAsmLabel platform lbl
   B (TReg r)     -> line $ text "\tbr" <+> pprReg W64 r



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/c67216918cd9c31391e3b30ba2ddbd50e0ab6958

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/c67216918cd9c31391e3b30ba2ddbd50e0ab6958
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240908/9e642a85/attachment-0001.html>


More information about the ghc-commits mailing list