[Git][ghc/ghc][wip/ncg-simd] 2 commits: Add test for C calls & SIMD vectors

sheaf (@sheaf) gitlab at gitlab.haskell.org
Thu Aug 8 18:06:13 UTC 2024



sheaf pushed to branch wip/ncg-simd at Glasgow Haskell Compiler / GHC


Commits:
df2185b0 by sheaf at 2024-08-08T20:05:25+02:00
Add test for C calls & SIMD vectors

- - - - -
ac0602ad by sheaf at 2024-08-08T20:05:59+02:00
Attempt to fix C calls with SIMD vectors

- - - - -


5 changed files:

- compiler/GHC/CmmToAsm/X86/CodeGen.hs
- compiler/GHC/CmmToAsm/X86/Instr.hs
- testsuite/tests/simd/should_run/all.T
- + testsuite/tests/simd/should_run/simd013.hs
- + testsuite/tests/simd/should_run/simd013C.c


Changes:

=====================================
compiler/GHC/CmmToAsm/X86/CodeGen.hs
=====================================
@@ -844,12 +844,12 @@ iselExpr64ParallelBin op e1 e2 = do
 -- targetted for any particular type like Int8, Int32 etc
 data VectorArithInstns = VA_Add | VA_Sub | VA_Mul | VA_Div | VA_Min | VA_Max
 
-getRegister :: CmmExpr -> NatM Register
+getRegister :: HasDebugCallStack => CmmExpr -> NatM Register
 getRegister e = do platform <- getPlatform
                    is32Bit <- is32BitPlatform
                    getRegister' platform is32Bit e
 
-getRegister' :: Platform -> Bool -> CmmExpr -> NatM Register
+getRegister' :: HasDebugCallStack => Platform -> Bool -> CmmExpr -> NatM Register
 
 getRegister' platform is32Bit (CmmReg reg)
   = case reg of
@@ -2306,7 +2306,7 @@ getNonClobberedOperand (CmmLit lit) =
     return (OpAddr addr, code)
   else do
     platform <- getPlatform
-    if is32BitLit platform lit && not (isFloatType (cmmLitType platform lit))
+    if is32BitLit platform lit && isIntFormat (cmmTypeFormat (cmmLitType platform lit))
     then return (OpImm (litToImm lit), nilOL)
     else getNonClobberedOperand_generic (CmmLit lit)
 
@@ -2363,13 +2363,13 @@ getOperand (CmmLit lit) = do
     else do
 
   platform <- getPlatform
-  if is32BitLit platform lit && not (isFloatType (cmmLitType platform lit))
+  if is32BitLit platform lit && (isIntFormat $ cmmTypeFormat (cmmLitType platform lit))
     then return (OpImm (litToImm lit), nilOL)
     else getOperand_generic (CmmLit lit)
 
 getOperand (CmmLoad mem ty _) = do
   is32Bit <- is32BitPlatform
-  if not (isFloatType ty) && (if is32Bit then not (isWord64 ty) else True)
+  if isIntFormat (cmmTypeFormat ty) && (if is32Bit then not (isWord64 ty) else True)
      then do
        Amode src mem_code <- getAmode mem
        return (OpAddr src, mem_code)
@@ -2400,7 +2400,7 @@ addAlignmentCheck align reg =
   where
     check :: Format -> Reg -> InstrBlock
     check fmt reg =
-        assert (not $ isFloatFormat fmt) $
+        assert (isIntFormat fmt) $
         toOL [ TEST fmt (OpImm $ ImmInt $ align-1) (OpReg reg)
              , JXX_GBL NE $ ImmCLbl mkBadAlignmentLabel
              ]
@@ -2445,7 +2445,7 @@ isSuitableFloatingPointLit _ = False
 getRegOrMem :: CmmExpr -> NatM (Operand, InstrBlock)
 getRegOrMem e@(CmmLoad mem ty _) = do
   is32Bit <- is32BitPlatform
-  if not (isFloatType ty) && (if is32Bit then not (isWord64 ty) else True)
+  if isIntFormat (cmmTypeFormat ty) && (if is32Bit then not (isWord64 ty) else True)
      then do
        Amode src mem_code <- getAmode mem
        return (OpAddr src, mem_code)
