[Git][ghc/ghc][wip/ncg-simd] 3 commits: X86 CodeGen: refactor getRegister CmmLit
sheaf (@sheaf)
gitlab at gitlab.haskell.org
Fri Sep 20 15:08:29 UTC 2024
sheaf pushed to branch wip/ncg-simd at Glasgow Haskell Compiler / GHC
Commits:
df5b64c4 by sheaf at 2024-09-20T17:08:06+02:00
X86 CodeGen: refactor getRegister CmmLit
This refactors the code dealing with loading literals into registers,
removing duplication and putting all the code in a single place.
It also changes which XOR instruction is used to place a zero value
into a register, so that we use VPXOR for a 128-bit integer vector
when AVX is supported.
- - - - -
cf56ab12 by sheaf at 2024-09-20T17:08:06+02:00
X86 genCCall64: simplify loadArg code
This commit simplifies the argument loading code by making the
assumption that it is safe to directly load the argument into register,
because doing so will not clobber any previous assignments.
This assumption is borne from the use of 'evalArgs', which evaluates
any arguments which might necessitate non-trivial code generation into
separate temporary registers.
- - - - -
c27747f2 by sheaf at 2024-09-20T17:08:06+02:00
LLVM: propagate GlobalRegUse information
This commit ensures we keep track of how any particular global register
is being used in the LLVM backend. This informs the LLVM type
annotations, and avoids type mismatches of the following form:
argument is not of expected type '<2 x double>'
call ccc <2 x double> (<2 x double>)
(<4 x i32> arg)
- - - - -
9 changed files:
- compiler/GHC/CmmToAsm/Format.hs
- compiler/GHC/CmmToAsm/X86/CodeGen.hs
- compiler/GHC/CmmToAsm/X86/Instr.hs
- compiler/GHC/CmmToAsm/X86/Ppr.hs
- compiler/GHC/CmmToLlvm.hs
- compiler/GHC/CmmToLlvm/Base.hs
- compiler/GHC/CmmToLlvm/CodeGen.hs
- compiler/GHC/CmmToLlvm/Ppr.hs
- compiler/GHC/CmmToLlvm/Regs.hs
Changes:
=====================================
compiler/GHC/CmmToAsm/Format.hs
=====================================
@@ -28,6 +28,7 @@ module GHC.CmmToAsm.Format (
scalarWidth,
formatInBytes,
isFloatScalarFormat,
+ isFloatOrFloatVecFormat,
floatScalarFormat,
scalarFormatFormat,
VirtualRegWithFormat(..),
@@ -134,6 +135,11 @@ isFloatScalarFormat = \case
FmtDouble -> True
_ -> False
+isFloatOrFloatVecFormat :: Format -> Bool
+isFloatOrFloatVecFormat = \case
+ VecFormat _ sFmt -> isFloatScalarFormat sFmt
+ fmt -> isFloatFormat fmt
+
floatScalarFormat :: Width -> ScalarFormat
floatScalarFormat W32 = FmtFloat
floatScalarFormat W64 = FmtDouble
=====================================
compiler/GHC/CmmToAsm/X86/CodeGen.hs
=====================================
@@ -913,20 +913,6 @@ getRegister' _ is32Bit (CmmMachOp (MO_UU_Conv W64 W16) [x])
ro <- getNewRegNat II16
return $ Fixed II16 ro (code `appOL` toOL [ MOVZxL II16 (OpReg rlo) (OpReg ro) ])
-getRegister' _ _ (CmmLit lit@(CmmFloat f w)) =
- float_const_sse2 where
- float_const_sse2
- | f == 0.0 = do
- -- TODO: this mishandles negative zero floating point literals.
- let
- format = floatFormat w
- code dst = unitOL (XOR format (OpReg dst) (OpReg dst))
- -- I don't know why there are xorpd, xorps, and pxor instructions.
- -- They all appear to do the same thing --SDM
- return (Any format code)
-
- | otherwise = getFloatLitRegister lit
-
-- catch simple cases of zero- or sign-extended load
getRegister' _ _ (CmmMachOp (MO_UU_Conv W8 W32) [CmmLoad addr _ _]) = do
code <- intLoadCode (MOVZxL II8) addr
@@ -1932,7 +1918,7 @@ getRegister' platform is32Bit load@(CmmLoad mem ty _)
| isFloatType ty
= do
Amode addr mem_code <- getAmode mem
- loadFloatAmode width addr mem_code
+ loadAmode (floatFormat width) addr mem_code
| is32Bit && not (isWord64 ty)
= do
@@ -1960,20 +1946,6 @@ getRegister' platform is32Bit load@(CmmLoad mem ty _)
format = cmmTypeFormat ty
width = typeWidth ty
-getRegister' _ is32Bit (CmmLit (CmmInt 0 width))
- = let
- format = intFormat width
-
- -- x86_64: 32-bit xor is one byte shorter, and zero-extends to 64 bits
- format1 = if is32Bit then format
- else case format of
- II64 -> II32
- _ -> format
- code dst
- = unitOL (XOR format1 (OpReg dst) (OpReg dst))
- in
- return (Any format code)
-
-- Handle symbol references with LEA and %rip-relative addressing.
-- See Note [%rip-relative addressing on x86-64].
getRegister' platform is32Bit (CmmLit lit)
@@ -1990,80 +1962,102 @@ getRegister' platform is32Bit (CmmLit lit)
is_label (CmmLabelDiffOff {}) = True
is_label _ = False
- -- optimisation for loading small literals on x86_64: take advantage
- -- of the automatic zero-extension from 32 to 64 bits, because the 32-bit
- -- instruction forms are shorter.
-getRegister' platform is32Bit (CmmLit lit)
- | not is32Bit, isWord64 (cmmLitType platform lit), not (isBigLit lit)
- = let
- imm = litToImm lit
- code dst = unitOL (MOV II32 (OpImm imm) (OpReg dst))
- in
- return (Any II64 code)
- where
- isBigLit (CmmInt i _) = i < 0 || i > 0xffffffff
- isBigLit _ = False
+getRegister' platform is32Bit (CmmLit lit) = do
+ avx <- avxEnabled
+
+ -- NB: it is important that the code produced here (to load a literal into
+ -- a register) doesn't clobber any registers other than the destination
+ -- register; the code for generating C calls relies on this property.
+ --
+ -- In particular, we have:
+ --
+ -- > loadIntoRegMightClobberOtherReg (CmmLit _) = False
+ --
+ -- which means that we assume that loading a literal into a register
+ -- will not clobber any other registers.
+
+ -- TODO: this function mishandles floating-point negative zero,
+ -- because -0.0 == 0.0 returns True and because we represent CmmFloat as
+ -- Rational, which can't properly represent negative zero.
+
+ if
+ -- Zero: use XOR.
+ | isZeroLit lit
+ -> let code dst
+ | isIntFormat fmt
+ = let fmt'
+ | is32Bit
+ = fmt
+ | otherwise
+ -- x86_64: 32-bit xor is one byte shorter,
+ -- and zero-extends to 64 bits
+ = case fmt of
+ II64 -> II32
+ _ -> fmt
+ in unitOL (XOR fmt' (OpReg dst) (OpReg dst))
+ | avx
+ = if float_or_floatvec
+ then unitOL (VXOR fmt (OpReg dst) dst dst)
+ else unitOL (VPXOR fmt dst dst dst)
+ | otherwise
+ = if float_or_floatvec
+ then unitOL (XOR fmt (OpReg dst) (OpReg dst))
+ else unitOL (PXOR fmt (OpReg dst) dst)
+ in return $ Any fmt code
+
+ -- Constant vector: use broadcast.
+ | VecFormat l sFmt <- fmt
+ , CmmVec (f:fs) <- lit
+ , all (== f) fs
+ -> do let w = scalarWidth sFmt
+ broadcast = if isFloatScalarFormat sFmt
+ then MO_VF_Broadcast l w
+ else MO_V_Broadcast l w
+ valCode <- getAnyReg (CmmMachOp broadcast [CmmLit f])
+ return $ Any fmt valCode
+
+ -- Optimisation for loading small literals on x86_64: take advantage
+ -- of the automatic zero-extension from 32 to 64 bits, because the 32-bit
+ -- instruction forms are shorter.
+ | not is32Bit, isWord64 cmmTy, not (isBigLit lit)
+ -> let
+ imm = litToImm lit
+ code dst = unitOL (MOV II32 (OpImm imm) (OpReg dst))
+ in
+ return (Any II64 code)
+
+ -- Scalar integer: use an immediate.
+ | isIntFormat fmt
+ -> let imm = litToImm lit
+ code dst = unitOL (MOV fmt (OpImm imm) (OpReg dst))
+ in return (Any fmt code)
+
+ -- General case: load literal from data address.
+ | otherwise
+ -> do let w = formatToWidth fmt
+ Amode addr addr_code <- memConstant (mkAlignment $ widthInBytes w) lit
+ loadAmode fmt addr addr_code
+
+ where
+ cmmTy = cmmLitType platform lit
+ fmt = cmmTypeFormat cmmTy
+ float_or_floatvec = isFloatOrFloatVecFormat fmt
+ isZeroLit (CmmInt i _) = i == 0
+ isZeroLit (CmmFloat f _) = f == 0 -- TODO: mishandles negative zero
+ isZeroLit (CmmVec fs) = all isZeroLit fs
+ isZeroLit _ = False
+
+ isBigLit (CmmInt i _) = i < 0 || i > 0xffffffff
+ isBigLit _ = False
-- note1: not the same as (not.is32BitLit), because that checks for
-- signed literals that fit in 32 bits, but we want unsigned
-- literals here.
-- note2: all labels are small, because we're assuming the
-- small memory model. See Note [%rip-relative addressing on x86-64].
-getRegister' platform _ (CmmLit lit) = do
- avx <- avxEnabled
- case fmt of
- VecFormat l sFmt
- | CmmVec fs <- lit
- , all is_zero fs
- -> let code dst
- | avx
- = if isFloatScalarFormat sFmt
- then unitOL (VXOR fmt (OpReg dst) dst dst)
- else unitOL (VPXOR fmt dst dst dst)
- | otherwise
- = unitOL (XOR fmt (OpReg dst) (OpReg dst))
- in return (Any fmt code)
- | CmmVec (f:fs) <- lit
- , all (== f) fs
- -- TODO: mishandles negative zero (because -0.0 == 0.0 returns True), and because we
- -- represent CmmFloat as Rational which can't properly represent negative zero.
- -> do let w = scalarWidth sFmt
- broadcast = if isFloatScalarFormat sFmt
- then MO_VF_Broadcast l w
- else MO_V_Broadcast l w
- valCode <- getAnyReg (CmmMachOp broadcast [CmmLit f])
- return $ Any fmt valCode
-
- | otherwise
- -> do
- let w = formatToWidth fmt
- config <- getConfig
- Amode addr addr_code <- memConstant (mkAlignment $ widthInBytes w) lit
- let code dst = addr_code `snocOL`
- movInstr config fmt (OpAddr addr) (OpReg dst)
- return (Any fmt code)
- where
- is_zero (CmmInt i _) = i == 0
- is_zero (CmmFloat f _) = f == 0 -- TODO: mishandles negative zero
- is_zero _ = False
-
- _ -> let imm = litToImm lit
- code dst = unitOL (MOV fmt (OpImm imm) (OpReg dst))
- in return (Any fmt code)
- where
- cmmTy = cmmLitType platform lit
- fmt = cmmTypeFormat cmmTy
-
getRegister' platform _ slot@(CmmStackSlot {}) =
pprPanic "getRegister(x86) CmmStackSlot" (pdoc platform slot)
-getFloatLitRegister :: CmmLit -> NatM Register
-getFloatLitRegister lit = do
- let w :: Width
- w = case lit of { CmmInt _ w -> w; CmmFloat _ w -> w; _ -> panic "getFloatLitRegister" (ppr lit) }
- Amode addr code <- memConstant (mkAlignment $ widthInBytes w) lit
- loadFloatAmode w addr code
-
intLoadCode :: (Operand -> Operand -> Instr) -> CmmExpr
-> NatM (Reg -> InstrBlock)
intLoadCode instr mem = do
@@ -2392,15 +2386,12 @@ memConstant align lit = do
`consOL` addr_code
return (Amode addr code)
-
-loadFloatAmode :: Width -> AddrMode -> InstrBlock -> NatM Register
-loadFloatAmode w addr addr_code = do
- let format = floatFormat w
- code dst = addr_code `snocOL`
- MOV format (OpAddr addr) (OpReg dst)
-
- return (Any format code)
-
+-- | Load the value at the given address into any register.
+loadAmode :: Format -> AddrMode -> InstrBlock -> NatM Register
+loadAmode fmt addr addr_code = do
+ config <- getConfig
+ let load dst = movInstr config fmt (OpAddr addr) (OpReg dst)
+ return $ Any fmt (\ dst -> addr_code `snocOL` load dst)
-- if we want a floating-point literal as an operand, we can
-- use it directly from memory. However, if the literal is
@@ -3100,10 +3091,8 @@ genSimplePrim _ op dst args = do
platform <- ncgPlatform <$> getConfig
pprPanic "genSimplePrim: unhandled primop" (ppr (pprCallishMachOp op, dst, fmap (pdoc platform) args))
-{-
-Note [Evaluate C-call arguments before placing in destination registers]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
+{- Note [Evaluate C-call arguments before placing in destination registers]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
When producing code for C calls we must take care when placing arguments
in their final registers. Specifically, we must ensure that temporary register
usage due to evaluation of one argument does not clobber a register in which we
@@ -3154,15 +3143,11 @@ genForeignCall{32,64}.
-- | See Note [Evaluate C-call arguments before placing in destination registers]
evalArgs :: BlockId -> [CmmActual] -> NatM (InstrBlock, [CmmActual])
evalArgs bid actuals
- | any mightContainMachOp actuals = do
+ | any loadIntoRegMightClobberOtherReg actuals = do
regs_blks <- mapM evalArg actuals
return (concatOL $ map fst regs_blks, map snd regs_blks)
| otherwise = return (nilOL, actuals)
where
- mightContainMachOp (CmmReg _) = False
- mightContainMachOp (CmmRegOff _ _) = False
- mightContainMachOp (CmmLit _) = False
- mightContainMachOp _ = True
evalArg :: CmmActual -> NatM (InstrBlock, CmmExpr)
evalArg actual = do
@@ -3176,6 +3161,16 @@ evalArgs bid actuals
newLocalReg :: CmmType -> NatM LocalReg
newLocalReg ty = LocalReg <$> getUniqueM <*> pure ty
+-- | Might the code to put this expression into a register
+-- clobber any other registers?
+loadIntoRegMightClobberOtherReg :: CmmExpr -> Bool
+loadIntoRegMightClobberOtherReg (CmmReg _) = False
+loadIntoRegMightClobberOtherReg (CmmRegOff _ _) = False
+loadIntoRegMightClobberOtherReg (CmmLit _) = False
+ -- NB: this last 'False' is slightly risky, because the code for loading
+ -- a literal into a register is not entirely trivial.
+loadIntoRegMightClobberOtherReg _ = True
+
-- Note [DIV/IDIV for bytes]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~
-- IDIV reminder:
@@ -3437,7 +3432,6 @@ genCCall64 addr conv@(ForeignConvention _ argHints _ _) dest_regs args = do
{ stackArgs = proper_stack_args
, stackDataArgs = stack_data_args
, usedRegs = arg_regs_used
- , computeArgsCode = compute_args_code
, assignArgsCode = assign_args_code
}
<- loadArgs config prom_args
@@ -3536,7 +3530,6 @@ genCCall64 addr conv@(ForeignConvention _ argHints _ _) dest_regs args = do
return (align_call_code `appOL`
push_code `appOL`
- compute_args_code `appOL`
assign_args_code `appOL`
load_data_refs `appOL`
shadow_space_code `appOL`
@@ -3557,16 +3550,14 @@ data LoadArgs
, stackDataArgs :: [CmmExpr]
-- | Which registers are we using for argument passing?
, usedRegs :: [RegWithFormat]
- -- | The code to compute arguments into (possibly temporary) registers.
- , computeArgsCode :: InstrBlock
-- | The code to assign arguments to registers used for argument passing.
, assignArgsCode :: InstrBlock
}
instance Semigroup LoadArgs where
- LoadArgs a1 d1 r1 i1 j1 <> LoadArgs a2 d2 r2 i2 j2
- = LoadArgs (a1 ++ a2) (d1 ++ d2) (r1 ++ r2) (i1 S.<> i2) (j1 S.<> j2)
+ LoadArgs a1 d1 r1 j1 <> LoadArgs a2 d2 r2 j2
+ = LoadArgs (a1 ++ a2) (d1 ++ d2) (r1 ++ r2) (j1 S.<> j2)
instance Monoid LoadArgs where
- mempty = LoadArgs [] [] [] nilOL nilOL
+ mempty = LoadArgs [] [] [] nilOL
-- | An argument passed on the stack, either directly or by reference.
--
@@ -3720,7 +3711,6 @@ loadArgsSysV config (arg:rest) = do
LoadArgs
{ stackArgs = map RawStackArg (arg:rest)
, stackDataArgs = []
- , computeArgsCode = nilOL
, assignArgsCode = nilOL
, usedRegs = []
}
@@ -3740,12 +3730,11 @@ loadArgsSysV config (arg:rest) = do
this_arg <-
case mbReg of
Just reg -> do
- (compute_code, assign_code) <- lift $ loadArgIntoReg config arg rest reg
+ assign_code <- lift $ loadArgIntoReg arg reg
return $
LoadArgs
{ stackArgs = [] -- passed in register
, stackDataArgs = []
- , computeArgsCode = compute_code
, assignArgsCode = assign_code
, usedRegs = [RegWithFormat reg arg_fmt]
}
@@ -3755,7 +3744,6 @@ loadArgsSysV config (arg:rest) = do
LoadArgs
{ stackArgs = [RawStackArg arg]
, stackDataArgs = []
- , computeArgsCode = nilOL
, assignArgsCode = nilOL
, usedRegs = []
}
@@ -3807,7 +3795,6 @@ loadArgsWin config (arg:rest) = do
LoadArgs
{ stackArgs = stk_args
, stackDataArgs = data_args
- , computeArgsCode = nilOL
, assignArgsCode = nilOL
, usedRegs = []
}
@@ -3823,8 +3810,7 @@ loadArgsWin config (arg:rest) = do
-- Pass the reference in a register,
-- and the argument data on the stack.
{ stackArgs = [RawStackArgRef (InReg ireg) (argSize platform arg)]
- , stackDataArgs = [arg]
- , computeArgsCode = nilOL -- we don't yet know where the data will reside,
+ , stackDataArgs = [arg] -- we don't yet know where the data will reside,
, assignArgsCode = nilOL -- so we defer computing the reference and storing it
-- in the register until later
, usedRegs = [RegWithFormat ireg II64]
@@ -3835,7 +3821,7 @@ loadArgsWin config (arg:rest) = do
= freg
| otherwise
= ireg
- (compute_code, assign_code) <- loadArgIntoReg config arg rest arg_reg
+ assign_code <- loadArgIntoReg arg arg_reg
-- Recall that, for varargs, we must pass floating-point
-- arguments in both fp and integer registers.
let (assign_code', regs')
@@ -3848,42 +3834,23 @@ loadArgsWin config (arg:rest) = do
LoadArgs
{ stackArgs = [] -- passed in register
, stackDataArgs = []
- , computeArgsCode = compute_code
, assignArgsCode = assign_code'
, usedRegs = regs'
}
-
--- | Return two pieces of code:
---
--- - code to compute a the given 'CmmExpr' into some (possibly temporary) register
--- - code to assign the resulting value to the specified register
+-- | Load an argument into a register.
--
--- Using two separate pieces of code handles clobbering issues reported
--- in e.g. #11792, #12614.
-loadArgIntoReg :: NCGConfig -> CmmExpr -> [CmmExpr] -> Reg -> NatM (InstrBlock, InstrBlock)
-loadArgIntoReg config arg rest reg
- -- "operand" args can be directly assigned into the register
- | isOperand platform arg
- = do arg_code <- getAnyReg arg
- return (nilOL, arg_code reg)
- -- The last non-operand arg can be directly assigned after its
- -- computation without going into a temporary register
- | all (isOperand platform) rest
- = do arg_code <- getAnyReg arg
- return (arg_code reg, nilOL)
- -- Other args need to be computed beforehand to avoid clobbering
- -- previously assigned registers used to pass parameters (see
- -- #11792, #12614). They are assigned into temporary registers
- -- and get assigned to proper call ABI registers after they all
- -- have been computed.
- | otherwise
- = do arg_code <- getAnyReg arg
- tmp <- getNewRegNat arg_fmt
- return (arg_code tmp, unitOL $ mkRegRegMoveInstr config arg_fmt tmp reg)
- where
- platform = ncgPlatform config
- arg_fmt = cmmTypeFormat $ cmmExprType platform arg
+-- Assumes that the expression does not contain any MachOps,
+-- as per Note [Evaluate C-call arguments before placing in destination registers].
+loadArgIntoReg :: CmmExpr -> Reg -> NatM InstrBlock
+loadArgIntoReg arg reg = do
+ when (debugIsOn && loadIntoRegMightClobberOtherReg arg) $ do
+ platform <- getPlatform
+ massertPpr False $
+ vcat [ text "loadArgIntoReg: arg might contain MachOp"
+ , text "arg:" <+> pdoc platform arg ]
+ arg_code <- getAnyReg arg
+ return $ arg_code reg
-- -----------------------------------------------------------------------------
-- Pushing arguments onto the stack for 64-bit C calls.
=====================================
compiler/GHC/CmmToAsm/X86/Instr.hs
=====================================
@@ -306,6 +306,7 @@ data Instr
| VMOVDQU Format Operand Operand
-- logic operations
+ | PXOR Format Operand Reg
| VPXOR Format Reg Reg Reg
-- Arithmetic
@@ -492,6 +493,12 @@ regUsageOfInstr platform instr
MOVDQU fmt src dst -> mkRU (use_R fmt src []) (use_R fmt dst [])
VMOVDQU fmt src dst -> mkRU (use_R fmt src []) (use_R fmt dst [])
+ PXOR fmt (OpReg src) dst
+ | src == dst
+ -> mkRU [] [mk fmt dst]
+ | otherwise
+ -> mkRU [mk fmt src, mk fmt dst] [mk fmt dst]
+
VPXOR fmt s1 s2 dst
| s1 == s2, s1 == dst
-> mkRU [] [mk fmt dst]
@@ -733,6 +740,7 @@ patchRegsOfInstr platform instr env
MOVDQU fmt src dst -> MOVDQU fmt (patchOp src) (patchOp dst)
VMOVDQU fmt src dst -> VMOVDQU fmt (patchOp src) (patchOp dst)
+ PXOR fmt src dst -> PXOR fmt (patchOp src) (env dst)
VPXOR fmt s1 s2 dst -> VPXOR fmt (env s1) (env s2) (env dst)
VADD fmt s1 s2 dst -> VADD fmt (patchOp s1) (env s2) (env dst)
=====================================
compiler/GHC/CmmToAsm/X86/Ppr.hs
=====================================
@@ -1016,6 +1016,8 @@ pprInstr platform i = case i of
VecFormat 64 FmtInt8 -> text "vmovdqu32" -- require the additional AVX512BW extension
_ -> text "vmovdqu"
+ PXOR format src dst
+ -> pprPXor (text "pxor") format src dst
VPXOR format s1 s2 dst
-> pprXor (text "vpxor") format s1 s2 dst
VEXTRACT format offset from to
@@ -1320,6 +1322,15 @@ pprInstr platform i = case i of
pprReg platform format reg3
]
+ pprPXor :: Line doc -> Format -> Operand -> Reg -> doc
+ pprPXor name format src dst
+ = line $ hcat [
+ pprGenMnemonic name format,
+ pprOperand platform format src,
+ comma,
+ pprReg platform format dst
+ ]
+
pprVxor :: Format -> Operand -> Reg -> Reg -> doc
pprVxor fmt src1 src2 dst
= line $ hcat [
=====================================
compiler/GHC/CmmToLlvm.hs
=====================================
@@ -139,7 +139,7 @@ llvmGroupLlvmGens cmm = do
Nothing -> l
Just (CmmStaticsRaw info_lbl _) -> info_lbl
lml <- strCLabel_llvm l'
- funInsert lml =<< llvmFunTy (map globalRegUse_reg live)
+ funInsert lml =<< llvmFunTy live
return Nothing
cdata <- fmap catMaybes $ mapM split cmm
=====================================
compiler/GHC/CmmToLlvm/Base.hs
=====================================
@@ -12,7 +12,7 @@
module GHC.CmmToLlvm.Base (
LlvmCmmDecl, LlvmBasicBlock,
- LiveGlobalRegs,
+ LiveGlobalRegs, LiveGlobalRegUses,
LlvmUnresData, LlvmData, UnresLabel, UnresStatic,
LlvmM,
@@ -29,6 +29,8 @@ module GHC.CmmToLlvm.Base (
llvmFunSig, llvmFunArgs, llvmStdFunAttrs, llvmFunAlign, llvmInfAlign,
llvmPtrBits, tysToParams, llvmFunSection, padLiveArgs, isFPR,
+ lookupRegUse,
+
strCLabel_llvm,
getGlobalPtr, generateExternDecls,
@@ -58,9 +60,11 @@ import GHC.Types.Unique.Set
import GHC.Types.Unique.Supply
import GHC.Utils.Logger
-import Data.Maybe (fromJust)
import Control.Monad.Trans.State (StateT (..))
-import Data.List (isPrefixOf)
+import Control.Applicative (Alternative((<|>)))
+import Data.Maybe (fromJust, mapMaybe)
+
+import Data.List (find, isPrefixOf)
import qualified Data.List.NonEmpty as NE
import Data.Ord (comparing)
@@ -73,6 +77,7 @@ type LlvmBasicBlock = GenBasicBlock LlvmStatement
-- | Global registers live on proc entry
type LiveGlobalRegs = [GlobalReg]
+type LiveGlobalRegUses = [GlobalRegUse]
-- | Unresolved code.
-- Of the form: (data label, data type, unresolved data)
@@ -116,16 +121,16 @@ llvmGhcCC platform
| otherwise = CC_Ghc
-- | Llvm Function type for Cmm function
-llvmFunTy :: LiveGlobalRegs -> LlvmM LlvmType
+llvmFunTy :: LiveGlobalRegUses -> LlvmM LlvmType
llvmFunTy live = return . LMFunction =<< llvmFunSig' live (fsLit "a") ExternallyVisible
-- | Llvm Function signature
-llvmFunSig :: LiveGlobalRegs -> CLabel -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
+llvmFunSig :: LiveGlobalRegUses -> CLabel -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
llvmFunSig live lbl link = do
lbl' <- strCLabel_llvm lbl
llvmFunSig' live lbl' link
-llvmFunSig' :: LiveGlobalRegs -> LMString -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
+llvmFunSig' :: LiveGlobalRegUses -> LMString -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
llvmFunSig' live lbl link
= do let toParams x | isPointer x = (x, [NoAlias, NoCapture])
| otherwise = (x, [])
@@ -149,16 +154,25 @@ llvmFunSection opts lbl
| otherwise = Nothing
-- | A Function's arguments
-llvmFunArgs :: Platform -> LiveGlobalRegs -> [LlvmVar]
+llvmFunArgs :: Platform -> LiveGlobalRegUses -> [LlvmVar]
llvmFunArgs platform live =
- map (lmGlobalRegArg platform) (filter isPassed allRegs)
+ map (lmGlobalRegArg platform) (mapMaybe isPassed allRegs)
where allRegs = activeStgRegs platform
paddingRegs = padLiveArgs platform live
- isLive r = r `elem` alwaysLive
- || r `elem` live
- || r `elem` paddingRegs
- isPassed r = not (isFPR r) || isLive r
-
+ isLive :: GlobalReg -> Maybe GlobalRegUse
+ isLive r =
+ lookupRegUse r (alwaysLive platform)
+ <|>
+ lookupRegUse r live
+ <|>
+ lookupRegUse r paddingRegs
+ isPassed r =
+ if not (isFPR r)
+ then Just $ GlobalRegUse r (globalRegSpillType platform r)
+ else isLive r
+
+lookupRegUse :: GlobalReg -> [GlobalRegUse] -> Maybe GlobalRegUse
+lookupRegUse r = find ((== r) . globalRegUse_reg)
isFPR :: GlobalReg -> Bool
isFPR (FloatReg _) = True
@@ -179,7 +193,7 @@ isFPR _ = False
-- Invariant: Cmm FPR regs with number "n" maps to real registers with number
-- "n" If the calling convention uses registers in a different order or if the
-- invariant doesn't hold, this code probably won't be correct.
-padLiveArgs :: Platform -> LiveGlobalRegs -> LiveGlobalRegs
+padLiveArgs :: Platform -> LiveGlobalRegUses -> LiveGlobalRegUses
padLiveArgs platform live =
if platformUnregisterised platform
then [] -- not using GHC's register convention for platform.
@@ -188,7 +202,7 @@ padLiveArgs platform live =
----------------------------------
-- handle floating-point registers (FPR)
- fprLive = filter isFPR live -- real live FPR registers
+ fprLive = filter (isFPR . globalRegUse_reg) live -- real live FPR registers
-- we group live registers sharing the same classes, i.e. that use the same
-- set of real registers to be passed. E.g. FloatReg, DoubleReg and XmmReg
@@ -196,39 +210,44 @@ padLiveArgs platform live =
--
classes = NE.groupBy sharesClass fprLive
sharesClass a b = globalRegsOverlap platform (norm a) (norm b) -- check if mapped to overlapping registers
- norm x = fpr_ctor x 1 -- get the first register of the family
+ norm x = globalRegUse_reg (fpr_ctor x 1) -- get the first register of the family
-- For each class, we just have to fill missing registers numbers. We use
-- the constructor of the greatest register to build padding registers.
--
-- E.g. sortedRs = [ F2, XMM4, D5]
-- output = [D1, D3]
+ padded :: [GlobalRegUse]
padded = concatMap padClass classes
+
+ padClass :: NE.NonEmpty GlobalRegUse -> [GlobalRegUse]
padClass rs = go (NE.toList sortedRs) 1
where
- sortedRs = NE.sortBy (comparing fpr_num) rs
+ sortedRs = NE.sortBy (comparing (fpr_num . globalRegUse_reg)) rs
maxr = NE.last sortedRs
ctor = fpr_ctor maxr
go [] _ = []
- go (c1:c2:_) _ -- detect bogus case (see #17920)
+ go (GlobalRegUse c1 _: GlobalRegUse c2 _:_) _ -- detect bogus case (see #17920)
| fpr_num c1 == fpr_num c2
, Just real <- globalRegMaybe platform c1
= sorryDoc "LLVM code generator" $
text "Found two different Cmm registers (" <> ppr c1 <> text "," <> ppr c2 <>
text ") both alive AND mapped to the same real register: " <> ppr real <>
text ". This isn't currently supported by the LLVM backend."
- go (c:cs) f
- | fpr_num c == f = go cs f -- already covered by a real register
- | otherwise = ctor f : go (c:cs) (f + 1) -- add padding register
-
- fpr_ctor :: GlobalReg -> Int -> GlobalReg
- fpr_ctor (FloatReg _) = FloatReg
- fpr_ctor (DoubleReg _) = DoubleReg
- fpr_ctor (XmmReg _) = XmmReg
- fpr_ctor (YmmReg _) = YmmReg
- fpr_ctor (ZmmReg _) = ZmmReg
- fpr_ctor _ = error "fpr_ctor expected only FPR regs"
+ go (cu@(GlobalRegUse c _):cs) f
+ | fpr_num c == f = go cs f -- already covered by a real register
+ | otherwise = ctor f : go (cu:cs) (f + 1) -- add padding register
+
+ fpr_ctor :: GlobalRegUse -> Int -> GlobalRegUse
+ fpr_ctor (GlobalRegUse r fmt) i =
+ case r of
+ FloatReg _ -> GlobalRegUse (FloatReg i) fmt
+ DoubleReg _ -> GlobalRegUse (DoubleReg i) fmt
+ XmmReg _ -> GlobalRegUse (XmmReg i) fmt
+ YmmReg _ -> GlobalRegUse (YmmReg i) fmt
+ ZmmReg _ -> GlobalRegUse (ZmmReg i) fmt
+ _ -> error "fpr_ctor expected only FPR regs"
fpr_num :: GlobalReg -> Int
fpr_num (FloatReg i) = i
=====================================
compiler/GHC/CmmToLlvm/CodeGen.hs
=====================================
@@ -37,13 +37,14 @@ import GHC.Utils.Outputable
import qualified GHC.Utils.Panic as Panic
import GHC.Utils.Misc
+import Control.Applicative (Alternative((<|>)))
import Control.Monad.Trans.Class
import Control.Monad.Trans.Writer
import Control.Monad
import qualified Data.Semigroup as Semigroup
import Data.List ( nub )
-import Data.Maybe ( catMaybes )
+import Data.Maybe ( catMaybes, isJust )
type Atomic = Maybe MemoryOrdering
type LlvmStatements = OrdList LlvmStatement
@@ -57,7 +58,7 @@ genLlvmProc :: RawCmmDecl -> LlvmM [LlvmCmmDecl]
genLlvmProc (CmmProc infos lbl live graph) = do
let blocks = toBlockListEntryFirstFalseFallthrough graph
- (lmblocks, lmdata) <- basicBlocksCodeGen (map globalRegUse_reg live) blocks
+ (lmblocks, lmdata) <- basicBlocksCodeGen live blocks
let info = mapLookup (g_entry graph) infos
proc = CmmProc info lbl live (ListGraph lmblocks)
return (proc:lmdata)
@@ -76,7 +77,7 @@ newtype UnreachableBlockId = UnreachableBlockId BlockId
-- | Generate code for a list of blocks that make up a complete
-- procedure. The first block in the list is expected to be the entry
-- point.
-basicBlocksCodeGen :: LiveGlobalRegs -> [CmmBlock]
+basicBlocksCodeGen :: LiveGlobalRegUses -> [CmmBlock]
-> LlvmM ([LlvmBasicBlock], [LlvmCmmDecl])
basicBlocksCodeGen _ [] = panic "no entry block!"
basicBlocksCodeGen live cmmBlocks
@@ -152,7 +153,7 @@ stmtToInstrs ubid stmt = case stmt of
-- Tail call
CmmCall { cml_target = arg,
- cml_args_regs = live } -> genJump arg $ map globalRegUse_reg live
+ cml_args_regs = live } -> genJump arg live
_ -> panic "Llvm.CodeGen.stmtToInstrs"
@@ -1050,7 +1051,7 @@ cmmPrimOpFunctions mop = do
-- | Tail function calls
-genJump :: CmmExpr -> [GlobalReg] -> LlvmM StmtData
+genJump :: CmmExpr -> LiveGlobalRegUses -> LlvmM StmtData
-- Call to known function
genJump (CmmLit (CmmLabel lbl)) live = do
@@ -2056,14 +2057,13 @@ getCmmReg (CmmLocal (LocalReg un _))
-- have been assigned a value at some point, triggering
-- "funPrologue" to allocate it on the stack.
-getCmmReg (CmmGlobal g)
- = do let r = globalRegUse_reg g
- onStack <- checkStackReg r
+getCmmReg (CmmGlobal ru@(GlobalRegUse r _))
+ = do onStack <- checkStackReg r
platform <- getPlatform
if onStack
- then return (lmGlobalRegVar platform r)
+ then return (lmGlobalRegVar platform ru)
else pprPanic "getCmmReg: Cmm register " $
- ppr g <> text " not stack-allocated!"
+ ppr r <> text " not stack-allocated!"
-- | Return the value of a given register, as well as its type. Might
-- need to be load from stack.
@@ -2074,7 +2074,7 @@ getCmmRegVal reg =
onStack <- checkStackReg (globalRegUse_reg g)
platform <- getPlatform
if onStack then loadFromStack else do
- let r = lmGlobalRegArg platform (globalRegUse_reg g)
+ let r = lmGlobalRegArg platform g
return (r, getVarType r, nilOL)
_ -> loadFromStack
where loadFromStack = do
@@ -2187,8 +2187,9 @@ convertMemoryOrdering MemOrderSeqCst = SyncSeqCst
-- question is never written. Therefore we skip it where we can to
-- save a few lines in the output and hopefully speed compilation up a
-- bit.
-funPrologue :: LiveGlobalRegs -> [CmmBlock] -> LlvmM StmtData
+funPrologue :: LiveGlobalRegUses -> [CmmBlock] -> LlvmM StmtData
funPrologue live cmmBlocks = do
+ platform <- getPlatform
let getAssignedRegs :: CmmNode O O -> [CmmReg]
getAssignedRegs (CmmAssign reg _) = [reg]
@@ -2196,7 +2197,8 @@ funPrologue live cmmBlocks = do
getAssignedRegs _ = []
getRegsBlock (_, body, _) = concatMap getAssignedRegs $ blockToList body
assignedRegs = nub $ concatMap (getRegsBlock . blockSplit) cmmBlocks
- isLive r = r `elem` alwaysLive || r `elem` live
+ mbLive r =
+ lookupRegUse r (alwaysLive platform) <|> lookupRegUse r live
platform <- getPlatform
stmtss <- forM assignedRegs $ \reg ->
@@ -2205,12 +2207,12 @@ funPrologue live cmmBlocks = do
let (newv, stmts) = allocReg reg
varInsert un (pLower $ getVarType newv)
return stmts
- CmmGlobal (GlobalRegUse r _) -> do
- let reg = lmGlobalRegVar platform r
- arg = lmGlobalRegArg platform r
+ CmmGlobal ru@(GlobalRegUse r _) -> do
+ let reg = lmGlobalRegVar platform ru
+ arg = lmGlobalRegArg platform ru
ty = (pLower . getVarType) reg
trash = LMLitVar $ LMUndefLit ty
- rval = if isLive r then arg else trash
+ rval = if isJust (mbLive r) then arg else trash
alloc = Assignment reg $ Alloca (pLower $ getVarType reg) 1
markStackReg r
return $ toOL [alloc, Store rval reg Nothing []]
@@ -2222,7 +2224,7 @@ funPrologue live cmmBlocks = do
-- | Function epilogue. Load STG variables to use as argument for call.
-- STG Liveness optimisation done here.
-funEpilogue :: LiveGlobalRegs -> LlvmM ([LlvmVar], LlvmStatements)
+funEpilogue :: LiveGlobalRegUses -> LlvmM ([LlvmVar], LlvmStatements)
funEpilogue live = do
platform <- getPlatform
@@ -2248,12 +2250,16 @@ funEpilogue live = do
let allRegs = activeStgRegs platform
loads <- forM allRegs $ \r -> if
-- load live registers
- | r `elem` alwaysLive -> loadExpr (GlobalRegUse r (globalRegSpillType platform r))
- | r `elem` live -> loadExpr (GlobalRegUse r (globalRegSpillType platform r))
+ | Just ru <- lookupRegUse r (alwaysLive platform)
+ -> loadExpr ru
+ | Just ru <- lookupRegUse r live
+ -> loadExpr ru
-- load all non Floating-Point Registers
- | not (isFPR r) -> loadUndef r
+ | not (isFPR r)
+ -> loadUndef (GlobalRegUse r (globalRegSpillType platform r))
-- load padding Floating-Point Registers
- | r `elem` paddingRegs -> loadUndef r
+ | Just ru <- lookupRegUse r paddingRegs
+ -> loadUndef ru
| otherwise -> return (Nothing, nilOL)
let (vars, stmts) = unzip loads
@@ -2263,7 +2269,7 @@ funEpilogue live = do
--
-- This is for Haskell functions, function type is assumed, so doesn't work
-- with foreign functions.
-getHsFunc :: LiveGlobalRegs -> CLabel -> LlvmM ExprData
+getHsFunc :: LiveGlobalRegUses -> CLabel -> LlvmM ExprData
getHsFunc live lbl
= do fty <- llvmFunTy live
name <- strCLabel_llvm lbl
=====================================
compiler/GHC/CmmToLlvm/Ppr.hs
=====================================
@@ -49,9 +49,8 @@ pprLlvmCmmDecl (CmmData _ lmdata) = do
return ( vcat $ map (pprLlvmData opts) lmdata
, vcat $ map (pprLlvmData opts) lmdata)
-pprLlvmCmmDecl (CmmProc mb_info entry_lbl liveWithUses (ListGraph blks))
- = do let live = map globalRegUse_reg liveWithUses
- lbl = case mb_info of
+pprLlvmCmmDecl (CmmProc mb_info entry_lbl live (ListGraph blks))
+ = do let lbl = case mb_info of
Nothing -> entry_lbl
Just (CmmStaticsRaw info_lbl _) -> info_lbl
link = if externallyVisibleCLabel lbl
=====================================
compiler/GHC/CmmToLlvm/Regs.hs
=====================================
@@ -14,25 +14,27 @@ import GHC.Prelude
import GHC.Llvm
import GHC.Cmm.Expr
+import GHC.CmmToAsm.Format
import GHC.Platform
import GHC.Data.FastString
import GHC.Utils.Panic ( panic )
import GHC.Types.Unique
+
-- | Get the LlvmVar function variable storing the real register
-lmGlobalRegVar :: Platform -> GlobalReg -> LlvmVar
+lmGlobalRegVar :: Platform -> GlobalRegUse -> LlvmVar
lmGlobalRegVar platform = pVarLift . lmGlobalReg platform "_Var"
-- | Get the LlvmVar function argument storing the real register
-lmGlobalRegArg :: Platform -> GlobalReg -> LlvmVar
+lmGlobalRegArg :: Platform -> GlobalRegUse -> LlvmVar
lmGlobalRegArg platform = lmGlobalReg platform "_Arg"
{- Need to make sure the names here can't conflict with the unique generated
names. Uniques generated names containing only base62 chars. So using say
the '_' char guarantees this.
-}
-lmGlobalReg :: Platform -> String -> GlobalReg -> LlvmVar
-lmGlobalReg platform suf reg
+lmGlobalReg :: Platform -> String -> GlobalRegUse -> LlvmVar
+lmGlobalReg platform suf (GlobalRegUse reg ty)
= case reg of
BaseReg -> ptrGlobal $ "Base" ++ suf
Sp -> ptrGlobal $ "Sp" ++ suf
@@ -88,13 +90,26 @@ lmGlobalReg platform suf reg
ptrGlobal name = LMNLocalVar (fsLit name) (llvmWordPtr platform)
floatGlobal name = LMNLocalVar (fsLit name) LMFloat
doubleGlobal name = LMNLocalVar (fsLit name) LMDouble
- xmmGlobal name = LMNLocalVar (fsLit name) (LMVector 4 (LMInt 32))
- ymmGlobal name = LMNLocalVar (fsLit name) (LMVector 8 (LMInt 32))
- zmmGlobal name = LMNLocalVar (fsLit name) (LMVector 16 (LMInt 32))
+ fmt = cmmTypeFormat ty
+ xmmGlobal name = LMNLocalVar (fsLit name) (formatLlvmType fmt)
+ ymmGlobal name = LMNLocalVar (fsLit name) (formatLlvmType fmt)
+ zmmGlobal name = LMNLocalVar (fsLit name) (formatLlvmType fmt)
+
+formatLlvmType :: Format -> LlvmType
+formatLlvmType II8 = LMInt 8
+formatLlvmType II16 = LMInt 16
+formatLlvmType II32 = LMInt 32
+formatLlvmType II64 = LMInt 64
+formatLlvmType FF32 = LMFloat
+formatLlvmType FF64 = LMDouble
+formatLlvmType (VecFormat l sFmt) = LMVector l (formatLlvmType $ scalarFormatFormat sFmt)
-- | A list of STG Registers that should always be considered alive
-alwaysLive :: [GlobalReg]
-alwaysLive = [BaseReg, Sp, Hp, SpLim, HpLim, node]
+alwaysLive :: Platform -> [GlobalRegUse]
+alwaysLive platform =
+ [ GlobalRegUse r (globalRegSpillType platform r)
+ | r <- [BaseReg, Sp, Hp, SpLim, HpLim, node]
+ ]
-- | STG Type Based Alias Analysis hierarchy
stgTBAA :: [(Unique, LMString, Maybe Unique)]
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/a34a4dc2269300b31c06173dad5d64f55317ecb4...c27747f25c7209555ae96573f4525839297b7703
--
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/a34a4dc2269300b31c06173dad5d64f55317ecb4...c27747f25c7209555ae96573f4525839297b7703
You're receiving this email because of your account on gitlab.haskell.org.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240920/789444a5/attachment-0001.html>
More information about the ghc-commits
mailing list