[Git][ghc/ghc][wip/ncg-simd] WIP: improve broadcast, especially on LLVM

sheaf (@sheaf) gitlab at gitlab.haskell.org
Fri Jun 14 15:32:00 UTC 2024



sheaf pushed to branch wip/ncg-simd at Glasgow Haskell Compiler / GHC


Commits:
df28c0e5 by sheaf at 2024-06-14T17:31:32+02:00
WIP: improve broadcast, especially on LLVM

- - - - -


6 changed files:

- compiler/GHC/Cmm/MachOp.hs
- compiler/GHC/Cmm/Opt.hs
- compiler/GHC/CmmToAsm/X86/CodeGen.hs
- compiler/GHC/CmmToLlvm/CodeGen.hs
- compiler/GHC/StgToCmm/Prim.hs
- testsuite/tests/simd/should_run/simd008.hs


Changes:

=====================================
compiler/GHC/Cmm/MachOp.hs
=====================================
@@ -596,10 +596,10 @@ machOpArgReps platform op =
     MO_V_Shuffle  l r _ -> [vecwidth l r, vecwidth l r]
     MO_VF_Shuffle l r _ -> [vecwidth l r, vecwidth l r]
 
-    MO_V_Broadcast l r  -> [vecwidth l r, r]
+    MO_V_Broadcast _ r  -> [r]
     MO_V_Insert   l r   -> [vecwidth l r, r, W32]
     MO_V_Extract  l r   -> [vecwidth l r, W32]
-    MO_VF_Broadcast l r -> [vecwidth l r, r]
+    MO_VF_Broadcast _ r -> [r]
     MO_VF_Insert  l r   -> [vecwidth l r, r, W32]
     MO_VF_Extract l r   -> [vecwidth l r, W32]
       -- SIMD vector indices are always 32 bit


=====================================
compiler/GHC/Cmm/Opt.hs
=====================================
@@ -79,6 +79,10 @@ cmmMachOpFoldM
     -> MachOp
     -> [CmmExpr]
     -> Maybe CmmExpr
+cmmMachOpFoldM _ (MO_V_Broadcast {}) _ = Nothing
+  -- SIMD NCG TODO: constant folding doesn't work correctly for Broadcast instructions,
+  -- perhaps due to the fact that the argument is a scalar but the result is a vector.
+cmmMachOpFoldM _ (MO_VF_Broadcast {}) _ = Nothing
 
 cmmMachOpFoldM _ op [CmmLit (CmmInt x rep)]
   = Just $! case op of
@@ -93,7 +97,6 @@ cmmMachOpFoldM _ op [CmmLit (CmmInt x rep)]
       MO_SS_Conv  from to -> CmmLit (CmmInt (narrowS from x) to)
       MO_UU_Conv  from to -> CmmLit (CmmInt (narrowU from x) to)
       MO_XX_Conv  from to -> CmmLit (CmmInt (narrowS from x) to)
-
       _ -> panic $ "cmmMachOpFoldM: unknown unary op: " ++ show op
 
 -- Eliminate shifts that are wider than the shiftee


=====================================
compiler/GHC/CmmToAsm/X86/CodeGen.hs
=====================================
@@ -1008,6 +1008,7 @@ getRegister' _ is32Bit (CmmMachOp (MO_Add W64) [CmmReg (CmmGlobal (GlobalRegUse
         LEA II64 (OpAddr (ripRel (litToImm displacement))) (OpReg dst))
 
 getRegister' platform is32Bit (CmmMachOp mop [x]) = do -- unary MachOps
+    sse4_1 <- sse4_1Enabled
     sse2   <- sse2Enabled
     sse    <- sseEnabled
     avx    <- avxEnabled
@@ -1104,6 +1105,19 @@ getRegister' platform is32Bit (CmmMachOp mop [x]) = do -- unary MachOps
       -- SIMD NCG TODO
       MO_VS_Neg {} -> needLlvm mop
 
