[Git][ghc/ghc][wip/ncg-simd] Fix treatment of signed zero in vector negation

sheaf (@sheaf) gitlab at gitlab.haskell.org
Fri Jun 28 11:18:16 UTC 2024



sheaf pushed to branch wip/ncg-simd at Glasgow Haskell Compiler / GHC


Commits:
8edac7a8 by sheaf at 2024-06-28T13:18:03+02:00
Fix treatment of signed zero in vector negation

This commit fixes the handling of signed zero in floating-point vector
negation.

A slight hack was introduced to work around the fact that Cmm doesn't
currently have a notion of signed floating point literals
(see get_float_broadcast_value_reg). This can be removed once CmmFloat
can express the value -0.0.

The simd006 test has been updated to use a stricter notion of equality
of floating-point values, which ensure the validity of this change.

- - - - -


5 changed files:

- compiler/GHC/CmmToAsm/Format.hs
- compiler/GHC/CmmToAsm/X86/CodeGen.hs
- compiler/GHC/CmmToAsm/X86/Instr.hs
- compiler/GHC/CmmToAsm/X86/Ppr.hs
- testsuite/tests/simd/should_run/simd006.hs


Changes:

=====================================
compiler/GHC/CmmToAsm/Format.hs
=====================================
@@ -24,6 +24,7 @@ module GHC.CmmToAsm.Format (
     isVecFormat,
     cmmTypeFormat,
     formatToWidth,
+    scalarWidth,
     formatInBytes,
     isFloatScalarFormat,
     scalarFormatFormat,


=====================================
compiler/GHC/CmmToAsm/X86/CodeGen.hs
=====================================
@@ -944,6 +944,7 @@ getRegister' _ _ (CmmLit lit@(CmmFloat f w)) =
   float_const_sse2  where
   float_const_sse2
     | f == 0.0 = do
+      -- TODO: this mishandles negative zero floating point literals.
       let
           format = floatFormat w
           code dst = unitOL  (XOR format (OpReg dst) (OpReg dst))
@@ -951,9 +952,7 @@ getRegister' _ _ (CmmLit lit@(CmmFloat f w)) =
         -- They all appear to do the same thing --SDM
       return (Any format code)
 
-   | otherwise = do
-      Amode addr code <- memConstant (mkAlignment $ widthInBytes w) lit
-      loadFloatAmode w addr code
+   | otherwise = getFloatLitRegister lit
 
 -- catch simple cases of zero- or sign-extended load
 getRegister' _ _ (CmmMachOp (MO_UU_Conv W8 W32) [CmmLoad addr _ _]) = do
@@ -1100,10 +1099,11 @@ getRegister' platform is32Bit (CmmMachOp mop [x]) = do -- unary MachOps
       MO_FS_Truncate from to -> coerceFP2Int from to x
       MO_SF_Round    from to -> coerceInt2FP from to x
 
-      MO_VF_Neg l w  | avx           -> vector_float_negate_avx l w x
-                     | sse && sse2   -> vector_float_negate_sse l w x
+      MO_VF_Neg l w  | avx              -> vector_float_negate_avx l w x
+                     | sse  && w == W32 -> vector_float_negate_sse l w x
+                     | sse2 && w == W64 -> vector_float_negate_sse l w x
                      | otherwise
-                       -> sorry "Please enable the -mavx or -msse, -msse2 flag"
+                     -> sorry "Please enable the -mavx or -msse, -msse2 flag"
       -- SIMD NCG TODO: add integer negation
       MO_VS_Neg {} -> needLlvm mop
 
@@ -1111,7 +1111,7 @@ getRegister' platform is32Bit (CmmMachOp mop [x]) = do -- unary MachOps
                             | sse4_1    -> vector_float_broadcast_sse l W32 x
                             | otherwise
                               -> sorry "Please enable the -mavx or -msse4 flag"
-      MO_VF_Broadcast l W64 | sse2      -> vector_float_broadcast_avx l W64 x
+      MO_VF_Broadcast l W64 | sse2      -> vector_float_broadcast_sse l W64 x
                             | otherwise -> sorry "Please enable the -msse2 flag"
       MO_VF_Broadcast {} -> incorrectOperands
 
@@ -1222,70 +1222,77 @@ getRegister' platform is32Bit (CmmMachOp mop [x]) = do -- unary MachOps
 
         vector_float_negate_avx :: Length -> Width -> CmmExpr -> NatM Register
         vector_float_negate_avx l w expr = do
-          tmp                  <- getNewRegNat (VecFormat l FmtFloat)
-          (reg, exp)           <- getSomeReg expr
-          Amode addr addr_code <- memConstant (mkAlignment $ widthInBytes W32) (CmmFloat 0.0 W32)
-          let format   = case w of
-                           W32 -> VecFormat l FmtFloat
-                           W64 -> VecFormat l FmtDouble
-                           _ -> pprPanic "Cannot negate vector of width" (ppr w)
-              code dst = case w of
-                           W32 -> exp `appOL` addr_code `snocOL`
-                                  (VBROADCAST format addr tmp) `snocOL`
-                                  (VSUB format (OpReg reg) tmp dst)
-                           W64 -> exp `appOL` addr_code `snocOL`
-                                  (MOVL format (OpAddr addr) (OpReg tmp)) `snocOL`
-                                  (MOVH format (OpAddr addr) (OpReg tmp)) `snocOL`
-                                  (VSUB format (OpReg reg) tmp dst)
-                           _ -> pprPanic "Cannot negate vector of width" (ppr w)
-          return (Any format code)
+          let fmt :: Format
+              mask :: CmmLit
+              (fmt, mask) = case w of
+                       W32 -> (VecFormat l FmtFloat , CmmInt (bit 31) w) -- TODO: these should be negative 0 floating point literals,
+                       W64 -> (VecFormat l FmtDouble, CmmInt (bit 63) w) -- but we don't currently have those in Cmm.
+                       _ -> panic "AVX floating-point negation: elements must be FF32 or FF64"
+          (maskReg, maskCode) <- getSomeReg (CmmMachOp (MO_VF_Broadcast l w) [CmmLit mask])
+          (reg, exp) <- getSomeReg expr
+          let code dst = maskCode `appOL`
+                         exp `snocOL`
+                         (VMOVU fmt (OpReg reg) (OpReg dst)) `snocOL`
+                         (VXOR fmt (OpReg maskReg) dst dst)
+          return (Any fmt code)
 
         vector_float_negate_sse :: Length -> Width -> CmmExpr -> NatM Register
         vector_float_negate_sse l w expr = do
-          tmp                  <- getNewRegNat (VecFormat l FmtFloat)
-          (reg, exp)           <- getSomeReg expr
-          let format   = case w of
-                           W32 -> VecFormat l FmtFloat
-                           W64 -> VecFormat l FmtDouble
-                           _ -> pprPanic "Cannot negate vector of width" (ppr w)
-              code dst = exp `snocOL`
-                         (XOR format (OpReg tmp) (OpReg tmp)) `snocOL`
-                         (MOVU format (OpReg tmp) (OpReg dst)) `snocOL`
-                         (SUB format (OpReg reg) (OpReg dst))
-          return (Any format code)
+          let fmt :: Format
+              mask :: CmmLit
+              (fmt, mask) = case w of
+                       W32 -> (VecFormat l FmtFloat , CmmInt (bit 31) w) -- Same comment as for vector_float_negate_avx,
+                       W64 -> (VecFormat l FmtDouble, CmmInt (bit 63) w) -- these should be -0.0 CmmFloat values.
+                       _ -> panic "SSE floating-point negation: elements must be FF32 or FF64"
+          (maskReg, maskCode) <- getSomeReg (CmmMachOp (MO_VF_Broadcast l w) [CmmLit mask])
+          (reg, exp) <- getSomeReg expr
+          let code dst = maskCode `appOL`
+                         exp `snocOL`
+                         (MOVU fmt (OpReg reg) (OpReg dst)) `snocOL`
+                         (XOR  fmt (OpReg maskReg) (OpReg dst))
+          return (Any fmt code)
 
         -----------------------
+
+        -- Like 'getSomeReg', but with special handling for int literals
+        -- used as floating point values, to work around the fact that we don't
+        -- have negative zero floating point literals in Cmm yet.
+        --
+        -- This should get removed once we have negative zero in CmmFloat.
+        get_float_broadcast_value_reg expr = case expr of
+          CmmLit lit -> do
+            r <- getFloatLitRegister lit
+            case r of
+              Any rep code -> do
+                tmp <- getNewRegNat rep
+                return (tmp, code tmp)
+              Fixed _ reg code ->
+                return (reg, code)
+          _ -> getSomeReg expr
+
         vector_float_broadcast_avx :: Length
                                    -> Width
                                    -> CmmExpr
                                    -> NatM Register
         vector_float_broadcast_avx len W32 expr
           = do
-          (reg, exp) <- getSomeReg expr
+          (reg, exp) <- get_float_broadcast_value_reg expr
           let f    = VecFormat len FmtFloat
               addr = spRel platform 0
-           in return $ Any f (\dst -> exp    `snocOL`
-                                    (MOVU f (OpReg reg) (OpAddr addr)) `snocOL`
-                                    (VBROADCAST f addr dst))
-        vector_float_broadcast_avx len W64 expr
-          = do
-          (reg, exp) <- getSomeReg expr
-          let f    = VecFormat len FmtDouble
-              addr = spRel platform 0
            in return $ Any f (\dst -> exp `snocOL`
                                     (MOVU f (OpReg reg) (OpAddr addr)) `snocOL`
-                                    (MOVL f (OpAddr addr) (OpReg dst)) `snocOL`
-                                    (MOVH f (OpAddr addr) (OpReg dst)))
-        vector_float_broadcast_avx _ _ c
-          = pprPanic "Broadcast not supported for : " (pdoc platform c)
-        -----------------------
+                                    (VBROADCAST f addr dst))
+        vector_float_broadcast_avx l w _
+          -- NB: for some reason, VBROADCASTSD does not support xmm, only ymm.
+          = pprPanic "vector_float_broadcast_avx" (text "l" <+> ppr l $$ text "w" <+> ppr w)
+
         vector_float_broadcast_sse :: Length
                                    -> Width
                                    -> CmmExpr
                                    -> NatM Register
         vector_float_broadcast_sse len W32 expr
           = do
