[Git][ghc/ghc][wip/supersven/riscv64-ncg] Fix signed shift right

Sven Tennie (@supersven) gitlab at gitlab.haskell.org
Sun May 21 15:53:46 UTC 2023



Sven Tennie pushed to branch wip/supersven/riscv64-ncg at Glasgow Haskell Compiler / GHC


Commits:
6c908960 by Sven Tennie at 2023-05-21T17:52:50+02:00
Fix signed shift right

This includes overhauling the sign extension and width truncation logic.

- - - - -


4 changed files:

- compiler/GHC/CmmToAsm/RV64/CodeGen.hs
- compiler/GHC/CmmToAsm/RV64/Instr.hs
- compiler/GHC/CmmToAsm/RV64/Ppr.hs
- + tests/compiler/cmm/shift_right.cmm


Changes:

=====================================
compiler/GHC/CmmToAsm/RV64/CodeGen.hs
=====================================
@@ -201,7 +201,7 @@ ann doc instr {- debugIsOn -} = ANN doc instr
 -- forced until we actually force them, and without -dppr-debug they should
 -- never end up being forced.
 annExpr :: CmmExpr -> Instr -> Instr
-annExpr e instr {- debugIsOn -} = ANN (text . show $ e) instr
+annExpr e {- debugIsOn -} = ANN (text . show $ e)
 -- annExpr e instr {- debugIsOn -} = ANN (pprExpr genericPlatform e) instr
 -- annExpr _ instr = instr
 {-# INLINE annExpr #-}
@@ -708,24 +708,20 @@ getRegister' config plat expr
       return $ Any (intFormat w) (\dst -> code_x `snocOL` annExpr expr (LSL (OpReg w dst) (OpReg w reg_x) (OpImm (ImmInteger n))))
 
     CmmMachOp (MO_S_Shr w) [x, (CmmLit (CmmInt n _))] | fitsIn12bitImm n -> do
-      (reg_x, _format_x, code_x) <- getSomeReg x
-      return $ Any (intFormat w) (\dst ->
-          code_x `appOL` toOL [ SUB sp sp (OpImm (ImmInt (widthInBits w)))
-                              , STR (intFormat w) (OpReg w reg_x) (OpAddr (AddrRegImm sp_reg (ImmInt 0)))
-                              , LDR (intFormat w)   (OpReg w reg_x)   (OpAddr (AddrRegImm sp_reg (ImmInt 0)))
-                              , ADD sp sp (OpImm (ImmInt (widthInBits w)))
-                              , ASR (OpReg w dst) (OpReg w reg_x) (OpImm (ImmInteger n))
-                              ])
+      (reg_x, format_x, code_x) <- getSomeReg x
+      (reg_x', code_x') <- signExtendReg (formatToWidth format_x) w reg_x
+      return $ Any (intFormat w) (
+        \dst ->
+          code_x `appOL` code_x' `snocOL` annExpr expr (ASR (OpReg w dst) (OpReg w reg_x') (OpImm (ImmInteger n)))
+        )
     CmmMachOp (MO_S_Shr w) [x, y] -> do
-      (reg_x, _format_x, code_x) <- getSomeReg x
+      (reg_x, format_x, code_x) <- getSomeReg x
       (reg_y, _format_y, code_y) <- getSomeReg y
-      return $ Any (intFormat w) (\dst ->
-          code_x `appOL` code_y `appOL` toOL [ SUB sp sp (OpImm (ImmInt (widthInBits w)))
-                              , STR (intFormat w) (OpReg w reg_x) (OpAddr (AddrRegImm sp_reg (ImmInt 0)))
-                              , LDR (intFormat w)   (OpReg w reg_x)   (OpAddr (AddrRegImm sp_reg (ImmInt 0)))
-                              , ADD sp sp (OpImm (ImmInt (widthInBits w)))
-                              , ASR (OpReg w dst) (OpReg w reg_y) (OpImm (ImmInteger 0))
-                              ])
+      (reg_x', code_x') <- signExtendReg (formatToWidth format_x) w reg_x
+      return $ Any (intFormat w) (
+        \dst ->
+          code_x `appOL` code_x' `appOL` code_y `snocOL` annExpr expr (ASR (OpReg w dst) (OpReg w reg_x') (OpReg w reg_y))
+        )
 
     CmmMachOp (MO_U_Shr w) [x, (CmmLit (CmmInt n _))] | w == W8, 0 <= n, n < 8 -> do
       (reg_x, _format_x, code_x) <- getSomeReg x