+      MO_VF_Broadcast l W32 | avx       -> vector_float_broadcast_avx l W32 x
+                            | sse4_1    -> vector_float_broadcast_sse l W32 x
+                            | otherwise
+                              -> sorry "Please enable the -mavx or -msse4 flag"
+      MO_VF_Broadcast l W64 | sse2      -> vector_float_broadcast_avx l W64 x
+                            | otherwise -> sorry "Please enable the -msse2 flag"
+      MO_VF_Broadcast {} -> incorrectOperands
+
+      MO_V_Broadcast l W64  | sse2      -> vector_int_broadcast l W64 x
+                            | otherwise -> sorry "Please enable the -msse2 flag"
+      -- SIMD NCG TODO: W32, W16, W8
+      MO_V_Broadcast {} -> needLlvm mop
+
       -- Binary MachOps
       MO_Add {}    -> incorrectOperands
       MO_Sub {}    -> incorrectOperands
@@ -1156,8 +1170,6 @@ getRegister' platform is32Bit (CmmMachOp mop [x]) = do -- unary MachOps
       MO_VF_Sub {}        -> incorrectOperands
       MO_VF_Mul {}        -> incorrectOperands
       MO_VF_Quot {}       -> incorrectOperands
-      MO_V_Broadcast {}   -> incorrectOperands
-      MO_VF_Broadcast {}  -> incorrectOperands
 
       -- Ternary MachOps
       MO_FMA {}           -> incorrectOperands
@@ -1240,9 +1252,74 @@ getRegister' platform is32Bit (CmmMachOp mop [x]) = do -- unary MachOps
                          (SUB format (OpReg reg) (OpReg dst))
           return (Any format code)
 
+        -----------------------
+        vector_float_broadcast_avx :: Length
+                                   -> Width
+                                   -> CmmExpr
+                                   -> NatM Register
+        vector_float_broadcast_avx len W32 expr
+          = do
+          (reg, exp) <- getSomeReg expr
+          let f    = VecFormat len FmtFloat
+              addr = spRel platform 0
+           in return $ Any f (\dst -> exp    `snocOL`
+                                    (MOVU f (OpReg reg) (OpAddr addr)) `snocOL`
+                                    (VBROADCAST f addr dst))
+        vector_float_broadcast_avx len W64 expr
+          = do
+          (reg, exp) <- getSomeReg expr
+          let f    = VecFormat len FmtDouble
+              addr = spRel platform 0
+           in return $ Any f (\dst -> exp `snocOL`
+                                    (MOVU f (OpReg reg) (OpAddr addr)) `snocOL`
+                                    (MOVL f (OpAddr addr) (OpReg dst)) `snocOL`
+                                    (MOVH f (OpAddr addr) (OpReg dst)))
+        vector_float_broadcast_avx _ _ c
+          = pprPanic "Broadcast not supported for : " (pdoc platform c)
+        -----------------------
+        vector_float_broadcast_sse :: Length
+                                   -> Width
+                                   -> CmmExpr
+                                   -> NatM Register
+        vector_float_broadcast_sse len W32 expr
+          = do
+          (reg, exp) <- getSomeReg expr
+          let f        = VecFormat len FmtFloat
+              addr     = spRel platform 0
+              code dst = exp `snocOL`
+                         (MOVU f (OpReg reg) (OpAddr addr)) `snocOL`
+                         (insertps $ 0b1110) `snocOL`
+                         (insertps $ 16) `snocOL`
+                         (insertps $ 32) `snocOL`
+                         (insertps $ 48)
+                where
+                  insertps imm =
+                    INSERTPS f (ImmInt imm) (OpAddr addr) dst
+
+           in return $ Any f code
+        vector_float_broadcast_sse _ _ c
+          = pprPanic "Broadcast not supported for : " (pdoc platform c)
+
+        vector_int_broadcast :: Length
+                             -> Width
+                             -> CmmExpr
+                             -> NatM Register
+        vector_int_broadcast len W64 expr
+          = do
+          (reg, exp) <- getSomeReg expr
+          let fmt = VecFormat len FmtInt64
+          return $ Any fmt (\dst -> exp `snocOL`
+                                    (MOV II64 (OpReg reg) (OpReg dst)) `snocOL`
+                                    (PUNPCKLQDQ fmt (OpReg dst) dst) `snocOL`
+                                    (PUNPCKLQDQ fmt (OpReg dst) dst) `snocOL`
+                                    (PUNPCKLQDQ fmt (OpReg dst) dst) `snocOL`
+                                    (PUNPCKLQDQ fmt (OpReg dst) dst)
+                                    )
+        vector_int_broadcast _ _ c
+          = pprPanic "Broadcast not supported for : " (pdoc platform c)
+
 
 getRegister' platform is32Bit (CmmMachOp mop [x, y]) = do -- dyadic MachOps
