[Git][ghc/ghc][wip/supersven/ghc-9.10-riscv-ncg] Add fused multiplication/addition (FMA)

Sven Tennie (@supersven) gitlab at gitlab.haskell.org
Sat Jun 8 12:23:19 UTC 2024



Sven Tennie pushed to branch wip/supersven/ghc-9.10-riscv-ncg at Glasgow Haskell Compiler / GHC


Commits:
53da5e54 by Sven Tennie at 2024-06-08T12:21:52+00:00
Add fused multiplication/addition (FMA)

- - - - -


3 changed files:

- compiler/GHC/CmmToAsm/RV64/CodeGen.hs
- compiler/GHC/CmmToAsm/RV64/Instr.hs
- compiler/GHC/CmmToAsm/RV64/Ppr.hs


Changes:

=====================================
compiler/GHC/CmmToAsm/RV64/CodeGen.hs
=====================================
@@ -755,7 +755,7 @@ getRegister' config plat expr =
       where w' = formatToWidth (cmmTypeFormat (cmmRegType reg))
             r' = getRegisterReg plat reg
 
-    -- Generic case.
+    -- Generic binary case.
     CmmMachOp op [x, y] -> do
       let
           -- A "plain" operation.
@@ -910,6 +910,42 @@ getRegister' config plat expr =
         MO_S_Shr w -> intOp True  w (\d x y -> unitOL $ annExpr expr (ASR d x y))
 
         op -> pprPanic "getRegister' (unhandled dyadic CmmMachOp): " $ pprMachOp op <+> text "in" <+> pdoc plat expr
+
+    -- Generic ternary case.
+    CmmMachOp op [x, y, z] ->
+
+      case op of
+
+        -- Floating-point fused multiply-add operations
+
+        -- x86 fmadd    x * y + z <=> AArch64 fmadd : d =   r1 * r2 + r3
+        -- x86 fmsub    x * y - z <=> AArch64 fnmsub: d =   r1 * r2 - r3
+        -- x86 fnmadd - x * y + z <=> AArch64 fmsub : d = - r1 * r2 + r3
+        -- x86 fnmsub - x * y - z <=> AArch64 fnmadd: d = - r1 * r2 - r3
+
+        MO_FMA var w -> case var of
+          FMAdd  -> float3Op w (\d n m a -> unitOL $ FMA FMAdd  d n m a)
+          FMSub  -> float3Op w (\d n m a -> unitOL $ FMA FMSub d n m a)
+          FNMAdd -> float3Op w (\d n m a -> unitOL $ FMA FNMSub  d n m a)
+          FNMSub -> float3Op w (\d n m a -> unitOL $ FMA FNMAdd d n m a)
+
+        _ -> pprPanic "getRegister' (unhandled ternary CmmMachOp): " $
+                (pprMachOp op) <+> text "in" <+> (pdoc plat expr)
+
+      where
+          float3Op w op = do
+            (reg_fx, format_x, code_fx) <- getFloatReg x
+            (reg_fy, format_y, code_fy) <- getFloatReg y
+            (reg_fz, format_z, code_fz) <- getFloatReg z
+            massertPpr (isFloatFormat format_x && isFloatFormat format_y && isFloatFormat format_z) $
+              text "float3Op: non-float"
+            return $
+              Any (floatFormat w) $ \ dst ->
+                code_fx `appOL`
+                code_fy `appOL`
+                code_fz `appOL`
+                op (OpReg w dst) (OpReg w reg_fx) (OpReg w reg_fy) (OpReg w reg_fz)
+
     CmmMachOp _op _xs
       -> pprPanic "getRegister' (variadic CmmMachOp): " (pdoc plat expr)
 


=====================================
compiler/GHC/CmmToAsm/RV64/Instr.hs
=====================================
@@ -134,6 +134,8 @@ regUsageOfInstr platform instr = case instr of
   SCVTF dst src            -> usage (regOp src, regOp dst)
   FCVTZS dst src           -> usage (regOp src, regOp dst)
   FABS dst src             -> usage (regOp src, regOp dst)
