[Git][ghc/ghc][wip/ncg-simd] Use xmm registers in genapply

sheaf (@sheaf) gitlab at gitlab.haskell.org
Tue Jul 9 16:58:37 UTC 2024



sheaf pushed to branch wip/ncg-simd at Glasgow Haskell Compiler / GHC


Commits:
717042b7 by sheaf at 2024-07-09T18:58:14+02:00
Use xmm registers in genapply

This commit updates genapply to use xmm, ymm and zmm registers, for
stg_ap_v16/stg_ap_v32/stg_ap_v64, respectively.

It also updates the Cmm lexer and parser to produce Cmm vectors rather
than 128/256/512 bit wide scalars for V16/V32/V64, removing bits128,
bits256 and bits512 in favour of vectors.

The Cmm Lint check is weakened for vectors, as (in practice, e.g. on X86)
it is okay to use a single vector register to hold multiple different
types of data, and we don't know just from seeing e.g. "XMM1" how to
interpret the 128 bits of data within.

Fixes #25062

- - - - -


7 changed files:

- compiler/GHC/Cmm/Lexer.x
- compiler/GHC/Cmm/Lint.hs
- compiler/GHC/Cmm/Parser.y
- compiler/GHC/Cmm/Type.hs
- rts/include/Cmm.h
- utils/deriveConstants/Main.hs
- utils/genapply/Main.hs


Changes:

=====================================
compiler/GHC/Cmm/Lexer.x
=====================================
@@ -104,11 +104,14 @@ $white_no_nl+           ;
   "False"               { kw CmmT_False }
   "likely"              { kw CmmT_likely}
 
-  P at decimal             { global_regN VanillaReg      gcWord }
-  R at decimal             { global_regN VanillaReg       bWord }
-  F at decimal             { global_regN FloatReg  (const $ cmmFloat W32) }
-  D at decimal             { global_regN DoubleReg (const $ cmmFloat W64) }
-  L at decimal             { global_regN LongReg   (const $ cmmBits  W64) }
+  P at decimal             { global_regN 1 VanillaReg      gcWord }
+  R at decimal             { global_regN 1 VanillaReg       bWord }
+  F at decimal             { global_regN 1 FloatReg  (const $ cmmFloat W32) }
+  D at decimal             { global_regN 1 DoubleReg (const $ cmmFloat W64) }
+  L at decimal             { global_regN 1 LongReg   (const $ cmmBits  W64) }
+  XMM at decimal           { global_regN 3 XmmReg    (const $ cmmVec 2 (cmmFloat W64)) }
+  YMM at decimal           { global_regN 3 YmmReg    (const $ cmmVec 4 (cmmFloat W64)) }
+  ZMM at decimal           { global_regN 3 ZmmReg    (const $ cmmVec 8 (cmmFloat W64)) }
   Sp                    { global_reg  Sp               bWord }
   SpLim                 { global_reg  SpLim            bWord }
   Hp                    { global_reg  Hp              gcWord }
@@ -173,9 +176,9 @@ data CmmToken
   | CmmT_bits16
   | CmmT_bits32
   | CmmT_bits64
-  | CmmT_bits128
-  | CmmT_bits256
-  | CmmT_bits512
+  | CmmT_vec128
+  | CmmT_vec256
+  | CmmT_vec512
   | CmmT_float32
   | CmmT_float64
   | CmmT_gcptr
@@ -211,14 +214,16 @@ special_char span buf _len = return (L span (CmmT_SpecChar (currentChar buf)))
 kw :: CmmToken -> Action
 kw tok span _buf _len = return (L span tok)
 
-global_regN :: (Int -> GlobalReg) -> (Platform -> CmmType) -> Action
-global_regN con ty_fn span buf len
+global_regN :: Int -> (Int -> GlobalReg) -> (Platform -> CmmType) -> Action
+global_regN ident_nb_chars con ty_fn span buf len
   = do { platform <- getPlatform
        ; let reg = con (fromIntegral n)
              ty = ty_fn platform
        ; return (L span (CmmT_GlobalReg (GlobalRegUse reg ty))) }
