[Git][ghc/ghc][wip/ncg-simd] 3 commits: X86 CodeGen: refactor getRegister CmmLit

sheaf (@sheaf) gitlab at gitlab.haskell.org
Fri Sep 20 15:08:29 UTC 2024



sheaf pushed to branch wip/ncg-simd at Glasgow Haskell Compiler / GHC


Commits:
df5b64c4 by sheaf at 2024-09-20T17:08:06+02:00
X86 CodeGen: refactor getRegister CmmLit

This refactors the code dealing with loading literals into registers,
removing duplication and putting all the code in a single place.
It also changes which XOR instruction is used to place a zero value
into a register, so that we use VPXOR for a 128-bit integer vector
when AVX is supported.

- - - - -
cf56ab12 by sheaf at 2024-09-20T17:08:06+02:00
X86 genCCall64: simplify loadArg code

This commit simplifies the argument loading code by making the
assumption that it is safe to directly load the argument into register,
because doing so will not clobber any previous assignments.

This assumption is borne from the use of 'evalArgs', which evaluates
any arguments which might necessitate non-trivial code generation into
separate temporary registers.

- - - - -
c27747f2 by sheaf at 2024-09-20T17:08:06+02:00
LLVM: propagate GlobalRegUse information

This commit ensures we keep track of how any particular global register
is being used in the LLVM backend. This informs the LLVM type
annotations, and avoids type mismatches of the following form:

  argument is not of expected type '<2 x double>'
    call ccc <2 x double> (<2 x double>)
      (<4 x i32> arg)

- - - - -


9 changed files:

- compiler/GHC/CmmToAsm/Format.hs
- compiler/GHC/CmmToAsm/X86/CodeGen.hs
- compiler/GHC/CmmToAsm/X86/Instr.hs
- compiler/GHC/CmmToAsm/X86/Ppr.hs
- compiler/GHC/CmmToLlvm.hs
- compiler/GHC/CmmToLlvm/Base.hs
- compiler/GHC/CmmToLlvm/CodeGen.hs
- compiler/GHC/CmmToLlvm/Ppr.hs
- compiler/GHC/CmmToLlvm/Regs.hs


Changes:

