[Git][ghc/ghc][wip/ncg-simd] Use xmm registers in genapply
sheaf (@sheaf)
gitlab at gitlab.haskell.org
Tue Jul 9 17:01:39 UTC 2024
sheaf pushed to branch wip/ncg-simd at Glasgow Haskell Compiler / GHC
Commits:
227ab64d by sheaf at 2024-07-09T19:01:29+02:00
Use xmm registers in genapply
This commit updates genapply to use xmm, ymm and zmm registers, for
stg_ap_v16/stg_ap_v32/stg_ap_v64, respectively.
It also updates the Cmm lexer and parser to produce Cmm vectors rather
than 128/256/512 bit wide scalars for V16/V32/V64, removing bits128,
bits256 and bits512 in favour of vectors.
The Cmm Lint check is weakened for vectors, as (in practice, e.g. on X86)
it is okay to use a single vector register to hold multiple different
types of data, and we don't know just from seeing e.g. "XMM1" how to
interpret the 128 bits of data within.
Fixes #25062
- - - - -
7 changed files:
- compiler/GHC/Cmm/Lexer.x
- compiler/GHC/Cmm/Lint.hs
- compiler/GHC/Cmm/Parser.y
- compiler/GHC/Cmm/Type.hs
- rts/include/Cmm.h
- utils/deriveConstants/Main.hs
- utils/genapply/Main.hs
Changes:
=====================================
compiler/GHC/Cmm/Lexer.x
=====================================
@@ -104,11 +104,14 @@ $white_no_nl+ ;
"False" { kw CmmT_False }
"likely" { kw CmmT_likely}
- P at decimal { global_regN VanillaReg gcWord }
- R at decimal { global_regN VanillaReg bWord }
- F at decimal { global_regN FloatReg (const $ cmmFloat W32) }
- D at decimal { global_regN DoubleReg (const $ cmmFloat W64) }
- L at decimal { global_regN LongReg (const $ cmmBits W64) }
+ P at decimal { global_regN 1 VanillaReg gcWord }
+ R at decimal { global_regN 1 VanillaReg bWord }
+ F at decimal { global_regN 1 FloatReg (const $ cmmFloat W32) }
+ D at decimal { global_regN 1 DoubleReg (const $ cmmFloat W64) }
+ L at decimal { global_regN 1 LongReg (const $ cmmBits W64) }
+ XMM at decimal { global_regN 3 XmmReg (const $ cmmVec 2 (cmmFloat W64)) }
+ YMM at decimal { global_regN 3 YmmReg (const $ cmmVec 4 (cmmFloat W64)) }
+ ZMM at decimal { global_regN 3 ZmmReg (const $ cmmVec 8 (cmmFloat W64)) }
Sp { global_reg Sp bWord }
SpLim { global_reg SpLim bWord }
Hp { global_reg Hp gcWord }
@@ -173,9 +176,9 @@ data CmmToken
| CmmT_bits16
| CmmT_bits32
| CmmT_bits64
- | CmmT_bits128
- | CmmT_bits256
- | CmmT_bits512
+ | CmmT_vec128
+ | CmmT_vec256
+ | CmmT_vec512
| CmmT_float32
| CmmT_float64
| CmmT_gcptr
@@ -211,14 +214,16 @@ special_char span buf _len = return (L span (CmmT_SpecChar (currentChar buf)))
kw :: CmmToken -> Action
kw tok span _buf _len = return (L span tok)
-global_regN :: (Int -> GlobalReg) -> (Platform -> CmmType) -> Action
-global_regN con ty_fn span buf len
+global_regN :: Int -> (Int -> GlobalReg) -> (Platform -> CmmType) -> Action
+global_regN ident_nb_chars con ty_fn span buf len
= do { platform <- getPlatform
; let reg = con (fromIntegral n)
ty = ty_fn platform
; return (L span (CmmT_GlobalReg (GlobalRegUse reg ty))) }
- where buf' = stepOn buf
- n = parseUnsignedInteger buf' (len-1) 10 octDecDigit
+ where buf' = go ident_nb_chars buf
+ where go 0 b = b
+ go i b = go (i-1) (stepOn b)
+ n = parseUnsignedInteger buf' (len-ident_nb_chars) 10 octDecDigit
global_reg :: GlobalReg -> (Platform -> CmmType) -> Action
global_reg reg ty_fn span _buf _len
@@ -269,9 +274,9 @@ reservedWordsFM = listToUFM $
( "bits16", CmmT_bits16 ),
( "bits32", CmmT_bits32 ),
( "bits64", CmmT_bits64 ),
- ( "bits128", CmmT_bits128 ),
- ( "bits256", CmmT_bits256 ),
- ( "bits512", CmmT_bits512 ),
+ ( "vec128", CmmT_vec128 ),
+ ( "vec256", CmmT_vec256 ),
+ ( "vec512", CmmT_vec512 ),
( "float32", CmmT_float32 ),
( "float64", CmmT_float64 ),
-- New forms
@@ -279,9 +284,6 @@ reservedWordsFM = listToUFM $
( "b16", CmmT_bits16 ),
( "b32", CmmT_bits32 ),
( "b64", CmmT_bits64 ),
- ( "b128", CmmT_bits128 ),
- ( "b256", CmmT_bits256 ),
- ( "b512", CmmT_bits512 ),
( "f32", CmmT_float32 ),
( "f64", CmmT_float64 ),
( "gcptr", CmmT_gcptr ),
=====================================
compiler/GHC/Cmm/Lint.hs
=====================================
@@ -171,7 +171,7 @@ lintCmmMiddle node = case node of
CmmAssign reg expr -> do
erep <- lintCmmExpr expr
let reg_ty = cmmRegType reg
- unless (erep `cmmEqType_ignoring_ptrhood` reg_ty) $
+ unless (erep `cmmCompatType` reg_ty) $
cmmLintAssignErr (CmmAssign reg expr) erep reg_ty
CmmStore l r _alignment -> do
=====================================
compiler/GHC/Cmm/Parser.y
=====================================
@@ -381,9 +381,9 @@ import qualified Data.ByteString.Char8 as BS8
'bits16' { L _ (CmmT_bits16) }
'bits32' { L _ (CmmT_bits32) }
'bits64' { L _ (CmmT_bits64) }
- 'bits128' { L _ (CmmT_bits128) }
- 'bits256' { L _ (CmmT_bits256) }
- 'bits512' { L _ (CmmT_bits512) }
+ 'vec128' { L _ (CmmT_vec128) }
+ 'vec256' { L _ (CmmT_vec256) }
+ 'vec512' { L _ (CmmT_vec512) }
'float32' { L _ (CmmT_float32) }
'float64' { L _ (CmmT_float64) }
'gcptr' { L _ (CmmT_gcptr) }
@@ -942,9 +942,9 @@ typenot8 :: { CmmType }
: 'bits16' { b16 }
| 'bits32' { b32 }
| 'bits64' { b64 }
- | 'bits128' { b128 }
- | 'bits256' { b256 }
- | 'bits512' { b512 }
+ | 'vec128' { cmmVec 2 f64 }
+ | 'vec256' { cmmVec 4 f64 }
+ | 'vec512' { cmmVec 8 f64 }
| 'float32' { f32 }
| 'float64' { f64 }
| 'gcptr' {% do platform <- PD.getPlatform; return $ gcWord platform }
=====================================
compiler/GHC/Cmm/Type.hs
=====================================
@@ -4,7 +4,7 @@ module GHC.Cmm.Type
, cInt
, cmmBits, cmmFloat
, typeWidth, setCmmTypeWidth
- , cmmEqType, cmmEqType_ignoring_ptrhood
+ , cmmEqType, cmmCompatType
, isFloatType, isGcPtrType, isBitsType
, isWordAny, isWord32, isWord64
, isFloat64, isFloat32
@@ -87,21 +87,27 @@ instance Outputable CmmCat where
cmmEqType :: CmmType -> CmmType -> Bool -- Exact equality
cmmEqType (CmmType c1 w1) (CmmType c2 w2) = c1==c2 && w1==w2
-cmmEqType_ignoring_ptrhood :: CmmType -> CmmType -> Bool
- -- This equality is temporary; used in CmmLint
- -- but the RTS files are not yet well-typed wrt pointers
-cmmEqType_ignoring_ptrhood (CmmType c1 w1) (CmmType c2 w2)
- = c1 `weak_eq` c2 && w1==w2
+-- | A weaker notion of equality of 'CmmType's than 'cmmEqType',
+-- used (only) in Cmm Lint.
+--
+-- Why "weaker"? Because:
+--
+-- - we don't distinguish GcPtr vs NonGcPtr, because the the RTS files
+-- are not yet well-typed wrt pointers,
+-- - for vectors, we only compare the widths, because in practice things like
+-- X86 xmm registers support different types of data (e.g. 4xf32, 2xf64, 2xu64 etc).
+cmmCompatType :: CmmType -> CmmType -> Bool
+cmmCompatType (CmmType c1 w1) (CmmType c2 w2)
+ = c1 `weak_eq` c2 && w1 == w2
where
weak_eq :: CmmCat -> CmmCat -> Bool
- FloatCat `weak_eq` FloatCat = True
- FloatCat `weak_eq` _other = False
- _other `weak_eq` FloatCat = False
- (VecCat l1 cat1) `weak_eq` (VecCat l2 cat2) = l1 == l2
- && cat1 `weak_eq` cat2
- (VecCat {}) `weak_eq` _other = False
- _other `weak_eq` (VecCat {}) = False
- _word1 `weak_eq` _word2 = True -- Ignores GcPtr
+ FloatCat `weak_eq` FloatCat = True
+ FloatCat `weak_eq` _other = False
+ _other `weak_eq` FloatCat = False
+ (VecCat {}) `weak_eq` (VecCat {}) = True -- only compare overall width
+ (VecCat {}) `weak_eq` _other = False
+ _other `weak_eq` (VecCat {}) = False
+ _word1 `weak_eq` _word2 = True -- Ignores GcPtr
--- Simple operations on CmmType -----
typeWidth :: CmmType -> Width
=====================================
rts/include/Cmm.h
=====================================
@@ -101,9 +101,9 @@
#define F_ float32
#define D_ float64
#define L_ bits64
-#define V16_ bits128
-#define V32_ bits256
-#define V64_ bits512
+#define V16_ vec128
+#define V32_ vec256
+#define V64_ vec512
#define SIZEOF_StgDouble 8
#define SIZEOF_StgWord64 8
=====================================
utils/deriveConstants/Main.hs
=====================================
@@ -1048,7 +1048,8 @@ writeHeader fn rs = atomicWriteFile fn xs
genapplyBits = mconcat ["// " ++ _name ++ " " ++ show v ++ "\n" | (_name, v) <- genapplyData]
genapplyData = [(_name, v) | (_, GetWord _name (Snd v)) <- rs, _name `elem` genapplyFields ]
genapplyFields = [
- "MAX_Real_Vanilla_REG", "MAX_Real_Float_REG", "MAX_Real_Double_REG", "MAX_Real_Long_REG",
+ "MAX_Real_Vanilla_REG", "MAX_Real_Float_REG", "MAX_Real_Double_REG",
+ "MAX_Real_Long_REG", "MAX_Real_XMM_REG",
"WORD_SIZE", "TAG_BITS", "BITMAP_BITS_SHIFT"
]
haskellRs = fmap snd $ filter (\r -> fst r `elem` [Haskell,Both]) rs
=====================================
utils/genapply/Main.hs
=====================================
@@ -14,7 +14,8 @@ import Prelude hiding ((<>))
import Text.PrettyPrint
import Data.Word
import Data.Bits
-import Data.List ( intersperse, nub, sort )
+import Data.List ( intercalate, intersperse, nub, sort )
+import Data.Maybe ( mapMaybe )
import System.Environment
import Control.Arrow ((***))
@@ -67,6 +68,7 @@ data TargetInfo = TargetInfo
maxRealFloatReg,
maxRealDoubleReg,
maxRealLongReg,
+ maxRealXmmReg,
wordSize,
tagBits,
tagBitsMax,
@@ -86,6 +88,7 @@ parseTargetInfo path = do
maxRealFloatReg = tups_get "MAX_Real_Float_REG",
maxRealDoubleReg = tups_get "MAX_Real_Double_REG",
maxRealLongReg = tups_get "MAX_Real_Long_REG",
+ maxRealXmmReg = tups_get "MAX_Real_XMM_REG",
wordSize = tups_get "WORD_SIZE",
tagBits = tag_bits,
tagBitsMax = 1 `shiftL` tag_bits,
@@ -105,6 +108,7 @@ data ArgRep
| V16 -- 16-byte (128-bit) vectors
| V32 -- 32-byte (256-bit) vectors
| V64 -- 64-byte (512-bit) vectors
+ deriving (Eq, Show)
-- size of a value in *words*
argSize :: TargetInfo -> ArgRep -> Int
@@ -138,13 +142,15 @@ isPtr _ = False
-- Registers
type Reg = String
+type AvailRegs = ([Reg],[Reg],[Reg],[Reg],[Int])
-availableRegs :: TargetInfo -> ([Reg],[Reg],[Reg],[Reg])
+availableRegs :: TargetInfo -> AvailRegs
availableRegs TargetInfo {..} =
( vanillaRegs maxRealVanillaReg,
floatRegs maxRealFloatReg,
doubleRegs maxRealDoubleReg,
- longRegs maxRealLongReg
+ longRegs maxRealLongReg,
+ xmmRegNos maxRealXmmReg
)
vanillaRegs, floatRegs, doubleRegs, longRegs :: Int -> [Reg]
@@ -153,6 +159,9 @@ floatRegs n = [ "F" ++ show m | m <- [1..n] ]
doubleRegs n = [ "D" ++ show m | m <- [1..n] ]
longRegs n = [ "L" ++ show m | m <- [1..n] ]
+xmmRegNos :: Int -> [Int]
+xmmRegNos n = [1..n]
+
-- -----------------------------------------------------------------------------
-- Loading/saving register arguments to the stack
@@ -176,6 +185,7 @@ assignRegs
Int) -- Sp of left-over args
assignRegs targetInfo sp args = assign targetInfo sp args (availableRegs targetInfo) []
+assign :: TargetInfo -> Int -> [ArgRep] -> AvailRegs -> [(Reg, Int)] -> ([(Reg, Int)], [ArgRep], Int)
assign _ sp [] _regs doc = (doc, [], sp)
assign targetInfo sp (V : args) regs doc = assign targetInfo sp args regs doc
assign targetInfo sp (arg : args) regs doc
@@ -184,28 +194,49 @@ assign targetInfo sp (arg : args) regs doc
((reg, sp) : doc)
Nothing -> (doc, (arg:args), sp)
-findAvailableReg N (vreg:vregs, fregs, dregs, lregs) =
- Just (vreg, (vregs,fregs,dregs,lregs))
-findAvailableReg P (vreg:vregs, fregs, dregs, lregs) =
- Just (vreg, (vregs,fregs,dregs,lregs))
-findAvailableReg F (vregs, freg:fregs, dregs, lregs) =
- Just (freg, (vregs,fregs,dregs,lregs))
-findAvailableReg D (vregs, fregs, dreg:dregs, lregs) =
- Just (dreg, (vregs,fregs,dregs,lregs))
-findAvailableReg L (vregs, fregs, dregs, lreg:lregs) =
- Just (lreg, (vregs,fregs,dregs,lregs))
+findAvailableReg :: ArgRep -> AvailRegs -> Maybe (Reg, AvailRegs)
+-- NB: this will go wrong if we try to generate stg_apply code with overlapping
+-- registers (e.g. stg_ap_df_fast).
+--
+-- This function should instead compute non-overlapping registers,
+-- depending on the platform.
+findAvailableReg N (vreg:vregs, fregs, dregs, lregs, xmmregNos) =
+ Just (vreg, (vregs,fregs,dregs,lregs,xmmregNos))
+findAvailableReg P (vreg:vregs, fregs, dregs, lregs, xmmregNos) =
+ Just (vreg, (vregs,fregs,dregs,lregs,xmmregNos))
+findAvailableReg F (vregs, freg:fregs, dregs, lregs, xmmregNos) =
+ Just (freg, (vregs,fregs,dregs,lregs,xmmregNos))
+findAvailableReg D (vregs, fregs, dreg:dregs, lregs, xmmregNos) =
+ Just (dreg, (vregs,fregs,dregs,lregs,xmmregNos))
+findAvailableReg L (vregs, fregs, dregs, lreg:lregs, xmmregNos) =
+ Just (lreg, (vregs,fregs,dregs,lregs,xmmregNos))
+findAvailableReg v (vregs, fregs, dregs, lregs, xmmregNo:xmmregNos)
+ | Just vecRegNm <-
+ case v of
+ V16 -> Just "XMM"
+ V32 -> Just "YMM"
+ V64 -> Just "ZMM"
+ _ -> Nothing
+ -- NB: here we assume xmm/ymm/zmm registers overlap.
+ = Just (vecRegNm ++ show xmmregNo, (vregs,fregs,dregs,lregs,xmmregNos))
findAvailableReg _ _ = Nothing
+assign_reg_to_stk :: String -> Int -> Doc
assign_reg_to_stk reg sp
= loadSpWordOff (regRep reg) sp <> text " = " <> text reg <> semi
+assign_stk_to_reg :: String -> Int -> Doc
assign_stk_to_reg reg sp
= text reg <> text " = " <> loadSpWordOff (regRep reg) sp <> semi
+regRep :: String -> String
regRep ('F':_) = "F_"
regRep ('D':_) = "D_"
regRep ('L':_) = "L_"
-regRep _ = "W_"
+regRep ('X':'M':'M':_) = "V16_"
+regRep ('Y':'M':'M':_) = "V32_"
+regRep ('Z':'M':'M':_) = "V64_"
+regRep _ = "W_"
loadSpWordOff :: String -> Int -> Doc
loadSpWordOff rep off = text rep <> text "[Sp+WDS(" <> int off <> text ")]"
@@ -649,6 +680,7 @@ formalParam arg n =
text "arg" <> int n <> text ", "
formalParamType arg = argRep arg
+argRep :: ArgRep -> Doc
argRep F = text "F_"
argRep D = text "D_"
argRep L = text "L_"
@@ -878,9 +910,19 @@ genApplyFast targetInfo args =
(reg_locs, _leftovers, sp_offset) = assignRegs targetInfo 1 args
stack_usage = maxStack [fun_stack, (sp_offset,sp_offset)]
+
+ vecs :: [Reg]
+ vecs = mapMaybe (\ (r,_) -> case r of { xyz:'M':'M':_ | xyz `elem` ['X','Y','Z'] -> Just r; _ -> Nothing}) reg_locs
+ vecs_cond :: Doc
+ vecs_cond = case vecs of
+ [] -> empty
+ _ -> text "#if" <+> text (intercalate " && " [ "defined(REG_" ++ r ++ ")" | r <- vecs ])
+
in
- vcat [
+ vcat $ [
fun_fast_label,
+ vecs_cond, -- If we use e.g. ZMM1, wrap the definition in "#if defined(REG_ZMM1)"
+ -- to prevent attempting to compile this code on unsupported architectures.
char '{',
nest 4 (vcat [
text "W_ info;",
@@ -926,7 +968,15 @@ genApplyFast targetInfo args =
char '}'
]),
char '}'
- ]
+ ] ++
+ if null vecs
+ then []
+ else [ text "#else"
+ , char '{'
+ , nest 4 $ text "foreign \"C\" barf(\"" <> mkApplyName args <> text ": unsupported register\", NULL) never returns;"
+ , char '}'
+ , text "#endif"
+ ]
-- -----------------------------------------------------------------------------
-- Making a stack apply
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/227ab64d3a04af4473e5822d8abc83f17fe0fea9
--
This project does not include diff previews in email notifications.
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/227ab64d3a04af4473e5822d8abc83f17fe0fea9
You're receiving this email because of your account on gitlab.haskell.org.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240709/afc59a42/attachment-0001.html>
More information about the ghc-commits
mailing list