[Git][ghc/ghc][wip/ncg-simd] LLVM: propagate GlobalRegUse information
sheaf (@sheaf)
gitlab at gitlab.haskell.org
Fri Aug 30 23:06:00 UTC 2024
sheaf pushed to branch wip/ncg-simd at Glasgow Haskell Compiler / GHC
Commits:
b792b70c by sheaf at 2024-08-31T01:05:49+02:00
LLVM: propagate GlobalRegUse information
This commit ensures we keep track of how any particular global register
is being used in the LLVM backend. This informs the LLVM type
annotations, and avoids type mismatches of the following form:
argument is not of expected type '<2 x double>'
call ccc <2 x double> (<2 x double>)
(<4 x i32> arg)
- - - - -
5 changed files:
- compiler/GHC/CmmToLlvm.hs
- compiler/GHC/CmmToLlvm/Base.hs
- compiler/GHC/CmmToLlvm/CodeGen.hs
- compiler/GHC/CmmToLlvm/Ppr.hs
- compiler/GHC/CmmToLlvm/Regs.hs
Changes:
=====================================
compiler/GHC/CmmToLlvm.hs
=====================================
@@ -139,7 +139,7 @@ llvmGroupLlvmGens cmm = do
Nothing -> l
Just (CmmStaticsRaw info_lbl _) -> info_lbl
lml <- strCLabel_llvm l'
- funInsert lml =<< llvmFunTy (map globalRegUseGlobalReg live)
+ funInsert lml =<< llvmFunTy live
return Nothing
cdata <- fmap catMaybes $ mapM split cmm
=====================================
compiler/GHC/CmmToLlvm/Base.hs
=====================================
@@ -12,7 +12,7 @@
module GHC.CmmToLlvm.Base (
LlvmCmmDecl, LlvmBasicBlock,
- LiveGlobalRegs,
+ LiveGlobalRegs, LiveGlobalRegUses,
LlvmUnresData, LlvmData, UnresLabel, UnresStatic,
LlvmM,
@@ -29,6 +29,8 @@ module GHC.CmmToLlvm.Base (
llvmFunSig, llvmFunArgs, llvmStdFunAttrs, llvmFunAlign, llvmInfAlign,
llvmPtrBits, tysToParams, llvmFunSection, padLiveArgs, isFPR,
+ lookupRegUse,
+
strCLabel_llvm,
getGlobalPtr, generateExternDecls,
@@ -58,9 +60,11 @@ import GHC.Types.Unique.Set
import GHC.Types.Unique.Supply
import GHC.Utils.Logger
-import Data.Maybe (fromJust)
import Control.Monad.Trans.State (StateT (..))
-import Data.List (isPrefixOf)
+import Control.Applicative (Alternative((<|>)))
+import Data.Maybe (fromJust, mapMaybe)
+
+import Data.List (find, isPrefixOf)
import qualified Data.List.NonEmpty as NE
import Data.Ord (comparing)
@@ -73,6 +77,7 @@ type LlvmBasicBlock = GenBasicBlock LlvmStatement
-- | Global registers live on proc entry
type LiveGlobalRegs = [GlobalReg]
+type LiveGlobalRegUses = [GlobalRegUse]
-- | Unresolved code.
-- Of the form: (data label, data type, unresolved data)
@@ -116,16 +121,16 @@ llvmGhcCC platform
| otherwise = CC_Ghc
-- | Llvm Function type for Cmm function
-llvmFunTy :: LiveGlobalRegs -> LlvmM LlvmType
+llvmFunTy :: LiveGlobalRegUses -> LlvmM LlvmType
llvmFunTy live = return . LMFunction =<< llvmFunSig' live (fsLit "a") ExternallyVisible
-- | Llvm Function signature
-llvmFunSig :: LiveGlobalRegs -> CLabel -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
+llvmFunSig :: LiveGlobalRegUses -> CLabel -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
llvmFunSig live lbl link = do
lbl' <- strCLabel_llvm lbl
llvmFunSig' live lbl' link
-llvmFunSig' :: LiveGlobalRegs -> LMString -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
+llvmFunSig' :: LiveGlobalRegUses -> LMString -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
llvmFunSig' live lbl link
= do let toParams x | isPointer x = (x, [NoAlias, NoCapture])
| otherwise = (x, [])
@@ -149,16 +154,25 @@ llvmFunSection opts lbl
| otherwise = Nothing
-- | A Function's arguments
-llvmFunArgs :: Platform -> LiveGlobalRegs -> [LlvmVar]
+llvmFunArgs :: Platform -> LiveGlobalRegUses -> [LlvmVar]
llvmFunArgs platform live =
- map (lmGlobalRegArg platform) (filter isPassed allRegs)
+ map (lmGlobalRegArg platform) (mapMaybe isPassed allRegs)
where allRegs = activeStgRegs platform
paddingRegs = padLiveArgs platform live
- isLive r = r `elem` alwaysLive
- || r `elem` live
- || r `elem` paddingRegs
- isPassed r = not (isFPR r) || isLive r
-
+ isLive :: GlobalReg -> Maybe GlobalRegUse
+ isLive r =
+ lookupRegUse r (alwaysLive platform)
+ <|>
+ lookupRegUse r live
+ <|>
+ lookupRegUse r paddingRegs
+ isPassed r =
+ if not (isFPR r)
+ then Just $ GlobalRegUse r (globalRegSpillType platform r)
+ else isLive r
+
+lookupRegUse :: GlobalReg -> [GlobalRegUse] -> Maybe GlobalRegUse
+lookupRegUse r = find ((== r) . globalRegUseGlobalReg)
isFPR :: GlobalReg -> Bool
isFPR (FloatReg _) = True
@@ -179,7 +193,7 @@ isFPR _ = False
-- Invariant: Cmm FPR regs with number "n" maps to real registers with number
-- "n" If the calling convention uses registers in a different order or if the
-- invariant doesn't hold, this code probably won't be correct.
-padLiveArgs :: Platform -> LiveGlobalRegs -> LiveGlobalRegs
+padLiveArgs :: Platform -> LiveGlobalRegUses -> LiveGlobalRegUses
padLiveArgs platform live =
if platformUnregisterised platform
then [] -- not using GHC's register convention for platform.
@@ -188,7 +202,7 @@ padLiveArgs platform live =
----------------------------------
-- handle floating-point registers (FPR)
- fprLive = filter isFPR live -- real live FPR registers
+ fprLive = filter (isFPR . globalRegUseGlobalReg) live -- real live FPR registers
-- we group live registers sharing the same classes, i.e. that use the same
-- set of real registers to be passed. E.g. FloatReg, DoubleReg and XmmReg
@@ -196,39 +210,44 @@ padLiveArgs platform live =
--
classes = NE.groupBy sharesClass fprLive
sharesClass a b = globalRegsOverlap platform (norm a) (norm b) -- check if mapped to overlapping registers
- norm x = fpr_ctor x 1 -- get the first register of the family
+ norm x = globalRegUseGlobalReg (fpr_ctor x 1) -- get the first register of the family
-- For each class, we just have to fill missing registers numbers. We use
-- the constructor of the greatest register to build padding registers.
--
-- E.g. sortedRs = [ F2, XMM4, D5]
-- output = [D1, D3]
+ padded :: [GlobalRegUse]
padded = concatMap padClass classes
+
+ padClass :: NE.NonEmpty GlobalRegUse -> [GlobalRegUse]
padClass rs = go (NE.toList sortedRs) 1
where
- sortedRs = NE.sortBy (comparing fpr_num) rs
+ sortedRs = NE.sortBy (comparing (fpr_num . globalRegUseGlobalReg)) rs
maxr = NE.last sortedRs
ctor = fpr_ctor maxr
go [] _ = []
- go (c1:c2:_) _ -- detect bogus case (see #17920)
+ go (GlobalRegUse c1 _: GlobalRegUse c2 _:_) _ -- detect bogus case (see #17920)
| fpr_num c1 == fpr_num c2
, Just real <- globalRegMaybe platform c1
= sorryDoc "LLVM code generator" $
text "Found two different Cmm registers (" <> ppr c1 <> text "," <> ppr c2 <>
text ") both alive AND mapped to the same real register: " <> ppr real <>
text ". This isn't currently supported by the LLVM backend."
- go (c:cs) f
- | fpr_num c == f = go cs f -- already covered by a real register
- | otherwise = ctor f : go (c:cs) (f + 1) -- add padding register
-
- fpr_ctor :: GlobalReg -> Int -> GlobalReg
- fpr_ctor (FloatReg _) = FloatReg
- fpr_ctor (DoubleReg _) = DoubleReg
- fpr_ctor (XmmReg _) = XmmReg
- fpr_ctor (YmmReg _) = YmmReg
- fpr_ctor (ZmmReg _) = ZmmReg
- fpr_ctor _ = error "fpr_ctor expected only FPR regs"
+ go (cu@(GlobalRegUse c _):cs) f
+ | fpr_num c == f = go cs f -- already covered by a real register
+ | otherwise = ctor f : go (cu:cs) (f + 1) -- add padding register
+
+ fpr_ctor :: GlobalRegUse -> Int -> GlobalRegUse
+ fpr_ctor (GlobalRegUse r fmt) i =
+ case r of
+ FloatReg _ -> GlobalRegUse (FloatReg i) fmt
+ DoubleReg _ -> GlobalRegUse (DoubleReg i) fmt
+ XmmReg _ -> GlobalRegUse (XmmReg i) fmt
+ YmmReg _ -> GlobalRegUse (YmmReg i) fmt
+ ZmmReg _ -> GlobalRegUse (ZmmReg i) fmt
+ _ -> error "fpr_ctor expected only FPR regs"
fpr_num :: GlobalReg -> Int
fpr_num (FloatReg i) = i
=====================================
compiler/GHC/CmmToLlvm/CodeGen.hs
=====================================
@@ -37,13 +37,14 @@ import GHC.Utils.Outputable
import qualified GHC.Utils.Panic as Panic
import GHC.Utils.Misc
+import Control.Applicative (Alternative((<|>)))
import Control.Monad.Trans.Class
import Control.Monad.Trans.Writer
import Control.Monad
import qualified Data.Semigroup as Semigroup
import Data.List ( nub )
-import Data.Maybe ( catMaybes )
+import Data.Maybe ( catMaybes, isJust )
type Atomic = Maybe MemoryOrdering
type LlvmStatements = OrdList LlvmStatement
@@ -57,7 +58,7 @@ genLlvmProc :: RawCmmDecl -> LlvmM [LlvmCmmDecl]
genLlvmProc (CmmProc infos lbl live graph) = do
let blocks = toBlockListEntryFirstFalseFallthrough graph
- (lmblocks, lmdata) <- basicBlocksCodeGen (map globalRegUseGlobalReg live) blocks
+ (lmblocks, lmdata) <- basicBlocksCodeGen live blocks
let info = mapLookup (g_entry graph) infos
proc = CmmProc info lbl live (ListGraph lmblocks)
return (proc:lmdata)
@@ -76,7 +77,7 @@ newtype UnreachableBlockId = UnreachableBlockId BlockId
-- | Generate code for a list of blocks that make up a complete
-- procedure. The first block in the list is expected to be the entry
-- point.
-basicBlocksCodeGen :: LiveGlobalRegs -> [CmmBlock]
+basicBlocksCodeGen :: LiveGlobalRegUses -> [CmmBlock]
-> LlvmM ([LlvmBasicBlock], [LlvmCmmDecl])
basicBlocksCodeGen _ [] = panic "no entry block!"
basicBlocksCodeGen live cmmBlocks
@@ -152,7 +153,7 @@ stmtToInstrs ubid stmt = case stmt of
-- Tail call
CmmCall { cml_target = arg,
- cml_args_regs = live } -> genJump arg $ map globalRegUseGlobalReg live
+ cml_args_regs = live } -> genJump arg live
_ -> panic "Llvm.CodeGen.stmtToInstrs"
@@ -1050,7 +1051,7 @@ cmmPrimOpFunctions mop = do
-- | Tail function calls
-genJump :: CmmExpr -> [GlobalReg] -> LlvmM StmtData
+genJump :: CmmExpr -> LiveGlobalRegUses -> LlvmM StmtData
-- Call to known function
genJump (CmmLit (CmmLabel lbl)) live = do
@@ -2056,14 +2057,13 @@ getCmmReg (CmmLocal (LocalReg un _))
-- have been assigned a value at some point, triggering
-- "funPrologue" to allocate it on the stack.
-getCmmReg (CmmGlobal g)
- = do let r = globalRegUseGlobalReg g
- onStack <- checkStackReg r
+getCmmReg (CmmGlobal ru@(GlobalRegUse r _))
+ = do onStack <- checkStackReg r
platform <- getPlatform
if onStack
- then return (lmGlobalRegVar platform r)
+ then return (lmGlobalRegVar platform ru)
else pprPanic "getCmmReg: Cmm register " $
- ppr g <> text " not stack-allocated!"
+ ppr r <> text " not stack-allocated!"
-- | Return the value of a given register, as well as its type. Might
-- need to be load from stack.
@@ -2074,7 +2074,7 @@ getCmmRegVal reg =
onStack <- checkStackReg (globalRegUseGlobalReg g)
platform <- getPlatform
if onStack then loadFromStack else do
- let r = lmGlobalRegArg platform (globalRegUseGlobalReg g)
+ let r = lmGlobalRegArg platform g
return (r, getVarType r, nilOL)
_ -> loadFromStack
where loadFromStack = do
@@ -2187,8 +2187,9 @@ convertMemoryOrdering MemOrderSeqCst = SyncSeqCst
-- question is never written. Therefore we skip it where we can to
-- save a few lines in the output and hopefully speed compilation up a
-- bit.
-funPrologue :: LiveGlobalRegs -> [CmmBlock] -> LlvmM StmtData
+funPrologue :: LiveGlobalRegUses -> [CmmBlock] -> LlvmM StmtData
funPrologue live cmmBlocks = do
+ platform <- getPlatform
let getAssignedRegs :: CmmNode O O -> [CmmReg]
getAssignedRegs (CmmAssign reg _) = [reg]
@@ -2196,7 +2197,8 @@ funPrologue live cmmBlocks = do
getAssignedRegs _ = []
getRegsBlock (_, body, _) = concatMap getAssignedRegs $ blockToList body
assignedRegs = nub $ concatMap (getRegsBlock . blockSplit) cmmBlocks
- isLive r = r `elem` alwaysLive || r `elem` live
+ mbLive r =
+ lookupRegUse r (alwaysLive platform) <|> lookupRegUse r live
platform <- getPlatform
stmtss <- forM assignedRegs $ \reg ->
@@ -2205,12 +2207,12 @@ funPrologue live cmmBlocks = do
let (newv, stmts) = allocReg reg
varInsert un (pLower $ getVarType newv)
return stmts
- CmmGlobal (GlobalRegUse r _) -> do
- let reg = lmGlobalRegVar platform r
- arg = lmGlobalRegArg platform r
+ CmmGlobal ru@(GlobalRegUse r _) -> do
+ let reg = lmGlobalRegVar platform ru
+ arg = lmGlobalRegArg platform ru
ty = (pLower . getVarType) reg
trash = LMLitVar $ LMUndefLit ty
- rval = if isLive r then arg else trash
+ rval = if isJust (mbLive r) then arg else trash
alloc = Assignment reg $ Alloca (pLower $ getVarType reg) 1
markStackReg r
return $ toOL [alloc, Store rval reg Nothing []]
@@ -2222,7 +2224,7 @@ funPrologue live cmmBlocks = do
-- | Function epilogue. Load STG variables to use as argument for call.
-- STG Liveness optimisation done here.
-funEpilogue :: LiveGlobalRegs -> LlvmM ([LlvmVar], LlvmStatements)
+funEpilogue :: LiveGlobalRegUses -> LlvmM ([LlvmVar], LlvmStatements)
funEpilogue live = do
platform <- getPlatform
@@ -2248,12 +2250,16 @@ funEpilogue live = do
let allRegs = activeStgRegs platform
loads <- forM allRegs $ \r -> if
-- load live registers
- | r `elem` alwaysLive -> loadExpr (GlobalRegUse r (globalRegSpillType platform r))
- | r `elem` live -> loadExpr (GlobalRegUse r (globalRegSpillType platform r))
+ | Just ru <- lookupRegUse r (alwaysLive platform)
+ -> loadExpr ru
+ | Just ru <- lookupRegUse r live
+ -> loadExpr ru
-- load all non Floating-Point Registers
- | not (isFPR r) -> loadUndef r
+ | not (isFPR r)
+ -> loadUndef (GlobalRegUse r (globalRegSpillType platform r))
-- load padding Floating-Point Registers
- | r `elem` paddingRegs -> loadUndef r
+ | Just ru <- lookupRegUse r paddingRegs
+ -> loadUndef ru
| otherwise -> return (Nothing, nilOL)
let (vars, stmts) = unzip loads
@@ -2263,7 +2269,7 @@ funEpilogue live = do
--
-- This is for Haskell functions, function type is assumed, so doesn't work
-- with foreign functions.
-getHsFunc :: LiveGlobalRegs -> CLabel -> LlvmM ExprData
+getHsFunc :: LiveGlobalRegUses -> CLabel -> LlvmM ExprData
getHsFunc live lbl
= do fty <- llvmFunTy live
name <- strCLabel_llvm lbl
=====================================
compiler/GHC/CmmToLlvm/Ppr.hs
=====================================
@@ -49,9 +49,8 @@ pprLlvmCmmDecl (CmmData _ lmdata) = do
return ( vcat $ map (pprLlvmData opts) lmdata
, vcat $ map (pprLlvmData opts) lmdata)
-pprLlvmCmmDecl (CmmProc mb_info entry_lbl liveWithUses (ListGraph blks))
- = do let live = map globalRegUseGlobalReg liveWithUses
- lbl = case mb_info of
+pprLlvmCmmDecl (CmmProc mb_info entry_lbl live (ListGraph blks))
+ = do let lbl = case mb_info of
Nothing -> entry_lbl
Just (CmmStaticsRaw info_lbl _) -> info_lbl
link = if externallyVisibleCLabel lbl
=====================================
compiler/GHC/CmmToLlvm/Regs.hs
=====================================
@@ -14,25 +14,27 @@ import GHC.Prelude
import GHC.Llvm
import GHC.Cmm.Expr
+import GHC.CmmToAsm.Format
import GHC.Platform
import GHC.Data.FastString
import GHC.Utils.Panic ( panic )
import GHC.Types.Unique
+
-- | Get the LlvmVar function variable storing the real register
-lmGlobalRegVar :: Platform -> GlobalReg -> LlvmVar
+lmGlobalRegVar :: Platform -> GlobalRegUse -> LlvmVar
lmGlobalRegVar platform = pVarLift . lmGlobalReg platform "_Var"
-- | Get the LlvmVar function argument storing the real register
-lmGlobalRegArg :: Platform -> GlobalReg -> LlvmVar
+lmGlobalRegArg :: Platform -> GlobalRegUse -> LlvmVar
lmGlobalRegArg platform = lmGlobalReg platform "_Arg"
{- Need to make sure the names here can't conflict with the unique generated
names. Uniques generated names containing only base62 chars. So using say
the '_' char guarantees this.
-}
-lmGlobalReg :: Platform -> String -> GlobalReg -> LlvmVar
-lmGlobalReg platform suf reg
+lmGlobalReg :: Platform -> String -> GlobalRegUse -> LlvmVar
+lmGlobalReg platform suf (GlobalRegUse reg ty)
= case reg of
BaseReg -> ptrGlobal $ "Base" ++ suf
Sp -> ptrGlobal $ "Sp" ++ suf
@@ -88,13 +90,26 @@ lmGlobalReg platform suf reg
ptrGlobal name = LMNLocalVar (fsLit name) (llvmWordPtr platform)
floatGlobal name = LMNLocalVar (fsLit name) LMFloat
doubleGlobal name = LMNLocalVar (fsLit name) LMDouble
- xmmGlobal name = LMNLocalVar (fsLit name) (LMVector 4 (LMInt 32))
- ymmGlobal name = LMNLocalVar (fsLit name) (LMVector 8 (LMInt 32))
- zmmGlobal name = LMNLocalVar (fsLit name) (LMVector 16 (LMInt 32))
+ fmt = cmmTypeFormat ty
+ xmmGlobal name = LMNLocalVar (fsLit name) (formatLlvmType fmt)
+ ymmGlobal name = LMNLocalVar (fsLit name) (formatLlvmType fmt)
+ zmmGlobal name = LMNLocalVar (fsLit name) (formatLlvmType fmt)
+
+formatLlvmType :: Format -> LlvmType
+formatLlvmType II8 = LMInt 8
+formatLlvmType II16 = LMInt 16
+formatLlvmType II32 = LMInt 32
+formatLlvmType II64 = LMInt 64
+formatLlvmType FF32 = LMFloat
+formatLlvmType FF64 = LMDouble
+formatLlvmType (VecFormat l sFmt) = LMVector l (formatLlvmType $ scalarFormatFormat sFmt)
-- | A list of STG Registers that should always be considered alive
-alwaysLive :: [GlobalReg]
-alwaysLive = [BaseReg, Sp, Hp, SpLim, HpLim, node]
+alwaysLive :: Platform -> [GlobalRegUse]
+alwaysLive platform =
+ [ GlobalRegUse r (globalRegSpillType platform r)
+ | r <- [BaseReg, Sp, Hp, SpLim, HpLim, node]
+ ]
-- | STG Type Based Alias Analysis hierarchy
stgTBAA :: [(Unique, LMString, Maybe Unique)]
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/b792b70cb22976009486514d578ec10de0c90b63
--
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/b792b70cb22976009486514d578ec10de0c90b63
You're receiving this email because of your account on gitlab.haskell.org.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240830/ccb01eec/attachment-0001.html>
More information about the ghc-commits
mailing list