@@ -3319,7 +3319,7 @@ genCCall32 addr (ForeignConvention _ argHints _ _) dest_regs args = do
                                      DELTA (delta-8)]
                     )
 
-              | isFloatType arg_ty = do
+              | isFloatType arg_ty || isVecType arg_ty = do
                 (reg, code) <- getSomeReg arg
                 delta <- getDeltaNat
                 setDeltaNat (delta-size)
@@ -3329,11 +3329,10 @@ genCCall32 addr (ForeignConvention _ argHints _ _) dest_regs args = do
                                       let addr = AddrBaseIndex (EABaseReg esp)
                                                                 EAIndexNone
                                                                 (ImmInt 0)
-                                          format = floatFormat (typeWidth arg_ty)
+                                          format = cmmTypeFormat arg_ty
                                       in
 
-                                      -- assume SSE2
-                                       MOV format (OpReg reg) (OpAddr addr)
+                                       movInstr config format (OpReg reg) (OpAddr addr)
 
                                      ]
                                )
@@ -3402,6 +3401,8 @@ genCCall32 addr (ForeignConvention _ argHints _ _) dest_regs args = do
             -- assign the results, if necessary
             assign_code []     = nilOL
             assign_code [dest]
+              | isVecType ty
+              = sorry "X86_32 C call: no support for returning SIMD vectors"
               | isFloatType ty =
                   -- we assume SSE2
                   let tmp_amode = AddrBaseIndex (EABaseReg esp)
@@ -3448,36 +3449,41 @@ genCCall64 addr conv@(ForeignConvention _ argHints _ _) dest_regs args = do
     let prom_args = map (maybePromoteCArg platform W32) args_hints
 
     let load_args :: [CmmExpr]
-                  -> [RegFormat]         -- int regs avail for args
-                  -> [RegFormat]         -- FP regs avail for args
+                  -> [Reg]         -- int regs avail for args
+                  -> [Reg]         -- FP regs avail for args
+                  -> [RegFormat]   -- used int regs
+                  -> [RegFormat]   -- used FP regs
                   -> InstrBlock    -- code computing args
                   -> InstrBlock    -- code assigning args to ABI regs
                   -> NatM ([CmmExpr],[RegFormat],[RegFormat],InstrBlock,InstrBlock)
         -- no more regs to use
-        load_args args [] [] code acode     =
-            return (args, [], [], code, acode)
+        load_args args [] [] used_aregs used_fregs code acode     =
+            return (args, used_aregs, used_fregs, code, acode)
 
         -- no more args to push
-        load_args [] aregs fregs code acode =
-            return ([], aregs, fregs, code, acode)
-
-        load_args (arg : rest) aregs fregs code acode
-            | isFloatType arg_rep = case fregs of
-                 []     -> push_this_arg
-                 (RegFormat r _fmt:rs) -> do
-                    (code',acode') <- reg_this_arg r
-                    load_args rest aregs rs code' acode'
-            | otherwise           = case aregs of
-                 []     -> push_this_arg
-                 (RegFormat r _fmt:rs) -> do
-                    (code',acode') <- reg_this_arg r
-                    load_args rest rs fregs code' acode'
+        load_args [] _aregs _fregs used_aregs used_fregs code acode =
+            return ([], used_aregs, used_fregs, code, acode)
+
+        load_args (arg : rest) aregs fregs used_aregs used_fregs code acode
+            | isFloatType arg_rep || isVecType arg_rep
+            = case fregs of
+                []     -> push_this_arg
+                (r:rs) -> do
+                   (code',acode') <- reg_this_arg r
+                   load_args rest aregs rs used_aregs (RegFormat r fmt:used_fregs) code' acode'
+            | otherwise
+            = case aregs of
+                []     -> push_this_arg
+                (r:rs) -> do
+                   (code',acode') <- reg_this_arg r
+                   load_args rest rs fregs (RegFormat r fmt:used_aregs) used_fregs code' acode'
             where
