[Git][ghc/ghc][wip/ncg-simd] 3 commits: X86 genCCall: promote arg before calling evalArgs

sheaf (@sheaf) gitlab at gitlab.haskell.org
Sat Sep 21 13:59:58 UTC 2024



sheaf pushed to branch wip/ncg-simd at Glasgow Haskell Compiler / GHC


Commits:
13bd1182 by sheaf at 2024-09-21T15:59:34+02:00
X86 genCCall: promote arg before calling evalArgs

The job of evalArgs is to ensure each argument is put into a temporary
register, so that it can then be loaded directly into one of the
argument registers for the C call, without the generated code clobbering
any other register used for argument passing.

However, if we promote arguments after calling evalArgs, there is the
possibility that the code used for the promotion will clobber a register,
defeating the work of evalArgs.
To avoid this, we first promote arguments, and only then call evalArgs.

- - - - -
e9685fea by sheaf at 2024-09-21T15:59:34+02:00
X86 genCCall64: simplify loadArg code

This commit simplifies the argument loading code by making the
assumption that it is safe to directly load the argument into register,
because doing so will not clobber any previous assignments.

This assumption is borne from the use of 'evalArgs', which evaluates
any arguments which might necessitate non-trivial code generation into
separate temporary registers.

- - - - -
a7b364be by sheaf at 2024-09-21T15:59:34+02:00
LLVM: propagate GlobalRegUse information

This commit ensures we keep track of how any particular global register
is being used in the LLVM backend. This informs the LLVM type
annotations, and avoids type mismatches of the following form:

  argument is not of expected type '<2 x double>'
    call ccc <2 x double> (<2 x double>)
      (<4 x i32> arg)

- - - - -


6 changed files:

- compiler/GHC/CmmToAsm/X86/CodeGen.hs
- compiler/GHC/CmmToLlvm.hs
- compiler/GHC/CmmToLlvm/Base.hs
- compiler/GHC/CmmToLlvm/CodeGen.hs
- compiler/GHC/CmmToLlvm/Ppr.hs
- compiler/GHC/CmmToLlvm/Regs.hs


Changes:

=====================================
compiler/GHC/CmmToAsm/X86/CodeGen.hs
=====================================
@@ -1965,6 +1965,17 @@ getRegister' platform is32Bit (CmmLit lit)
 getRegister' platform is32Bit (CmmLit lit) = do
   avx <- avxEnabled
 
+  -- NB: it is important that the code produced here (to load a literal into
+  -- a register) doesn't clobber any registers other than the destination
+  -- register; the code for generating C calls relies on this property.
+  --
+  -- In particular, we have:
+  --
+  -- > loadIntoRegMightClobberOtherReg (CmmLit _) = False
+  --
+  -- which means that we assume that loading a literal into a register
+  -- will not clobber any other registers.
+
   -- TODO: this function mishandles floating-point negative zero,
   -- because -0.0 == 0.0 returns True and because we represent CmmFloat as
   -- Rational, which can't properly represent negative zero.
@@ -3080,10 +3091,8 @@ genSimplePrim _   op                   dst     args           = do
   platform <- ncgPlatform <$> getConfig
   pprPanic "genSimplePrim: unhandled primop" (ppr (pprCallishMachOp op, dst, fmap (pdoc platform) args))
 
