[Git][ghc/ghc][wip/ncg-simd] LLVM: propagate GlobalRegUse information

sheaf (@sheaf) gitlab at gitlab.haskell.org
Fri Aug 30 23:06:00 UTC 2024



sheaf pushed to branch wip/ncg-simd at Glasgow Haskell Compiler / GHC


Commits:
b792b70c by sheaf at 2024-08-31T01:05:49+02:00
LLVM: propagate GlobalRegUse information

This commit ensures we keep track of how any particular global register
is being used in the LLVM backend. This informs the LLVM type
annotations, and avoids type mismatches of the following form:

  argument is not of expected type '<2 x double>'
    call ccc <2 x double> (<2 x double>)
      (<4 x i32> arg)

- - - - -


5 changed files:

- compiler/GHC/CmmToLlvm.hs
- compiler/GHC/CmmToLlvm/Base.hs
- compiler/GHC/CmmToLlvm/CodeGen.hs
- compiler/GHC/CmmToLlvm/Ppr.hs
- compiler/GHC/CmmToLlvm/Regs.hs


Changes:

=====================================
compiler/GHC/CmmToLlvm.hs
=====================================
@@ -139,7 +139,7 @@ llvmGroupLlvmGens cmm = do
                          Nothing                   -> l
                          Just (CmmStaticsRaw info_lbl _) -> info_lbl
               lml <- strCLabel_llvm l'
-              funInsert lml =<< llvmFunTy (map globalRegUseGlobalReg live)
+              funInsert lml =<< llvmFunTy live
               return Nothing
         cdata <- fmap catMaybes $ mapM split cmm
 


