[Git][ghc/ghc][master] Implements MO_S_Mul2 and MO_U_Mul2 using the UMULH, UMULL and SMULH instructions for AArch64
Marge Bot (@marge-bot)
gitlab at gitlab.haskell.org
Wed Apr 17 00:06:40 UTC 2024
Marge Bot pushed to branch master at Glasgow Haskell Compiler / GHC
Commits:
dbdf1995 by Alex Mason at 2024-04-15T15:28:26+10:00
Implements MO_S_Mul2 and MO_U_Mul2 using the UMULH, UMULL and SMULH instructions for AArch64
Also adds a test for MO_S_Mul2
- - - - -
10 changed files:
- compiler/GHC/CmmToAsm/AArch64/CodeGen.hs
- compiler/GHC/CmmToAsm/AArch64/Instr.hs
- compiler/GHC/CmmToAsm/AArch64/Ppr.hs
- compiler/GHC/Driver/Config/StgToCmm.hs
- compiler/GHC/StgToCmm/Config.hs
- compiler/GHC/StgToCmm/Prim.hs
- testsuite/tests/numeric/should_run/all.T
- + testsuite/tests/numeric/should_run/mul2int.hs
- + testsuite/tests/numeric/should_run/mul2int.stdout
- + testsuite/tests/numeric/should_run/mul2int.stdout-ws-32
Changes:
=====================================
compiler/GHC/CmmToAsm/AArch64/CodeGen.hs
=====================================
@@ -1556,7 +1556,7 @@ genCCall target dest_regs arg_regs bid = do
-- pprTraceM "genCCall target" (ppr target)
-- pprTraceM "genCCall formal" (ppr dest_regs)
-- pprTraceM "genCCall actual" (ppr arg_regs)
-
+ platform <- getPlatform
case target of
-- The target :: ForeignTarget call can either
-- be a foreign procedure with an address expr
@@ -1584,7 +1584,6 @@ genCCall target dest_regs arg_regs bid = do
let (_res_hints, arg_hints) = foreignTargetHints target
arg_regs'' = zipWith (\(r, f, c) h -> (r,f,h,c)) arg_regs' arg_hints
- platform <- getPlatform
let packStack = platformOS platform == OSDarwin
(stackSpace', passRegs, passArgumentsCode) <- passArguments packStack allGpArgRegs allFpArgRegs arg_regs'' 0 [] nilOL
@@ -1625,6 +1624,139 @@ genCCall target dest_regs arg_regs bid = do
| [arg_reg] <- arg_regs, [dest_reg] <- dest_regs ->
unaryFloatOp W64 (\d x -> unitOL $ FABS d x) arg_reg dest_reg
+ PrimTarget (MO_S_Mul2 w)
+ -- Life is easier when we're working with word sized operands,
+ -- we can use SMULH to compute the high 64 bits, and dst_needed
+ -- checks if the high half's bits are all the same as the low half's
+ -- top bit.
+ | w == W64
+ , [src_a, src_b] <- arg_regs
+ -- dst_needed = did the result fit into just the low half
+ , [dst_needed, dst_hi, dst_lo] <- dest_regs
+ -> do
+ (reg_a, _format_x, code_x) <- getSomeReg src_a
+ (reg_b, _format_y, code_y) <- getSomeReg src_b
+
+ let lo = getRegisterReg platform (CmmLocal dst_lo)
+ hi = getRegisterReg platform (CmmLocal dst_hi)
+ nd = getRegisterReg platform (CmmLocal dst_needed)
+ return (
+ code_x `appOL`
+ code_y `snocOL`
+ MUL (OpReg W64 lo) (OpReg W64 reg_a) (OpReg W64 reg_b) `snocOL`
+ SMULH (OpReg W64 hi) (OpReg W64 reg_a) (OpReg W64 reg_b) `snocOL`
+ -- Are all high bits equal to the sign bit of the low word?
+ -- nd = (hi == ASR(lo,width-1)) ? 1 : 0
+ CMP (OpReg W64 hi) (OpRegShift W64 lo SASR (widthInBits w - 1)) `snocOL`
+ CSET (OpReg W64 nd) NE
+ , Nothing)
+ -- For sizes < platform width, we can just perform a multiply and shift
+ -- using the normal 64 bit multiply. Calculating the dst_needed value is
+ -- complicated a little by the need to be careful when truncation happens.
+ -- Currently this case can't be generated since
+ -- timesInt2# :: Int# -> Int# -> (# Int#, Int#, Int# #)
+ -- TODO: Should this be removed or would other primops be useful?
+ | w < W64
+ , [src_a, src_b] <- arg_regs
+ , [dst_needed, dst_hi, dst_lo] <- dest_regs
+ -> do
+ (reg_a', _format_x, code_a) <- getSomeReg src_a
+ (reg_b', _format_y, code_b) <- getSomeReg src_b
+
+ let lo = getRegisterReg platform (CmmLocal dst_lo)
+ hi = getRegisterReg platform (CmmLocal dst_hi)
+ nd = getRegisterReg platform (CmmLocal dst_needed)
+ -- Do everything in a full 64 bit registers
+ w' = platformWordWidth platform
+
+ (reg_a, code_a') <- signExtendReg w w' reg_a'
+ (reg_b, code_b') <- signExtendReg w w' reg_b'
+
+ return (
+ code_a `appOL`
+ code_b `appOL`
+ code_a' `appOL`
+ code_b' `snocOL`
+ -- the low 2w' of lo contains the full multiplication;
+ -- eg: int8 * int8 -> int16 result
+ -- so lo is in the last w of the register, and hi is in the second w.
+ SMULL (OpReg w' lo) (OpReg w' reg_a) (OpReg w' reg_b) `snocOL`
+ -- Make sure we hold onto the sign bits for dst_needed
+ ASR (OpReg w' hi) (OpReg w' lo) (OpImm (ImmInt $ widthInBits w)) `appOL`
+ -- lo can now be truncated so we can get at it's top bit easily.
+ truncateReg w' w lo `snocOL`
+ -- Note the use of CMN (compare negative), not CMP: we want to
+ -- test if the top half is negative one and the top
+ -- bit of the bottom half is positive one. eg:
+ -- hi = 0b1111_1111 (actually 64 bits)
+ -- lo = 0b1010_1111 (-81, so the result didn't need the top half)
+ -- lo' = ASR(lo,7) (second reg of SMN)
+ -- = 0b0000_0001 (theeshift gives us 1 for negative,
+ -- and 0 for positive)
+ -- hi == -lo'?
+ -- 0b1111_1111 == 0b1111_1111 (yes, top half is just overflow)
+ -- Another way to think of this is if hi + lo' == 0, which is what
+ -- CMN really is under the hood.
+ CMN (OpReg w' hi) (OpRegShift w' lo SLSR (widthInBits w - 1)) `snocOL`
+ -- Set dst_needed to 1 if hi and lo' were (negatively) equal
+ CSET (OpReg w' nd) EQ `appOL`
+ -- Finally truncate hi to drop any extraneous sign bits.
+ truncateReg w' w hi
+ , Nothing)
+ -- Can't handle > 64 bit operands
+ | otherwise -> unsupported (MO_S_Mul2 w)
+ PrimTarget (MO_U_Mul2 w)
+ -- The unsigned case is much simpler than the signed, all we need to
+ -- do is the multiplication straight into the destination registers.
+ | w == W64
+ , [src_a, src_b] <- arg_regs
+ , [dst_hi, dst_lo] <- dest_regs
+ -> do
+ (reg_a, _format_x, code_x) <- getSomeReg src_a
+ (reg_b, _format_y, code_y) <- getSomeReg src_b
+
+ let lo = getRegisterReg platform (CmmLocal dst_lo)
+ hi = getRegisterReg platform (CmmLocal dst_hi)
+ return (
+ code_x `appOL`
+ code_y `snocOL`
+ MUL (OpReg W64 lo) (OpReg W64 reg_a) (OpReg W64 reg_b) `snocOL`
+ UMULH (OpReg W64 hi) (OpReg W64 reg_a) (OpReg W64 reg_b)
+ , Nothing)
+ -- For sizes < platform width, we can just perform a multiply and shift
+ -- Need to be careful to truncate the low half, but the upper half should be
+ -- be ok if the invariant in [Signed arithmetic on AArch64] is maintained.
+ -- Currently this case can't be produced by the compiler since
+ -- timesWord2# :: Word# -> Word# -> (# Word#, Word# #)
+ -- TODO: Remove? Or would the extra primop be useful for avoiding the extra
+ -- steps needed to do this in userland?
+ | w < W64
+ , [src_a, src_b] <- arg_regs
+ , [dst_hi, dst_lo] <- dest_regs
+ -> do
+ (reg_a, _format_x, code_x) <- getSomeReg src_a
+ (reg_b, _format_y, code_y) <- getSomeReg src_b
+
+ let lo = getRegisterReg platform (CmmLocal dst_lo)
+ hi = getRegisterReg platform (CmmLocal dst_hi)
+ w' = opRegWidth w
+ return (
+ code_x `appOL`
+ code_y `snocOL`
+ -- UMULL: Xd = Wa * Wb with 64 bit result
+ -- W64 inputs should have been caught by case above
+ UMULL (OpReg W64 lo) (OpReg w' reg_a) (OpReg w' reg_b) `snocOL`
+ -- Extract and truncate high result
+ -- hi[w:0] = lo[2w:w]
+ UBFX (OpReg W64 hi) (OpReg W64 lo)
+ (OpImm (ImmInt $ widthInBits w)) -- lsb
+ (OpImm (ImmInt $ widthInBits w)) -- width to extract
+ `appOL`
+ truncateReg W64 w lo
+ , Nothing)
+ | otherwise -> unsupported (MO_U_Mul2 w)
+
+
-- or a possibly side-effecting machine operation
-- mop :: CallishMachOp (see GHC.Cmm.MachOp)
PrimTarget mop -> do
@@ -1714,7 +1846,6 @@ genCCall target dest_regs arg_regs bid = do
-- Arithmatic
-- These are not supported on X86, so I doubt they are used much.
- MO_S_Mul2 _w -> unsupported mop
MO_S_QuotRem _w -> unsupported mop
MO_U_QuotRem _w -> unsupported mop
MO_U_QuotRem2 _w -> unsupported mop
@@ -1723,7 +1854,6 @@ genCCall target dest_regs arg_regs bid = do
MO_SubWordC _w -> unsupported mop
MO_AddIntC _w -> unsupported mop
MO_SubIntC _w -> unsupported mop
- MO_U_Mul2 _w -> unsupported mop
-- Memory Ordering
MO_AcquireFence -> return (unitOL DMBISH, Nothing)
=====================================
compiler/GHC/CmmToAsm/AArch64/Instr.hs
=====================================
@@ -79,11 +79,14 @@ regUsageOfInstr platform instr = case instr of
-- 1. Arithmetic Instructions ------------------------------------------------
ADD dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst)
CMP l r -> usage (regOp l ++ regOp r, [])
+ CMN l r -> usage (regOp l ++ regOp r, [])
MSUB dst src1 src2 src3 -> usage (regOp src1 ++ regOp src2 ++ regOp src3, regOp dst)
MUL dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst)
NEG dst src -> usage (regOp src, regOp dst)
SMULH dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst)
SMULL dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst)
+ UMULH dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst)
+ UMULL dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst)
SDIV dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst)
SUB dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst)
UDIV dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst)
@@ -209,11 +212,14 @@ patchRegsOfInstr instr env = case instr of
-- 1. Arithmetic Instructions ----------------------------------------------
ADD o1 o2 o3 -> ADD (patchOp o1) (patchOp o2) (patchOp o3)
CMP o1 o2 -> CMP (patchOp o1) (patchOp o2)
+ CMN o1 o2 -> CMN (patchOp o1) (patchOp o2)
MSUB o1 o2 o3 o4 -> MSUB (patchOp o1) (patchOp o2) (patchOp o3) (patchOp o4)
MUL o1 o2 o3 -> MUL (patchOp o1) (patchOp o2) (patchOp o3)
NEG o1 o2 -> NEG (patchOp o1) (patchOp o2)
SMULH o1 o2 o3 -> SMULH (patchOp o1) (patchOp o2) (patchOp o3)
SMULL o1 o2 o3 -> SMULL (patchOp o1) (patchOp o2) (patchOp o3)
+ UMULH o1 o2 o3 -> UMULH (patchOp o1) (patchOp o2) (patchOp o3)
+ UMULL o1 o2 o3 -> UMULL (patchOp o1) (patchOp o2) (patchOp o3)
SDIV o1 o2 o3 -> SDIV (patchOp o1) (patchOp o2) (patchOp o3)
SUB o1 o2 o3 -> SUB (patchOp o1) (patchOp o2) (patchOp o3)
UDIV o1 o2 o3 -> UDIV (patchOp o1) (patchOp o2) (patchOp o3)
@@ -540,6 +546,7 @@ data Instr
-- | ADR ...
-- | ADRP ...
| CMP Operand Operand -- rd - op2
+ | CMN Operand Operand -- rd + op2
-- | MADD ...
-- | MNEG ...
| MSUB Operand Operand Operand Operand -- rd = ra - rn × rm
@@ -562,8 +569,8 @@ data Instr
-- | UMADDL ... -- Xd = Xa + Wn × Wm
-- | UMNEGL ... -- Xd = - Wn × Wm
-- | UMSUBL ... -- Xd = Xa - Wn × Wm
- -- | UMULH ... -- Xd = (Xn × Xm)_127:64
- -- | UMULL ... -- Xd = Wn × Wm
+ | UMULH Operand Operand Operand -- Xd = (Xn × Xm)_127:64
+ | UMULL Operand Operand Operand -- Xd = Wn × Wm
-- 2. Bit Manipulation Instructions ----------------------------------------
| SBFM Operand Operand Operand Operand -- rd = rn[i,j]
@@ -644,12 +651,15 @@ instrCon i =
POP_STACK_FRAME{} -> "POP_STACK_FRAME"
ADD{} -> "ADD"
CMP{} -> "CMP"
+ CMN{} -> "CMN"
MSUB{} -> "MSUB"
MUL{} -> "MUL"
NEG{} -> "NEG"
SDIV{} -> "SDIV"
SMULH{} -> "SMULH"
SMULL{} -> "SMULL"
+ UMULH{} -> "UMULH"
+ UMULL{} -> "UMULL"
SUB{} -> "SUB"
UDIV{} -> "UDIV"
SBFM{} -> "SBFM"
=====================================
compiler/GHC/CmmToAsm/AArch64/Ppr.hs
=====================================
@@ -372,12 +372,15 @@ pprInstr platform instr = case instr of
CMP o1 o2
| isFloatOp o1 && isFloatOp o2 -> op2 (text "\tfcmp") o1 o2
| otherwise -> op2 (text "\tcmp") o1 o2
+ CMN o1 o2 -> op2 (text "\tcmn") o1 o2
MSUB o1 o2 o3 o4 -> op4 (text "\tmsub") o1 o2 o3 o4
MUL o1 o2 o3
| isFloatOp o1 && isFloatOp o2 && isFloatOp o3 -> op3 (text "\tfmul") o1 o2 o3
| otherwise -> op3 (text "\tmul") o1 o2 o3
SMULH o1 o2 o3 -> op3 (text "\tsmulh") o1 o2 o3
SMULL o1 o2 o3 -> op3 (text "\tsmull") o1 o2 o3
+ UMULH o1 o2 o3 -> op3 (text "\tumulh") o1 o2 o3
+ UMULL o1 o2 o3 -> op3 (text "\tumull") o1 o2 o3
NEG o1 o2
| isFloatOp o1 && isFloatOp o2 -> op2 (text "\tfneg") o1 o2
| otherwise -> op2 (text "\tneg") o1 o2
=====================================
compiler/GHC/Driver/Config/StgToCmm.hs
=====================================
@@ -76,7 +76,8 @@ initStgToCmmConfig dflags mod = StgToCmmConfig
| otherwise
-> const True
- , stgToCmmAllowIntMul2Instr = (ncg && x86ish) || llvm
+ , stgToCmmAllowIntMul2Instr = (ncg && (x86ish || aarch64)) || llvm
+ , stgToCmmAllowWordMul2Instr = (ncg && (x86ish || ppc || aarch64)) || llvm
-- SIMD flags
, stgToCmmVecInstrsErr = vec_err
, stgToCmmAvx = isAvxEnabled dflags
@@ -92,6 +93,9 @@ initStgToCmmConfig dflags mod = StgToCmmConfig
JSPrimitives -> (False, False)
NcgPrimitives -> (True, False)
LlvmPrimitives -> (False, True)
+ aarch64 = case platformArch platform of
+ ArchAArch64 -> True
+ _ -> False
x86ish = case platformArch platform of
ArchX86 -> True
ArchX86_64 -> True
=====================================
compiler/GHC/StgToCmm/Config.hs
=====================================
@@ -70,6 +70,7 @@ data StgToCmmConfig = StgToCmmConfig
, stgToCmmAllowQuotRem2 :: !Bool -- ^ Allowed to generate QuotRem
, stgToCmmAllowExtendedAddSubInstrs :: !Bool -- ^ Allowed to generate AddWordC, SubWordC, Add2, etc.
, stgToCmmAllowIntMul2Instr :: !Bool -- ^ Allowed to generate IntMul2 instruction
+ , stgToCmmAllowWordMul2Instr :: !Bool -- ^ Allowed to generate WordMul2 instruction
, stgToCmmAllowFMAInstr :: FMASign -> Bool -- ^ Allowed to generate FMA instruction
, stgToCmmTickyAP :: !Bool -- ^ Disable use of precomputed standard thunks.
------------------------------ SIMD flags ------------------------------------
=====================================
compiler/GHC/StgToCmm/Prim.hs
=====================================
@@ -1623,7 +1623,7 @@ emitPrimOp cfg primop =
else Right genericIntSubCOp
WordMul2Op -> \args -> opCallishHandledLater args $
- if allowExtAdd
+ if allowWord2Mul
then Left (MO_U_Mul2 (wordWidth platform))
else Right genericWordMul2Op
@@ -1850,6 +1850,7 @@ emitPrimOp cfg primop =
allowQuotRem2 = stgToCmmAllowQuotRem2 cfg
allowExtAdd = stgToCmmAllowExtendedAddSubInstrs cfg
allowInt2Mul = stgToCmmAllowIntMul2Instr cfg
+ allowWord2Mul = stgToCmmAllowWordMul2Instr cfg
allowFMA = stgToCmmAllowFMAInstr cfg
=====================================
testsuite/tests/numeric/should_run/all.T
=====================================
@@ -50,6 +50,7 @@ test('T4383', normal, compile_and_run, [''])
test('add2', normal, compile_and_run, ['-fobject-code'])
test('mul2', normal, compile_and_run, ['-fobject-code'])
+test('mul2int', normal, compile_and_run, ['-fobject-code'])
test('quotRem2', normal, compile_and_run, ['-fobject-code'])
test('T5863', normal, compile_and_run, [''])
=====================================
testsuite/tests/numeric/should_run/mul2int.hs
=====================================
@@ -0,0 +1,35 @@
+{-# LANGUAGE MagicHash, UnboxedTuples #-}
+
+import GHC.Exts
+import Data.Bits
+
+main :: IO ()
+main = do g 5 6
+ g (-5) 6
+ g 0x7ECA71DBFF1B7D8C 49
+ g (-0x7ECA71DBFF1B7D8C) 49
+ g 0x7ECA71DBFF1B7D8C 0x7E0EC51DFD94FE35
+ g 0x7ECA71DBFF1B7D8C (-0x7E0EC51DFD94FE35)
+
+
+g :: Int -> Int -> IO ()
+g wx@(I# x) wy@(I# y)
+ = do putStrLn "-----"
+ putStrLn ("Doing " ++ show wx ++ " * " ++ show wy)
+ case x `timesInt2#` y of
+ (# n, h, l #) ->
+ do let wh = I# h
+ wl = I# l
+ wlw = W# (int2Word# l)
+ wn = I# n
+ r | wn == 1 = shiftL (fromIntegral wh) (finiteBitSize wh)
+ + fromIntegral wlw
+ | otherwise = fromIntegral wl
+
+ putStrLn ("High: " ++ show wh)
+ putStrLn ("Low: " ++ show wl)
+ putStrLn ("Needed: " ++ show wn)
+ putStrLn ("Result: " ++ show (r :: Integer))
+ putStrLn ("Should be: " ++ show (fromIntegral wx * fromIntegral wy :: Integer))
+
+
=====================================
testsuite/tests/numeric/should_run/mul2int.stdout
=====================================
@@ -0,0 +1,42 @@
+-----
+Doing 5 * 6
+High: 0
+Low: 30
+Needed: 0
+Result: 30
+Should be: 30
+-----
+Doing -5 * 6
+High: -1
+Low: -30
+Needed: 0
+Result: -30
+Should be: -30
+-----
+Doing 9136239983766240652 * 49
+High: 24
+Low: 4953901435516553164
+Needed: 1
+Result: 447675759204545791948
+Should be: 447675759204545791948
+-----
+Doing -9136239983766240652 * 49
+High: -25
+Low: -4953901435516553164
+Needed: 1
+Result: -447675759204545791948
+Should be: -447675759204545791948
+-----
+Doing 9136239983766240652 * 9083414231051992629
+High: 4498802171008813567
+Low: 3355592377236579836
+Needed: 1
+Result: 82988252286848496451678442784944154108
+Should be: 82988252286848496451678442784944154108
+-----
+Doing 9136239983766240652 * -9083414231051992629
+High: -4498802171008813568
+Low: -3355592377236579836
+Needed: 1
+Result: -82988252286848496451678442784944154108
+Should be: -82988252286848496451678442784944154108
=====================================
testsuite/tests/numeric/should_run/mul2int.stdout-ws-32
=====================================
@@ -0,0 +1,42 @@
+-----
+Doing 5 * 6
+High: 0
+Low: 30
+Needed: 0
+Result: 30
+Should be: 30
+-----
+Doing -5 * 6
+High: -1
+Low: -30
+Needed: 0
+Result: -30
+Should be: -30
+-----
+Doing -14975604 * 49
+High: -1
+Low: -733804596
+Needed: 0
+Result: -733804596
+Should be: -733804596
+-----
+Doing 14975604 * 49
+High: 0
+Low: 733804596
+Needed: 0
+Result: 733804596
+Should be: 733804596
+-----
+Doing -14975604 * -40567243
+High: 141449
+Low: 137487868
+Needed: 1
+Result: 607518966539772
+Should be: 607518966539772
+-----
+Doing -14975604 * 40567243
+High: -141450
+Low: -137487868
+Needed: 1
+Result: -607518966539772
+Should be: -607518966539772
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/dbdf1995956a7457c34b6895c67ef48f6c8384f2
--
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/dbdf1995956a7457c34b6895c67ef48f6c8384f2
You're receiving this email because of your account on gitlab.haskell.org.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240416/98fdd99a/attachment-0001.html>
More information about the ghc-commits
mailing list