-  sse4_1 <- sse4_1Enabled
   sse2   <- sse2Enabled
   sse    <- sseEnabled
   avx    <- avxEnabled
@@ -1299,7 +1376,7 @@ getRegister' platform is32Bit (CmmMachOp mop [x, y]) = do -- dyadic MachOps
       MO_S_Shr rep -> shift_code rep SAR x y {-False-}
 
       MO_VF_Shuffle l w is
-        | l * widthInBytes w == 128
+        | l * widthInBits w == 128
         -> if
             | avx
             -> vector_shuffle_float l w x y is
@@ -1308,19 +1385,6 @@ getRegister' platform is32Bit (CmmMachOp mop [x, y]) = do -- dyadic MachOps
         | otherwise
         -> sorry "Please use -fllvm for wide shuffle instructions"
 
-      MO_VF_Broadcast l W32 | avx       -> vector_float_broadcast_avx l W32 x y
-                            | sse4_1    -> vector_float_broadcast_sse l W32 x y
-                            | otherwise
-                              -> sorry "Please enable the -mavx or -msse4 flag"
-      MO_VF_Broadcast l W64 | sse2      -> vector_float_broadcast_avx l W64 x y
-                            | otherwise -> sorry "Please enable the -msse2 flag"
-      MO_VF_Broadcast {} -> incorrectOperands
-
-      MO_V_Broadcast l W64  | sse2      -> vector_int_broadcast l W64 x y
-                            | otherwise -> sorry "Please enable the -msse2 flag"
-      -- SIMD NCG TODO: W32, W16, W8
-      MO_V_Broadcast {} -> needLlvm mop
-
       MO_VF_Extract l W32   | avx       -> vector_float_unpack l W32 x y
                             | sse       -> vector_float_unpack_sse l W32 x y
                             | otherwise
@@ -1384,6 +1448,8 @@ getRegister' platform is32Bit (CmmMachOp mop [x, y]) = do -- dyadic MachOps
       MO_AlignmentCheck {} -> incorrectOperands
       MO_VS_Neg {} -> incorrectOperands
       MO_VF_Neg {} -> incorrectOperands
+      MO_V_Broadcast {} -> incorrectOperands
+      MO_VF_Broadcast {} -> incorrectOperands
 
       -- Ternary MachOps
       MO_FMA {} -> incorrectOperands
@@ -1677,78 +1743,6 @@ getRegister' platform is32Bit (CmmMachOp mop [x, y]) = do -- dyadic MachOps
     vector_float_unpack_sse _ w c e
       = pprPanic "Unpack not supported for : " (pdoc platform c $$ pdoc platform e $$ ppr w)
     -----------------------