=====================================
compiler/GHC/CmmToLlvm/Base.hs
=====================================
@@ -12,7 +12,7 @@
 module GHC.CmmToLlvm.Base (
 
         LlvmCmmDecl, LlvmBasicBlock,
-        LiveGlobalRegs,
+        LiveGlobalRegs, LiveGlobalRegUses,
         LlvmUnresData, LlvmData, UnresLabel, UnresStatic,
 
         LlvmM,
@@ -29,6 +29,8 @@ module GHC.CmmToLlvm.Base (
         llvmFunSig, llvmFunArgs, llvmStdFunAttrs, llvmFunAlign, llvmInfAlign,
         llvmPtrBits, tysToParams, llvmFunSection, padLiveArgs, isFPR,
 
+        lookupRegUse,
+
         strCLabel_llvm,
         getGlobalPtr, generateExternDecls,
 
@@ -58,9 +60,11 @@ import GHC.Types.Unique.Set
 import GHC.Types.Unique.Supply
 import GHC.Utils.Logger
 
-import Data.Maybe (fromJust)
 import Control.Monad.Trans.State (StateT (..))
-import Data.List (isPrefixOf)
+import Control.Applicative (Alternative((<|>)))
+import Data.Maybe (fromJust, mapMaybe)
+
+import Data.List (find, isPrefixOf)
 import qualified Data.List.NonEmpty as NE
 import Data.Ord (comparing)
 
@@ -73,6 +77,7 @@ type LlvmBasicBlock = GenBasicBlock LlvmStatement
 
 -- | Global registers live on proc entry
 type LiveGlobalRegs = [GlobalReg]
+type LiveGlobalRegUses = [GlobalRegUse]
 
 -- | Unresolved code.
 -- Of the form: (data label, data type, unresolved data)
@@ -116,16 +121,16 @@ llvmGhcCC platform
  | otherwise                       = CC_Ghc
 
 -- | Llvm Function type for Cmm function
-llvmFunTy :: LiveGlobalRegs -> LlvmM LlvmType
+llvmFunTy :: LiveGlobalRegUses -> LlvmM LlvmType
 llvmFunTy live = return . LMFunction =<< llvmFunSig' live (fsLit "a") ExternallyVisible
 
 -- | Llvm Function signature
-llvmFunSig :: LiveGlobalRegs ->  CLabel -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
+llvmFunSig :: LiveGlobalRegUses ->  CLabel -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
 llvmFunSig live lbl link = do
   lbl' <- strCLabel_llvm lbl
   llvmFunSig' live lbl' link
 
-llvmFunSig' :: LiveGlobalRegs -> LMString -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
+llvmFunSig' :: LiveGlobalRegUses -> LMString -> LlvmLinkageType -> LlvmM LlvmFunctionDecl
 llvmFunSig' live lbl link
   = do let toParams x | isPointer x = (x, [NoAlias, NoCapture])
                       | otherwise   = (x, [])
@@ -149,16 +154,25 @@ llvmFunSection opts lbl
     | otherwise               = Nothing
 
 -- | A Function's arguments
-llvmFunArgs :: Platform -> LiveGlobalRegs -> [LlvmVar]
+llvmFunArgs :: Platform -> LiveGlobalRegUses -> [LlvmVar]
 llvmFunArgs platform live =
-    map (lmGlobalRegArg platform) (filter isPassed allRegs)
+    map (lmGlobalRegArg platform) (mapMaybe isPassed allRegs)
     where allRegs = activeStgRegs platform
           paddingRegs = padLiveArgs platform live
-          isLive r = r `elem` alwaysLive
-                     || r `elem` live
-                     || r `elem` paddingRegs
-          isPassed r = not (isFPR r) || isLive r
-
+          isLive :: GlobalReg -> Maybe GlobalRegUse
+          isLive r =
+            lookupRegUse r (alwaysLive platform)
+              <|>
+            lookupRegUse r live
+              <|>
+            lookupRegUse r paddingRegs
+          isPassed r =
+            if not (isFPR r)
+            then Just $ GlobalRegUse r (globalRegSpillType platform r)
+            else isLive r
+
+lookupRegUse :: GlobalReg -> [GlobalRegUse] -> Maybe GlobalRegUse
+lookupRegUse r = find ((== r) . globalRegUseGlobalReg)
 
 isFPR :: GlobalReg -> Bool
 isFPR (FloatReg _)  = True
@@ -179,7 +193,7 @@ isFPR _             = False
 -- Invariant: Cmm FPR regs with number "n" maps to real registers with number
 -- "n" If the calling convention uses registers in a different order or if the
 -- invariant doesn't hold, this code probably won't be correct.
-padLiveArgs :: Platform -> LiveGlobalRegs -> LiveGlobalRegs
+padLiveArgs :: Platform -> LiveGlobalRegUses -> LiveGlobalRegUses
 padLiveArgs platform live =
       if platformUnregisterised platform
         then [] -- not using GHC's register convention for platform.
@@ -188,7 +202,7 @@ padLiveArgs platform live =
     ----------------------------------
     -- handle floating-point registers (FPR)
 
-    fprLive = filter isFPR live  -- real live FPR registers
+    fprLive = filter (isFPR . globalRegUseGlobalReg) live  -- real live FPR registers
 
     -- we group live registers sharing the same classes, i.e. that use the same
     -- set of real registers to be passed. E.g. FloatReg, DoubleReg and XmmReg
@@ -196,39 +210,44 @@ padLiveArgs platform live =
     --
     classes         = NE.groupBy sharesClass fprLive
     sharesClass a b = globalRegsOverlap platform (norm a) (norm b) -- check if mapped to overlapping registers
-    norm x          = fpr_ctor x 1                                 -- get the first register of the family
+    norm x = globalRegUseGlobalReg (fpr_ctor x 1)                  -- get the first register of the family
 
     -- For each class, we just have to fill missing registers numbers. We use
     -- the constructor of the greatest register to build padding registers.
     --
     -- E.g. sortedRs = [   F2,   XMM4, D5]
     --      output   = [D1,   D3]
+    padded :: [GlobalRegUse]
     padded      = concatMap padClass classes
+
+    padClass :: NE.NonEmpty GlobalRegUse -> [GlobalRegUse]
     padClass rs = go (NE.toList sortedRs) 1
       where
-         sortedRs = NE.sortBy (comparing fpr_num) rs
+         sortedRs = NE.sortBy (comparing (fpr_num . globalRegUseGlobalReg)) rs
          maxr     = NE.last sortedRs
          ctor     = fpr_ctor maxr
 
          go [] _ = []
-         go (c1:c2:_) _   -- detect bogus case (see #17920)
+         go (GlobalRegUse c1 _: GlobalRegUse c2 _:_) _   -- detect bogus case (see #17920)
             | fpr_num c1 == fpr_num c2
             , Just real <- globalRegMaybe platform c1
             = sorryDoc "LLVM code generator" $
                text "Found two different Cmm registers (" <> ppr c1 <> text "," <> ppr c2 <>
                text ") both alive AND mapped to the same real register: " <> ppr real <>
                text ". This isn't currently supported by the LLVM backend."
-         go (c:cs) f
-            | fpr_num c == f = go cs f                    -- already covered by a real register
-            | otherwise      = ctor f : go (c:cs) (f + 1) -- add padding register
-
-    fpr_ctor :: GlobalReg -> Int -> GlobalReg
-    fpr_ctor (FloatReg _)  = FloatReg
-    fpr_ctor (DoubleReg _) = DoubleReg
-    fpr_ctor (XmmReg _)    = XmmReg
-    fpr_ctor (YmmReg _)    = YmmReg
-    fpr_ctor (ZmmReg _)    = ZmmReg
-    fpr_ctor _ = error "fpr_ctor expected only FPR regs"
+         go (cu@(GlobalRegUse c _):cs) f
+            | fpr_num c == f = go cs f                     -- already covered by a real register
+            | otherwise      = ctor f : go (cu:cs) (f + 1) -- add padding register
+
+    fpr_ctor :: GlobalRegUse -> Int -> GlobalRegUse
+    fpr_ctor (GlobalRegUse r fmt) i =
+      case r of
+        FloatReg _  -> GlobalRegUse (FloatReg  i) fmt
+        DoubleReg _ -> GlobalRegUse (DoubleReg i) fmt
+        XmmReg _    -> GlobalRegUse (XmmReg    i) fmt
+        YmmReg _    -> GlobalRegUse (YmmReg    i) fmt
+        ZmmReg _    -> GlobalRegUse (ZmmReg    i) fmt
+        _           -> error "fpr_ctor expected only FPR regs"
 
     fpr_num :: GlobalReg -> Int
     fpr_num (FloatReg i)  = i


=====================================
compiler/GHC/CmmToLlvm/CodeGen.hs
=====================================
@@ -37,13 +37,14 @@ import GHC.Utils.Outputable
 import qualified GHC.Utils.Panic as Panic
 import GHC.Utils.Misc
 
+import Control.Applicative (Alternative((<|>)))
 import Control.Monad.Trans.Class
 import Control.Monad.Trans.Writer
 import Control.Monad
 
 import qualified Data.Semigroup as Semigroup
 import Data.List ( nub )
-import Data.Maybe ( catMaybes )
+import Data.Maybe ( catMaybes, isJust )
 
 type Atomic = Maybe MemoryOrdering
 type LlvmStatements = OrdList LlvmStatement
@@ -57,7 +58,7 @@ genLlvmProc :: RawCmmDecl -> LlvmM [LlvmCmmDecl]
 genLlvmProc (CmmProc infos lbl live graph) = do
     let blocks = toBlockListEntryFirstFalseFallthrough graph
 
-    (lmblocks, lmdata) <- basicBlocksCodeGen (map globalRegUseGlobalReg live) blocks
+    (lmblocks, lmdata) <- basicBlocksCodeGen live blocks
     let info = mapLookup (g_entry graph) infos
         proc = CmmProc info lbl live (ListGraph lmblocks)
     return (proc:lmdata)
@@ -76,7 +77,7 @@ newtype UnreachableBlockId = UnreachableBlockId BlockId
 -- | Generate code for a list of blocks that make up a complete
 -- procedure. The first block in the list is expected to be the entry
 -- point.
-basicBlocksCodeGen :: LiveGlobalRegs -> [CmmBlock]
+basicBlocksCodeGen :: LiveGlobalRegUses -> [CmmBlock]
                       -> LlvmM ([LlvmBasicBlock], [LlvmCmmDecl])
 basicBlocksCodeGen _    []                     = panic "no entry block!"
 basicBlocksCodeGen live cmmBlocks
@@ -152,7 +153,7 @@ stmtToInstrs ubid stmt = case stmt of
 
     -- Tail call
     CmmCall { cml_target = arg,
-              cml_args_regs = live } -> genJump arg $ map globalRegUseGlobalReg live
+              cml_args_regs = live } -> genJump arg live
 
     _ -> panic "Llvm.CodeGen.stmtToInstrs"
 
@@ -1050,7 +1051,7 @@ cmmPrimOpFunctions mop = do
 
 
 -- | Tail function calls
-genJump :: CmmExpr -> [GlobalReg] -> LlvmM StmtData
+genJump :: CmmExpr -> LiveGlobalRegUses -> LlvmM StmtData
 
 -- Call to known function
 genJump (CmmLit (CmmLabel lbl)) live = do
@@ -2056,14 +2057,13 @@ getCmmReg (CmmLocal (LocalReg un _))
            -- have been assigned a value at some point, triggering
            -- "funPrologue" to allocate it on the stack.
 
-getCmmReg (CmmGlobal g)
-  = do let r = globalRegUseGlobalReg g
-       onStack  <- checkStackReg r
+getCmmReg (CmmGlobal ru@(GlobalRegUse r _))
+  = do onStack  <- checkStackReg r
        platform <- getPlatform
        if onStack
-         then return (lmGlobalRegVar platform r)
+         then return (lmGlobalRegVar platform ru)
          else pprPanic "getCmmReg: Cmm register " $
-                ppr g <> text " not stack-allocated!"
+                ppr r <> text " not stack-allocated!"
 
 -- | Return the value of a given register, as well as its type. Might
 -- need to be load from stack.
@@ -2074,7 +2074,7 @@ getCmmRegVal reg =
       onStack <- checkStackReg (globalRegUseGlobalReg g)
       platform <- getPlatform
       if onStack then loadFromStack else do
-        let r = lmGlobalRegArg platform (globalRegUseGlobalReg g)
+        let r = lmGlobalRegArg platform g
         return (r, getVarType r, nilOL)
     _ -> loadFromStack
  where loadFromStack = do
@@ -2187,8 +2187,9 @@ convertMemoryOrdering MemOrderSeqCst  = SyncSeqCst
 -- question is never written. Therefore we skip it where we can to
 -- save a few lines in the output and hopefully speed compilation up a
 -- bit.
-funPrologue :: LiveGlobalRegs -> [CmmBlock] -> LlvmM StmtData
+funPrologue :: LiveGlobalRegUses -> [CmmBlock] -> LlvmM StmtData
 funPrologue live cmmBlocks = do
+  platform <- getPlatform
 
   let getAssignedRegs :: CmmNode O O -> [CmmReg]
       getAssignedRegs (CmmAssign reg _)  = [reg]
@@ -2196,7 +2197,8 @@ funPrologue live cmmBlocks = do
       getAssignedRegs _                  = []
       getRegsBlock (_, body, _)          = concatMap getAssignedRegs $ blockToList body
       assignedRegs = nub $ concatMap (getRegsBlock . blockSplit) cmmBlocks
-      isLive r     = r `elem` alwaysLive || r `elem` live
+      mbLive r     =
+        lookupRegUse r (alwaysLive platform) <|> lookupRegUse r live
 
   platform <- getPlatform
   stmtss <- forM assignedRegs $ \reg ->
@@ -2205,12 +2207,12 @@ funPrologue live cmmBlocks = do
         let (newv, stmts) = allocReg reg
         varInsert un (pLower $ getVarType newv)
         return stmts
-      CmmGlobal (GlobalRegUse r _) -> do
-        let reg   = lmGlobalRegVar platform r
-            arg   = lmGlobalRegArg platform r
+      CmmGlobal ru@(GlobalRegUse r _) -> do
+        let reg   = lmGlobalRegVar platform ru
+            arg   = lmGlobalRegArg platform ru
             ty    = (pLower . getVarType) reg
             trash = LMLitVar $ LMUndefLit ty
-            rval  = if isLive r then arg else trash
+            rval  = if isJust (mbLive r) then arg else trash
             alloc = Assignment reg $ Alloca (pLower $ getVarType reg) 1
         markStackReg r
         return $ toOL [alloc, Store rval reg Nothing []]
@@ -2222,7 +2224,7 @@ funPrologue live cmmBlocks = do
 
 -- | Function epilogue. Load STG variables to use as argument for call.
 -- STG Liveness optimisation done here.
-funEpilogue :: LiveGlobalRegs -> LlvmM ([LlvmVar], LlvmStatements)
+funEpilogue :: LiveGlobalRegUses -> LlvmM ([LlvmVar], LlvmStatements)
 funEpilogue live = do
     platform <- getPlatform
 
@@ -2248,12 +2250,16 @@ funEpilogue live = do
     let allRegs = activeStgRegs platform
     loads <- forM allRegs $ \r -> if
       -- load live registers
-      | r `elem` alwaysLive  -> loadExpr (GlobalRegUse r (globalRegSpillType platform r))
-      | r `elem` live        -> loadExpr (GlobalRegUse r (globalRegSpillType platform r))
+      | Just ru <- lookupRegUse r (alwaysLive platform)
+      -> loadExpr ru
+      | Just ru <- lookupRegUse r live
+      -> loadExpr ru
       -- load all non Floating-Point Registers
-      | not (isFPR r)        -> loadUndef r
+      | not (isFPR r)
+      -> loadUndef (GlobalRegUse r (globalRegSpillType platform r))
       -- load padding Floating-Point Registers
-      | r `elem` paddingRegs -> loadUndef r
+      | Just ru <- lookupRegUse r paddingRegs
+      -> loadUndef ru
       | otherwise            -> return (Nothing, nilOL)
 
     let (vars, stmts) = unzip loads
@@ -2263,7 +2269,7 @@ funEpilogue live = do
 --
 -- This is for Haskell functions, function type is assumed, so doesn't work
 -- with foreign functions.
-getHsFunc :: LiveGlobalRegs -> CLabel -> LlvmM ExprData
+getHsFunc :: LiveGlobalRegUses -> CLabel -> LlvmM ExprData
 getHsFunc live lbl
   = do fty <- llvmFunTy live
        name <- strCLabel_llvm lbl


=====================================
compiler/GHC/CmmToLlvm/Ppr.hs
=====================================
@@ -49,9 +49,8 @@ pprLlvmCmmDecl (CmmData _ lmdata) = do
   return ( vcat $ map (pprLlvmData opts) lmdata
          , vcat $ map (pprLlvmData opts) lmdata)
 
-pprLlvmCmmDecl (CmmProc mb_info entry_lbl liveWithUses (ListGraph blks))
-  = do let live = map globalRegUseGlobalReg liveWithUses
-           lbl = case mb_info of
+pprLlvmCmmDecl (CmmProc mb_info entry_lbl live (ListGraph blks))
+  = do let lbl = case mb_info of
                      Nothing -> entry_lbl
                      Just (CmmStaticsRaw info_lbl _) -> info_lbl
            link = if externallyVisibleCLabel lbl


=====================================
compiler/GHC/CmmToLlvm/Regs.hs
=====================================
@@ -14,25 +14,27 @@ import GHC.Prelude
 import GHC.Llvm
 
 import GHC.Cmm.Expr
+import GHC.CmmToAsm.Format
 import GHC.Platform
 import GHC.Data.FastString
 import GHC.Utils.Panic ( panic )
 import GHC.Types.Unique
 
+
 -- | Get the LlvmVar function variable storing the real register
-lmGlobalRegVar :: Platform -> GlobalReg -> LlvmVar
+lmGlobalRegVar :: Platform -> GlobalRegUse -> LlvmVar
 lmGlobalRegVar platform = pVarLift . lmGlobalReg platform "_Var"
 
 -- | Get the LlvmVar function argument storing the real register
-lmGlobalRegArg :: Platform -> GlobalReg -> LlvmVar
+lmGlobalRegArg :: Platform -> GlobalRegUse -> LlvmVar
 lmGlobalRegArg platform = lmGlobalReg platform "_Arg"
 
 {- Need to make sure the names here can't conflict with the unique generated
    names. Uniques generated names containing only base62 chars. So using say
    the '_' char guarantees this.
 -}
-lmGlobalReg :: Platform -> String -> GlobalReg -> LlvmVar
-lmGlobalReg platform suf reg
+lmGlobalReg :: Platform -> String -> GlobalRegUse -> LlvmVar
+lmGlobalReg platform suf (GlobalRegUse reg ty)
   = case reg of
         BaseReg        -> ptrGlobal $ "Base" ++ suf
         Sp             -> ptrGlobal $ "Sp" ++ suf
@@ -88,13 +90,26 @@ lmGlobalReg platform suf reg
         ptrGlobal    name = LMNLocalVar (fsLit name) (llvmWordPtr platform)
         floatGlobal  name = LMNLocalVar (fsLit name) LMFloat
         doubleGlobal name = LMNLocalVar (fsLit name) LMDouble
-        xmmGlobal    name = LMNLocalVar (fsLit name) (LMVector 4 (LMInt 32))
-        ymmGlobal    name = LMNLocalVar (fsLit name) (LMVector 8 (LMInt 32))
-        zmmGlobal    name = LMNLocalVar (fsLit name) (LMVector 16 (LMInt 32))
+        fmt = cmmTypeFormat ty
+        xmmGlobal    name = LMNLocalVar (fsLit name) (formatLlvmType fmt)
+        ymmGlobal    name = LMNLocalVar (fsLit name) (formatLlvmType fmt)
+        zmmGlobal    name = LMNLocalVar (fsLit name) (formatLlvmType fmt)
+
+formatLlvmType :: Format -> LlvmType
+formatLlvmType II8 = LMInt 8
+formatLlvmType II16 = LMInt 16
+formatLlvmType II32 = LMInt 32
+formatLlvmType II64 = LMInt 64
+formatLlvmType FF32 = LMFloat
+formatLlvmType FF64 = LMDouble
+formatLlvmType (VecFormat l sFmt) = LMVector l (formatLlvmType $ scalarFormatFormat sFmt)
 
 -- | A list of STG Registers that should always be considered alive
-alwaysLive :: [GlobalReg]
-alwaysLive = [BaseReg, Sp, Hp, SpLim, HpLim, node]
+alwaysLive :: Platform -> [GlobalRegUse]
+alwaysLive platform =
+  [ GlobalRegUse r (globalRegSpillType platform r)
+  | r <- [BaseReg, Sp, Hp, SpLim, HpLim, node]
+  ]
 
 -- | STG Type Based Alias Analysis hierarchy
 stgTBAA :: [(Unique, LMString, Maybe Unique)]



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/b792b70cb22976009486514d578ec10de0c90b63

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/b792b70cb22976009486514d578ec10de0c90b63
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240830/ccb01eec/attachment-0001.html>


More information about the ghc-commits mailing list