[Git][ghc/ghc][master] Implements MO_S_Mul2 and MO_U_Mul2 using the UMULH, UMULL and SMULH instructions for AArch64

Marge Bot (@marge-bot) gitlab at gitlab.haskell.org
Wed Apr 17 00:06:40 UTC 2024



Marge Bot pushed to branch master at Glasgow Haskell Compiler / GHC


Commits:
dbdf1995 by Alex Mason at 2024-04-15T15:28:26+10:00
Implements MO_S_Mul2 and MO_U_Mul2 using the  UMULH, UMULL and SMULH instructions for AArch64

Also adds a test for MO_S_Mul2

- - - - -


10 changed files:

- compiler/GHC/CmmToAsm/AArch64/CodeGen.hs
- compiler/GHC/CmmToAsm/AArch64/Instr.hs
- compiler/GHC/CmmToAsm/AArch64/Ppr.hs
- compiler/GHC/Driver/Config/StgToCmm.hs
- compiler/GHC/StgToCmm/Config.hs
- compiler/GHC/StgToCmm/Prim.hs
- testsuite/tests/numeric/should_run/all.T
- + testsuite/tests/numeric/should_run/mul2int.hs
- + testsuite/tests/numeric/should_run/mul2int.stdout
- + testsuite/tests/numeric/should_run/mul2int.stdout-ws-32


Changes:

=====================================
compiler/GHC/CmmToAsm/AArch64/CodeGen.hs
=====================================
@@ -1556,7 +1556,7 @@ genCCall target dest_regs arg_regs bid = do
   -- pprTraceM "genCCall target" (ppr target)
   -- pprTraceM "genCCall formal" (ppr dest_regs)
   -- pprTraceM "genCCall actual" (ppr arg_regs)
-
+  platform <- getPlatform
   case target of
     -- The target :: ForeignTarget call can either
     -- be a foreign procedure with an address expr
@@ -1584,7 +1584,6 @@ genCCall target dest_regs arg_regs bid = do
       let (_res_hints, arg_hints) = foreignTargetHints target
           arg_regs'' = zipWith (\(r, f, c) h -> (r,f,h,c)) arg_regs' arg_hints
 