-  where buf' = stepOn buf
-        n = parseUnsignedInteger buf' (len-1) 10 octDecDigit
+  where buf' = go ident_nb_chars buf
+          where go 0 b = b
+                go i b = go (i-1) (stepOn b)
+        n = parseUnsignedInteger buf' (len-ident_nb_chars) 10 octDecDigit
 
 global_reg :: GlobalReg -> (Platform -> CmmType) -> Action
 global_reg reg ty_fn span _buf _len
@@ -269,9 +274,9 @@ reservedWordsFM = listToUFM $
         ( "bits16",             CmmT_bits16 ),
         ( "bits32",             CmmT_bits32 ),
         ( "bits64",             CmmT_bits64 ),
-        ( "bits128",            CmmT_bits128 ),
-        ( "bits256",            CmmT_bits256 ),
-        ( "bits512",            CmmT_bits512 ),
+        ( "vec128",             CmmT_vec128 ),
+        ( "vec256",             CmmT_vec256 ),
+        ( "vec512",             CmmT_vec512 ),
         ( "float32",            CmmT_float32 ),
         ( "float64",            CmmT_float64 ),
 -- New forms
@@ -279,9 +284,6 @@ reservedWordsFM = listToUFM $
         ( "b16",                CmmT_bits16 ),
         ( "b32",                CmmT_bits32 ),
         ( "b64",                CmmT_bits64 ),
-        ( "b128",               CmmT_bits128 ),
-        ( "b256",               CmmT_bits256 ),
-        ( "b512",               CmmT_bits512 ),
         ( "f32",                CmmT_float32 ),
         ( "f64",                CmmT_float64 ),
         ( "gcptr",              CmmT_gcptr ),


=====================================
compiler/GHC/Cmm/Lint.hs
=====================================
@@ -171,7 +171,7 @@ lintCmmMiddle node = case node of
   CmmAssign reg expr -> do
             erep <- lintCmmExpr expr
             let reg_ty = cmmRegType reg
-            unless (erep `cmmEqType_ignoring_ptrhood` reg_ty) $
+            unless (erep `cmmCompatType` reg_ty) $
               cmmLintAssignErr (CmmAssign reg expr) erep reg_ty
 
   CmmStore l r _alignment -> do


=====================================
compiler/GHC/Cmm/Parser.y
=====================================
@@ -381,9 +381,9 @@ import qualified Data.ByteString.Char8 as BS8
         'bits16'        { L _ (CmmT_bits16) }
         'bits32'        { L _ (CmmT_bits32) }
         'bits64'        { L _ (CmmT_bits64) }
-        'bits128'       { L _ (CmmT_bits128) }
-        'bits256'       { L _ (CmmT_bits256) }
-        'bits512'       { L _ (CmmT_bits512) }
+        'vec128'        { L _ (CmmT_vec128) }
+        'vec256'        { L _ (CmmT_vec256) }
+        'vec512'        { L _ (CmmT_vec512) }
         'float32'       { L _ (CmmT_float32) }
         'float64'       { L _ (CmmT_float64) }
         'gcptr'         { L _ (CmmT_gcptr) }
@@ -942,9 +942,9 @@ typenot8 :: { CmmType }
         : 'bits16'              { b16 }
         | 'bits32'              { b32 }
         | 'bits64'              { b64 }
-        | 'bits128'             { b128 }
-        | 'bits256'             { b256 }
-        | 'bits512'             { b512 }
+        | 'vec128'              { cmmVec 2 f64 }
+        | 'vec256'              { cmmVec 4 f64 }
+        | 'vec512'              { cmmVec 8 f64 }
         | 'float32'             { f32 }
         | 'float64'             { f64 }
         | 'gcptr'               {% do platform <- PD.getPlatform; return $ gcWord platform }


=====================================
compiler/GHC/Cmm/Type.hs
=====================================
@@ -4,7 +4,7 @@ module GHC.Cmm.Type
     , cInt
     , cmmBits, cmmFloat
     , typeWidth, setCmmTypeWidth
-    , cmmEqType, cmmEqType_ignoring_ptrhood
+    , cmmEqType, cmmCompatType
     , isFloatType, isGcPtrType, isBitsType
     , isWordAny, isWord32, isWord64
     , isFloat64, isFloat32
@@ -87,21 +87,27 @@ instance Outputable CmmCat where
 cmmEqType :: CmmType -> CmmType -> Bool -- Exact equality
 cmmEqType (CmmType c1 w1) (CmmType c2 w2) = c1==c2 && w1==w2
 