+  FMA _ dst src1 src2 src3 ->
+    usage (regOp src1 ++ regOp src2 ++ regOp src3, regOp dst)
 
   _ -> panic $ "regUsageOfInstr: " ++ instrCon instr
 
@@ -253,6 +255,8 @@ patchRegsOfInstr instr env = case instr of
     SCVTF o1 o2    -> SCVTF (patchOp o1) (patchOp o2)
     FCVTZS o1 o2   -> FCVTZS (patchOp o1) (patchOp o2)
     FABS o1 o2     -> FABS (patchOp o1) (patchOp o2)
+    FMA s o1 o2 o3 o4 ->
+      FMA s (patchOp o1) (patchOp o2) (patchOp o3) (patchOp o4)
     _              -> panic $ "patchRegsOfInstr: " ++ instrCon instr
     where
         patchOp :: Operand -> Operand
@@ -634,6 +638,13 @@ data Instr
     | FCVTZS Operand Operand
     -- Float ABSolute value
     | FABS Operand Operand
+    -- | Floating-point fused multiply-add instructions
+    --
+    -- - fmadd : d =   r1 * r2 + r3
+    -- - fnmsub: d =   r1 * r2 - r3
+    -- - fmsub : d = - r1 * r2 + r3
+    -- - fnmadd: d = - r1 * r2 - r3
+    | FMA FMASign Operand Operand Operand Operand
 
 data DmbType = DmbRead | DmbWrite | DmbReadWrite
 
@@ -683,6 +694,12 @@ instrCon i =
       SCVTF{} -> "SCVTF"
       FCVTZS{} -> "FCVTZS"
       FABS{} -> "FABS"
+      FMA variant _ _ _ _ ->
+        case variant of
+          FMAdd  -> "FMADD"
+          FMSub  -> "FMSUB"
+          FNMAdd -> "FNMADD"
+          FNMSub -> "FNMSUB"
 
 data Target
     = TBlock BlockId


=====================================
compiler/GHC/CmmToAsm/RV64/Ppr.hs
=====================================
@@ -648,12 +648,23 @@ pprInstr platform instr = case instr of
 
   FABS o1 o2 | isSingleOp o2 -> op2 (text "\tfabs.s") o1 o2
   FABS o1 o2 | isDoubleOp o2 -> op2 (text "\tfabs.d") o1 o2
+  FMA variant d r1 r2 r3 ->
+    let fma = case variant of
+                FMAdd  -> text "\tfmadd" <> dot <> floatPrecission d
+                FMSub  -> text "\tfmsub" <> dot <> floatPrecission d
+                FNMAdd -> text "\tfnmadd" <> dot <> floatPrecission d
+                FNMSub -> text "\tfnmsub" <> dot <> floatPrecission d
+    in op4 fma d r1 r2 r3
   instr -> panic $ "RV64.pprInstr - Unknown instruction: " ++ instrCon instr
  where op2 op o1 o2        = line $ op <+> pprOp platform o1 <> comma <+> pprOp platform o2
        op3 op o1 o2 o3     = line $ op <+> pprOp platform o1 <> comma <+> pprOp platform o2 <> comma <+> pprOp platform o3
+       op4 op o1 o2 o3 o4  = line $ op <+> pprOp platform o1 <> comma <+> pprOp platform o2 <> comma <+> pprOp platform o3 <> comma <+> pprOp platform o4
        pprDmbType DmbRead = text "r"
        pprDmbType DmbWrite = text "w"
        pprDmbType DmbReadWrite = text "rw"
+       floatPrecission o | isSingleOp o = text "s"
+                         | isDoubleOp o = text "d"
+                         | otherwise  = pprPanic "Impossible floating point precission: " (pprOp platform o)
 
 floatOpPrecision :: Platform -> Operand -> Operand -> String
 floatOpPrecision _p l r | isFloatOp l && isFloatOp r && isSingleOp l && isSingleOp r = "s" -- single precision



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/53da5e54a2021cc9042a9a3274dcd5bf840c6d6c

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/53da5e54a2021cc9042a9a3274dcd5bf840c6d6c
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240608/5b6fc7eb/attachment-0001.html>


More information about the ghc-commits mailing list