@@ -752,7 +748,7 @@ getRegister' config plat expr
       return $ Any (intFormat w) (\dst -> code_x `snocOL` annExpr expr (LSR (OpReg w dst) (OpReg w reg_x) (OpImm (ImmInteger n))))
 
     -- 3. Logic &&, ||
-    CmmMachOp (MO_And w) [(CmmReg reg), CmmLit (CmmInt n _)] | isBitMaskImmediate (fromIntegral n) ->
+    CmmMachOp (MO_And w) [(CmmReg reg), CmmLit (CmmInt n _)] | fitsIn12bitImm n ->
       return $ Any (intFormat w) (\d -> unitOL $ annExpr expr (AND (OpReg w d) (OpReg w' r') (OpImm (ImmInteger n))))
       where w' = formatToWidth (cmmTypeFormat (cmmRegType plat reg))
             r' = getRegisterReg plat reg
@@ -934,17 +930,6 @@ getRegister' config plat expr
   where
     isNbitEncodeable :: Int -> Integer -> Bool
     isNbitEncodeable n i = let shift = n - 1 in (-1 `shiftL` shift) <= i && i < (1 `shiftL` shift)
-    -- FIXME: These are wrong, they are for AArch64, not RISCV! I'm not even sure we need them for RISCV
-    isBitMaskImmediate :: Integer -> Bool
-    isBitMaskImmediate i = i `elem` [0b0000_0001, 0b0000_0010, 0b0000_0100, 0b0000_1000, 0b0001_0000, 0b0010_0000, 0b0100_0000, 0b1000_0000
-                                    ,0b0000_0011, 0b0000_0110, 0b0000_1100, 0b0001_1000, 0b0011_0000, 0b0110_0000, 0b1100_0000
-                                    ,0b0000_0111, 0b0000_1110, 0b0001_1100, 0b0011_1000, 0b0111_0000, 0b1110_0000
-                                    ,0b0000_1111, 0b0001_1110, 0b0011_1100, 0b0111_1000, 0b1111_0000
-                                    ,0b0001_1111, 0b0011_1110, 0b0111_1100, 0b1111_1000
-                                    ,0b0011_1111, 0b0111_1110, 0b1111_1100
-                                    ,0b0111_1111, 0b1111_1110
-                                    ,0b1111_1111]
-
     -- N.B. MUL does not set the overflow flag.
     do_mul_may_oflo :: Width -> CmmExpr -> CmmExpr -> NatM Register
     do_mul_may_oflo w at W64 x y = do
@@ -984,35 +969,49 @@ getRegister' config plat expr
             mul (OpReg tmp_w tmp) (OpReg w reg_x) (OpReg w reg_y) `snocOL`
             CSET (OpReg w dst) (OpReg tmp_w tmp) (OpRegExt tmp_w tmp ext_mode 0) NE)
 
+-- TODO: Some cases can surely be implemented with shifts and SEXT.W. This would
+-- save 2 (expensive) memory accesses!
 -- | Instructions to sign-extend the value in the given register from width @w@
 -- up to width @w'@.
 signExtendReg :: Width -> Width -> Reg -> NatM (Reg, OrdList Instr)