-cmmEqType_ignoring_ptrhood :: CmmType -> CmmType -> Bool
-  -- This equality is temporary; used in CmmLint
-  -- but the RTS files are not yet well-typed wrt pointers
-cmmEqType_ignoring_ptrhood (CmmType c1 w1) (CmmType c2 w2)
-   = c1 `weak_eq` c2 && w1==w2
+-- | A weaker notion of equality of 'CmmType's than 'cmmEqType',
+-- used (only) in Cmm Lint.
+--
+-- Why "weaker"? Because:
+--
+--  - we don't distinguish GcPtr vs NonGcPtr, because the the RTS files
+--    are not yet well-typed wrt pointers,
+--  - for vectors, we only compare the widths, because in practice things like
+--    X86 xmm registers support different types of data (e.g. 4xf32, 2xf64, 2xu64 etc).
+cmmCompatType :: CmmType -> CmmType -> Bool
+cmmCompatType (CmmType c1 w1) (CmmType c2 w2)
+   = c1 `weak_eq` c2 && w1 == w2
    where
      weak_eq :: CmmCat -> CmmCat -> Bool
-     FloatCat         `weak_eq` FloatCat         = True
-     FloatCat         `weak_eq` _other           = False
-     _other           `weak_eq` FloatCat         = False
-     (VecCat l1 cat1) `weak_eq` (VecCat l2 cat2) = l1 == l2
-                                                   && cat1 `weak_eq` cat2
-     (VecCat {})      `weak_eq` _other           = False
-     _other           `weak_eq` (VecCat {})      = False
-     _word1           `weak_eq` _word2           = True        -- Ignores GcPtr
+     FloatCat    `weak_eq` FloatCat    = True
+     FloatCat    `weak_eq` _other      = False
+     _other      `weak_eq` FloatCat    = False
+     (VecCat {}) `weak_eq` (VecCat {}) = True  -- only compare overall width
+     (VecCat {}) `weak_eq` _other      = False
+     _other      `weak_eq` (VecCat {}) = False
+     _word1      `weak_eq` _word2      = True  -- Ignores GcPtr
 
 --- Simple operations on CmmType -----
 typeWidth :: CmmType -> Width


=====================================
rts/include/Cmm.h
=====================================
@@ -101,9 +101,9 @@
 #define F_   float32
 #define D_   float64
 #define L_   bits64
-#define V16_ bits128
-#define V32_ bits256
-#define V64_ bits512
+#define V16_ vec128
+#define V32_ vec256
+#define V64_ vec512
 
 #define SIZEOF_StgDouble 8
 #define SIZEOF_StgWord64 8


=====================================
utils/deriveConstants/Main.hs
=====================================
@@ -1048,7 +1048,8 @@ writeHeader fn rs = atomicWriteFile fn xs
           genapplyBits = mconcat ["// " ++ _name ++ " " ++ show v ++ "\n" | (_name, v) <- genapplyData]
           genapplyData = [(_name, v) | (_, GetWord _name (Snd v)) <- rs, _name `elem` genapplyFields ]
           genapplyFields = [
-            "MAX_Real_Vanilla_REG", "MAX_Real_Float_REG", "MAX_Real_Double_REG", "MAX_Real_Long_REG",
+            "MAX_Real_Vanilla_REG", "MAX_Real_Float_REG", "MAX_Real_Double_REG",
+            "MAX_Real_Long_REG", "MAX_Real_XMM_REG",
             "WORD_SIZE", "TAG_BITS", "BITMAP_BITS_SHIFT"
             ]
           haskellRs = fmap snd $ filter (\r -> fst r `elem` [Haskell,Both]) rs


=====================================
utils/genapply/Main.hs
=====================================
@@ -14,7 +14,8 @@ import Prelude hiding ((<>))
 import Text.PrettyPrint
 import Data.Word
 import Data.Bits
-import Data.List        ( intersperse, nub, sort )
+import Data.List        ( intercalate, intersperse, nub, sort )
+import Data.Maybe       ( mapMaybe )
 import System.Environment
 import Control.Arrow ((***))
 