-    vector_float_broadcast_avx :: Length
-                               -> Width
-                               -> CmmExpr
-                               -> CmmExpr
-                               -> NatM Register
-    vector_float_broadcast_avx len W32 expr1 expr2
-      = do
-      fn        <- getAnyReg expr1
-      (r', exp) <- getSomeReg expr2
-      let f    = VecFormat len FmtFloat
-          addr = spRel platform 0
-       in return $ Any f (\r -> exp    `appOL`
-                                (fn r) `snocOL`
-                                (MOVU f (OpReg r') (OpAddr addr)) `snocOL`
-                                (VBROADCAST f addr r))
-    vector_float_broadcast_avx len W64 expr1 expr2
-      = do
-      fn        <- getAnyReg  expr1
-      (r', exp) <- getSomeReg expr2
-      let f    = VecFormat len FmtDouble
-          addr = spRel platform 0
-       in return $ Any f (\r -> exp    `appOL`
-                                (fn r) `snocOL`
-                                (MOVU f (OpReg r') (OpAddr addr)) `snocOL`
-                                (MOVL f (OpAddr addr) (OpReg r)) `snocOL`
-                                (MOVH f (OpAddr addr) (OpReg r)))
-    vector_float_broadcast_avx _ _ c _
-      = pprPanic "Broadcast not supported for : " (pdoc platform c)
-    -----------------------
-    vector_float_broadcast_sse :: Length
-                               -> Width
-                               -> CmmExpr
-                               -> CmmExpr
-                               -> NatM Register
-    vector_float_broadcast_sse len W32 expr1 expr2
-      = do
-      fn       <- getAnyReg  expr1  -- destination
-      (r, exp) <- getSomeReg expr2  -- source
-      let f        = VecFormat len FmtFloat
-          addr     = spRel platform 0
-          code dst = exp `appOL`
-                     (fn dst) `snocOL`
-                     (MOVU f (OpReg r) (OpAddr addr)) `snocOL`
-                     (insertps 0) `snocOL`
-                     (insertps 16) `snocOL`
-                     (insertps 32) `snocOL`
-                     (insertps 48)
-            where
-              insertps off =
-                INSERTPS f (litToImm $ CmmInt off W32) (OpAddr addr) dst
-
-       in return $ Any f code
-    vector_float_broadcast_sse _ _ c _
-      = pprPanic "Broadcast not supported for : " (pdoc platform c)
-
-    vector_int_broadcast :: Length
-                         -> Width
-                         -> CmmExpr
-                         -> CmmExpr
-                         -> NatM Register
-    vector_int_broadcast len W64 expr1 expr2
-      = do
-      fn        <- getAnyReg  expr1
-      (val, exp) <- getSomeReg expr2
-      let fmt    = VecFormat len FmtInt64
-      return $ Any fmt (\dst -> exp `appOL`
-                                (fn dst) `snocOL`
-                                (MOV II64 (OpReg val) (OpReg dst)) `snocOL`
-                                (PUNPCKLQDQ fmt (OpReg dst) dst))
-    vector_int_broadcast _ _ c _
-      = pprPanic "Broadcast not supported for : " (pdoc platform c)
-    -----------------------
 
     vector_int_unpack_sse :: Length
                           -> Width


=====================================
compiler/GHC/CmmToLlvm/CodeGen.hs
=====================================
@@ -1460,6 +1460,9 @@ genMachOp _ op [x] = case op of
             all0s = LMLitVar $ LMVectorLit (replicate len all0)
         in negateVec vecty all0s LM_MO_FSub
 
+    MO_V_Broadcast  l w -> genBroadcastOp l w x
+    MO_VF_Broadcast l w -> genBroadcastOp l w x
+
     MO_RelaxedRead w -> exprToVar (CmmLoad x (cmmBits w) NaturallyAligned)
 
     MO_AlignmentCheck _ _ -> panic "-falignment-sanitisation is not supported by -fllvm"
@@ -1520,8 +1523,6 @@ genMachOp _ op [x] = case op of
     MO_VU_Quot    _ _ -> panicOp
     MO_VU_Rem     _ _ -> panicOp
 
-    MO_VF_Broadcast _ _ -> panicOp
-    MO_V_Broadcast _ _ -> panicOp
     MO_VF_Insert  _ _ -> panicOp
     MO_VF_Extract _ _ -> panicOp
 
@@ -1719,12 +1720,11 @@ genMachOp_slow opt op [x, y] = case op of
     MO_WF_Bitcast _to ->  panicOp
     MO_FW_Bitcast _to ->  panicOp
 
-    MO_V_Insert  {} -> panicOp
-
     MO_VS_Neg {} -> panicOp
 
-    MO_V_Broadcast  {} -> panicOp
-    MO_VF_Broadcast  {} -> panicOp
+    MO_VF_Broadcast {} -> panicOp
+    MO_V_Broadcast {} -> panicOp
+    MO_V_Insert  {} -> panicOp
     MO_VF_Insert  {} -> panicOp
 
     MO_V_Shuffle _ _ is -> genShuffleOp is x y
@@ -1818,12 +1818,12 @@ genMachOp_slow opt op [x, y] = case op of
                     pprPanic "isSMulOK: Not bit type! " $
                         lparen <> ppr word <> rparen
 
-        panicOp = panic $ "LLVM.CodeGen.genMachOp_slow: non-binary op encountered"
+        panicOp = panic $ "LLVM.CodeGen.genMachOp_slow: non-binary op encountered "
                        ++ "with two arguments! (" ++ show op ++ ")"
 
 genMachOp_slow _opt op [x, y, z] = do
   let
-    panicOp = panic $ "LLVM.CodeGen.genMachOp_slow: non-ternary op encountered"
+    panicOp = panic $ "LLVM.CodeGen.genMachOp_slow: non-ternary op encountered "
                    ++ "with three arguments! (" ++ show op ++ ")"
   case op of
     MO_FMA var lg width ->
@@ -1846,6 +1846,21 @@ genMachOp_slow _opt op [x, y, z] = do
 -- More than three expressions, invalid!
 genMachOp_slow _ _ _ = panic "genMachOp_slow: More than 3 expressions in MachOp!"
 
+genBroadcastOp :: Int -> Width -> CmmExpr -> LlvmM ExprData
+genBroadcastOp lg _width x = runExprData $ do
+  -- To broadcast a scalar x as a vector v:
+  --   1. insert x at the 0 position of the zero vector
+  --   2. shuffle x into all positions
+  var_x <- exprToVarW x
+  let tx = getVarType var_x
+      tv = LMVector lg tx
+      z = if isFloat tx
+          then LMFloatLit 0 tx
+          else LMIntLit   0 tx
+      zs = LMLitVar $ LMVectorLit $ replicate lg z
+  w <- doExprW tv $ Insert zs var_x (LMLitVar $ LMIntLit 0 (LMInt 32))
+  doExprW tv $ Shuffle w w (replicate lg 0)
+
 genShuffleOp :: [Int] -> CmmExpr -> CmmExpr -> LlvmM ExprData
 genShuffleOp is x y = runExprData $ do
   vx <- exprToVarW x


=====================================
compiler/GHC/StgToCmm/Prim.hs
=====================================
@@ -949,16 +949,8 @@ emitPrimOp cfg primop =
 -- SIMD primops
   (VecBroadcastOp vcat n w) -> \[e] -> opIntoRegs $ \[res] -> do
     checkVecCompatibility cfg vcat n w
-    doVecBroadcastOp ty zeros e res
+    doVecBroadcastOp ty e res
    where
-    zeros :: CmmExpr
-    zeros = CmmLit $ CmmVec (replicate n zero)
-
-    zero :: CmmLit
-    zero = case vcat of
-             IntVec   -> CmmInt 0 w
-             WordVec  -> CmmInt 0 w
-             FloatVec -> CmmFloat 0 w
 
     ty :: CmmType
     ty = vecVmmType vcat n w
@@ -2612,28 +2604,17 @@ checkVecCompatibility cfg vcat l w =
 -- Helpers for translating vector packing and unpacking.
 
 doVecBroadcastOp :: CmmType       -- Type of vector
-                 -> CmmExpr       -- Initial vector
-                 -> CmmExpr     -- Elements
+                 -> CmmExpr       -- Element
                  -> CmmFormal     -- Destination for result
                  -> FCode ()
-doVecBroadcastOp ty z es res = do
-    dst <- newTemp ty
-    emitAssign (CmmLocal dst) z
-    vecBroadcast dst es 0
+doVecBroadcastOp ty e dst
+  | isFloatType (vecElemType ty)
+  = emitAssign (CmmLocal dst) (CmmMachOp (MO_VF_Broadcast len wid) [e])
+  | otherwise
+  = emitAssign (CmmLocal dst) (CmmMachOp (MO_V_Broadcast len wid) [e])
   where
-    vecBroadcast :: CmmFormal -> CmmExpr -> Int -> FCode ()
-    vecBroadcast src e _ = do
-        dst <- newTemp ty
-        if isFloatType (vecElemType ty)
-          then emitAssign (CmmLocal dst) (CmmMachOp (MO_VF_Broadcast len wid)
-                                                    [CmmReg (CmmLocal src), e])
-          else emitAssign (CmmLocal dst) (CmmMachOp (MO_V_Broadcast len wid)
-                                                    [CmmReg (CmmLocal src), e])
-        emitAssign (CmmLocal res) (CmmReg (CmmLocal dst))
-
     len :: Length
     len = vecLength ty
-
     wid :: Width
     wid = typeWidth (vecElemType ty)
 


=====================================
testsuite/tests/simd/should_run/simd008.hs
=====================================
@@ -1,6 +1,5 @@
 {-# OPTIONS_GHC -mavx #-}
 {-# OPTIONS_GHC -msse4 #-}
-{-# OPTIONS_GHC -ddump-asm-native -ddump-asm-regalloc -ddump-asm-liveness #-}
 {-# LANGUAGE MagicHash #-}
 {-# LANGUAGE UnboxedTuples #-}
 {-# LANGUAGE ExtendedLiterals #-}



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/df28c0e55e4a921024ac2a7d26d6d3e3bc3393f8

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/df28c0e55e4a921024ac2a7d26d6d3e3bc3393f8
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240614/38c73d87/attachment-0001.html>


More information about the ghc-commits mailing list