-signExtendReg w w' r =
-    case w of
-      W64 -> noop
-      W32
-        | w' == W32 -> noop
-        | otherwise -> extend SXTH
-      W16           -> extend SXTH
-      W8            -> extend SXTB
-      _             -> panic "intOp"
-  where
-    noop = return (r, nilOL)
-    extend instr = do
-        r' <- getNewRegNat II64
-        return (r', unitOL $ instr (OpReg w' r') (OpReg w' r))
+signExtendReg w _w' r | w == W64 = pure (r, nilOL)
+signExtendReg _w w' _r | w' > W64 = pprPanic "Cannot sign extend to width bigger than register size:" (ppr w')
+signExtendReg w _w' r | w > W64 = pprPanic "Unexpected register size (max is 64bit):" $ text (show r) <> char ':' <+> ppr w
+signExtendReg w w' r | w == W32 && w' == W64 =
+                       -- `ADDIW r r 0` is the pseudo-op SEXT.W
+                       pure (r, unitOL $
+                              ann (text "sign-extend register" <+> ppr r <+> ppr w <> text "->" <> ppr w')
+                                  (ADD (OpReg w' r) (OpReg w r) (OpImm (ImmInt 0)))
+                            )
+signExtendReg w w' r = do
+  r' <- getNewRegNat (intFormat w')
+  let instrs = toOL [ann (text "sign-extend register" <+> ppr r <+> ppr w <> text "->" <> ppr w')
+                          (SUB sp sp (OpImm (ImmInt (widthInBits w))))
+                    -- loading (LW, LH, LB) sign extends to 64bit
+                    , STR (intFormat w) (OpReg w r) (OpAddr (AddrRegImm sp_reg (ImmInt 0)))
+                    , LDR (intFormat w) (OpReg w r)   (OpAddr (AddrRegImm sp_reg (ImmInt 0)))
+                    , ADD sp sp (OpImm (ImmInt (widthInBits w)))
+                    -- ADD to move the result to r', which has the correct width / format
+                    , ADD (OpReg w' r') (OpReg w r) zero
+                    ]
+  pure (r', instrs)
 
 -- | Instructions to truncate the value in the given register from width @w@
 -- down to width @w'@.
+-- N.B.: This ignores signedness!
 truncateReg :: Width -> Width -> Reg -> OrdList Instr
-truncateReg w _w' _r | w == W64 = nilOL
+truncateReg _w w' _r | w' == W64 = nilOL
+truncateReg _w w' r | w' > W64 = pprPanic "Cannot truncate to width bigger than register size (max is 64bit):" $ text (show r) <> char ':' <+> ppr w'
+truncateReg w _w' r | w > W64 = pprPanic "Unexpected register size (max is 64bit):" $ text (show r) <> char ':' <+> ppr w
+truncateReg w w' _r | w < w' = pprPanic "This is not a truncation." $ ppr w <+> char '<' <+> ppr w'
 truncateReg w w' _r | w == w' = nilOL
-truncateReg w w' r =
-  toOL [ SUB sp sp (OpImm (ImmInt (widthInBits w)))
-       , STR (intFormat w) (OpReg w r) (OpAddr (AddrRegImm sp_reg (ImmInt 0)))
-       , LDR (intFormat w') (OpReg w' r)   (OpAddr (AddrRegImm sp_reg (ImmInt 0)))
-       , ADD sp sp (OpImm (ImmInt (widthInBits w)))
-       ]
+truncateReg w w' r = toOL [ann (text "truncate register" <+> ppr r <+> ppr w <> text "->" <> ppr w')
+                               (LSL (OpReg w' r) (OpReg w r) (OpImm (ImmInt shift)))
+                          -- SHL ignores signedness!
+                          , LSR (OpReg w' r) (OpReg w r) (OpImm (ImmInt shift))
+                          ]
+  where
+    shift = 64 - (widthInBits w - widthInBits w')
 
 -- -----------------------------------------------------------------------------
 --  The 'Amode' type: Memory addressing modes passed up the tree.


=====================================
compiler/GHC/CmmToAsm/RV64/Instr.hs
=====================================
@@ -94,7 +94,6 @@ regUsageOfInstr platform instr = case instr of
   -- 2. Bit Manipulation Instructions ------------------------------------------
   SBFM dst src _ _         -> usage (regOp src, regOp dst)
   UBFM dst src _ _         -> usage (regOp src, regOp dst)
-  SBFX dst src _ _         -> usage (regOp src, regOp dst)
   UBFX dst src _ _         -> usage (regOp src, regOp dst)
   SXTB dst src             -> usage (regOp src, regOp dst)
   UXTB dst src             -> usage (regOp src, regOp dst)
@@ -234,7 +233,6 @@ patchRegsOfInstr instr env = case instr of
     -- 2. Bit Manipulation Instructions ----------------------------------------
     SBFM o1 o2 o3 o4 -> SBFM (patchOp o1) (patchOp o2) (patchOp o3) (patchOp o4)
     UBFM o1 o2 o3 o4 -> UBFM (patchOp o1) (patchOp o2) (patchOp o3) (patchOp o4)
-    SBFX o1 o2 o3 o4 -> SBFX (patchOp o1) (patchOp o2) (patchOp o3) (patchOp o4)
     UBFX o1 o2 o3 o4 -> UBFX (patchOp o1) (patchOp o2) (patchOp o3) (patchOp o4)
     SXTB o1 o2       -> SXTB (patchOp o1) (patchOp o2)
     UXTB o1 o2       -> UXTB (patchOp o1) (patchOp o2)
@@ -632,7 +630,6 @@ data Instr
     -- UXTB = UBFM <Wd>, <Wn>, #0, #7
     -- UXTH = UBFM <Wd>, <Wn>, #0, #15
     -- Signed/Unsigned bitfield extract
-    | SBFX Operand Operand Operand Operand -- rd = rn[i,j]
     | UBFX Operand Operand Operand Operand -- rd = rn[i,j]
 
     -- 3. Logical and Move Instructions ----------------------------------------
@@ -717,7 +714,6 @@ instrCon i =
       UDIV{} -> "UDIV"
       SBFM{} -> "SBFM"
       UBFM{} -> "UBFM"
-      SBFX{} -> "SBFX"
       UBFX{} -> "UBFX"
       AND{} -> "AND"
       -- ANDS{} -> "ANDS"


=====================================
compiler/GHC/CmmToAsm/RV64/Ppr.hs
=====================================
@@ -456,6 +456,8 @@ pprInstr platform instr = case instr of
   -- 1. Arithmetic Instructions ------------------------------------------------
   ADD  o1 o2 o3
     | isFloatOp o1 && isFloatOp o2 && isFloatOp o3 -> op3 (text "\tfadd") o1 o2 o3
+    -- This case is used for sign extension.
+    | OpReg W64 _ <- o1 , OpReg w _ <- o2, w < W64, isImmOp o3 -> op3 (text "\taddiw") o1 o2 o3
     | otherwise -> op3 (text "\tadd") o1 o2 o3
   -- CMN  o1 o2    -> op2 (text "\tcmn") o1 o2
   -- CMP  o1 o2
@@ -487,7 +489,6 @@ pprInstr platform instr = case instr of
   SBFM o1 o2 o3 o4 -> op4 (text "\tsbfm") o1 o2 o3 o4
   UBFM o1 o2 o3 o4 -> op4 (text "\tubfm") o1 o2 o3 o4
   -- signed and unsigned bitfield extract
-  SBFX o1 o2 o3 o4 -> op4 (text "\tsbfx") o1 o2 o3 o4
   UBFX o1 o2 o3 o4 -> op4 (text "\tubfx") o1 o2 o3 o4
   SXTB o1 o2       -> op2 (text "\tsxtb") o1 o2
   UXTB o1 o2       -> op2 (text "\tuxtb") o1 o2


=====================================
tests/compiler/cmm/shift_right.cmm
=====================================
@@ -0,0 +1,24 @@
+// RUN: "$HC" -debug -dppr-debug -cpp -dcmm-lint -keep-s-file -O0 -c "$1" && cat "${1%%.*}.s" | FileCheck "$1" -check-prefix=CHECK-RV64
+// RUN: "$CC" "${1%%.*}.o" -o "${1%%.*}.exe"
+// RUN: "$EXEC" "${1%%.cmm}.exe"
+
+#include "Cmm.h"
+#include "Types.h"
+
+main() {
+    I64 buffer;
+    I32 a, b, c, d;
+
+    I64 arr;
+    (arr) = foreign "C" malloc(1024);
+    bits64[arr] = 2;
+
+    a = I32[arr];
+    b = %mul(a, 32 :: I32);
+    c = %neg(b);
+    d = %shra(c, 4::I64);
+
+    foreign "C" printf("a: %hd b: %hd c: %hd d: %hd", a, b, c, d);
+
+    foreign "C" exit(d == -4 :: I32);
+}



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/6c908960575d8a0bc0cc65897ff347260503d7d1

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/6c908960575d8a0bc0cc65897ff347260503d7d1
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20230521/09aa18f7/attachment-0001.html>


More information about the ghc-commits mailing list