-{-
-Note [Evaluate C-call arguments before placing in destination registers]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
+{- Note [Evaluate C-call arguments before placing in destination registers]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 When producing code for C calls we must take care when placing arguments
 in their final registers. Specifically, we must ensure that temporary register
 usage due to evaluation of one argument does not clobber a register in which we
@@ -3134,15 +3143,11 @@ genForeignCall{32,64}.
 -- | See Note [Evaluate C-call arguments before placing in destination registers]
 evalArgs :: BlockId -> [CmmActual] -> NatM (InstrBlock, [CmmActual])
 evalArgs bid actuals
-  | any mightContainMachOp actuals = do
+  | any loadIntoRegMightClobberOtherReg actuals = do
       regs_blks <- mapM evalArg actuals
       return (concatOL $ map fst regs_blks, map snd regs_blks)
   | otherwise = return (nilOL, actuals)
   where
-    mightContainMachOp (CmmReg _)      = False
-    mightContainMachOp (CmmRegOff _ _) = False
-    mightContainMachOp (CmmLit _)      = False
-    mightContainMachOp _               = True
 
     evalArg :: CmmActual -> NatM (InstrBlock, CmmExpr)
     evalArg actual = do
@@ -3156,6 +3161,16 @@ evalArgs bid actuals
     newLocalReg :: CmmType -> NatM LocalReg
     newLocalReg ty = LocalReg <$> getUniqueM <*> pure ty
 
+-- | Might the code to put this expression into a register
+-- clobber any other registers?
+loadIntoRegMightClobberOtherReg :: CmmExpr -> Bool
+loadIntoRegMightClobberOtherReg (CmmReg _)      = False
+loadIntoRegMightClobberOtherReg (CmmRegOff _ _) = False
+loadIntoRegMightClobberOtherReg (CmmLit _)      = False
+  -- NB: this last 'False' is slightly risky, because the code for loading
+  -- a literal into a register is not entirely trivial.
+loadIntoRegMightClobberOtherReg _               = True
+
 -- Note [DIV/IDIV for bytes]
 -- ~~~~~~~~~~~~~~~~~~~~~~~~~
 -- IDIV reminder:
@@ -3228,33 +3243,39 @@ genCCall
   -> [CmmFormal]
   -> [CmmActual]
   -> NatM InstrBlock
-genCCall bid addr conv dest_regs args = do
+genCCall bid addr conv@(ForeignConvention _ argHints _ _) dest_regs args = do
+  platform <- getPlatform
   is32Bit <- is32BitPlatform
-  (instrs0, args') <- evalArgs bid args
+  let args_hints = zip args (argHints ++ repeat NoHint)
+      prom_args = map (maybePromoteCArgToW32 platform) args_hints
+  (instrs0, args') <- evalArgs bid prom_args
   instrs1 <- if is32Bit
     then genCCall32 addr conv dest_regs args'
     else genCCall64 addr conv dest_regs args'
   return (instrs0 `appOL` instrs1)
 
-maybePromoteCArg :: Platform -> Width -> (CmmExpr, ForeignHint) -> CmmExpr
-maybePromoteCArg platform wto (arg, hint)
- | wfrom < wto = case hint of
-     SignedHint -> CmmMachOp (MO_SS_Conv wfrom wto) [arg]
-     _          -> CmmMachOp (MO_UU_Conv wfrom wto) [arg]
+maybePromoteCArgToW32 :: Platform -> (CmmExpr, ForeignHint) -> CmmExpr
+maybePromoteCArgToW32 platform (arg, hint)
+ | wfrom < wto =
+    -- As wto=W32, we only need to handle integer conversions,
+    -- never Float -> Double.
+    case hint of
+      SignedHint -> CmmMachOp (MO_SS_Conv wfrom wto) [arg]
+      _          -> CmmMachOp (MO_UU_Conv wfrom wto) [arg]
  | otherwise   = arg
  where
-   wfrom = cmmExprWidth platform arg
+   ty = cmmExprType platform arg
+   wfrom = typeWidth ty
+   wto = W32
 
 genCCall32 :: CmmExpr           -- ^ address of the function to call
            -> ForeignConvention -- ^ calling convention
            -> [CmmFormal]       -- ^ where to put the result
            -> [CmmActual]       -- ^ arguments (of mixed type)
            -> NatM InstrBlock
-genCCall32 addr (ForeignConvention _ argHints _ _) dest_regs args = do
+genCCall32 addr _conv dest_regs args = do
         config <- getConfig
         let platform = ncgPlatform config
-            args_hints = zip args (argHints ++ repeat NoHint)
-            prom_args = map (maybePromoteCArg platform W32) args_hints
 
             -- If the size is smaller than the word, we widen things (see maybePromoteCArg)
             arg_size_bytes :: CmmType -> Int
@@ -3324,7 +3345,7 @@ genCCall32 addr (ForeignConvention _ argHints _ _) dest_regs args = do
         delta0 <- getDeltaNat
         setDeltaNat (delta0 - arg_pad_size)
 
-        push_codes <- mapM push_arg (reverse prom_args)
+        push_codes <- mapM push_arg (reverse args)
         delta <- getDeltaNat
         massert (delta == delta0 - tot_arg_size)
 
@@ -3399,11 +3420,9 @@ genCCall64 :: CmmExpr           -- ^ address of function to call
            -> [CmmFormal]       -- ^ where to put the result
            -> [CmmActual]       -- ^ arguments (of mixed type)
            -> NatM InstrBlock
-genCCall64 addr conv@(ForeignConvention _ argHints _ _) dest_regs args = do
+genCCall64 addr conv dest_regs args = do
     config <- getConfig
     let platform = ncgPlatform config
-        args_hints = zip args (argHints ++ repeat NoHint)
-        prom_args = map (maybePromoteCArg platform W32) args_hints
         word_size = platformWordSizeInBytes platform
         wordFmt = archWordFormat (target32Bit platform)
 
@@ -3417,10 +3436,9 @@ genCCall64 addr conv@(ForeignConvention _ argHints _ _) dest_regs args = do
       { stackArgs       = proper_stack_args
       , stackDataArgs   = stack_data_args
       , usedRegs        = arg_regs_used
-      , computeArgsCode = compute_args_code
       , assignArgsCode  = assign_args_code
       }
-      <- loadArgs config prom_args
+      <- loadArgs config args
 
     let
 
@@ -3516,7 +3534,6 @@ genCCall64 addr conv@(ForeignConvention _ argHints _ _) dest_regs args = do
 
     return (align_call_code     `appOL`
             push_code           `appOL`
-            compute_args_code   `appOL`
             assign_args_code    `appOL`
             load_data_refs      `appOL`
             shadow_space_code   `appOL`
@@ -3537,16 +3554,14 @@ data LoadArgs
   , stackDataArgs :: [CmmExpr]
   -- | Which registers are we using for argument passing?
   , usedRegs      :: [RegWithFormat]
-  -- | The code to compute arguments into (possibly temporary) registers.
-  , computeArgsCode :: InstrBlock
   -- | The code to assign arguments to registers used for argument passing.
   , assignArgsCode :: InstrBlock
   }
 instance Semigroup LoadArgs where
-  LoadArgs a1 d1 r1 i1 j1 <> LoadArgs a2 d2 r2 i2 j2
-    = LoadArgs (a1 ++ a2) (d1 ++ d2) (r1 ++ r2) (i1 S.<> i2) (j1 S.<> j2)
+  LoadArgs a1 d1 r1 j1 <> LoadArgs a2 d2 r2 j2
+    = LoadArgs (a1 ++ a2) (d1 ++ d2) (r1 ++ r2) (j1 S.<> j2)
 instance Monoid LoadArgs where
-  mempty = LoadArgs [] [] [] nilOL nilOL
+  mempty = LoadArgs [] [] [] nilOL
 
 -- | An argument passed on the stack, either directly or by reference.
 --
@@ -3703,7 +3718,6 @@ loadArgsSysV config (arg:rest) = do
           LoadArgs
             { stackArgs       = map RawStackArg (arg:rest)
             , stackDataArgs   = []
-            , computeArgsCode = nilOL
             , assignArgsCode  = nilOL
             , usedRegs        = []
             }
@@ -3723,12 +3737,11 @@ loadArgsSysV config (arg:rest) = do
     this_arg <-
       case mbReg of
         Just reg -> do
-          (compute_code, assign_code) <- lift $ loadArgIntoReg config arg rest reg
+          assign_code <- lift $ loadArgIntoReg arg reg
           return $
             LoadArgs
                 { stackArgs       = [] -- passed in register
                 , stackDataArgs   = []
-                , computeArgsCode = compute_code
                 , assignArgsCode  = assign_code
                 , usedRegs        = [RegWithFormat reg arg_fmt]
                 }
@@ -3738,7 +3751,6 @@ loadArgsSysV config (arg:rest) = do
             LoadArgs
                 { stackArgs       = [RawStackArg arg]
                 , stackDataArgs   = []
-                , computeArgsCode = nilOL
                 , assignArgsCode  = nilOL
                 , usedRegs        = []
                 }
@@ -3790,7 +3802,6 @@ loadArgsWin config (arg:rest) = do
         LoadArgs
           { stackArgs       = stk_args
           , stackDataArgs   = data_args
-          , computeArgsCode = nilOL
           , assignArgsCode  = nilOL
           , usedRegs        = []
           }
@@ -3806,8 +3817,7 @@ loadArgsWin config (arg:rest) = do
                 -- Pass the reference in a register,
                 -- and the argument data on the stack.
                 { stackArgs       = [RawStackArgRef (InReg ireg) (argSize platform arg)]
-                , stackDataArgs   = [arg]
-                , computeArgsCode = nilOL -- we don't yet know where the data will reside,
+                , stackDataArgs   = [arg] -- we don't yet know where the data will reside,
                 , assignArgsCode  = nilOL -- so we defer computing the reference and storing it
                                           -- in the register until later
                 , usedRegs        = [RegWithFormat ireg II64]
@@ -3818,7 +3828,7 @@ loadArgsWin config (arg:rest) = do
                   = freg
                   | otherwise
                   = ireg
-           (compute_code, assign_code) <- loadArgIntoReg config arg rest arg_reg
+           assign_code <- loadArgIntoReg arg arg_reg
            -- Recall that, for varargs, we must pass floating-point
            -- arguments in both fp and integer registers.
            let (assign_code', regs')
@@ -3831,42 +3841,23 @@ loadArgsWin config (arg:rest) = do
              LoadArgs
                { stackArgs       = [] -- passed in register
                , stackDataArgs   = []
-               , computeArgsCode = compute_code
                , assignArgsCode = assign_code'
                , usedRegs = regs'
                }
 
-
--- | Return two pieces of code:
---
---  - code to compute a the given 'CmmExpr' into some (possibly temporary) register
---  - code to assign the resulting value to the specified register
+-- | Load an argument into a register.
 --
--- Using two separate pieces of code handles clobbering issues reported
--- in e.g. #11792, #12614.
-loadArgIntoReg :: NCGConfig -> CmmExpr -> [CmmExpr] -> Reg -> NatM (InstrBlock, InstrBlock)
-loadArgIntoReg config arg rest reg
-  -- "operand" args can be directly assigned into the register
-  | isOperand platform arg
-  = do arg_code <- getAnyReg arg
-       return (nilOL, arg_code reg)
-  -- The last non-operand arg can be directly assigned after its
-  -- computation without going into a temporary register
-  | all (isOperand platform) rest
-  = do arg_code <- getAnyReg arg
-       return (arg_code reg, nilOL)
-  -- Other args need to be computed beforehand to avoid clobbering
-  -- previously assigned registers used to pass parameters (see
-  -- #11792, #12614). They are assigned into temporary registers
-  -- and get assigned to proper call ABI registers after they all
-  -- have been computed.
-  | otherwise
-  = do arg_code <- getAnyReg arg
-       tmp      <- getNewRegNat arg_fmt
-       return (arg_code tmp, unitOL $ mkRegRegMoveInstr config arg_fmt tmp reg)
-  where
-    platform = ncgPlatform config
-    arg_fmt = cmmTypeFormat $ cmmExprType platform arg
+-- Assumes that the expression does not contain any MachOps,
+-- as per Note [Evaluate C-call arguments before placing in destination registers].
+loadArgIntoReg :: CmmExpr -> Reg -> NatM InstrBlock
+loadArgIntoReg arg reg = do
+  when (debugIsOn && loadIntoRegMightClobberOtherReg arg) $ do
+    platform <- getPlatform
+    massertPpr False $
+      vcat [ text "loadArgIntoReg: arg might contain MachOp"
+           , text "arg:" <+> pdoc platform arg ]
+  arg_code <- getAnyReg arg
+  return $ arg_code reg
 
 -- -----------------------------------------------------------------------------
 -- Pushing arguments onto the stack for 64-bit C calls.


=====================================
compiler/GHC/CmmToLlvm.hs
=====================================
@@ -139,7 +139,7 @@ llvmGroupLlvmGens cmm = do
                          Nothing                   -> l
                          Just (CmmStaticsRaw info_lbl _) -> info_lbl
               lml <- strCLabel_llvm l'
-              funInsert lml =<< llvmFunTy (map globalRegUse_reg live)
+              funInsert lml =<< llvmFunTy live
               return Nothing
         cdata <- fmap catMaybes $ mapM split cmm
 


=====================================
compiler/GHC/CmmToLlvm/Base.hs
=====================================
@@ -12,7 +12,7 @@
 module GHC.CmmToLlvm.Base (
 
         LlvmCmmDecl, LlvmBasicBlock,
-        LiveGlobalRegs,
+        LiveGlobalRegs, LiveGlobalRegUses,
         LlvmUnresData, LlvmData, UnresLabel, UnresStatic,
 
         LlvmM,
@@ -29,6 +29,8 @@ module GHC.CmmToLlvm.Base (
         llvmFunSig, llvmFunArgs, llvmStdFunAttrs, llvmFunAlign, llvmInfAlign,
         llvmPtrBits, tysToParams, llvmFunSection, padLiveArgs, isFPR,
 
+        lookupRegUse,
+
         strCLabel_llvm,
         getGlobalPtr, generateExternDecls,
 
@@ -58,9 +60,11 @@ import GHC.Types.Unique.Set
 import GHC.Types.Unique.Supply
 import GHC.Utils.Logger
 
-import Data.Maybe (fromJust)
 import Control.Monad.Trans.State (StateT (..))
-import Data.List (isPrefixOf)
+import Control.Applicative (Alternative((<|>)))
+import Data.Maybe (fromJust, mapMaybe)
+
+import Data.List (find, isPrefixOf)
 import qualified Data.List.NonEmpty as NE
 import Data.Ord (comparing)
 
@@ -73,6 +77,7 @@ type LlvmBasicBlock = GenBasicBlock LlvmStatement
 
 -- | Global registers live on proc entry
 type LiveGlobalRegs = [GlobalReg]
+type LiveGlobalRegUses = [GlobalRegUse]
 
 -- | Unresolved code.
 -- Of the form: (data label, data type, unresolved data)
@@ -116,16 +121,16 @@ llvmGhcCC platform
  | otherwise                       = CC_Ghc
 
 -- | Llvm Function type for Cmm function
-llvmFunTy :: LiveGlobalRegs -> LlvmM LlvmType
+llvmFunTy :: LiveGlobalRegUses -> LlvmM LlvmType
 llvmFunTy live = return . LMFunction =<< llvmFunSig' live (fsLit "a") ExternallyVisible
 
 -- | Llvm Function signature
-llvmFunSig :: LiveGlobalRegs ->  CLabel -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
+llvmFunSig :: LiveGlobalRegUses ->  CLabel -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
 llvmFunSig live lbl link = do
   lbl' <- strCLabel_llvm lbl
   llvmFunSig' live lbl' link
 
-llvmFunSig' :: LiveGlobalRegs -> LMString -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
+llvmFunSig' :: LiveGlobalRegUses -> LMString -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
 llvmFunSig' live lbl link
   = do let toParams x | isPointer x = (x, [NoAlias, NoCapture])
                       | otherwise   = (x, [])
@@ -149,16 +154,25 @@ llvmFunSection opts lbl
     | otherwise               = Nothing
 
 -- | A Function's arguments
-llvmFunArgs :: Platform -> LiveGlobalRegs -> [LlvmVar]
+llvmFunArgs :: Platform -> LiveGlobalRegUses -> [LlvmVar]
 llvmFunArgs platform live =
-    map (lmGlobalRegArg platform) (filter isPassed allRegs)
+    map (lmGlobalRegArg platform) (mapMaybe isPassed allRegs)
     where allRegs = activeStgRegs platform
           paddingRegs = padLiveArgs platform live
-          isLive r = r `elem` alwaysLive
-                     || r `elem` live
-                     || r `elem` paddingRegs
-          isPassed r = not (isFPR r) || isLive r
-
+          isLive :: GlobalReg -> Maybe GlobalRegUse
+          isLive r =
+            lookupRegUse r (alwaysLive platform)
+              <|>
+            lookupRegUse r live
+              <|>
+            lookupRegUse r paddingRegs
+          isPassed r =
+            if not (isFPR r)
+            then Just $ GlobalRegUse r (globalRegSpillType platform r)
+            else isLive r
+
+lookupRegUse :: GlobalReg -> [GlobalRegUse] -> Maybe GlobalRegUse
+lookupRegUse r = find ((== r) . globalRegUse_reg)
 
 isFPR :: GlobalReg -> Bool
 isFPR (FloatReg _)  = True
@@ -179,7 +193,7 @@ isFPR _             = False
 -- Invariant: Cmm FPR regs with number "n" maps to real registers with number
 -- "n" If the calling convention uses registers in a different order or if the
 -- invariant doesn't hold, this code probably won't be correct.
-padLiveArgs :: Platform -> LiveGlobalRegs -> LiveGlobalRegs
+padLiveArgs :: Platform -> LiveGlobalRegUses -> LiveGlobalRegUses
 padLiveArgs platform live =
       if platformUnregisterised platform
         then [] -- not using GHC's register convention for platform.
@@ -188,7 +202,7 @@ padLiveArgs platform live =
     ----------------------------------
     -- handle floating-point registers (FPR)
 
-    fprLive = filter isFPR live  -- real live FPR registers
+    fprLive = filter (isFPR . globalRegUse_reg) live  -- real live FPR registers
 
     -- we group live registers sharing the same classes, i.e. that use the same
     -- set of real registers to be passed. E.g. FloatReg, DoubleReg and XmmReg
@@ -196,39 +210,44 @@ padLiveArgs platform live =
     --
     classes         = NE.groupBy sharesClass fprLive
     sharesClass a b = globalRegsOverlap platform (norm a) (norm b) -- check if mapped to overlapping registers
-    norm x          = fpr_ctor x 1                                 -- get the first register of the family
+    norm x = globalRegUse_reg (fpr_ctor x 1)                  -- get the first register of the family
 
     -- For each class, we just have to fill missing registers numbers. We use
     -- the constructor of the greatest register to build padding registers.
     --
     -- E.g. sortedRs = [   F2,   XMM4, D5]
     --      output   = [D1,   D3]
+    padded :: [GlobalRegUse]
     padded      = concatMap padClass classes
+
+    padClass :: NE.NonEmpty GlobalRegUse -> [GlobalRegUse]
     padClass rs = go (NE.toList sortedRs) 1
       where
-         sortedRs = NE.sortBy (comparing fpr_num) rs
+         sortedRs = NE.sortBy (comparing (fpr_num . globalRegUse_reg)) rs
          maxr     = NE.last sortedRs
          ctor     = fpr_ctor maxr
 
          go [] _ = []
-         go (c1:c2:_) _   -- detect bogus case (see #17920)
+         go (GlobalRegUse c1 _: GlobalRegUse c2 _:_) _   -- detect bogus case (see #17920)
             | fpr_num c1 == fpr_num c2
             , Just real <- globalRegMaybe platform c1
             = sorryDoc "LLVM code generator" $
                text "Found two different Cmm registers (" <> ppr c1 <> text "," <> ppr c2 <>
                text ") both alive AND mapped to the same real register: " <> ppr real <>
                text ". This isn't currently supported by the LLVM backend."
-         go (c:cs) f
-            | fpr_num c == f = go cs f                    -- already covered by a real register
-            | otherwise      = ctor f : go (c:cs) (f + 1) -- add padding register
-
-    fpr_ctor :: GlobalReg -> Int -> GlobalReg
-    fpr_ctor (FloatReg _)  = FloatReg
-    fpr_ctor (DoubleReg _) = DoubleReg
-    fpr_ctor (XmmReg _)    = XmmReg
-    fpr_ctor (YmmReg _)    = YmmReg
-    fpr_ctor (ZmmReg _)    = ZmmReg
-    fpr_ctor _ = error "fpr_ctor expected only FPR regs"
+         go (cu@(GlobalRegUse c _):cs) f
+            | fpr_num c == f = go cs f                     -- already covered by a real register
+            | otherwise      = ctor f : go (cu:cs) (f + 1) -- add padding register
+
+    fpr_ctor :: GlobalRegUse -> Int -> GlobalRegUse
+    fpr_ctor (GlobalRegUse r fmt) i =
+      case r of
+        FloatReg _  -> GlobalRegUse (FloatReg  i) fmt
+        DoubleReg _ -> GlobalRegUse (DoubleReg i) fmt
+        XmmReg _    -> GlobalRegUse (XmmReg    i) fmt
+        YmmReg _    -> GlobalRegUse (YmmReg    i) fmt
+        ZmmReg _    -> GlobalRegUse (ZmmReg    i) fmt
+        _           -> error "fpr_ctor expected only FPR regs"
 
     fpr_num :: GlobalReg -> Int
     fpr_num (FloatReg i)  = i


=====================================
compiler/GHC/CmmToLlvm/CodeGen.hs
=====================================
@@ -37,13 +37,14 @@ import GHC.Utils.Outputable
 import qualified GHC.Utils.Panic as Panic
 import GHC.Utils.Misc
 
+import Control.Applicative (Alternative((<|>)))
 import Control.Monad.Trans.Class
 import Control.Monad.Trans.Writer
 import Control.Monad
 
 import qualified Data.Semigroup as Semigroup
 import Data.List ( nub )
-import Data.Maybe ( catMaybes )
+import Data.Maybe ( catMaybes, isJust )
 
 type Atomic = Maybe MemoryOrdering
 type LlvmStatements = OrdList LlvmStatement
@@ -57,7 +58,7 @@ genLlvmProc :: RawCmmDecl -> LlvmM [LlvmCmmDecl]
 genLlvmProc (CmmProc infos lbl live graph) = do
     let blocks = toBlockListEntryFirstFalseFallthrough graph
 
-    (lmblocks, lmdata) <- basicBlocksCodeGen (map globalRegUse_reg live) blocks
+    (lmblocks, lmdata) <- basicBlocksCodeGen live blocks
     let info = mapLookup (g_entry graph) infos
         proc = CmmProc info lbl live (ListGraph lmblocks)
     return (proc:lmdata)
@@ -76,7 +77,7 @@ newtype UnreachableBlockId = UnreachableBlockId BlockId
 -- | Generate code for a list of blocks that make up a complete
 -- procedure. The first block in the list is expected to be the entry
 -- point.
-basicBlocksCodeGen :: LiveGlobalRegs -> [CmmBlock]
+basicBlocksCodeGen :: LiveGlobalRegUses -> [CmmBlock]
                       -> LlvmM ([LlvmBasicBlock], [LlvmCmmDecl])
 basicBlocksCodeGen _    []                     = panic "no entry block!"
 basicBlocksCodeGen live cmmBlocks
@@ -152,7 +153,7 @@ stmtToInstrs ubid stmt = case stmt of
 
     -- Tail call
     CmmCall { cml_target = arg,
-              cml_args_regs = live } -> genJump arg $ map globalRegUse_reg live
+              cml_args_regs = live } -> genJump arg live
 
     _ -> panic "Llvm.CodeGen.stmtToInstrs"
 
@@ -1050,7 +1051,7 @@ cmmPrimOpFunctions mop = do
 
 
 -- | Tail function calls
-genJump :: CmmExpr -> [GlobalReg] -> LlvmM StmtData
+genJump :: CmmExpr -> LiveGlobalRegUses -> LlvmM StmtData
 
 -- Call to known function
 genJump (CmmLit (CmmLabel lbl)) live = do
@@ -2056,14 +2057,13 @@ getCmmReg (CmmLocal (LocalReg un _))
            -- have been assigned a value at some point, triggering
            -- "funPrologue" to allocate it on the stack.
 
-getCmmReg (CmmGlobal g)
-  = do let r = globalRegUse_reg g
-       onStack  <- checkStackReg r
+getCmmReg (CmmGlobal ru@(GlobalRegUse r _))
+  = do onStack  <- checkStackReg r
        platform <- getPlatform
        if onStack
-         then return (lmGlobalRegVar platform r)
+         then return (lmGlobalRegVar platform ru)
          else pprPanic "getCmmReg: Cmm register " $
-                ppr g <> text " not stack-allocated!"
+                ppr r <> text " not stack-allocated!"
 
 -- | Return the value of a given register, as well as its type. Might
 -- need to be load from stack.
@@ -2074,7 +2074,7 @@ getCmmRegVal reg =
       onStack <- checkStackReg (globalRegUse_reg g)
       platform <- getPlatform
       if onStack then loadFromStack else do
-        let r = lmGlobalRegArg platform (globalRegUse_reg g)
+        let r = lmGlobalRegArg platform g
         return (r, getVarType r, nilOL)
     _ -> loadFromStack
  where loadFromStack = do
@@ -2187,8 +2187,9 @@ convertMemoryOrdering MemOrderSeqCst  = SyncSeqCst
 -- question is never written. Therefore we skip it where we can to
 -- save a few lines in the output and hopefully speed compilation up a
 -- bit.
-funPrologue :: LiveGlobalRegs -> [CmmBlock] -> LlvmM StmtData
+funPrologue :: LiveGlobalRegUses -> [CmmBlock] -> LlvmM StmtData
 funPrologue live cmmBlocks = do
+  platform <- getPlatform
 
   let getAssignedRegs :: CmmNode O O -> [CmmReg]
       getAssignedRegs (CmmAssign reg _)  = [reg]
@@ -2196,7 +2197,8 @@ funPrologue live cmmBlocks = do
       getAssignedRegs _                  = []
       getRegsBlock (_, body, _)          = concatMap getAssignedRegs $ blockToList body
       assignedRegs = nub $ concatMap (getRegsBlock . blockSplit) cmmBlocks
-      isLive r     = r `elem` alwaysLive || r `elem` live
+      mbLive r     =
+        lookupRegUse r (alwaysLive platform) <|> lookupRegUse r live
 
   platform <- getPlatform
   stmtss <- forM assignedRegs $ \reg ->
@@ -2205,12 +2207,12 @@ funPrologue live cmmBlocks = do
         let (newv, stmts) = allocReg reg
         varInsert un (pLower $ getVarType newv)
         return stmts
-      CmmGlobal (GlobalRegUse r _) -> do
-        let reg   = lmGlobalRegVar platform r
-            arg   = lmGlobalRegArg platform r
+      CmmGlobal ru@(GlobalRegUse r _) -> do
+        let reg   = lmGlobalRegVar platform ru
+            arg   = lmGlobalRegArg platform ru
             ty    = (pLower . getVarType) reg
             trash = LMLitVar $ LMUndefLit ty
-            rval  = if isLive r then arg else trash
+            rval  = if isJust (mbLive r) then arg else trash
             alloc = Assignment reg $ Alloca (pLower $ getVarType reg) 1
         markStackReg r
         return $ toOL [alloc, Store rval reg Nothing []]
@@ -2222,7 +2224,7 @@ funPrologue live cmmBlocks = do
 
 -- | Function epilogue. Load STG variables to use as argument for call.
 -- STG Liveness optimisation done here.
-funEpilogue :: LiveGlobalRegs -> LlvmM ([LlvmVar], LlvmStatements)
+funEpilogue :: LiveGlobalRegUses -> LlvmM ([LlvmVar], LlvmStatements)
 funEpilogue live = do
     platform <- getPlatform
 
@@ -2248,12 +2250,16 @@ funEpilogue live = do
     let allRegs = activeStgRegs platform
     loads <- forM allRegs $ \r -> if
       -- load live registers
-      | r `elem` alwaysLive  -> loadExpr (GlobalRegUse r (globalRegSpillType platform r))
-      | r `elem` live        -> loadExpr (GlobalRegUse r (globalRegSpillType platform r))
+      | Just ru <- lookupRegUse r (alwaysLive platform)
+      -> loadExpr ru
+      | Just ru <- lookupRegUse r live
+      -> loadExpr ru
       -- load all non Floating-Point Registers
-      | not (isFPR r)        -> loadUndef r
+      | not (isFPR r)
+      -> loadUndef (GlobalRegUse r (globalRegSpillType platform r))
       -- load padding Floating-Point Registers
-      | r `elem` paddingRegs -> loadUndef r
+      | Just ru <- lookupRegUse r paddingRegs
+      -> loadUndef ru
       | otherwise            -> return (Nothing, nilOL)
 
     let (vars, stmts) = unzip loads
@@ -2263,7 +2269,7 @@ funEpilogue live = do
 --
 -- This is for Haskell functions, function type is assumed, so doesn't work
 -- with foreign functions.
-getHsFunc :: LiveGlobalRegs -> CLabel -> LlvmM ExprData
+getHsFunc :: LiveGlobalRegUses -> CLabel -> LlvmM ExprData
 getHsFunc live lbl
   = do fty <- llvmFunTy live
        name <- strCLabel_llvm lbl


=====================================
compiler/GHC/CmmToLlvm/Ppr.hs
=====================================
@@ -49,9 +49,8 @@ pprLlvmCmmDecl (CmmData _ lmdata) = do
   return ( vcat $ map (pprLlvmData opts) lmdata
          , vcat $ map (pprLlvmData opts) lmdata)
 
-pprLlvmCmmDecl (CmmProc mb_info entry_lbl liveWithUses (ListGraph blks))
-  = do let live = map globalRegUse_reg liveWithUses
-           lbl = case mb_info of
+pprLlvmCmmDecl (CmmProc mb_info entry_lbl live (ListGraph blks))
+  = do let lbl = case mb_info of
                      Nothing -> entry_lbl
                      Just (CmmStaticsRaw info_lbl _) -> info_lbl
            link = if externallyVisibleCLabel lbl


=====================================
compiler/GHC/CmmToLlvm/Regs.hs
=====================================
@@ -14,25 +14,27 @@ import GHC.Prelude
 import GHC.Llvm
 
 import GHC.Cmm.Expr
+import GHC.CmmToAsm.Format
 import GHC.Platform
 import GHC.Data.FastString
 import GHC.Utils.Panic ( panic )
 import GHC.Types.Unique
 
+
 -- | Get the LlvmVar function variable storing the real register
-lmGlobalRegVar :: Platform -> GlobalReg -> LlvmVar
+lmGlobalRegVar :: Platform -> GlobalRegUse -> LlvmVar
 lmGlobalRegVar platform = pVarLift . lmGlobalReg platform "_Var"
 
 -- | Get the LlvmVar function argument storing the real register
-lmGlobalRegArg :: Platform -> GlobalReg -> LlvmVar
+lmGlobalRegArg :: Platform -> GlobalRegUse -> LlvmVar
 lmGlobalRegArg platform = lmGlobalReg platform "_Arg"
 
 {- Need to make sure the names here can't conflict with the unique generated
    names. Uniques generated names containing only base62 chars. So using say
    the '_' char guarantees this.
 -}
-lmGlobalReg :: Platform -> String -> GlobalReg -> LlvmVar
-lmGlobalReg platform suf reg
+lmGlobalReg :: Platform -> String -> GlobalRegUse -> LlvmVar
+lmGlobalReg platform suf (GlobalRegUse reg ty)
   = case reg of
         BaseReg        -> ptrGlobal $ "Base" ++ suf
         Sp             -> ptrGlobal $ "Sp" ++ suf
@@ -88,13 +90,26 @@ lmGlobalReg platform suf reg
         ptrGlobal    name = LMNLocalVar (fsLit name) (llvmWordPtr platform)
         floatGlobal  name = LMNLocalVar (fsLit name) LMFloat
         doubleGlobal name = LMNLocalVar (fsLit name) LMDouble
-        xmmGlobal    name = LMNLocalVar (fsLit name) (LMVector 4 (LMInt 32))
-        ymmGlobal    name = LMNLocalVar (fsLit name) (LMVector 8 (LMInt 32))
-        zmmGlobal    name = LMNLocalVar (fsLit name) (LMVector 16 (LMInt 32))
+        fmt = cmmTypeFormat ty
+        xmmGlobal    name = LMNLocalVar (fsLit name) (formatLlvmType fmt)
+        ymmGlobal    name = LMNLocalVar (fsLit name) (formatLlvmType fmt)
+        zmmGlobal    name = LMNLocalVar (fsLit name) (formatLlvmType fmt)
+
+formatLlvmType :: Format -> LlvmType
+formatLlvmType II8 = LMInt 8
+formatLlvmType II16 = LMInt 16
+formatLlvmType II32 = LMInt 32
+formatLlvmType II64 = LMInt 64
+formatLlvmType FF32 = LMFloat
+formatLlvmType FF64 = LMDouble
+formatLlvmType (VecFormat l sFmt) = LMVector l (formatLlvmType $ scalarFormatFormat sFmt)
 
 -- | A list of STG Registers that should always be considered alive
-alwaysLive :: [GlobalReg]
-alwaysLive = [BaseReg, Sp, Hp, SpLim, HpLim, node]
+alwaysLive :: Platform -> [GlobalRegUse]
+alwaysLive platform =
+  [ GlobalRegUse r (globalRegSpillType platform r)
+  | r <- [BaseReg, Sp, Hp, SpLim, HpLim, node]
+  ]
 
 -- | STG Type Based Alias Analysis hierarchy
 stgTBAA :: [(Unique, LMString, Maybe Unique)]



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/0a288d07bb5458c5c0d3e17a5fcc63f8059782a2...a7b364be1d319a7f78246060eafb3d483bd94beb

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/0a288d07bb5458c5c0d3e17a5fcc63f8059782a2...a7b364be1d319a7f78246060eafb3d483bd94beb
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240921/1518db3c/attachment-0001.html>


More information about the ghc-commits mailing list