[Git][ghc/ghc][wip/supersven/ghc-9.10-riscv-ncg] Add fused multiplication/addition (FMA)
Sven Tennie (@supersven)
gitlab at gitlab.haskell.org
Sat Jun 8 12:23:19 UTC 2024
Sven Tennie pushed to branch wip/supersven/ghc-9.10-riscv-ncg at Glasgow Haskell Compiler / GHC
Commits:
53da5e54 by Sven Tennie at 2024-06-08T12:21:52+00:00
Add fused multiplication/addition (FMA)
- - - - -
3 changed files:
- compiler/GHC/CmmToAsm/RV64/CodeGen.hs
- compiler/GHC/CmmToAsm/RV64/Instr.hs
- compiler/GHC/CmmToAsm/RV64/Ppr.hs
Changes:
=====================================
compiler/GHC/CmmToAsm/RV64/CodeGen.hs
=====================================
@@ -755,7 +755,7 @@ getRegister' config plat expr =
where w' = formatToWidth (cmmTypeFormat (cmmRegType reg))
r' = getRegisterReg plat reg
- -- Generic case.
+ -- Generic binary case.
CmmMachOp op [x, y] -> do
let
-- A "plain" operation.
@@ -910,6 +910,42 @@ getRegister' config plat expr =
MO_S_Shr w -> intOp True w (\d x y -> unitOL $ annExpr expr (ASR d x y))
op -> pprPanic "getRegister' (unhandled dyadic CmmMachOp): " $ pprMachOp op <+> text "in" <+> pdoc plat expr
+
+ -- Generic ternary case.
+ CmmMachOp op [x, y, z] ->
+
+ case op of
+
+ -- Floating-point fused multiply-add operations
+
+ -- x86 fmadd x * y + z <=> AArch64 fmadd : d = r1 * r2 + r3
+ -- x86 fmsub x * y - z <=> AArch64 fnmsub: d = r1 * r2 - r3
+ -- x86 fnmadd - x * y + z <=> AArch64 fmsub : d = - r1 * r2 + r3
+ -- x86 fnmsub - x * y - z <=> AArch64 fnmadd: d = - r1 * r2 - r3
+
+ MO_FMA var w -> case var of
+ FMAdd -> float3Op w (\d n m a -> unitOL $ FMA FMAdd d n m a)
+ FMSub -> float3Op w (\d n m a -> unitOL $ FMA FMSub d n m a)
+ FNMAdd -> float3Op w (\d n m a -> unitOL $ FMA FNMSub d n m a)
+ FNMSub -> float3Op w (\d n m a -> unitOL $ FMA FNMAdd d n m a)
+
+ _ -> pprPanic "getRegister' (unhandled ternary CmmMachOp): " $
+ (pprMachOp op) <+> text "in" <+> (pdoc plat expr)
+
+ where
+ float3Op w op = do
+ (reg_fx, format_x, code_fx) <- getFloatReg x
+ (reg_fy, format_y, code_fy) <- getFloatReg y
+ (reg_fz, format_z, code_fz) <- getFloatReg z
+ massertPpr (isFloatFormat format_x && isFloatFormat format_y && isFloatFormat format_z) $
+ text "float3Op: non-float"
+ return $
+ Any (floatFormat w) $ \ dst ->
+ code_fx `appOL`
+ code_fy `appOL`
+ code_fz `appOL`
+ op (OpReg w dst) (OpReg w reg_fx) (OpReg w reg_fy) (OpReg w reg_fz)
+
CmmMachOp _op _xs
-> pprPanic "getRegister' (variadic CmmMachOp): " (pdoc plat expr)
=====================================
compiler/GHC/CmmToAsm/RV64/Instr.hs
=====================================
@@ -134,6 +134,8 @@ regUsageOfInstr platform instr = case instr of
SCVTF dst src -> usage (regOp src, regOp dst)
FCVTZS dst src -> usage (regOp src, regOp dst)
FABS dst src -> usage (regOp src, regOp dst)
+ FMA _ dst src1 src2 src3 ->
+ usage (regOp src1 ++ regOp src2 ++ regOp src3, regOp dst)
_ -> panic $ "regUsageOfInstr: " ++ instrCon instr
@@ -253,6 +255,8 @@ patchRegsOfInstr instr env = case instr of
SCVTF o1 o2 -> SCVTF (patchOp o1) (patchOp o2)
FCVTZS o1 o2 -> FCVTZS (patchOp o1) (patchOp o2)
FABS o1 o2 -> FABS (patchOp o1) (patchOp o2)
+ FMA s o1 o2 o3 o4 ->
+ FMA s (patchOp o1) (patchOp o2) (patchOp o3) (patchOp o4)
_ -> panic $ "patchRegsOfInstr: " ++ instrCon instr
where
patchOp :: Operand -> Operand
@@ -634,6 +638,13 @@ data Instr
| FCVTZS Operand Operand
-- Float ABSolute value
| FABS Operand Operand
+ -- | Floating-point fused multiply-add instructions
+ --
+ -- - fmadd : d = r1 * r2 + r3
+ -- - fnmsub: d = r1 * r2 - r3
+ -- - fmsub : d = - r1 * r2 + r3
+ -- - fnmadd: d = - r1 * r2 - r3
+ | FMA FMASign Operand Operand Operand Operand
data DmbType = DmbRead | DmbWrite | DmbReadWrite
@@ -683,6 +694,12 @@ instrCon i =
SCVTF{} -> "SCVTF"
FCVTZS{} -> "FCVTZS"
FABS{} -> "FABS"
+ FMA variant _ _ _ _ ->
+ case variant of
+ FMAdd -> "FMADD"
+ FMSub -> "FMSUB"
+ FNMAdd -> "FNMADD"
+ FNMSub -> "FNMSUB"
data Target
= TBlock BlockId
=====================================
compiler/GHC/CmmToAsm/RV64/Ppr.hs
=====================================
@@ -648,12 +648,23 @@ pprInstr platform instr = case instr of
FABS o1 o2 | isSingleOp o2 -> op2 (text "\tfabs.s") o1 o2
FABS o1 o2 | isDoubleOp o2 -> op2 (text "\tfabs.d") o1 o2
+ FMA variant d r1 r2 r3 ->
+ let fma = case variant of
+ FMAdd -> text "\tfmadd" <> dot <> floatPrecission d
+ FMSub -> text "\tfmsub" <> dot <> floatPrecission d
+ FNMAdd -> text "\tfnmadd" <> dot <> floatPrecission d
+ FNMSub -> text "\tfnmsub" <> dot <> floatPrecission d
+ in op4 fma d r1 r2 r3
instr -> panic $ "RV64.pprInstr - Unknown instruction: " ++ instrCon instr
where op2 op o1 o2 = line $ op <+> pprOp platform o1 <> comma <+> pprOp platform o2
op3 op o1 o2 o3 = line $ op <+> pprOp platform o1 <> comma <+> pprOp platform o2 <> comma <+> pprOp platform o3
+ op4 op o1 o2 o3 o4 = line $ op <+> pprOp platform o1 <> comma <+> pprOp platform o2 <> comma <+> pprOp platform o3 <> comma <+> pprOp platform o4
pprDmbType DmbRead = text "r"
pprDmbType DmbWrite = text "w"
pprDmbType DmbReadWrite = text "rw"
+ floatPrecission o | isSingleOp o = text "s"
+ | isDoubleOp o = text "d"
+ | otherwise = pprPanic "Impossible floating point precission: " (pprOp platform o)
floatOpPrecision :: Platform -> Operand -> Operand -> String
floatOpPrecision _p l r | isFloatOp l && isFloatOp r && isSingleOp l && isSingleOp r = "s" -- single precision
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/53da5e54a2021cc9042a9a3274dcd5bf840c6d6c
--
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/53da5e54a2021cc9042a9a3274dcd5bf840c6d6c
You're receiving this email because of your account on gitlab.haskell.org.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240608/5b6fc7eb/attachment-0001.html>
More information about the ghc-commits
mailing list