+              fmt = cmmTypeFormat arg_rep
 
               -- put arg into the list of stack pushed args
               push_this_arg = do
                  (args',ars,frs,code',acode')
-                     <- load_args rest aregs fregs code acode
+                     <- load_args rest aregs fregs used_aregs used_fregs code acode
                  return (arg:args', ars, frs, code', acode')
 
               -- pass the arg into the given register
@@ -3522,8 +3528,8 @@ genCCall64 addr conv@(ForeignConvention _ argHints _ _) dest_regs args = do
             -- no more args to push
         load_args_win (arg : rest) usedInt usedFP
                       ((ireg, freg) : regs) code
-            | isFloatType arg_rep = do
-                 arg_code <- getAnyReg arg
+            | isFloatType arg_rep
+            = do arg_code <- getAnyReg arg
                  load_args_win rest (RegFormat ireg II64: usedInt) (RegFormat freg FF64 : usedFP) regs
                                (code `appOL`
                                 arg_code freg `snocOL`
@@ -3531,26 +3537,34 @@ genCCall64 addr conv@(ForeignConvention _ argHints _ _) dest_regs args = do
                                 -- then we need to define ireg as well
                                 -- as freg
                                 MOVD FF64 (OpReg freg) (OpReg ireg))
-            | otherwise = do
-                 arg_code <- getAnyReg arg
+            | isVecType arg_rep
+            , let fmt = cmmTypeFormat arg_rep
+            = do arg_code <- getAnyReg arg
+                 load_args_win rest (RegFormat ireg II64: usedInt) (RegFormat freg fmt : usedFP) regs
+                               (code `appOL` arg_code freg)
+                -- SIMD NCG TODO:
+                --   Vector arguments in a varargs function should be passed
+                --   after other arguments. For the time being we ignore this issue.
+            | otherwise
+            = do arg_code <- getAnyReg arg
                  load_args_win rest (RegFormat ireg II64: usedInt) usedFP regs
                                (code `appOL` arg_code ireg)
             where
               arg_rep = cmmExprType platform arg
 
-        arg_size = 8 -- always, at the mo
+        expr_size arg = max (widthInBytes (wordWidth platform)) $ widthInBytes (typeWidth $ cmmExprType platform arg)
 
         push_args [] code = return code
         push_args (arg:rest) code
-           | isFloatType arg_rep = do
+           | isFloatType arg_rep || isVecType arg_rep = do
              (arg_reg, arg_code) <- getSomeReg arg
              delta <- getDeltaNat
              setDeltaNat (delta-arg_size)
-             let fmt = floatFormat width
+             let fmt = cmmTypeFormat arg_rep
                  code' = code `appOL` arg_code `appOL` toOL [
                             SUB (intFormat (wordWidth platform)) (OpImm (ImmInt arg_size)) (OpReg rsp),
                             DELTA (delta-arg_size),
-                            MOV fmt (OpReg arg_reg) (OpAddr (spRel platform 0))]
+                            movInstr config fmt (OpReg arg_reg) (OpAddr (spRel platform 0))]
              push_args rest code'
 
            | otherwise = do
@@ -3566,28 +3580,28 @@ genCCall64 addr conv@(ForeignConvention _ argHints _ _) dest_regs args = do
                                     DELTA (delta-arg_size)]
              push_args rest code'
             where
+              arg_size = expr_size arg
               arg_rep = cmmExprType platform arg
               width = typeWidth arg_rep
 
         leaveStackSpace n = do
              delta <- getDeltaNat
-             setDeltaNat (delta - n * arg_size)
+             setDeltaNat (delta - n * 8)
              return $ toOL [
                          SUB II64 (OpImm (ImmInt (n * platformWordSizeInBytes platform))) (OpReg rsp),
-                         DELTA (delta - n * arg_size)]
+                         DELTA (delta - n * 8)]
+              -- NB: the shadow store is always 8 * 4 = 32 bytes large,
+              -- i.e. the cumulative size of rcx, rdx, r8, r9 (see 'allArgRegs').
 
     (stack_args, int_regs_used, fp_regs_used, load_args_code, assign_args_code)
          <-
         if platformOS platform == OSMinGW32
         then load_args_win prom_args [] [] (allArgRegs platform) nilOL
         else do
