[Git][ghc/ghc][wip/ncg-simd] Fix treatment of signed zero in vector negation
sheaf (@sheaf)
gitlab at gitlab.haskell.org
Fri Jun 28 11:18:16 UTC 2024
sheaf pushed to branch wip/ncg-simd at Glasgow Haskell Compiler / GHC
Commits:
8edac7a8 by sheaf at 2024-06-28T13:18:03+02:00
Fix treatment of signed zero in vector negation
This commit fixes the handling of signed zero in floating-point vector
negation.
A slight hack was introduced to work around the fact that Cmm doesn't
currently have a notion of signed floating point literals
(see get_float_broadcast_value_reg). This can be removed once CmmFloat
can express the value -0.0.
The simd006 test has been updated to use a stricter notion of equality
of floating-point values, which ensure the validity of this change.
- - - - -
5 changed files:
- compiler/GHC/CmmToAsm/Format.hs
- compiler/GHC/CmmToAsm/X86/CodeGen.hs
- compiler/GHC/CmmToAsm/X86/Instr.hs
- compiler/GHC/CmmToAsm/X86/Ppr.hs
- testsuite/tests/simd/should_run/simd006.hs
Changes:
=====================================
compiler/GHC/CmmToAsm/Format.hs
=====================================
@@ -24,6 +24,7 @@ module GHC.CmmToAsm.Format (
isVecFormat,
cmmTypeFormat,
formatToWidth,
+ scalarWidth,
formatInBytes,
isFloatScalarFormat,
scalarFormatFormat,
=====================================
compiler/GHC/CmmToAsm/X86/CodeGen.hs
=====================================
@@ -944,6 +944,7 @@ getRegister' _ _ (CmmLit lit@(CmmFloat f w)) =
float_const_sse2 where
float_const_sse2
| f == 0.0 = do
+ -- TODO: this mishandles negative zero floating point literals.
let
format = floatFormat w
code dst = unitOL (XOR format (OpReg dst) (OpReg dst))
@@ -951,9 +952,7 @@ getRegister' _ _ (CmmLit lit@(CmmFloat f w)) =
-- They all appear to do the same thing --SDM
return (Any format code)
- | otherwise = do
- Amode addr code <- memConstant (mkAlignment $ widthInBytes w) lit
- loadFloatAmode w addr code
+ | otherwise = getFloatLitRegister lit
-- catch simple cases of zero- or sign-extended load
getRegister' _ _ (CmmMachOp (MO_UU_Conv W8 W32) [CmmLoad addr _ _]) = do
@@ -1100,10 +1099,11 @@ getRegister' platform is32Bit (CmmMachOp mop [x]) = do -- unary MachOps
MO_FS_Truncate from to -> coerceFP2Int from to x
MO_SF_Round from to -> coerceInt2FP from to x
- MO_VF_Neg l w | avx -> vector_float_negate_avx l w x
- | sse && sse2 -> vector_float_negate_sse l w x
+ MO_VF_Neg l w | avx -> vector_float_negate_avx l w x
+ | sse && w == W32 -> vector_float_negate_sse l w x
+ | sse2 && w == W64 -> vector_float_negate_sse l w x
| otherwise
- -> sorry "Please enable the -mavx or -msse, -msse2 flag"
+ -> sorry "Please enable the -mavx or -msse, -msse2 flag"
-- SIMD NCG TODO: add integer negation
MO_VS_Neg {} -> needLlvm mop
@@ -1111,7 +1111,7 @@ getRegister' platform is32Bit (CmmMachOp mop [x]) = do -- unary MachOps
| sse4_1 -> vector_float_broadcast_sse l W32 x
| otherwise
-> sorry "Please enable the -mavx or -msse4 flag"
- MO_VF_Broadcast l W64 | sse2 -> vector_float_broadcast_avx l W64 x
+ MO_VF_Broadcast l W64 | sse2 -> vector_float_broadcast_sse l W64 x
| otherwise -> sorry "Please enable the -msse2 flag"
MO_VF_Broadcast {} -> incorrectOperands
@@ -1222,70 +1222,77 @@ getRegister' platform is32Bit (CmmMachOp mop [x]) = do -- unary MachOps
vector_float_negate_avx :: Length -> Width -> CmmExpr -> NatM Register
vector_float_negate_avx l w expr = do
- tmp <- getNewRegNat (VecFormat l FmtFloat)
- (reg, exp) <- getSomeReg expr
- Amode addr addr_code <- memConstant (mkAlignment $ widthInBytes W32) (CmmFloat 0.0 W32)
- let format = case w of
- W32 -> VecFormat l FmtFloat
- W64 -> VecFormat l FmtDouble
- _ -> pprPanic "Cannot negate vector of width" (ppr w)
- code dst = case w of
- W32 -> exp `appOL` addr_code `snocOL`
- (VBROADCAST format addr tmp) `snocOL`
- (VSUB format (OpReg reg) tmp dst)
- W64 -> exp `appOL` addr_code `snocOL`
- (MOVL format (OpAddr addr) (OpReg tmp)) `snocOL`
- (MOVH format (OpAddr addr) (OpReg tmp)) `snocOL`
- (VSUB format (OpReg reg) tmp dst)
- _ -> pprPanic "Cannot negate vector of width" (ppr w)
- return (Any format code)
+ let fmt :: Format
+ mask :: CmmLit
+ (fmt, mask) = case w of
+ W32 -> (VecFormat l FmtFloat , CmmInt (bit 31) w) -- TODO: these should be negative 0 floating point literals,
+ W64 -> (VecFormat l FmtDouble, CmmInt (bit 63) w) -- but we don't currently have those in Cmm.
+ _ -> panic "AVX floating-point negation: elements must be FF32 or FF64"
+ (maskReg, maskCode) <- getSomeReg (CmmMachOp (MO_VF_Broadcast l w) [CmmLit mask])
+ (reg, exp) <- getSomeReg expr
+ let code dst = maskCode `appOL`
+ exp `snocOL`
+ (VMOVU fmt (OpReg reg) (OpReg dst)) `snocOL`
+ (VXOR fmt (OpReg maskReg) dst dst)
+ return (Any fmt code)
vector_float_negate_sse :: Length -> Width -> CmmExpr -> NatM Register
vector_float_negate_sse l w expr = do
- tmp <- getNewRegNat (VecFormat l FmtFloat)
- (reg, exp) <- getSomeReg expr
- let format = case w of
- W32 -> VecFormat l FmtFloat
- W64 -> VecFormat l FmtDouble
- _ -> pprPanic "Cannot negate vector of width" (ppr w)
- code dst = exp `snocOL`
- (XOR format (OpReg tmp) (OpReg tmp)) `snocOL`
- (MOVU format (OpReg tmp) (OpReg dst)) `snocOL`
- (SUB format (OpReg reg) (OpReg dst))
- return (Any format code)
+ let fmt :: Format
+ mask :: CmmLit
+ (fmt, mask) = case w of
+ W32 -> (VecFormat l FmtFloat , CmmInt (bit 31) w) -- Same comment as for vector_float_negate_avx,
+ W64 -> (VecFormat l FmtDouble, CmmInt (bit 63) w) -- these should be -0.0 CmmFloat values.
+ _ -> panic "SSE floating-point negation: elements must be FF32 or FF64"
+ (maskReg, maskCode) <- getSomeReg (CmmMachOp (MO_VF_Broadcast l w) [CmmLit mask])
+ (reg, exp) <- getSomeReg expr
+ let code dst = maskCode `appOL`
+ exp `snocOL`
+ (MOVU fmt (OpReg reg) (OpReg dst)) `snocOL`
+ (XOR fmt (OpReg maskReg) (OpReg dst))
+ return (Any fmt code)
-----------------------
+
+ -- Like 'getSomeReg', but with special handling for int literals
+ -- used as floating point values, to work around the fact that we don't
+ -- have negative zero floating point literals in Cmm yet.
+ --
+ -- This should get removed once we have negative zero in CmmFloat.
+ get_float_broadcast_value_reg expr = case expr of
+ CmmLit lit -> do
+ r <- getFloatLitRegister lit
+ case r of
+ Any rep code -> do
+ tmp <- getNewRegNat rep
+ return (tmp, code tmp)
+ Fixed _ reg code ->
+ return (reg, code)
+ _ -> getSomeReg expr
+
vector_float_broadcast_avx :: Length
-> Width
-> CmmExpr
-> NatM Register
vector_float_broadcast_avx len W32 expr
= do
- (reg, exp) <- getSomeReg expr
+ (reg, exp) <- get_float_broadcast_value_reg expr
let f = VecFormat len FmtFloat
addr = spRel platform 0
- in return $ Any f (\dst -> exp `snocOL`
- (MOVU f (OpReg reg) (OpAddr addr)) `snocOL`
- (VBROADCAST f addr dst))
- vector_float_broadcast_avx len W64 expr
- = do
- (reg, exp) <- getSomeReg expr
- let f = VecFormat len FmtDouble
- addr = spRel platform 0
in return $ Any f (\dst -> exp `snocOL`
(MOVU f (OpReg reg) (OpAddr addr)) `snocOL`
- (MOVL f (OpAddr addr) (OpReg dst)) `snocOL`
- (MOVH f (OpAddr addr) (OpReg dst)))
- vector_float_broadcast_avx _ _ c
- = pprPanic "Broadcast not supported for : " (pdoc platform c)
- -----------------------
+ (VBROADCAST f addr dst))
+ vector_float_broadcast_avx l w _
+ -- NB: for some reason, VBROADCASTSD does not support xmm, only ymm.
+ = pprPanic "vector_float_broadcast_avx" (text "l" <+> ppr l $$ text "w" <+> ppr w)
+
vector_float_broadcast_sse :: Length
-> Width
-> CmmExpr
-> NatM Register
vector_float_broadcast_sse len W32 expr
= do
- (reg, exp) <- getSomeReg expr
+ (reg, exp) <- get_float_broadcast_value_reg expr
let f = VecFormat len FmtFloat
addr = spRel platform 0
code dst = exp `snocOL`
@@ -1299,6 +1306,15 @@ getRegister' platform is32Bit (CmmMachOp mop [x]) = do -- unary MachOps
INSERTPS f (ImmInt imm) (OpAddr addr) dst
in return $ Any f code
+ vector_float_broadcast_sse len W64 expr
+ = do
+ (reg, exp) <- get_float_broadcast_value_reg expr
+ let f = VecFormat len FmtDouble
+ addr = spRel platform 0
+ in return $ Any f (\dst -> exp `snocOL`
+ (MOVU f (OpReg reg) (OpAddr addr)) `snocOL`
+ (MOVL f (OpAddr addr) (OpReg dst)) `snocOL`
+ (MOVH f (OpAddr addr) (OpReg dst)))
vector_float_broadcast_sse _ _ c
= pprPanic "Broadcast not supported for : " (pdoc platform c)
@@ -2057,33 +2073,48 @@ getRegister' platform is32Bit (CmmLit lit)
-- note2: all labels are small, because we're assuming the
-- small memory model. See Note [%rip-relative addressing on x86-64].
-getRegister' platform _ (CmmLit lit)
- | isVecType cmmtype = vectorRegister cmmtype
- | otherwise = standardRegister cmmtype
- where
- cmmtype = cmmLitType platform lit
- vectorRegister ctype
- | case lit of { CmmVec fs -> all (\case { CmmInt i _ -> i == 0; CmmFloat f _ -> f == 0; _ -> False }) fs; _ -> False }
- = -- NOTE:
- -- This operation is only used to zero a register. For loading a
- -- vector literal there are pack and broadcast operations
- let format = cmmTypeFormat ctype
- code dst = unitOL (XOR format (OpReg dst) (OpReg dst))
- in return (Any format code)
+getRegister' platform _ (CmmLit lit) =
+ case fmt of
+ VecFormat l sFmt
+ | case lit of { CmmVec fs -> all is_zero fs; _ -> False }
+ -> let code dst = unitOL (XOR fmt (OpReg dst) (OpReg dst))
+ in return (Any fmt code)
+ | Just f <- case lit of { CmmVec (f:fs) | all (== f) fs -> Just f; _ -> Nothing }
+ -> do config <- getConfig
+ let w = scalarWidth sFmt
+ broadcast = if isFloatScalarFormat sFmt
+ then MO_VF_Broadcast l w
+ else MO_V_Broadcast l w
+ (valReg, valCode) <- getSomeReg (CmmMachOp broadcast [CmmLit f])
+ let code dst =
+ valCode `snocOL`
+ (mkRegRegMoveInstr config fmt valReg dst)
+ return $ Any fmt code
| otherwise
- = pprPanic "getRegister': no support for (nonzero) vector literals" $
- vcat [ text "lit:" <+> ppr lit ]
- -- SIMD NCG TODO: can we do better here?
- standardRegister ctype
- = do
- let format = cmmTypeFormat ctype
- imm = litToImm lit
- code dst = unitOL (MOV format (OpImm imm) (OpReg dst))
- return (Any format code)
+ -- SIMD NCG TODO: handle this case as well.
+ -> pprPanic "getRegister': non-constant vector literals not supported"
+ (ppr lit)
+ where
+ is_zero (CmmInt i _) = i == 0
+ is_zero (CmmFloat f _) = f == 0 -- TODO: mishandles negative zero
+ is_zero _ = False
+
+ _ -> let imm = litToImm lit
+ code dst = unitOL (MOV fmt (OpImm imm) (OpReg dst))
+ in return (Any fmt code)
+ where
+ cmmTy = cmmLitType platform lit
+ fmt = cmmTypeFormat cmmTy
getRegister' platform _ other
= pprPanic "getRegister(x86)" (pdoc platform other)
+getFloatLitRegister :: CmmLit -> NatM Register
+getFloatLitRegister lit = do
+ let w :: Width
+ w = case lit of { CmmInt _ w -> w; CmmFloat _ w -> w; _ -> panic "getFloatLitRegister" (ppr lit) }
+ Amode addr code <- memConstant (mkAlignment $ widthInBytes w) lit
+ loadFloatAmode w addr code
intLoadCode :: (Operand -> Operand -> Instr) -> CmmExpr
-> NatM (Reg -> InstrBlock)
=====================================
compiler/GHC/CmmToAsm/X86/Instr.hs
=====================================
@@ -259,6 +259,8 @@ data Instr
| AND Format Operand Operand
| OR Format Operand Operand
| XOR Format Operand Operand
+ -- | AVX bitwise logical XOR operation
+ | VXOR Format Operand Reg Reg
| NOT Format Operand
| NEGI Format Operand -- NEG instruction (name clash with Cond)
| BSWAP Format Reg
@@ -477,9 +479,16 @@ regUsageOfInstr platform instr
OR fmt src dst -> usageRM fmt src dst
XOR fmt (OpReg src) (OpReg dst)
- | src == dst -> mkRU [] [mk fmt dst]
+ | src == dst
+ -> mkRU [] [mk fmt dst]
+ XOR fmt src dst
+ -> usageRM fmt src dst
+ VXOR fmt (OpReg src1) src2 dst
+ | src1 == src2, src1 == dst
+ -> mkRU [] [mk fmt dst]
+ VXOR fmt src1 src2 dst
+ -> mkRU (use_R fmt src1 [mk fmt src2]) [mk fmt dst]
- XOR fmt src dst -> usageRM fmt src dst
NOT fmt op -> usageM fmt op
BSWAP fmt reg -> mkRU [mk fmt reg] [mk fmt reg]
NEGI fmt op -> usageM fmt op
@@ -722,6 +731,7 @@ patchRegsOfInstr platform instr env
AND fmt src dst -> patch2 (AND fmt) src dst
OR fmt src dst -> patch2 (OR fmt) src dst
XOR fmt src dst -> patch2 (XOR fmt) src dst
+ VXOR fmt src1 src2 dst -> VXOR fmt (patchOp src1) (env src2) (env dst)
NOT fmt op -> patch1 (NOT fmt) op
BSWAP fmt reg -> BSWAP fmt (env reg)
NEGI fmt op -> patch1 (NEGI fmt) op
@@ -764,6 +774,8 @@ patchRegsOfInstr platform instr env
LOCATION {} -> instr
UNWIND {} -> instr
DELTA _ -> instr
+ LDATA {} -> instr
+ NEWBLOCK {} -> instr
JXX _ _ -> instr
JXX_GBL _ _ -> instr
@@ -830,8 +842,6 @@ patchRegsOfInstr platform instr env
PUNPCKLQDQ fmt src dst
-> PUNPCKLQDQ fmt (patchOp src) (env dst)
- _other -> panic "patchRegs: unrecognised instr"
-
where
patch1 :: (Operand -> a) -> Operand -> a
patch1 insn op = insn $! patchOp op
=====================================
compiler/GHC/CmmToAsm/X86/Ppr.hs
=====================================
@@ -724,11 +724,14 @@ pprInstr platform i = case i of
XOR format src dst
-> pprFormatOpOp (text "xor") format src dst
+ VXOR fmt src1 src2 dst
+ -> pprVxor fmt src1 src2 dst
+
POPCNT format src dst
-> pprOpOp (text "popcnt") format src (OpReg dst)
LZCNT format src dst
- -> pprOpOp (text "lzcnt") format src (OpReg dst)
+ -> pprOpOp (text "lzcnt") format src (OpReg dst)
TZCNT format src dst
-> pprOpOp (text "tzcnt") format src (OpReg dst)
@@ -1276,6 +1279,23 @@ pprInstr platform i = case i of
pprReg platform format reg3
]
+ pprVxor :: Format -> Operand -> Reg -> Reg -> doc
+ pprVxor fmt src1 src2 dst
+ = line $ hcat [
+ pprGenMnemonic mem fmt,
+ pprOperand platform fmt src1,
+ comma,
+ pprReg platform fmt src2,
+ comma,
+ pprReg platform fmt dst
+ ]
+ where
+ mem = case fmt of
+ VecFormat _ FmtFloat -> text "vxorps"
+ VecFormat _ FmtDouble -> text "vxorpd"
+ _ -> pprPanic "GHC.CmmToAsm.X86.Ppr.pprVxor: elementy type must be Float or Double"
+ (ppr fmt)
+
pprInsert :: Line doc -> Format -> Imm -> Operand -> Reg -> doc
pprInsert name format off src dst
= line $ hcat [
=====================================
testsuite/tests/simd/should_run/simd006.hs
=====================================
@@ -127,17 +127,15 @@ instance Arbitrary Word32 where
newtype FloatNT = FloatNT Float
deriving newtype (Show, Num)
instance Eq FloatNT where
- FloatNT f1 == FloatNT f2 = f1 == f2
- -- TODO: tests fail with this equality due to signed zeros
- -- castFloatToWord32 f1 == castFloatToWord32 f2
+ FloatNT f1 == FloatNT f2 =
+ castFloatToWord32 f1 == castFloatToWord32 f2
instance Arbitrary FloatNT where
arbitrary = FloatNT . castWord32ToFloat <$> arbitrary
newtype DoubleNT = DoubleNT Double
deriving newtype (Show, Num)
instance Eq DoubleNT where
- DoubleNT d1 == DoubleNT d2 = d1 == d2
- -- TODO: tests fail with this equality due to signed zeros
- -- castDoubleToWord64 d1 == castDoubleToWord64 d2
+ DoubleNT d1 == DoubleNT d2 =
+ castDoubleToWord64 d1 == castDoubleToWord64 d2
instance Arbitrary DoubleNT where
arbitrary = DoubleNT . castWord64ToDouble <$> arbitrary
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/8edac7a83baf3bbdf2df1f1d4a2ef1e31e6d8aaa
--
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/8edac7a83baf3bbdf2df1f1d4a2ef1e31e6d8aaa
You're receiving this email because of your account on gitlab.haskell.org.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240628/c22f7d7c/attachment-0001.html>
More information about the ghc-commits
mailing list