-          (reg, exp) <- getSomeReg expr
+          (reg, exp) <- get_float_broadcast_value_reg expr
           let f        = VecFormat len FmtFloat
               addr     = spRel platform 0
               code dst = exp `snocOL`
@@ -1299,6 +1306,15 @@ getRegister' platform is32Bit (CmmMachOp mop [x]) = do -- unary MachOps
                     INSERTPS f (ImmInt imm) (OpAddr addr) dst
 
            in return $ Any f code
+        vector_float_broadcast_sse len W64 expr
+          = do
+          (reg, exp) <- get_float_broadcast_value_reg expr
+          let f    = VecFormat len FmtDouble
+              addr = spRel platform 0
+           in return $ Any f (\dst -> exp `snocOL`
+                                    (MOVU f (OpReg reg) (OpAddr addr)) `snocOL`
+                                    (MOVL f (OpAddr addr) (OpReg dst)) `snocOL`
+                                    (MOVH f (OpAddr addr) (OpReg dst)))
         vector_float_broadcast_sse _ _ c
           = pprPanic "Broadcast not supported for : " (pdoc platform c)
 
@@ -2057,33 +2073,48 @@ getRegister' platform is32Bit (CmmLit lit)
         -- note2: all labels are small, because we're assuming the
         -- small memory model. See Note [%rip-relative addressing on x86-64].
 