=====================================
compiler/GHC/CmmToAsm/Format.hs
=====================================
@@ -28,6 +28,7 @@ module GHC.CmmToAsm.Format (
     scalarWidth,
     formatInBytes,
     isFloatScalarFormat,
+    isFloatOrFloatVecFormat,
     floatScalarFormat,
     scalarFormatFormat,
     VirtualRegWithFormat(..),
@@ -134,6 +135,11 @@ isFloatScalarFormat = \case
   FmtDouble -> True
   _ -> False
 
+isFloatOrFloatVecFormat :: Format -> Bool
+isFloatOrFloatVecFormat = \case
+  VecFormat _ sFmt -> isFloatScalarFormat sFmt
+  fmt -> isFloatFormat fmt
+
 floatScalarFormat :: Width -> ScalarFormat
 floatScalarFormat W32 = FmtFloat
 floatScalarFormat W64 = FmtDouble


=====================================
compiler/GHC/CmmToAsm/X86/CodeGen.hs
=====================================
@@ -913,20 +913,6 @@ getRegister' _ is32Bit (CmmMachOp (MO_UU_Conv W64 W16) [x])
   ro <- getNewRegNat II16
   return $ Fixed II16 ro (code `appOL` toOL [ MOVZxL II16 (OpReg rlo) (OpReg ro) ])
 
-getRegister' _ _ (CmmLit lit@(CmmFloat f w)) =
-  float_const_sse2  where
-  float_const_sse2
-    | f == 0.0 = do
-      -- TODO: this mishandles negative zero floating point literals.
-      let
-          format = floatFormat w
-          code dst = unitOL  (XOR format (OpReg dst) (OpReg dst))
-        -- I don't know why there are xorpd, xorps, and pxor instructions.
-        -- They all appear to do the same thing --SDM
-      return (Any format code)
-
-   | otherwise = getFloatLitRegister lit
-
 -- catch simple cases of zero- or sign-extended load
 getRegister' _ _ (CmmMachOp (MO_UU_Conv W8 W32) [CmmLoad addr _ _]) = do
   code <- intLoadCode (MOVZxL II8) addr
@@ -1932,7 +1918,7 @@ getRegister' platform is32Bit load@(CmmLoad mem ty _)
   | isFloatType ty
   = do
     Amode addr mem_code <- getAmode mem
-    loadFloatAmode width addr mem_code
+    loadAmode (floatFormat width) addr mem_code
 
   | is32Bit && not (isWord64 ty)
   = do
@@ -1960,20 +1946,6 @@ getRegister' platform is32Bit load@(CmmLoad mem ty _)
     format = cmmTypeFormat ty
     width = typeWidth ty
 
-getRegister' _ is32Bit (CmmLit (CmmInt 0 width))
-  = let
-        format = intFormat width
-
-        -- x86_64: 32-bit xor is one byte shorter, and zero-extends to 64 bits
-        format1 = if is32Bit then format
-                           else case format of
-                                II64 -> II32
-                                _ -> format
-        code dst
-           = unitOL (XOR format1 (OpReg dst) (OpReg dst))
-    in
-        return (Any format code)
-
 -- Handle symbol references with LEA and %rip-relative addressing.
 -- See Note [%rip-relative addressing on x86-64].
 getRegister' platform is32Bit (CmmLit lit)
@@ -1990,80 +1962,102 @@ getRegister' platform is32Bit (CmmLit lit)
     is_label (CmmLabelDiffOff {}) = True
     is_label _                    = False
 
-  -- optimisation for loading small literals on x86_64: take advantage
-  -- of the automatic zero-extension from 32 to 64 bits, because the 32-bit
-  -- instruction forms are shorter.
-getRegister' platform is32Bit (CmmLit lit)
-  | not is32Bit, isWord64 (cmmLitType platform lit), not (isBigLit lit)
-  = let
-        imm = litToImm lit
-        code dst = unitOL (MOV II32 (OpImm imm) (OpReg dst))
-    in
-        return (Any II64 code)
-  where
-   isBigLit (CmmInt i _) = i < 0 || i > 0xffffffff
-   isBigLit _ = False
+getRegister' platform is32Bit (CmmLit lit) = do
+  avx <- avxEnabled
+
+  -- NB: it is important that the code produced here (to load a literal into
+  -- a register) doesn't clobber any registers other than the destination
+  -- register; the code for generating C calls relies on this property.
+  --
+  -- In particular, we have:
+  --
+  -- > loadIntoRegMightClobberOtherReg (CmmLit _) = False
+  --
+  -- which means that we assume that loading a literal into a register
+  -- will not clobber any other registers.
+
+  -- TODO: this function mishandles floating-point negative zero,
+  -- because -0.0 == 0.0 returns True and because we represent CmmFloat as
+  -- Rational, which can't properly represent negative zero.
+
+  if
+    -- Zero: use XOR.
+    | isZeroLit lit
+    -> let code dst
+             | isIntFormat fmt
+             = let fmt'
+                     | is32Bit
+                     = fmt
+                     | otherwise
+                     -- x86_64: 32-bit xor is one byte shorter,
+                     -- and zero-extends to 64 bits
+                     = case fmt of
+                         II64 -> II32
+                         _ -> fmt
+               in unitOL (XOR fmt' (OpReg dst) (OpReg dst))
+             | avx
+             = if float_or_floatvec
+               then unitOL (VXOR fmt (OpReg dst) dst dst)
+               else unitOL (VPXOR fmt dst dst dst)
+             | otherwise
+             = if float_or_floatvec
+               then unitOL (XOR fmt (OpReg dst) (OpReg dst))
+               else unitOL (PXOR fmt (OpReg dst) dst)
+       in return $ Any fmt code
+
+    -- Constant vector: use broadcast.
+    | VecFormat l sFmt <- fmt
+    , CmmVec (f:fs) <- lit
+    , all (== f) fs
+    -> do let w = scalarWidth sFmt
+              broadcast = if isFloatScalarFormat sFmt
+                          then MO_VF_Broadcast l w
+                          else MO_V_Broadcast l w
+          valCode <- getAnyReg (CmmMachOp broadcast [CmmLit f])
+          return $ Any fmt valCode
+
+    -- Optimisation for loading small literals on x86_64: take advantage
+    -- of the automatic zero-extension from 32 to 64 bits, because the 32-bit
+    -- instruction forms are shorter.
+    | not is32Bit, isWord64 cmmTy, not (isBigLit lit)
+    -> let
+          imm = litToImm lit
+          code dst = unitOL (MOV II32 (OpImm imm) (OpReg dst))
+      in
+          return (Any II64 code)
+
+    -- Scalar integer: use an immediate.
+    | isIntFormat fmt
+    -> let imm = litToImm lit
+           code dst = unitOL (MOV fmt (OpImm imm) (OpReg dst))
+       in return (Any fmt code)
+
+    -- General case: load literal from data address.
+    | otherwise
+    -> do let w = formatToWidth fmt
+          Amode addr addr_code <- memConstant (mkAlignment $ widthInBytes w) lit
+          loadAmode fmt addr addr_code
+
+    where
+      cmmTy = cmmLitType platform lit
+      fmt = cmmTypeFormat cmmTy
+      float_or_floatvec = isFloatOrFloatVecFormat fmt
+      isZeroLit (CmmInt i _) = i == 0
+      isZeroLit (CmmFloat f _) = f == 0 -- TODO: mishandles negative zero
+      isZeroLit (CmmVec fs) = all isZeroLit fs
+      isZeroLit _ = False
+
+      isBigLit (CmmInt i _) = i < 0 || i > 0xffffffff
+      isBigLit _ = False
         -- note1: not the same as (not.is32BitLit), because that checks for
         -- signed literals that fit in 32 bits, but we want unsigned
         -- literals here.
         -- note2: all labels are small, because we're assuming the
         -- small memory model. See Note [%rip-relative addressing on x86-64].
 
-getRegister' platform _ (CmmLit lit) = do
-  avx <- avxEnabled
-  case fmt of
-    VecFormat l sFmt
-      | CmmVec fs <- lit
-      , all is_zero fs
-      -> let code dst
-               | avx
-               = if isFloatScalarFormat sFmt
-                 then unitOL (VXOR fmt (OpReg dst) dst dst)
-                 else unitOL (VPXOR fmt dst dst dst)
-               | otherwise
-               = unitOL (XOR fmt (OpReg dst) (OpReg dst))
-         in return (Any fmt code)
-      | CmmVec (f:fs) <- lit
-      , all (== f) fs
-      -- TODO: mishandles negative zero (because -0.0 == 0.0 returns True), and because we
-      -- represent CmmFloat as Rational which can't properly represent negative zero.
-      -> do let w = scalarWidth sFmt
-                broadcast = if isFloatScalarFormat sFmt
-                            then MO_VF_Broadcast l w
-                            else MO_V_Broadcast l w
-            valCode <- getAnyReg (CmmMachOp broadcast [CmmLit f])
-            return $ Any fmt valCode
-
-      | otherwise
-      -> do
-           let w = formatToWidth fmt
-           config <- getConfig
-           Amode addr addr_code <- memConstant (mkAlignment $ widthInBytes w) lit
-           let code dst = addr_code `snocOL`
-                            movInstr config fmt (OpAddr addr) (OpReg dst)
-           return (Any fmt code)
-       where
-        is_zero (CmmInt i _) = i == 0
-        is_zero (CmmFloat f _) = f == 0 -- TODO: mishandles negative zero
-        is_zero _ = False
-
-    _ -> let imm = litToImm lit
-             code dst = unitOL (MOV fmt (OpImm imm) (OpReg dst))
-         in return (Any fmt code)
-  where
-    cmmTy = cmmLitType platform lit
-    fmt = cmmTypeFormat cmmTy
-
 getRegister' platform _ slot@(CmmStackSlot {}) =
   pprPanic "getRegister(x86) CmmStackSlot" (pdoc platform slot)
 
-getFloatLitRegister :: CmmLit -> NatM Register
-getFloatLitRegister lit = do
-  let w :: Width
-      w = case lit of { CmmInt _ w -> w; CmmFloat _ w -> w; _ -> panic "getFloatLitRegister" (ppr lit) }
-  Amode addr code <- memConstant (mkAlignment $ widthInBytes w) lit
-  loadFloatAmode w addr code
-
 intLoadCode :: (Operand -> Operand -> Instr) -> CmmExpr
    -> NatM (Reg -> InstrBlock)
 intLoadCode instr mem = do
@@ -2392,15 +2386,12 @@ memConstant align lit = do
         `consOL` addr_code
   return (Amode addr code)
 
-
-loadFloatAmode :: Width -> AddrMode -> InstrBlock -> NatM Register
-loadFloatAmode w addr addr_code = do
-  let format = floatFormat w
-      code dst = addr_code `snocOL`
-                    MOV format (OpAddr addr) (OpReg dst)
-
-  return (Any format code)
-
+-- | Load the value at the given address into any register.
+loadAmode :: Format -> AddrMode -> InstrBlock -> NatM Register
+loadAmode fmt addr addr_code = do
+  config <- getConfig
+  let load dst = movInstr config fmt (OpAddr addr) (OpReg dst)
+  return $ Any fmt (\ dst -> addr_code `snocOL` load dst)
 
 -- if we want a floating-point literal as an operand, we can
 -- use it directly from memory.  However, if the literal is
@@ -3100,10 +3091,8 @@ genSimplePrim _   op                   dst     args           = do
   platform <- ncgPlatform <$> getConfig
   pprPanic "genSimplePrim: unhandled primop" (ppr (pprCallishMachOp op, dst, fmap (pdoc platform) args))
 
-{-
-Note [Evaluate C-call arguments before placing in destination registers]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
+{- Note [Evaluate C-call arguments before placing in destination registers]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 When producing code for C calls we must take care when placing arguments
 in their final registers. Specifically, we must ensure that temporary register
 usage due to evaluation of one argument does not clobber a register in which we
@@ -3154,15 +3143,11 @@ genForeignCall{32,64}.
 -- | See Note [Evaluate C-call arguments before placing in destination registers]
 evalArgs :: BlockId -> [CmmActual] -> NatM (InstrBlock, [CmmActual])
 evalArgs bid actuals
-  | any mightContainMachOp actuals = do
+  | any loadIntoRegMightClobberOtherReg actuals = do
       regs_blks <- mapM evalArg actuals
       return (concatOL $ map fst regs_blks, map snd regs_blks)
   | otherwise = return (nilOL, actuals)
   where
-    mightContainMachOp (CmmReg _)      = False
-    mightContainMachOp (CmmRegOff _ _) = False
-    mightContainMachOp (CmmLit _)      = False
-    mightContainMachOp _               = True
 
     evalArg :: CmmActual -> NatM (InstrBlock, CmmExpr)
     evalArg actual = do
@@ -3176,6 +3161,16 @@ evalArgs bid actuals
     newLocalReg :: CmmType -> NatM LocalReg
     newLocalReg ty = LocalReg <$> getUniqueM <*> pure ty
 
+-- | Might the code to put this expression into a register
+-- clobber any other registers?
+loadIntoRegMightClobberOtherReg :: CmmExpr -> Bool
+loadIntoRegMightClobberOtherReg (CmmReg _)      = False
+loadIntoRegMightClobberOtherReg (CmmRegOff _ _) = False
+loadIntoRegMightClobberOtherReg (CmmLit _)      = False
+  -- NB: this last 'False' is slightly risky, because the code for loading
+  -- a literal into a register is not entirely trivial.
+loadIntoRegMightClobberOtherReg _               = True
+
 -- Note [DIV/IDIV for bytes]
 -- ~~~~~~~~~~~~~~~~~~~~~~~~~
 -- IDIV reminder:
@@ -3437,7 +3432,6 @@ genCCall64 addr conv@(ForeignConvention _ argHints _ _) dest_regs args = do
       { stackArgs       = proper_stack_args
       , stackDataArgs   = stack_data_args
       , usedRegs        = arg_regs_used
-      , computeArgsCode = compute_args_code
       , assignArgsCode  = assign_args_code
       }
       <- loadArgs config prom_args
@@ -3536,7 +3530,6 @@ genCCall64 addr conv@(ForeignConvention _ argHints _ _) dest_regs args = do
 
     return (align_call_code     `appOL`
             push_code           `appOL`
-            compute_args_code   `appOL`
             assign_args_code    `appOL`
             load_data_refs      `appOL`
             shadow_space_code   `appOL`
@@ -3557,16 +3550,14 @@ data LoadArgs
   , stackDataArgs :: [CmmExpr]
   -- | Which registers are we using for argument passing?
   , usedRegs      :: [RegWithFormat]
-  -- | The code to compute arguments into (possibly temporary) registers.
-  , computeArgsCode :: InstrBlock
   -- | The code to assign arguments to registers used for argument passing.
   , assignArgsCode :: InstrBlock
   }
 instance Semigroup LoadArgs where
-  LoadArgs a1 d1 r1 i1 j1 <> LoadArgs a2 d2 r2 i2 j2
-    = LoadArgs (a1 ++ a2) (d1 ++ d2) (r1 ++ r2) (i1 S.<> i2) (j1 S.<> j2)
+  LoadArgs a1 d1 r1 j1 <> LoadArgs a2 d2 r2 j2
+    = LoadArgs (a1 ++ a2) (d1 ++ d2) (r1 ++ r2) (j1 S.<> j2)
 instance Monoid LoadArgs where
-  mempty = LoadArgs [] [] [] nilOL nilOL
+  mempty = LoadArgs [] [] [] nilOL
 
 -- | An argument passed on the stack, either directly or by reference.
 --
@@ -3720,7 +3711,6 @@ loadArgsSysV config (arg:rest) = do
           LoadArgs
             { stackArgs       = map RawStackArg (arg:rest)
             , stackDataArgs   = []
-            , computeArgsCode = nilOL
             , assignArgsCode  = nilOL
             , usedRegs        = []
             }
@@ -3740,12 +3730,11 @@ loadArgsSysV config (arg:rest) = do
     this_arg <-
       case mbReg of
         Just reg -> do
-          (compute_code, assign_code) <- lift $ loadArgIntoReg config arg rest reg
+          assign_code <- lift $ loadArgIntoReg arg reg
           return $
             LoadArgs
                 { stackArgs       = [] -- passed in register
                 , stackDataArgs   = []
-                , computeArgsCode = compute_code
                 , assignArgsCode  = assign_code
                 , usedRegs        = [RegWithFormat reg arg_fmt]
                 }
@@ -3755,7 +3744,6 @@ loadArgsSysV config (arg:rest) = do
             LoadArgs
                 { stackArgs       = [RawStackArg arg]
                 , stackDataArgs   = []
-                , computeArgsCode = nilOL
                 , assignArgsCode  = nilOL
                 , usedRegs        = []
                 }
@@ -3807,7 +3795,6 @@ loadArgsWin config (arg:rest) = do
         LoadArgs
           { stackArgs       = stk_args
           , stackDataArgs   = data_args
-          , computeArgsCode = nilOL
           , assignArgsCode  = nilOL
           , usedRegs        = []
           }
@@ -3823,8 +3810,7 @@ loadArgsWin config (arg:rest) = do
                 -- Pass the reference in a register,
                 -- and the argument data on the stack.
                 { stackArgs       = [RawStackArgRef (InReg ireg) (argSize platform arg)]
-                , stackDataArgs   = [arg]
-                , computeArgsCode = nilOL -- we don't yet know where the data will reside,
+                , stackDataArgs   = [arg] -- we don't yet know where the data will reside,
                 , assignArgsCode  = nilOL -- so we defer computing the reference and storing it
                                           -- in the register until later
                 , usedRegs        = [RegWithFormat ireg II64]
@@ -3835,7 +3821,7 @@ loadArgsWin config (arg:rest) = do
                   = freg
                   | otherwise
                   = ireg
-           (compute_code, assign_code) <- loadArgIntoReg config arg rest arg_reg
+           assign_code <- loadArgIntoReg arg arg_reg
            -- Recall that, for varargs, we must pass floating-point
            -- arguments in both fp and integer registers.
            let (assign_code', regs')
@@ -3848,42 +3834,23 @@ loadArgsWin config (arg:rest) = do
              LoadArgs
                { stackArgs       = [] -- passed in register
                , stackDataArgs   = []
-               , computeArgsCode = compute_code
                , assignArgsCode = assign_code'
                , usedRegs = regs'
                }
 
-
--- | Return two pieces of code:
---
---  - code to compute a the given 'CmmExpr' into some (possibly temporary) register
---  - code to assign the resulting value to the specified register
+-- | Load an argument into a register.
 --
--- Using two separate pieces of code handles clobbering issues reported
--- in e.g. #11792, #12614.
-loadArgIntoReg :: NCGConfig -> CmmExpr -> [CmmExpr] -> Reg -> NatM (InstrBlock, InstrBlock)
-loadArgIntoReg config arg rest reg
-  -- "operand" args can be directly assigned into the register
-  | isOperand platform arg
-  = do arg_code <- getAnyReg arg
-       return (nilOL, arg_code reg)
-  -- The last non-operand arg can be directly assigned after its
-  -- computation without going into a temporary register
-  | all (isOperand platform) rest
-  = do arg_code <- getAnyReg arg
-       return (arg_code reg, nilOL)
-  -- Other args need to be computed beforehand to avoid clobbering
-  -- previously assigned registers used to pass parameters (see
-  -- #11792, #12614). They are assigned into temporary registers
-  -- and get assigned to proper call ABI registers after they all
-  -- have been computed.
-  | otherwise
-  = do arg_code <- getAnyReg arg
-       tmp      <- getNewRegNat arg_fmt
-       return (arg_code tmp, unitOL $ mkRegRegMoveInstr config arg_fmt tmp reg)
-  where
-    platform = ncgPlatform config
-    arg_fmt = cmmTypeFormat $ cmmExprType platform arg
+-- Assumes that the expression does not contain any MachOps,
+-- as per Note [Evaluate C-call arguments before placing in destination registers].
+loadArgIntoReg :: CmmExpr -> Reg -> NatM InstrBlock
+loadArgIntoReg arg reg = do
+  when (debugIsOn && loadIntoRegMightClobberOtherReg arg) $ do
+    platform <- getPlatform
+    massertPpr False $
+      vcat [ text "loadArgIntoReg: arg might contain MachOp"
+           , text "arg:" <+> pdoc platform arg ]
+  arg_code <- getAnyReg arg
+  return $ arg_code reg
 
 -- -----------------------------------------------------------------------------
 -- Pushing arguments onto the stack for 64-bit C calls.


=====================================
compiler/GHC/CmmToAsm/X86/Instr.hs
=====================================
@@ -306,6 +306,7 @@ data Instr
         | VMOVDQU     Format Operand Operand
 
         -- logic operations
+        | PXOR        Format Operand Reg
         | VPXOR       Format Reg Reg Reg
 
         -- Arithmetic
@@ -492,6 +493,12 @@ regUsageOfInstr platform instr
     MOVDQU       fmt src dst   -> mkRU (use_R fmt src []) (use_R fmt dst [])
     VMOVDQU      fmt src dst   -> mkRU (use_R fmt src []) (use_R fmt dst [])
 
+    PXOR fmt (OpReg src) dst
+      | src == dst
+      -> mkRU [] [mk fmt dst]
+      | otherwise
+      -> mkRU [mk fmt src, mk fmt dst] [mk fmt dst]
+
     VPXOR        fmt s1 s2 dst
       | s1 == s2, s1 == dst
       -> mkRU [] [mk fmt dst]
@@ -733,6 +740,7 @@ patchRegsOfInstr platform instr env
     MOVDQU     fmt src dst   -> MOVDQU  fmt (patchOp src) (patchOp dst)
     VMOVDQU    fmt src dst   -> VMOVDQU fmt (patchOp src) (patchOp dst)
 
+    PXOR       fmt src dst   -> PXOR fmt (patchOp src) (env dst)
     VPXOR      fmt s1 s2 dst -> VPXOR fmt (env s1) (env s2) (env dst)
 
     VADD       fmt s1 s2 dst -> VADD fmt (patchOp s1) (env s2) (env dst)


=====================================
compiler/GHC/CmmToAsm/X86/Ppr.hs
=====================================
@@ -1016,6 +1016,8 @@ pprInstr platform i = case i of
         VecFormat 64 FmtInt8  -> text "vmovdqu32" -- require the additional AVX512BW extension
         _ -> text "vmovdqu"
 
+   PXOR format src dst
+     -> pprPXor (text "pxor") format src dst
    VPXOR format s1 s2 dst
      -> pprXor (text "vpxor") format s1 s2 dst
    VEXTRACT format offset from to
@@ -1320,6 +1322,15 @@ pprInstr platform i = case i of
            pprReg platform format reg3
        ]
 
+   pprPXor :: Line doc -> Format -> Operand -> Reg -> doc
+   pprPXor name format src dst
+     = line $ hcat [
+           pprGenMnemonic name format,
+           pprOperand platform format src,
+           comma,
+           pprReg platform format dst
+       ]
+
    pprVxor :: Format -> Operand -> Reg -> Reg -> doc
    pprVxor fmt src1 src2 dst
      = line $ hcat [


=====================================
compiler/GHC/CmmToLlvm.hs
=====================================
@@ -139,7 +139,7 @@ llvmGroupLlvmGens cmm = do
                          Nothing                   -> l
                          Just (CmmStaticsRaw info_lbl _) -> info_lbl
               lml <- strCLabel_llvm l'
-              funInsert lml =<< llvmFunTy (map globalRegUse_reg live)
+              funInsert lml =<< llvmFunTy live
               return Nothing
         cdata <- fmap catMaybes $ mapM split cmm
 


=====================================
compiler/GHC/CmmToLlvm/Base.hs
=====================================
@@ -12,7 +12,7 @@
 module GHC.CmmToLlvm.Base (
 
         LlvmCmmDecl, LlvmBasicBlock,
-        LiveGlobalRegs,
+        LiveGlobalRegs, LiveGlobalRegUses,
         LlvmUnresData, LlvmData, UnresLabel, UnresStatic,
 
         LlvmM,
@@ -29,6 +29,8 @@ module GHC.CmmToLlvm.Base (
         llvmFunSig, llvmFunArgs, llvmStdFunAttrs, llvmFunAlign, llvmInfAlign,
         llvmPtrBits, tysToParams, llvmFunSection, padLiveArgs, isFPR,
 
+        lookupRegUse,
+
         strCLabel_llvm,
         getGlobalPtr, generateExternDecls,
 
@@ -58,9 +60,11 @@ import GHC.Types.Unique.Set
 import GHC.Types.Unique.Supply
 import GHC.Utils.Logger
 
-import Data.Maybe (fromJust)
 import Control.Monad.Trans.State (StateT (..))
-import Data.List (isPrefixOf)
+import Control.Applicative (Alternative((<|>)))
+import Data.Maybe (fromJust, mapMaybe)
+
+import Data.List (find, isPrefixOf)
 import qualified Data.List.NonEmpty as NE
 import Data.Ord (comparing)
 
@@ -73,6 +77,7 @@ type LlvmBasicBlock = GenBasicBlock LlvmStatement
 
 -- | Global registers live on proc entry
 type LiveGlobalRegs = [GlobalReg]
+type LiveGlobalRegUses = [GlobalRegUse]
 
 -- | Unresolved code.
 -- Of the form: (data label, data type, unresolved data)
@@ -116,16 +121,16 @@ llvmGhcCC platform
  | otherwise                       = CC_Ghc
 
 -- | Llvm Function type for Cmm function
-llvmFunTy :: LiveGlobalRegs -> LlvmM LlvmType
+llvmFunTy :: LiveGlobalRegUses -> LlvmM LlvmType
 llvmFunTy live = return . LMFunction =<< llvmFunSig' live (fsLit "a") ExternallyVisible
 
 -- | Llvm Function signature
-llvmFunSig :: LiveGlobalRegs ->  CLabel -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
+llvmFunSig :: LiveGlobalRegUses ->  CLabel -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
 llvmFunSig live lbl link = do
   lbl' <- strCLabel_llvm lbl
   llvmFunSig' live lbl' link
 
-llvmFunSig' :: LiveGlobalRegs -> LMString -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
+llvmFunSig' :: LiveGlobalRegUses -> LMString -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
 llvmFunSig' live lbl link
   = do let toParams x | isPointer x = (x, [NoAlias, NoCapture])
                       | otherwise   = (x, [])
@@ -149,16 +154,25 @@ llvmFunSection opts lbl
     | otherwise               = Nothing
 
 -- | A Function's arguments
-llvmFunArgs :: Platform -> LiveGlobalRegs -> [LlvmVar]
+llvmFunArgs :: Platform -> LiveGlobalRegUses -> [LlvmVar]
 llvmFunArgs platform live =
-    map (lmGlobalRegArg platform) (filter isPassed allRegs)
+    map (lmGlobalRegArg platform) (mapMaybe isPassed allRegs)
     where allRegs = activeStgRegs platform
           paddingRegs = padLiveArgs platform live
-          isLive r = r `elem` alwaysLive
-                     || r `elem` live
-                     || r `elem` paddingRegs
-          isPassed r = not (isFPR r) || isLive r
-
+          isLive :: GlobalReg -> Maybe GlobalRegUse
+          isLive r =
+            lookupRegUse r (alwaysLive platform)
+              <|>
+            lookupRegUse r live
+              <|>
+            lookupRegUse r paddingRegs
+          isPassed r =
+            if not (isFPR r)
+            then Just $ GlobalRegUse r (globalRegSpillType platform r)
+            else isLive r
+
+lookupRegUse :: GlobalReg -> [GlobalRegUse] -> Maybe GlobalRegUse
+lookupRegUse r = find ((== r) . globalRegUse_reg)
 
 isFPR :: GlobalReg -> Bool
 isFPR (FloatReg _)  = True
@@ -179,7 +193,7 @@ isFPR _             = False
 -- Invariant: Cmm FPR regs with number "n" maps to real registers with number
 -- "n" If the calling convention uses registers in a different order or if the
 -- invariant doesn't hold, this code probably won't be correct.
-padLiveArgs :: Platform -> LiveGlobalRegs -> LiveGlobalRegs
+padLiveArgs :: Platform -> LiveGlobalRegUses -> LiveGlobalRegUses
 padLiveArgs platform live =
       if platformUnregisterised platform
         then [] -- not using GHC's register convention for platform.
@@ -188,7 +202,7 @@ padLiveArgs platform live =
     ----------------------------------
     -- handle floating-point registers (FPR)
 
-    fprLive = filter isFPR live  -- real live FPR registers
+    fprLive = filter (isFPR . globalRegUse_reg) live  -- real live FPR registers
 
     -- we group live registers sharing the same classes, i.e. that use the same
     -- set of real registers to be passed. E.g. FloatReg, DoubleReg and XmmReg
@@ -196,39 +210,44 @@ padLiveArgs platform live =
     --
     classes         = NE.groupBy sharesClass fprLive
     sharesClass a b = globalRegsOverlap platform (norm a) (norm b) -- check if mapped to overlapping registers
-    norm x          = fpr_ctor x 1                                 -- get the first register of the family
+    norm x = globalRegUse_reg (fpr_ctor x 1)                  -- get the first register of the family
 
     -- For each class, we just have to fill missing registers numbers. We use
     -- the constructor of the greatest register to build padding registers.
     --
     -- E.g. sortedRs = [   F2,   XMM4, D5]
     --      output   = [D1,   D3]
+    padded :: [GlobalRegUse]
     padded      = concatMap padClass classes
+
+    padClass :: NE.NonEmpty GlobalRegUse -> [GlobalRegUse]
     padClass rs = go (NE.toList sortedRs) 1
       where
-         sortedRs = NE.sortBy (comparing fpr_num) rs
+         sortedRs = NE.sortBy (comparing (fpr_num . globalRegUse_reg)) rs
          maxr     = NE.last sortedRs
          ctor     = fpr_ctor maxr
 
          go [] _ = []
-         go (c1:c2:_) _   -- detect bogus case (see #17920)
+         go (GlobalRegUse c1 _: GlobalRegUse c2 _:_) _   -- detect bogus case (see #17920)
             | fpr_num c1 == fpr_num c2
             , Just real <- globalRegMaybe platform c1
             = sorryDoc "LLVM code generator" $
                text "Found two different Cmm registers (" <> ppr c1 <> text "," <> ppr c2 <>
                text ") both alive AND mapped to the same real register: " <> ppr real <>
                text ". This isn't currently supported by the LLVM backend."
-         go (c:cs) f
-            | fpr_num c == f = go cs f                    -- already covered by a real register
-            | otherwise      = ctor f : go (c:cs) (f + 1) -- add padding register
-
-    fpr_ctor :: GlobalReg -> Int -> GlobalReg
-    fpr_ctor (FloatReg _)  = FloatReg
-    fpr_ctor (DoubleReg _) = DoubleReg
-    fpr_ctor (XmmReg _)    = XmmReg
-    fpr_ctor (YmmReg _)    = YmmReg
-    fpr_ctor (ZmmReg _)    = ZmmReg
-    fpr_ctor _ = error "fpr_ctor expected only FPR regs"
+         go (cu@(GlobalRegUse c _):cs) f
+            | fpr_num c == f = go cs f                     -- already covered by a real register
+            | otherwise      = ctor f : go (cu:cs) (f + 1) -- add padding register
+
+    fpr_ctor :: GlobalRegUse -> Int -> GlobalRegUse
+    fpr_ctor (GlobalRegUse r fmt) i =
+      case r of
+        FloatReg _  -> GlobalRegUse (FloatReg  i) fmt
+        DoubleReg _ -> GlobalRegUse (DoubleReg i) fmt
+        XmmReg _    -> GlobalRegUse (XmmReg    i) fmt
+        YmmReg _    -> GlobalRegUse (YmmReg    i) fmt
+        ZmmReg _    -> GlobalRegUse (ZmmReg    i) fmt
+        _           -> error "fpr_ctor expected only FPR regs"
 
     fpr_num :: GlobalReg -> Int
     fpr_num (FloatReg i)  = i


=====================================
compiler/GHC/CmmToLlvm/CodeGen.hs
=====================================
@@ -37,13 +37,14 @@ import GHC.Utils.Outputable
 import qualified GHC.Utils.Panic as Panic
 import GHC.Utils.Misc
 
+import Control.Applicative (Alternative((<|>)))
 import Control.Monad.Trans.Class
 import Control.Monad.Trans.Writer
 import Control.Monad
 
 import qualified Data.Semigroup as Semigroup
 import Data.List ( nub )
-import Data.Maybe ( catMaybes )
+import Data.Maybe ( catMaybes, isJust )
 
 type Atomic = Maybe MemoryOrdering
 type LlvmStatements = OrdList LlvmStatement
@@ -57,7 +58,7 @@ genLlvmProc :: RawCmmDecl -> LlvmM [LlvmCmmDecl]
 genLlvmProc (CmmProc infos lbl live graph) = do
     let blocks = toBlockListEntryFirstFalseFallthrough graph
 
-    (lmblocks, lmdata) <- basicBlocksCodeGen (map globalRegUse_reg live) blocks
+    (lmblocks, lmdata) <- basicBlocksCodeGen live blocks
     let info = mapLookup (g_entry graph) infos
         proc = CmmProc info lbl live (ListGraph lmblocks)
     return (proc:lmdata)
@@ -76,7 +77,7 @@ newtype UnreachableBlockId = UnreachableBlockId BlockId
 -- | Generate code for a list of blocks that make up a complete
 -- procedure. The first block in the list is expected to be the entry
 -- point.
-basicBlocksCodeGen :: LiveGlobalRegs -> [CmmBlock]
+basicBlocksCodeGen :: LiveGlobalRegUses -> [CmmBlock]
                       -> LlvmM ([LlvmBasicBlock], [LlvmCmmDecl])
 basicBlocksCodeGen _    []                     = panic "no entry block!"
 basicBlocksCodeGen live cmmBlocks
@@ -152,7 +153,7 @@ stmtToInstrs ubid stmt = case stmt of
 
     -- Tail call
     CmmCall { cml_target = arg,
-              cml_args_regs = live } -> genJump arg $ map globalRegUse_reg live
+              cml_args_regs = live } -> genJump arg live
 
     _ -> panic "Llvm.CodeGen.stmtToInstrs"
 
@@ -1050,7 +1051,7 @@ cmmPrimOpFunctions mop = do
 
 
 -- | Tail function calls
-genJump :: CmmExpr -> [GlobalReg] -> LlvmM StmtData
+genJump :: CmmExpr -> LiveGlobalRegUses -> LlvmM StmtData
 
 -- Call to known function
 genJump (CmmLit (CmmLabel lbl)) live = do
@@ -2056,14 +2057,13 @@ getCmmReg (CmmLocal (LocalReg un _))
            -- have been assigned a value at some point, triggering
            -- "funPrologue" to allocate it on the stack.
 
-getCmmReg (CmmGlobal g)
-  = do let r = globalRegUse_reg g
-       onStack  <- checkStackReg r
+getCmmReg (CmmGlobal ru@(GlobalRegUse r _))
+  = do onStack  <- checkStackReg r
        platform <- getPlatform
        if onStack
-         then return (lmGlobalRegVar platform r)
+         then return (lmGlobalRegVar platform ru)
          else pprPanic "getCmmReg: Cmm register " $
-                ppr g <> text " not stack-allocated!"
+                ppr r <> text " not stack-allocated!"
 
 -- | Return the value of a given register, as well as its type. Might
 -- need to be load from stack.
@@ -2074,7 +2074,7 @@ getCmmRegVal reg =
       onStack <- checkStackReg (globalRegUse_reg g)
       platform <- getPlatform
       if onStack then loadFromStack else do
-        let r = lmGlobalRegArg platform (globalRegUse_reg g)
+        let r = lmGlobalRegArg platform g
         return (r, getVarType r, nilOL)
     _ -> loadFromStack
  where loadFromStack = do
@@ -2187,8 +2187,9 @@ convertMemoryOrdering MemOrderSeqCst  = SyncSeqCst
 -- question is never written. Therefore we skip it where we can to
 -- save a few lines in the output and hopefully speed compilation up a
 -- bit.
-funPrologue :: LiveGlobalRegs -> [CmmBlock] -> LlvmM StmtData
+funPrologue :: LiveGlobalRegUses -> [CmmBlock] -> LlvmM StmtData
 funPrologue live cmmBlocks = do
+  platform <- getPlatform
 
   let getAssignedRegs :: CmmNode O O -> [CmmReg]
       getAssignedRegs (CmmAssign reg _)  = [reg]
@@ -2196,7 +2197,8 @@ funPrologue live cmmBlocks = do
       getAssignedRegs _                  = []
       getRegsBlock (_, body, _)          = concatMap getAssignedRegs $ blockToList body
       assignedRegs = nub $ concatMap (getRegsBlock . blockSplit) cmmBlocks
-      isLive r     = r `elem` alwaysLive || r `elem` live
+      mbLive r     =
+        lookupRegUse r (alwaysLive platform) <|> lookupRegUse r live
 
   platform <- getPlatform
   stmtss <- forM assignedRegs $ \reg ->
@@ -2205,12 +2207,12 @@ funPrologue live cmmBlocks = do
         let (newv, stmts) = allocReg reg
         varInsert un (pLower $ getVarType newv)
         return stmts
-      CmmGlobal (GlobalRegUse r _) -> do
-        let reg   = lmGlobalRegVar platform r
-            arg   = lmGlobalRegArg platform r
+      CmmGlobal ru@(GlobalRegUse r _) -> do
+        let reg   = lmGlobalRegVar platform ru
+            arg   = lmGlobalRegArg platform ru
             ty    = (pLower . getVarType) reg
             trash = LMLitVar $ LMUndefLit ty
-            rval  = if isLive r then arg else trash
+            rval  = if isJust (mbLive r) then arg else trash
             alloc = Assignment reg $ Alloca (pLower $ getVarType reg) 1
         markStackReg r
         return $ toOL [alloc, Store rval reg Nothing []]
@@ -2222,7 +2224,7 @@ funPrologue live cmmBlocks = do
 
 -- | Function epilogue. Load STG variables to use as argument for call.
 -- STG Liveness optimisation done here.
-funEpilogue :: LiveGlobalRegs -> LlvmM ([LlvmVar], LlvmStatements)
+funEpilogue :: LiveGlobalRegUses -> LlvmM ([LlvmVar], LlvmStatements)
 funEpilogue live = do
     platform <- getPlatform
 
@@ -2248,12 +2250,16 @@ funEpilogue live = do
     let allRegs = activeStgRegs platform
     loads <- forM allRegs $ \r -> if
       -- load live registers
-      | r `elem` alwaysLive  -> loadExpr (GlobalRegUse r (globalRegSpillType platform r))
-      | r `elem` live        -> loadExpr (GlobalRegUse r (globalRegSpillType platform r))
+      | Just ru <- lookupRegUse r (alwaysLive platform)
+      -> loadExpr ru
+      | Just ru <- lookupRegUse r live
+      -> loadExpr ru
       -- load all non Floating-Point Registers
-      | not (isFPR r)        -> loadUndef r
+      | not (isFPR r)
+      -> loadUndef (GlobalRegUse r (globalRegSpillType platform r))
       -- load padding Floating-Point Registers
-      | r `elem` paddingRegs -> loadUndef r
+      | Just ru <- lookupRegUse r paddingRegs
+      -> loadUndef ru
       | otherwise            -> return (Nothing, nilOL)
 
     let (vars, stmts) = unzip loads
@@ -2263,7 +2269,7 @@ funEpilogue live = do
 --
 -- This is for Haskell functions, function type is assumed, so doesn't work
 -- with foreign functions.
-getHsFunc :: LiveGlobalRegs -> CLabel -> LlvmM ExprData
+getHsFunc :: LiveGlobalRegUses -> CLabel -> LlvmM ExprData
 getHsFunc live lbl
   = do fty <- llvmFunTy live
        name <- strCLabel_llvm lbl


=====================================
compiler/GHC/CmmToLlvm/Ppr.hs
=====================================
@@ -49,9 +49,8 @@ pprLlvmCmmDecl (CmmData _ lmdata) = do
   return ( vcat $ map (pprLlvmData opts) lmdata
          , vcat $ map (pprLlvmData opts) lmdata)
 
-pprLlvmCmmDecl (CmmProc mb_info entry_lbl liveWithUses (ListGraph blks))
-  = do let live = map globalRegUse_reg liveWithUses
-           lbl = case mb_info of
+pprLlvmCmmDecl (CmmProc mb_info entry_lbl live (ListGraph blks))
+  = do let lbl = case mb_info of
                      Nothing -> entry_lbl
                      Just (CmmStaticsRaw info_lbl _) -> info_lbl
            link = if externallyVisibleCLabel lbl


=====================================
compiler/GHC/CmmToLlvm/Regs.hs
=====================================
@@ -14,25 +14,27 @@ import GHC.Prelude
 import GHC.Llvm
 
 import GHC.Cmm.Expr
+import GHC.CmmToAsm.Format
 import GHC.Platform
 import GHC.Data.FastString
 import GHC.Utils.Panic ( panic )
 import GHC.Types.Unique
 
+
 -- | Get the LlvmVar function variable storing the real register
-lmGlobalRegVar :: Platform -> GlobalReg -> LlvmVar
+lmGlobalRegVar :: Platform -> GlobalRegUse -> LlvmVar
 lmGlobalRegVar platform = pVarLift . lmGlobalReg platform "_Var"
 
 -- | Get the LlvmVar function argument storing the real register
-lmGlobalRegArg :: Platform -> GlobalReg -> LlvmVar
+lmGlobalRegArg :: Platform -> GlobalRegUse -> LlvmVar
 lmGlobalRegArg platform = lmGlobalReg platform "_Arg"
 
 {- Need to make sure the names here can't conflict with the unique generated
    names. Uniques generated names containing only base62 chars. So using say
    the '_' char guarantees this.
 -}
-lmGlobalReg :: Platform -> String -> GlobalReg -> LlvmVar
-lmGlobalReg platform suf reg
+lmGlobalReg :: Platform -> String -> GlobalRegUse -> LlvmVar
+lmGlobalReg platform suf (GlobalRegUse reg ty)
   = case reg of
         BaseReg        -> ptrGlobal $ "Base" ++ suf
         Sp             -> ptrGlobal $ "Sp" ++ suf
@@ -88,13 +90,26 @@ lmGlobalReg platform suf reg
         ptrGlobal    name = LMNLocalVar (fsLit name) (llvmWordPtr platform)
         floatGlobal  name = LMNLocalVar (fsLit name) LMFloat
         doubleGlobal name = LMNLocalVar (fsLit name) LMDouble
-        xmmGlobal    name = LMNLocalVar (fsLit name) (LMVector 4 (LMInt 32))
-        ymmGlobal    name = LMNLocalVar (fsLit name) (LMVector 8 (LMInt 32))
-        zmmGlobal    name = LMNLocalVar (fsLit name) (LMVector 16 (LMInt 32))
+        fmt = cmmTypeFormat ty
+        xmmGlobal    name = LMNLocalVar (fsLit name) (formatLlvmType fmt)
+        ymmGlobal    name = LMNLocalVar (fsLit name) (formatLlvmType fmt)
+        zmmGlobal    name = LMNLocalVar (fsLit name) (formatLlvmType fmt)
+
+formatLlvmType :: Format -> LlvmType
+formatLlvmType II8 = LMInt 8
+formatLlvmType II16 = LMInt 16
+formatLlvmType II32 = LMInt 32
+formatLlvmType II64 = LMInt 64
+formatLlvmType FF32 = LMFloat
+formatLlvmType FF64 = LMDouble
+formatLlvmType (VecFormat l sFmt) = LMVector l (formatLlvmType $ scalarFormatFormat sFmt)
 
 -- | A list of STG Registers that should always be considered alive
-alwaysLive :: [GlobalReg]
-alwaysLive = [BaseReg, Sp, Hp, SpLim, HpLim, node]
+alwaysLive :: Platform -> [GlobalRegUse]
+alwaysLive platform =
+  [ GlobalRegUse r (globalRegSpillType platform r)
+  | r <- [BaseReg, Sp, Hp, SpLim, HpLim, node]
+  ]
 
 -- | STG Type Based Alias Analysis hierarchy
 stgTBAA :: [(Unique, LMString, Maybe Unique)]



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/a34a4dc2269300b31c06173dad5d64f55317ecb4...c27747f25c7209555ae96573f4525839297b7703

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/a34a4dc2269300b31c06173dad5d64f55317ecb4...c27747f25c7209555ae96573f4525839297b7703
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240920/789444a5/attachment-0001.html>


More information about the ghc-commits mailing list