@@ -67,6 +68,7 @@ data TargetInfo = TargetInfo
     maxRealFloatReg,
     maxRealDoubleReg,
     maxRealLongReg,
+    maxRealXmmReg,
     wordSize,
     tagBits,
     tagBitsMax,
@@ -86,6 +88,7 @@ parseTargetInfo path = do
     maxRealFloatReg = tups_get "MAX_Real_Float_REG",
     maxRealDoubleReg = tups_get "MAX_Real_Double_REG",
     maxRealLongReg = tups_get "MAX_Real_Long_REG",
+    maxRealXmmReg = tups_get "MAX_Real_XMM_REG",
     wordSize = tups_get "WORD_SIZE",
     tagBits = tag_bits,
     tagBitsMax = 1 `shiftL` tag_bits,
@@ -105,6 +108,7 @@ data ArgRep
   | V16 -- 16-byte (128-bit) vectors
   | V32 -- 32-byte (256-bit) vectors
   | V64 -- 64-byte (512-bit) vectors
+  deriving (Eq, Show)
 
 -- size of a value in *words*
 argSize :: TargetInfo -> ArgRep -> Int
@@ -138,13 +142,15 @@ isPtr _ = False
 -- Registers
 
 type Reg = String
+type AvailRegs = ([Reg],[Reg],[Reg],[Reg],[Int])
 
-availableRegs :: TargetInfo -> ([Reg],[Reg],[Reg],[Reg])
+availableRegs :: TargetInfo -> AvailRegs
 availableRegs TargetInfo {..} =
   ( vanillaRegs maxRealVanillaReg,
     floatRegs   maxRealFloatReg,
     doubleRegs  maxRealDoubleReg,
-    longRegs    maxRealLongReg
+    longRegs    maxRealLongReg,
+    xmmRegNos   maxRealXmmReg
   )
 
 vanillaRegs, floatRegs, doubleRegs, longRegs :: Int -> [Reg]
@@ -153,6 +159,9 @@ floatRegs   n = [ "F" ++ show m | m <- [1..n] ]
 doubleRegs  n = [ "D" ++ show m | m <- [1..n] ]
 longRegs    n = [ "L" ++ show m | m <- [1..n] ]
 
+xmmRegNos :: Int -> [Int]
+xmmRegNos n = [1..n]
+
 -- -----------------------------------------------------------------------------
 -- Loading/saving register arguments to the stack
 
@@ -176,6 +185,7 @@ assignRegs
             Int)                -- Sp of left-over args
 assignRegs targetInfo sp args = assign targetInfo sp args (availableRegs targetInfo) []
 
+assign :: TargetInfo -> Int -> [ArgRep] -> AvailRegs -> [(Reg, Int)] -> ([(Reg, Int)], [ArgRep], Int)
 assign _ sp [] _regs doc = (doc, [], sp)
 assign targetInfo sp (V : args) regs doc = assign targetInfo sp args regs doc
 assign targetInfo sp (arg : args) regs doc
@@ -184,28 +194,49 @@ assign targetInfo sp (arg : args) regs doc
                             ((reg, sp) : doc)
     Nothing -> (doc, (arg:args), sp)
 