-getRegister' platform _ (CmmLit lit)
-  | isVecType cmmtype = vectorRegister cmmtype
-  | otherwise         = standardRegister cmmtype
-  where
-    cmmtype = cmmLitType platform lit
-    vectorRegister ctype
-      | case lit of { CmmVec fs -> all (\case { CmmInt i _ -> i == 0; CmmFloat f _ -> f == 0; _ -> False }) fs; _ -> False }
-      = -- NOTE:
-        -- This operation is only used to zero a register. For loading a
-        -- vector literal there are pack and broadcast operations
-        let format = cmmTypeFormat ctype
-            code dst = unitOL (XOR format (OpReg dst) (OpReg dst))
-        in return (Any format code)
+getRegister' platform _ (CmmLit lit) =
+  case fmt of
+    VecFormat l sFmt
+      | case lit of { CmmVec fs -> all is_zero fs; _ -> False }
+      -> let code dst = unitOL (XOR fmt (OpReg dst) (OpReg dst))
+         in return (Any fmt code)
+      | Just f <- case lit of { CmmVec (f:fs) | all (== f) fs -> Just f; _ -> Nothing }
+      -> do config <- getConfig
+            let w = scalarWidth sFmt
+                broadcast = if isFloatScalarFormat sFmt
+                            then MO_VF_Broadcast l w
+                            else MO_V_Broadcast l w
+            (valReg, valCode) <- getSomeReg (CmmMachOp broadcast [CmmLit f])
+            let code dst =
+                   valCode `snocOL`
+                   (mkRegRegMoveInstr config fmt valReg dst)
+            return $ Any fmt code
       | otherwise