-           let intArgRegs = map (\r -> RegFormat r II64) $ allIntArgRegs platform
-               fpArgRegs = map (\r -> RegFormat r FF64) $ allFPArgRegs platform
-           (stack_args, aregs, fregs, load_args_code, assign_args_code)
-               <- load_args prom_args intArgRegs fpArgRegs nilOL nilOL
-           let used_regs rs as = dropTail (length rs) as
-               fregs_used      = used_regs fregs fpArgRegs
-               aregs_used      = used_regs aregs intArgRegs
+           let intArgRegs = allIntArgRegs platform
+               fpArgRegs = allFPArgRegs platform
+           (stack_args, aregs_used, fregs_used, load_args_code, assign_args_code)
+               <- load_args prom_args intArgRegs fpArgRegs [] [] nilOL nilOL
            return (stack_args, aregs_used, fregs_used, load_args_code
                                                       , assign_args_code)
 
@@ -3597,11 +3611,10 @@ genCCall64 addr conv@(ForeignConvention _ argHints _ _) dest_regs args = do
         arg_regs = [RegFormat eax wordFmt] ++ arg_regs_used
                 -- for annotating the call instruction with
         sse_regs = length fp_regs_used
-        arg_stack_slots = if platformOS platform == OSMinGW32
-                          then length stack_args + length (allArgRegs platform)
-                          else length stack_args
-        tot_arg_size = arg_size * arg_stack_slots
-
+        shadow_store = if platformOS platform == OSMinGW32
+                       then 8 * length (allArgRegs platform)
+                       else 0
+        tot_arg_size = shadow_store + sum (map expr_size stack_args)
 
     -- Align stack to 16n for calls, assuming a starting stack
     -- alignment of 16n - word_size on procedure entry. Which we
@@ -3622,7 +3635,7 @@ genCCall64 addr conv@(ForeignConvention _ argHints _ _) dest_regs args = do
     -- On Win64, we also have to leave stack space for the arguments
     -- that we are passing in registers
     lss_code <- if platformOS platform == OSMinGW32
-                then leaveStackSpace (length (allArgRegs platform))
+                then leaveStackSpace $ length (allArgRegs platform)
                 else return nilOL
     delta <- getDeltaNat
 


=====================================
compiler/GHC/CmmToAsm/X86/Instr.hs
=====================================
@@ -892,7 +892,7 @@ mkLoadInstr config (RegFormat reg fmt) delta slot =
 
 -- | A move instruction for moving the entire contents of an operand
 -- at the given 'Format'.
-movInstr :: NCGConfig -> Format -> (Operand -> Operand -> Instr)
+movInstr :: HasDebugCallStack => NCGConfig -> Format -> (Operand -> Operand -> Instr)
 movInstr config fmt =
   case fmt of
     VecFormat _ sFmt ->
@@ -914,17 +914,38 @@ movInstr config fmt =
         _ -> sorry $ "Unhandled SIMD vector width: " ++ show (8 * bytes) ++ " bits"
     _ -> MOV fmt
   where
+    plat    = ncgPlatform config
     bytes   = formatInBytes fmt
     avx     = ncgAvxEnabled config
     avx2    = ncgAvx2Enabled config
     avx512f = ncgAvx512fEnabled config
     avx_move sFmt =
       if isFloatScalarFormat sFmt
-      then VMOVU   fmt
+      then \ op1 op2 ->
+              if
+                | OpReg r1 <- op1
+                , OpReg r2 <- op2
+                , targetClassOfReg plat r1 /= targetClassOfReg plat r2
+                -> pprPanic "movInstr: VMOVU between incompatible registers"
+                     ( vcat [ text "fmt:" <+> ppr fmt
+                            , text "r1:" <+> ppr r1
+                            , text "r2:" <+> ppr r2 ] )
+                | otherwise
+                -> VMOVU   fmt op1 op2
       else VMOVDQU fmt
     sse_move sFmt =
       if isFloatScalarFormat sFmt