-findAvailableReg N (vreg:vregs, fregs, dregs, lregs) =
-  Just (vreg, (vregs,fregs,dregs,lregs))
-findAvailableReg P (vreg:vregs, fregs, dregs, lregs) =
-  Just (vreg, (vregs,fregs,dregs,lregs))
-findAvailableReg F (vregs, freg:fregs, dregs, lregs) =
-  Just (freg, (vregs,fregs,dregs,lregs))
-findAvailableReg D (vregs, fregs, dreg:dregs, lregs) =
-  Just (dreg, (vregs,fregs,dregs,lregs))
-findAvailableReg L (vregs, fregs, dregs, lreg:lregs) =
-  Just (lreg, (vregs,fregs,dregs,lregs))
+findAvailableReg :: ArgRep -> AvailRegs -> Maybe (Reg, AvailRegs)
+-- NB: this will go wrong if we try to generate stg_apply code with overlapping
+-- registers (e.g. stg_ap_df_fast).
+--
+-- This function should instead compute non-overlapping registers,
+-- depending on the platform.
+findAvailableReg N (vreg:vregs, fregs, dregs, lregs, xmmregNos) =
+  Just (vreg, (vregs,fregs,dregs,lregs,xmmregNos))
+findAvailableReg P (vreg:vregs, fregs, dregs, lregs, xmmregNos) =
+  Just (vreg, (vregs,fregs,dregs,lregs,xmmregNos))
+findAvailableReg F (vregs, freg:fregs, dregs, lregs, xmmregNos) =
+  Just (freg, (vregs,fregs,dregs,lregs,xmmregNos))
+findAvailableReg D (vregs, fregs, dreg:dregs, lregs, xmmregNos) =
+  Just (dreg, (vregs,fregs,dregs,lregs,xmmregNos))
+findAvailableReg L (vregs, fregs, dregs, lreg:lregs, xmmregNos) =
+  Just (lreg, (vregs,fregs,dregs,lregs,xmmregNos))
+findAvailableReg v (vregs, fregs, dregs, lregs, xmmregNo:xmmregNos)
+  | Just vecRegNm <-
+      case v of
+        V16 -> Just "XMM"
+        V32 -> Just "YMM"
+        V64 -> Just "ZMM"
+        _   -> Nothing
+  -- NB: here we assume xmm/ymm/zmm registers overlap.
+  = Just (vecRegNm ++ show xmmregNo, (vregs,fregs,dregs,lregs,xmmregNos))
 findAvailableReg _ _ = Nothing
 
+assign_reg_to_stk :: String -> Int -> Doc
 assign_reg_to_stk reg sp
    = loadSpWordOff (regRep reg) sp <> text " = " <> text reg <> semi
 
+assign_stk_to_reg :: String -> Int -> Doc
 assign_stk_to_reg reg sp
    = text reg <> text " = " <> loadSpWordOff (regRep reg) sp <> semi
 
+regRep :: String -> String
 regRep ('F':_) = "F_"
 regRep ('D':_) = "D_"
 regRep ('L':_) = "L_"
-regRep _       = "W_"
+regRep ('X':'M':'M':_) = "V16_"
+regRep ('Y':'M':'M':_) = "V32_"
+regRep ('Z':'M':'M':_) = "V64_"
+regRep _ = "W_"
 
 loadSpWordOff :: String -> Int -> Doc
 loadSpWordOff rep off = text rep <> text "[Sp+WDS(" <> int off <> text ")]"
@@ -649,6 +680,7 @@ formalParam arg n =
     text "arg" <> int n <> text ", "
 formalParamType arg = argRep arg
 
+argRep :: ArgRep -> Doc
 argRep F   = text "F_"
 argRep D   = text "D_"
 argRep L   = text "L_"
@@ -878,9 +910,18 @@ genApplyFast targetInfo args =
     (reg_locs, _leftovers, sp_offset) = assignRegs targetInfo 1 args
 
     stack_usage = maxStack [fun_stack, (sp_offset,sp_offset)]
+
+    vecs :: [Reg]
+    vecs = mapMaybe (\ (r,_) -> case r of { xyz:'M':'M':_ | xyz `elem` ['X','Y','Z'] -> Just r; _ -> Nothing}) reg_locs
+    vecs_cond :: Doc
+    vecs_cond = case vecs of
+      [] -> empty
+      _ -> text "#if" <+> text (intercalate " && " [ "defined(REG_" ++ r ++ ")" | r <- vecs ])
+
    in
-    vcat [
+    vcat $ [
      fun_fast_label,
+     vecs_cond,
      char '{',
      nest 4 (vcat [
         text "W_ info;",
@@ -926,7 +967,15 @@ genApplyFast targetInfo args =
        char '}'
      ]),
      char '}'
-   ]
+   ] ++
+    if null vecs
+    then []
+    else [ text "#else"
+         , char '{'
+         , nest 4 $ text "foreign \"C\" barf(\"" <> mkApplyName args <> text ": unsupported register\", NULL) never returns;"
+         , char '}'
+         , text "#endif"
+         ]
 
 -- -----------------------------------------------------------------------------
 -- Making a stack apply



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/717042b734aba2d2395558377fe897306de19c1b

-- 
This project does not include diff previews in email notifications.
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/717042b734aba2d2395558377fe897306de19c1b
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240709/b331b38a/attachment-0001.html>


More information about the ghc-commits mailing list