-      = pprPanic "getRegister': no support for (nonzero) vector literals" $
-          vcat [ text "lit:" <+> ppr lit ]
-      -- SIMD NCG TODO: can we do better here?
-    standardRegister ctype
-      = do
-      let format = cmmTypeFormat ctype
-          imm = litToImm lit
-          code dst = unitOL (MOV format (OpImm imm) (OpReg dst))
-      return (Any format code)
+      -- SIMD NCG TODO: handle this case as well.
+      -> pprPanic "getRegister': non-constant vector literals not supported"
+          (ppr lit)
+       where
+        is_zero (CmmInt i _) = i == 0
+        is_zero (CmmFloat f _) = f == 0 -- TODO: mishandles negative zero
+        is_zero _ = False
+
+    _ -> let imm = litToImm lit
+             code dst = unitOL (MOV fmt (OpImm imm) (OpReg dst))
+         in return (Any fmt code)
+  where
+    cmmTy = cmmLitType platform lit
+    fmt = cmmTypeFormat cmmTy
 
 getRegister' platform _ other
   = pprPanic "getRegister(x86)" (pdoc platform other)
 
+getFloatLitRegister :: CmmLit -> NatM Register
+getFloatLitRegister lit = do
+  let w :: Width
+      w = case lit of { CmmInt _ w -> w; CmmFloat _ w -> w; _ -> panic "getFloatLitRegister" (ppr lit) }
+  Amode addr code <- memConstant (mkAlignment $ widthInBytes w) lit
+  loadFloatAmode w addr code
 
 intLoadCode :: (Operand -> Operand -> Instr) -> CmmExpr
    -> NatM (Reg -> InstrBlock)


=====================================
compiler/GHC/CmmToAsm/X86/Instr.hs
=====================================
@@ -259,6 +259,8 @@ data Instr
         | AND         Format Operand Operand
         | OR          Format Operand Operand
         | XOR         Format Operand Operand
+        -- | AVX bitwise logical XOR operation
+        | VXOR        Format Operand Reg Reg
         | NOT         Format Operand
         | NEGI        Format Operand         -- NEG instruction (name clash with Cond)
         | BSWAP       Format Reg
@@ -477,9 +479,16 @@ regUsageOfInstr platform instr
     OR     fmt src dst    -> usageRM fmt src dst
 
     XOR    fmt (OpReg src) (OpReg dst)
-        | src == dst    -> mkRU [] [mk fmt dst]
+      | src == dst
+      -> mkRU [] [mk fmt dst]
+    XOR    fmt src dst
+      -> usageRM fmt src dst
+    VXOR fmt (OpReg src1) src2 dst
+      | src1 == src2, src1 == dst
+      -> mkRU [] [mk fmt dst]
+    VXOR fmt src1 src2 dst
+      -> mkRU (use_R fmt src1 [mk fmt src2]) [mk fmt dst]
 
-    XOR    fmt src dst    -> usageRM fmt src dst
     NOT    fmt op         -> usageM fmt op
     BSWAP  fmt reg        -> mkRU [mk fmt reg] [mk fmt reg]
     NEGI   fmt op         -> usageM fmt op
@@ -722,6 +731,7 @@ patchRegsOfInstr platform instr env
     AND  fmt src dst     -> patch2 (AND  fmt) src dst
     OR   fmt src dst     -> patch2 (OR   fmt) src dst
     XOR  fmt src dst     -> patch2 (XOR  fmt) src dst