-      then MOVU   fmt
+      then \ op1 op2 ->
+              if
+                | OpReg r1 <- op1
+                , OpReg r2 <- op2
+                , targetClassOfReg plat r1 /= targetClassOfReg plat r2
+                -> pprPanic "movInstr: MOVU between incompatible registers"
+                     ( vcat [ text "fmt:" <+> ppr fmt
+                            , text "r1:" <+> ppr r1
+                            , text "r2:" <+> ppr r2 ] )
+                | otherwise
+                -> MOVU   fmt op1 op2
       else MOVDQU fmt
     -- NB: we are using {V}MOVU and not {V}MOVA, because we have no guarantees
     -- about the stack being sufficiently aligned (even for even numbered stack slots).
@@ -989,12 +1010,7 @@ mkRegRegMoveInstr
     -> Reg
     -> Instr
 mkRegRegMoveInstr config fmt src dst =
-  assertPpr (targetClassOfReg platform src == targetClassOfReg platform dst)
-    (vcat [ text "mkRegRegMoveInstr: incompatible register classes"
-          , text "fmt:" <+> ppr fmt
-          , text "src:" <+> ppr src
-          , text "dst:" <+> ppr dst ]) $
-    movInstr config fmt' (OpReg src) (OpReg dst)
+  movInstr config fmt' (OpReg src) (OpReg dst)
       -- Move the platform word size, at a minimum
   where
     platform = ncgPlatform config


=====================================
testsuite/tests/simd/should_run/all.T
=====================================
@@ -40,6 +40,12 @@ test('simd011', [ unless(have_cpu_feature('fma'), skip)
                 , extra_hc_opts('-mfma')
                 ], compile_and_run, [''])
 test('simd012', [], compile_and_run, [''])
+test('simd013',
+     [ req_c
+     , unless(arch('x86_64'), skip) # because the C file uses Intel intrinsics
+     ],
+     compile_and_run, ['simd013C.c'])
+
 
 test('T25062_V16', [], compile_and_run, [''])
 test('T25062_V32', [ unless(have_cpu_feature('avx2'), skip)


=====================================
testsuite/tests/simd/should_run/simd013.hs
=====================================
@@ -0,0 +1,28 @@
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE UnboxedTuples #-}
+{-# LANGUAGE UnliftedFFITypes #-}
+-- test C calls with SIMD vectors
+
+module Main where
+
+import GHC.Exts
+import GHC.Prim
+
+foreign import ccall "f"
+  f :: DoubleX2# -> DoubleX2# -> DoubleX2#
+
+foreign import ccall "g"
+  g :: DoubleX2# -> DoubleX2# -> DoubleX2# -> DoubleX2# -> DoubleX2# -> DoubleX2#
+
+main :: IO ()
+main = do
+  let x1, x2, x3, x4, x5 :: DoubleX2#
+      !x1 = packDoubleX2# (#     1.1##,         2.1## #)
+      !x2 = packDoubleX2# (#    10.02##,       20.02## #)
+      !x3 = packDoubleX2# (#   100.003##,     200.003## #)
+      !x4 = packDoubleX2# (#  1000.0004##,   2000.0004## #)
+      !x5 = packDoubleX2# (# 10000.00005##, 20000.00005## #)
+      !(# a, b #) = unpackDoubleX2# ( f x1 x2 )
+      !(# c, d #) = unpackDoubleX2# ( g x1 x2 x3 x4 x5 )
+  print ( D# a, D# b )
+  print ( D# c, D# d )


=====================================
testsuite/tests/simd/should_run/simd013C.c
=====================================
@@ -0,0 +1,12 @@
+
+#include <xmmintrin.h>
+
+__m128d f(__m128d x, __m128d y)
+{
+  return _mm_add_pd(x,y);
+}
+
+__m128d g(__m128d x1, __m128d x2, __m128d x3, __m128d x4, __m128d x5)
+{
+  return _mm_add_pd(x1,_mm_add_pd(x2,_mm_add_pd(x3,_mm_add_pd(x4,x5))));
+}



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/55da517aa6ab3e88e53d31d2f13355212b06c2b5...ac0602ad249d5a7076cecc6522fd5de9f37dd3e1

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/55da517aa6ab3e88e53d31d2f13355212b06c2b5...ac0602ad249d5a7076cecc6522fd5de9f37dd3e1
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240808/a4959742/attachment-0001.html>


More information about the ghc-commits mailing list