-      platform <- getPlatform
       let packStack = platformOS platform == OSDarwin
 
       (stackSpace', passRegs, passArgumentsCode) <- passArguments packStack allGpArgRegs allFpArgRegs arg_regs'' 0 [] nilOL
@@ -1625,6 +1624,139 @@ genCCall target dest_regs arg_regs bid = do
       | [arg_reg] <- arg_regs, [dest_reg] <- dest_regs ->
         unaryFloatOp W64 (\d x -> unitOL $ FABS d x) arg_reg dest_reg
 
+    PrimTarget (MO_S_Mul2 w)
+          -- Life is easier when we're working with word sized operands,
+          -- we can use SMULH to compute the high 64 bits, and dst_needed
+          -- checks if the high half's bits are all the same as the low half's
+          -- top bit.
+          | w == W64
+          , [src_a, src_b] <- arg_regs
+          -- dst_needed = did the result fit into just the low half
+          , [dst_needed, dst_hi, dst_lo] <- dest_regs
+            ->  do
+              (reg_a, _format_x, code_x) <- getSomeReg src_a
+              (reg_b, _format_y, code_y) <- getSomeReg src_b
+
+              let lo = getRegisterReg platform (CmmLocal dst_lo)
+                  hi = getRegisterReg platform (CmmLocal dst_hi)
+                  nd = getRegisterReg platform (CmmLocal dst_needed)
+              return (
+                  code_x `appOL`
+                  code_y `snocOL`
+                  MUL   (OpReg W64 lo) (OpReg W64 reg_a) (OpReg W64 reg_b) `snocOL`
+                  SMULH (OpReg W64 hi) (OpReg W64 reg_a) (OpReg W64 reg_b) `snocOL`
+                  -- Are all high bits equal to the sign bit of the low word?
+                  -- nd = (hi == ASR(lo,width-1)) ? 1 : 0
+                  CMP   (OpReg W64 hi) (OpRegShift W64 lo SASR (widthInBits w - 1)) `snocOL`
+                  CSET  (OpReg W64 nd) NE
+                  , Nothing)
+            -- For sizes < platform width, we can just perform a multiply and shift
+            -- using the normal 64 bit multiply. Calculating the dst_needed value is
+            -- complicated a little by the need to be careful when truncation happens.
+            -- Currently this case can't be generated since
+            -- timesInt2# :: Int# -> Int# -> (# Int#, Int#, Int# #)
+            -- TODO: Should this be removed or would other primops be useful?
+          | w < W64
+          , [src_a, src_b] <- arg_regs
+          , [dst_needed, dst_hi, dst_lo] <- dest_regs
+            ->  do
+              (reg_a', _format_x, code_a) <- getSomeReg src_a
+              (reg_b', _format_y, code_b) <- getSomeReg src_b
+
+              let lo = getRegisterReg platform (CmmLocal dst_lo)
+                  hi = getRegisterReg platform (CmmLocal dst_hi)
+                  nd = getRegisterReg platform (CmmLocal dst_needed)
+                  -- Do everything in a full 64 bit registers
+                  w' = platformWordWidth platform
+
+              (reg_a, code_a') <- signExtendReg w w' reg_a'
+              (reg_b, code_b') <- signExtendReg w w' reg_b'
+
+              return (
+                  code_a  `appOL`
+                  code_b  `appOL`
+                  code_a' `appOL`
+                  code_b' `snocOL`
+                  -- the low 2w' of lo contains the full multiplication;
+                  -- eg: int8 * int8 -> int16 result
+                  -- so lo is in the last w of the register, and hi is in the second w.
+                  SMULL (OpReg w' lo) (OpReg w' reg_a) (OpReg w' reg_b) `snocOL`
+                  -- Make sure we hold onto the sign bits for dst_needed
+                  ASR (OpReg w' hi) (OpReg w' lo)    (OpImm (ImmInt $ widthInBits w)) `appOL`
+                  -- lo can now be truncated so we can get at it's top bit easily.
+                  truncateReg w' w lo `snocOL`
+                  -- Note the use of CMN (compare negative), not CMP: we want to
+                  -- test if the top half is negative one and the top
+                  -- bit of the bottom half is positive one. eg:
+                  -- hi = 0b1111_1111  (actually 64 bits)
+                  -- lo = 0b1010_1111  (-81, so the result didn't need the top half)
+                  -- lo' = ASR(lo,7)   (second reg of SMN)
+                  --     = 0b0000_0001 (theeshift gives us 1 for negative,
+                  --                    and 0 for positive)
+                  -- hi == -lo'?
+                  -- 0b1111_1111 == 0b1111_1111 (yes, top half is just overflow)
+                  -- Another way to think of this is if hi + lo' == 0, which is what
+                  -- CMN really is under the hood.
+                  CMN   (OpReg w' hi) (OpRegShift w' lo SLSR (widthInBits w - 1)) `snocOL`
+                  -- Set dst_needed to 1 if hi and lo' were (negatively) equal
+                  CSET  (OpReg w' nd) EQ `appOL`
+                  -- Finally truncate hi to drop any extraneous sign bits.
+                  truncateReg w' w hi
+                  , Nothing)
+          -- Can't handle > 64 bit operands
+          | otherwise -> unsupported (MO_S_Mul2 w)
+    PrimTarget (MO_U_Mul2  w)
+          -- The unsigned case is much simpler than the signed, all we need to
+          -- do is the multiplication straight into the destination registers.
+          | w == W64
+          , [src_a, src_b] <- arg_regs
+          , [dst_hi, dst_lo] <- dest_regs
+            ->  do
+              (reg_a, _format_x, code_x) <- getSomeReg src_a
+              (reg_b, _format_y, code_y) <- getSomeReg src_b
+
+              let lo = getRegisterReg platform (CmmLocal dst_lo)
+                  hi = getRegisterReg platform (CmmLocal dst_hi)
+              return (
+                  code_x `appOL`
+                  code_y `snocOL`
+                  MUL   (OpReg W64 lo) (OpReg W64 reg_a) (OpReg W64 reg_b) `snocOL`
+                  UMULH (OpReg W64 hi) (OpReg W64 reg_a) (OpReg W64 reg_b)
+                  , Nothing)
+            -- For sizes < platform width, we can just perform a multiply and shift
+            -- Need to be careful to truncate the low half, but the upper half should be
+            -- be ok if the invariant in [Signed arithmetic on AArch64] is maintained.
+            -- Currently this case can't be produced by the compiler since
+            -- timesWord2# :: Word# -> Word# -> (# Word#, Word# #)
+            -- TODO: Remove? Or would the extra primop be useful for avoiding the extra
+            -- steps needed to do this in userland?
+          | w < W64
+          , [src_a, src_b] <- arg_regs
+          , [dst_hi, dst_lo] <- dest_regs
+            ->  do
+              (reg_a, _format_x, code_x) <- getSomeReg src_a
+              (reg_b, _format_y, code_y) <- getSomeReg src_b
+
+              let lo = getRegisterReg platform (CmmLocal dst_lo)
+                  hi = getRegisterReg platform (CmmLocal dst_hi)
+                  w' = opRegWidth w
+              return (
+                  code_x `appOL`
+                  code_y `snocOL`
+                  -- UMULL: Xd = Wa * Wb with 64 bit result
+                  -- W64 inputs should have been caught by case above
+                  UMULL (OpReg W64 lo) (OpReg w' reg_a) (OpReg w' reg_b) `snocOL`
+                  -- Extract and truncate high result
+                  -- hi[w:0] = lo[2w:w]
+                  UBFX (OpReg W64 hi) (OpReg W64 lo)
+                      (OpImm (ImmInt $ widthInBits w)) -- lsb
+                      (OpImm (ImmInt $ widthInBits w)) -- width to extract
+                      `appOL`
+                  truncateReg W64 w lo
+                  , Nothing)
+          | otherwise -> unsupported (MO_U_Mul2  w)
+
+
     -- or a possibly side-effecting machine operation
     -- mop :: CallishMachOp (see GHC.Cmm.MachOp)
     PrimTarget mop -> do
@@ -1714,7 +1846,6 @@ genCCall target dest_regs arg_regs bid = do
 
         -- Arithmatic
         -- These are not supported on X86, so I doubt they are used much.
-        MO_S_Mul2     _w -> unsupported mop
         MO_S_QuotRem  _w -> unsupported mop
         MO_U_QuotRem  _w -> unsupported mop
         MO_U_QuotRem2 _w -> unsupported mop
@@ -1723,7 +1854,6 @@ genCCall target dest_regs arg_regs bid = do
         MO_SubWordC   _w -> unsupported mop
         MO_AddIntC    _w -> unsupported mop
         MO_SubIntC    _w -> unsupported mop
-        MO_U_Mul2     _w -> unsupported mop
 
         -- Memory Ordering
         MO_AcquireFence     ->  return (unitOL DMBISH, Nothing)


=====================================
compiler/GHC/CmmToAsm/AArch64/Instr.hs
=====================================
@@ -79,11 +79,14 @@ regUsageOfInstr platform instr = case instr of
   -- 1. Arithmetic Instructions ------------------------------------------------
   ADD dst src1 src2        -> usage (regOp src1 ++ regOp src2, regOp dst)
   CMP l r                  -> usage (regOp l ++ regOp r, [])
+  CMN l r                  -> usage (regOp l ++ regOp r, [])
   MSUB dst src1 src2 src3  -> usage (regOp src1 ++ regOp src2 ++ regOp src3, regOp dst)
   MUL dst src1 src2        -> usage (regOp src1 ++ regOp src2, regOp dst)
   NEG dst src              -> usage (regOp src, regOp dst)
   SMULH dst src1 src2      -> usage (regOp src1 ++ regOp src2, regOp dst)
   SMULL dst src1 src2      -> usage (regOp src1 ++ regOp src2, regOp dst)
+  UMULH dst src1 src2      -> usage (regOp src1 ++ regOp src2, regOp dst)
+  UMULL dst src1 src2      -> usage (regOp src1 ++ regOp src2, regOp dst)
   SDIV dst src1 src2       -> usage (regOp src1 ++ regOp src2, regOp dst)
   SUB dst src1 src2        -> usage (regOp src1 ++ regOp src2, regOp dst)
   UDIV dst src1 src2       -> usage (regOp src1 ++ regOp src2, regOp dst)
@@ -209,11 +212,14 @@ patchRegsOfInstr instr env = case instr of
     -- 1. Arithmetic Instructions ----------------------------------------------
     ADD o1 o2 o3   -> ADD (patchOp o1) (patchOp o2) (patchOp o3)
     CMP o1 o2      -> CMP (patchOp o1) (patchOp o2)
+    CMN o1 o2      -> CMN (patchOp o1) (patchOp o2)
     MSUB o1 o2 o3 o4 -> MSUB (patchOp o1) (patchOp o2) (patchOp o3) (patchOp o4)
     MUL o1 o2 o3   -> MUL (patchOp o1) (patchOp o2) (patchOp o3)
     NEG o1 o2      -> NEG (patchOp o1) (patchOp o2)
     SMULH o1 o2 o3 -> SMULH (patchOp o1) (patchOp o2)  (patchOp o3)
     SMULL o1 o2 o3 -> SMULL (patchOp o1) (patchOp o2)  (patchOp o3)
+    UMULH o1 o2 o3 -> UMULH (patchOp o1) (patchOp o2)  (patchOp o3)
+    UMULL o1 o2 o3 -> UMULL (patchOp o1) (patchOp o2)  (patchOp o3)
     SDIV o1 o2 o3  -> SDIV (patchOp o1) (patchOp o2) (patchOp o3)
     SUB o1 o2 o3   -> SUB  (patchOp o1) (patchOp o2) (patchOp o3)
     UDIV o1 o2 o3  -> UDIV (patchOp o1) (patchOp o2) (patchOp o3)
@@ -540,6 +546,7 @@ data Instr
     -- | ADR ...
     -- | ADRP ...
     | CMP Operand Operand -- rd - op2
+    | CMN Operand Operand -- rd + op2
     -- | MADD ...
     -- | MNEG ...
     | MSUB Operand Operand Operand Operand -- rd = ra - rn × rm
@@ -562,8 +569,8 @@ data Instr
     -- | UMADDL ...  -- Xd = Xa + Wn × Wm
     -- | UMNEGL ... -- Xd = - Wn × Wm
     -- | UMSUBL ... -- Xd = Xa - Wn × Wm
-    -- | UMULH ... -- Xd = (Xn × Xm)_127:64
-    -- | UMULL ... -- Xd = Wn × Wm
+    | UMULH Operand Operand Operand -- Xd = (Xn × Xm)_127:64
+    | UMULL Operand Operand Operand -- Xd = Wn × Wm
 
     -- 2. Bit Manipulation Instructions ----------------------------------------
     | SBFM Operand Operand Operand Operand -- rd = rn[i,j]
@@ -644,12 +651,15 @@ instrCon i =
       POP_STACK_FRAME{} -> "POP_STACK_FRAME"
       ADD{} -> "ADD"
       CMP{} -> "CMP"
+      CMN{} -> "CMN"
       MSUB{} -> "MSUB"
       MUL{} -> "MUL"
       NEG{} -> "NEG"
       SDIV{} -> "SDIV"
       SMULH{} -> "SMULH"
       SMULL{} -> "SMULL"
+      UMULH{} -> "UMULH"
+      UMULL{} -> "UMULL"
       SUB{} -> "SUB"
       UDIV{} -> "UDIV"
       SBFM{} -> "SBFM"


=====================================
compiler/GHC/CmmToAsm/AArch64/Ppr.hs
=====================================
@@ -372,12 +372,15 @@ pprInstr platform instr = case instr of
   CMP  o1 o2
     | isFloatOp o1 && isFloatOp o2 -> op2 (text "\tfcmp") o1 o2
     | otherwise -> op2 (text "\tcmp") o1 o2
+  CMN  o1 o2       -> op2 (text "\tcmn") o1 o2
   MSUB o1 o2 o3 o4 -> op4 (text "\tmsub") o1 o2 o3 o4
   MUL  o1 o2 o3
     | isFloatOp o1 && isFloatOp o2 && isFloatOp o3 -> op3 (text "\tfmul") o1 o2 o3
     | otherwise -> op3 (text "\tmul") o1 o2 o3
   SMULH o1 o2 o3 -> op3 (text "\tsmulh") o1 o2 o3
   SMULL o1 o2 o3 -> op3 (text "\tsmull") o1 o2 o3
+  UMULH o1 o2 o3 -> op3 (text "\tumulh") o1 o2 o3
+  UMULL o1 o2 o3 -> op3 (text "\tumull") o1 o2 o3
   NEG  o1 o2
     | isFloatOp o1 && isFloatOp o2 -> op2 (text "\tfneg") o1 o2
     | otherwise -> op2 (text "\tneg") o1 o2


=====================================
compiler/GHC/Driver/Config/StgToCmm.hs
=====================================
@@ -76,7 +76,8 @@ initStgToCmmConfig dflags mod = StgToCmmConfig
         | otherwise
         -> const True
 
-  , stgToCmmAllowIntMul2Instr         = (ncg && x86ish) || llvm
+  , stgToCmmAllowIntMul2Instr         = (ncg && (x86ish || aarch64)) || llvm
+  , stgToCmmAllowWordMul2Instr        = (ncg && (x86ish || ppc || aarch64)) || llvm
   -- SIMD flags
   , stgToCmmVecInstrsErr  = vec_err
   , stgToCmmAvx           = isAvxEnabled                   dflags
@@ -92,6 +93,9 @@ initStgToCmmConfig dflags mod = StgToCmmConfig
                           JSPrimitives      -> (False, False)
                           NcgPrimitives     -> (True, False)
                           LlvmPrimitives    -> (False, True)
+          aarch64 = case platformArch platform of
+                      ArchAArch64  -> True
+                      _            -> False
           x86ish  = case platformArch platform of
                       ArchX86    -> True
                       ArchX86_64 -> True


=====================================
compiler/GHC/StgToCmm/Config.hs
=====================================
@@ -70,6 +70,7 @@ data StgToCmmConfig = StgToCmmConfig
   , stgToCmmAllowQuotRem2             :: !Bool   -- ^ Allowed to generate QuotRem
   , stgToCmmAllowExtendedAddSubInstrs :: !Bool   -- ^ Allowed to generate AddWordC, SubWordC, Add2, etc.
   , stgToCmmAllowIntMul2Instr         :: !Bool   -- ^ Allowed to generate IntMul2 instruction
+  , stgToCmmAllowWordMul2Instr        :: !Bool   -- ^ Allowed to generate WordMul2 instruction
   , stgToCmmAllowFMAInstr             :: FMASign -> Bool -- ^ Allowed to generate FMA instruction
   , stgToCmmTickyAP                   :: !Bool   -- ^ Disable use of precomputed standard thunks.
   ------------------------------ SIMD flags ------------------------------------


=====================================
compiler/GHC/StgToCmm/Prim.hs
=====================================
@@ -1623,7 +1623,7 @@ emitPrimOp cfg primop =
     else Right genericIntSubCOp
 
   WordMul2Op -> \args -> opCallishHandledLater args $
-    if allowExtAdd
+    if allowWord2Mul
     then Left (MO_U_Mul2     (wordWidth platform))
     else Right genericWordMul2Op
 
@@ -1850,6 +1850,7 @@ emitPrimOp cfg primop =
   allowQuotRem2 = stgToCmmAllowQuotRem2             cfg
   allowExtAdd   = stgToCmmAllowExtendedAddSubInstrs cfg
   allowInt2Mul  = stgToCmmAllowIntMul2Instr         cfg
+  allowWord2Mul = stgToCmmAllowWordMul2Instr        cfg
 
   allowFMA = stgToCmmAllowFMAInstr cfg
 


=====================================
testsuite/tests/numeric/should_run/all.T
=====================================
@@ -50,6 +50,7 @@ test('T4383', normal, compile_and_run, [''])
 
 test('add2', normal, compile_and_run, ['-fobject-code'])
 test('mul2', normal, compile_and_run, ['-fobject-code'])
+test('mul2int', normal, compile_and_run, ['-fobject-code'])
 test('quotRem2', normal, compile_and_run, ['-fobject-code'])
 test('T5863', normal, compile_and_run, [''])
 


=====================================
testsuite/tests/numeric/should_run/mul2int.hs
=====================================
@@ -0,0 +1,35 @@
+{-# LANGUAGE MagicHash, UnboxedTuples #-}
+
+import GHC.Exts
+import Data.Bits
+
+main :: IO ()
+main = do g 5 6
+          g (-5) 6
+          g 0x7ECA71DBFF1B7D8C 49
+          g (-0x7ECA71DBFF1B7D8C) 49
+          g 0x7ECA71DBFF1B7D8C 0x7E0EC51DFD94FE35
+          g 0x7ECA71DBFF1B7D8C (-0x7E0EC51DFD94FE35)
+
+
+g :: Int -> Int -> IO ()
+g wx@(I# x) wy@(I# y)
+    = do putStrLn "-----"
+         putStrLn ("Doing " ++ show wx ++ " * " ++ show wy)
+         case x `timesInt2#` y of
+             (# n, h, l #) ->
+                 do let wh = I# h
+                        wl = I# l
+                        wlw = W# (int2Word# l)
+                        wn = I# n
+                        r | wn == 1   = shiftL (fromIntegral wh) (finiteBitSize wh)
+                                      + fromIntegral wlw
+                          | otherwise = fromIntegral wl
+
+                    putStrLn ("High:      " ++ show wh)
+                    putStrLn ("Low:       " ++ show wl)
+                    putStrLn ("Needed:    " ++ show wn)
+                    putStrLn ("Result:    " ++ show (r :: Integer))
+                    putStrLn ("Should be: " ++ show (fromIntegral wx * fromIntegral wy :: Integer))
+
+


=====================================
testsuite/tests/numeric/should_run/mul2int.stdout
=====================================
@@ -0,0 +1,42 @@
+-----
+Doing 5 * 6
+High:      0
+Low:       30
+Needed:    0
+Result:    30
+Should be: 30
+-----
+Doing -5 * 6
+High:      -1
+Low:       -30
+Needed:    0
+Result:    -30
+Should be: -30
+-----
+Doing 9136239983766240652 * 49
+High:      24
+Low:       4953901435516553164
+Needed:    1
+Result:    447675759204545791948
+Should be: 447675759204545791948
+-----
+Doing -9136239983766240652 * 49
+High:      -25
+Low:       -4953901435516553164
+Needed:    1
+Result:    -447675759204545791948
+Should be: -447675759204545791948
+-----
+Doing 9136239983766240652 * 9083414231051992629
+High:      4498802171008813567
+Low:       3355592377236579836
+Needed:    1
+Result:    82988252286848496451678442784944154108
+Should be: 82988252286848496451678442784944154108
+-----
+Doing 9136239983766240652 * -9083414231051992629
+High:      -4498802171008813568
+Low:       -3355592377236579836
+Needed:    1
+Result:    -82988252286848496451678442784944154108
+Should be: -82988252286848496451678442784944154108


=====================================
testsuite/tests/numeric/should_run/mul2int.stdout-ws-32
=====================================
@@ -0,0 +1,42 @@
+-----
+Doing 5 * 6
+High:      0
+Low:       30
+Needed:    0
+Result:    30
+Should be: 30
+-----
+Doing -5 * 6
+High:      -1
+Low:       -30
+Needed:    0
+Result:    -30
+Should be: -30
+-----
+Doing -14975604 * 49
+High:      -1
+Low:       -733804596
+Needed:    0
+Result:    -733804596
+Should be: -733804596
+-----
+Doing 14975604 * 49
+High:      0
+Low:       733804596
+Needed:    0
+Result:    733804596
+Should be: 733804596
+-----
+Doing -14975604 * -40567243
+High:      141449
+Low:       137487868
+Needed:    1
+Result:    607518966539772
+Should be: 607518966539772
+-----
+Doing -14975604 * 40567243
+High:      -141450
+Low:       -137487868
+Needed:    1
+Result:    -607518966539772
+Should be: -607518966539772



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/dbdf1995956a7457c34b6895c67ef48f6c8384f2

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/dbdf1995956a7457c34b6895c67ef48f6c8384f2
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240416/98fdd99a/attachment-0001.html>


More information about the ghc-commits mailing list