+    VXOR fmt src1 src2 dst -> VXOR fmt (patchOp src1) (env src2) (env dst)
     NOT  fmt op          -> patch1 (NOT  fmt) op
     BSWAP fmt reg        -> BSWAP fmt (env reg)
     NEGI fmt op          -> patch1 (NEGI fmt) op
@@ -764,6 +774,8 @@ patchRegsOfInstr platform instr env
     LOCATION {}         -> instr
     UNWIND {}           -> instr
     DELTA _             -> instr
+    LDATA {}            -> instr
+    NEWBLOCK {}         -> instr
 
     JXX _ _             -> instr
     JXX_GBL _ _         -> instr
@@ -830,8 +842,6 @@ patchRegsOfInstr platform instr env
     PUNPCKLQDQ fmt src dst
       -> PUNPCKLQDQ fmt (patchOp src) (env dst)
 
-    _other              -> panic "patchRegs: unrecognised instr"
-
   where
     patch1 :: (Operand -> a) -> Operand -> a
     patch1 insn op      = insn $! patchOp op


=====================================
compiler/GHC/CmmToAsm/X86/Ppr.hs
=====================================
@@ -724,11 +724,14 @@ pprInstr platform i = case i of
    XOR format src dst
       -> pprFormatOpOp (text "xor") format src dst
 
+   VXOR fmt src1 src2 dst
+      -> pprVxor fmt src1 src2 dst
+
    POPCNT format src dst
       -> pprOpOp (text "popcnt") format src (OpReg dst)
 
    LZCNT format src dst
-      ->  pprOpOp (text "lzcnt") format src (OpReg dst)
+      -> pprOpOp (text "lzcnt") format src (OpReg dst)
 
    TZCNT format src dst
       -> pprOpOp (text "tzcnt") format src (OpReg dst)
@@ -1276,6 +1279,23 @@ pprInstr platform i = case i of
            pprReg platform format reg3
        ]
 
+   pprVxor :: Format -> Operand -> Reg -> Reg -> doc
+   pprVxor fmt src1 src2 dst
+     = line $ hcat [
+           pprGenMnemonic mem fmt,
+           pprOperand platform fmt src1,
+           comma,
+           pprReg platform fmt src2,
+           comma,
+           pprReg platform fmt dst
+       ]
+     where
+      mem = case fmt of
+        VecFormat _ FmtFloat -> text "vxorps"
+        VecFormat _ FmtDouble -> text "vxorpd"
+        _ -> pprPanic "GHC.CmmToAsm.X86.Ppr.pprVxor: elementy type must be Float or Double"
+              (ppr fmt)
+
    pprInsert :: Line doc -> Format -> Imm -> Operand -> Reg -> doc
    pprInsert name format off src dst
      = line $ hcat [


=====================================
testsuite/tests/simd/should_run/simd006.hs
=====================================
@@ -127,17 +127,15 @@ instance Arbitrary Word32 where
 newtype FloatNT = FloatNT Float
   deriving newtype (Show, Num)
 instance Eq FloatNT where
-  FloatNT f1 == FloatNT f2 = f1 == f2
-    -- TODO: tests fail with this equality due to signed zeros
-    -- castFloatToWord32 f1 == castFloatToWord32 f2
+  FloatNT f1 == FloatNT f2 =
+    castFloatToWord32 f1 == castFloatToWord32 f2
 instance Arbitrary FloatNT where
   arbitrary = FloatNT . castWord32ToFloat <$> arbitrary
 newtype DoubleNT = DoubleNT Double
   deriving newtype (Show, Num)
 instance Eq DoubleNT where
-  DoubleNT d1 == DoubleNT d2 = d1 == d2
-    -- TODO: tests fail with this equality due to signed zeros
-    -- castDoubleToWord64 d1 == castDoubleToWord64 d2
+  DoubleNT d1 == DoubleNT d2 =
+    castDoubleToWord64 d1 == castDoubleToWord64 d2
 instance Arbitrary DoubleNT where
   arbitrary = DoubleNT . castWord64ToDouble <$> arbitrary
 



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/8edac7a83baf3bbdf2df1f1d4a2ef1e31e6d8aaa

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/8edac7a83baf3bbdf2df1f1d4a2ef1e31e6d8aaa
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240628/c22f7d7c/attachment-0001.html>


More information about the ghc-commits mailing list