[Git][ghc/ghc][wip/with2-primop] Work in progress on runRW#

Ben Gamari gitlab at gitlab.haskell.org
Thu Apr 16 01:42:48 UTC 2020



Ben Gamari pushed to branch wip/with2-primop at Glasgow Haskell Compiler / GHC


Commits:
354d630c by Simon Peyton Jones at 2020-04-15T21:42:12-04:00
Work in progress on runRW#

This is a proof of concept, in progress. It treats
    runRW# (\s. e)
specially in three ways

* In the simplifier, we transform
    K[ runRW# rr ty (\s. body) ]
    -->  runRW rr' ty' (\s. K[ body ])
  where K is a context

* In Lint, join points are allowed to occur inside the continuation.
    join j x = rhs
    in runRW# (\s. case ... of
                     A -> j 1
                     B -> ...
                     C -> J 2)
  Very much as they can occur in other join points.

* In OccurAnal, we infer join points using the same rule

We get much beter optimisation as a result.

Still not finished.  E.g. Float out may take
    runST# (\s. e)
and float that lambda out.  But really we want to keep
that lambda in runST#'s argument, otherwise things that
were join point might stop being so.

But it's a start

- - - - -


4 changed files:

- compiler/GHC/Core/Lint.hs
- compiler/GHC/Core/Op/OccurAnal.hs
- compiler/GHC/Core/Op/Simplify.hs
- libraries/integer-gmp/src/GHC/Integer/Type.hs


Changes:

=====================================
compiler/GHC/Core/Lint.hs
=====================================
@@ -678,22 +678,9 @@ lintRhs :: Id -> CoreExpr -> LintM LintedType
 --     its OccInfo and join-pointer-hood
 lintRhs bndr rhs
     | Just arity <- isJoinId_maybe bndr
-    = lint_join_lams arity arity True rhs
+    = lintJoinLams arity (Just bndr) rhs
     | AlwaysTailCalled arity <- tailCallInfo (idOccInfo bndr)
-    = lint_join_lams arity arity False rhs
-  where
-    lint_join_lams 0 _ _ rhs
-      = lintCoreExpr rhs
-
-    lint_join_lams n tot enforce (Lam var expr)
-      = lintLambda var $ lint_join_lams (n-1) tot enforce expr
-
-    lint_join_lams n tot True _other
-      = failWithL $ mkBadJoinArityMsg bndr tot (tot-n) rhs
-    lint_join_lams _ _ False rhs
-      = markAllJoinsBad $ lintCoreExpr rhs
-          -- Future join point, not yet eta-expanded
-          -- Body is not a tail position
+    = lintJoinLams arity Nothing rhs
 
 -- Allow applications of the data constructor @StaticPtr@ at the top
 -- but produce errors otherwise.
@@ -715,6 +702,18 @@ lintRhs _bndr rhs = fmap lf_check_static_ptrs getLintFlags >>= go
         binders0
     go _ = markAllJoinsBad $ lintCoreExpr rhs
 
+lintJoinLams :: JoinArity -> Maybe Id -> CoreExpr -> LintM LintedType
+lintJoinLams join_arity enforce rhs
+  = go join_arity rhs
+  where
+    go 0 rhs            = lintCoreExpr rhs
+    go n (Lam var expr) = lintLambda var $ go (n-1) expr
+    go n _other | Just bndr <- enforce -- Join point with too few RHS lambdas
+                = failWithL $ mkBadJoinArityMsg bndr join_arity n rhs
+                | otherwise -- Future join point, not yet eta-expanded
+                = markAllJoinsBad $ lintCoreExpr rhs
+                  -- Body of lambda is not a tail position
+
 lintIdUnfolding :: Id -> Type -> Unfolding -> LintM ()
 lintIdUnfolding bndr bndr_ty uf
   | isStableUnfolding uf
@@ -854,6 +853,15 @@ lintCoreExpr e@(Let (Rec pairs) body)
     bndrs = map fst pairs
 
 lintCoreExpr e@(App _ _)
+  | Var fun <- fun
+  , fun `hasKey` runRWKey
+  , [arg_ty1, arg_ty2, arg3] <- args
+  = do { fun_ty1 <- lintCoreArg (idType fun) arg_ty1
+       ; fun_ty2 <- lintCoreArg fun_ty1      arg_ty2
+       ; arg3_ty <- lintJoinLams 1 (Just fun) arg3
+       ; lintValApp arg3 fun_ty2 arg3_ty }
+
+  | otherwise
   = do { fun_ty <- lintCoreFun fun (length args)
        ; lintCoreArgs fun_ty args }
   where
@@ -2751,11 +2759,11 @@ mkInvalidJoinPointMsg var ty
         2 (ppr var <+> dcolon <+> ppr ty)
 
 mkBadJoinArityMsg :: Var -> Int -> Int -> CoreExpr -> SDoc
-mkBadJoinArityMsg var ar nlams rhs
+mkBadJoinArityMsg var ar n rhs
   = vcat [ text "Join point has too few lambdas",
            text "Join var:" <+> ppr var,
            text "Join arity:" <+> ppr ar,
-           text "Number of lambdas:" <+> ppr nlams,
+           text "Number of lambdas:" <+> ppr (ar - n),
            text "Rhs = " <+> ppr rhs
            ]
 


=====================================
compiler/GHC/Core/Op/OccurAnal.hs
=====================================
@@ -39,6 +39,7 @@ import GHC.Types.Demand ( argOneShots, argsOneShots )
 import Digraph          ( SCC(..), Node(..)
                         , stronglyConnCompFromEdgedVerticesUniq
                         , stronglyConnCompFromEdgedVerticesUniqR )
+import PrelNames( runRWKey )
 import GHC.Types.Unique
 import GHC.Types.Unique.FM
 import GHC.Types.Unique.Set
@@ -1880,8 +1881,12 @@ occAnalApp :: OccEnv
            -> (UsageDetails, Expr CoreBndr)
 -- Naked variables (not applied) end up here too
 occAnalApp env (Var fun, args, ticks)
-  | null ticks = (all_uds, mkApps fun' args')
-  | otherwise  = (all_uds, mkTicks ticks $ mkApps fun' args')
+  | fun `hasKey` runRWKey
+  , [t1, t2, arg]  <- args
+  , let (usage, arg') = occAnalRhs env (Just 1) arg
+  = (usage, mkTicks ticks $ mkApps (Var fun) [t1, t2, arg'])
+  | otherwise
+  = (all_uds, mkTicks ticks $ mkApps fun' args')
   where
     (fun', fun_id') = lookupVarEnv (occ_bs_env env) fun
                       `orElse` (Var fun, fun)


=====================================
compiler/GHC/Core/Op/Simplify.hs
=====================================
@@ -37,10 +37,12 @@ import GHC.Core.DataCon
    , StrictnessMark (..) )
 import GHC.Core.Op.Monad ( Tick(..), SimplMode(..) )
 import GHC.Core
+import PrelNames        ( runRWKey )
 import GHC.Types.Demand ( StrictSig(..), dmdTypeDepth, isStrictDmd
                         , mkClosedStrictSig, topDmd, botDiv )
 import GHC.Types.Cpr    ( mkCprSig, botCpr )
 import GHC.Core.Ppr     ( pprCoreExpr )
+import GHC.Types.Unique ( hasKey )
 import GHC.Core.Unfold
 import GHC.Core.Utils
 import GHC.Core.SimpleOpt ( pushCoTyArg, pushCoValArg
@@ -1852,6 +1854,20 @@ rebuildCall env (ArgInfo { ai_fun = fun, ai_args = rev_args, ai_strs = [] }) con
     res     = argInfoExpr fun rev_args
     cont_ty = contResultType cont
 
+-- runRW# :: forall (r :: RuntimeRep) (o :: TYPE r). (State# RealWorld -> o) -> o
+-- K[ runRW# rr ty (\s. body) ]  -->  runRW rr' ty' (\s. K[ body ])
+rebuildCall env (ArgInfo { ai_fun = fun, ai_args = rev_args }) cont
+  | fun `hasKey` runRWKey
+  , [ ValArg (Lam s body)
+    , TyArg {}, TyArg {} ] <- rev_args
+  = do { (env', s') <- simplLamBndr (zapSubstEnv env) s
+       ; body' <- simplExprC env' body cont
+       ; let arg'  = Lam s' body'
+             ty'   = contResultType cont
+             rr'   = getRuntimeRep ty'
+             call' = mkApps (Var fun) [mkTyArg rr', mkTyArg ty', arg']
+       ; return (emptyFloats env, call') }
+
 ---------- Try rewrite RULES --------------
 -- See Note [Trying rewrite rules]
 rebuildCall env info@(ArgInfo { ai_fun = fun, ai_args = rev_args


=====================================
libraries/integer-gmp/src/GHC/Integer/Type.hs
=====================================
@@ -2110,7 +2110,7 @@ liftIO (IO m) = m
 
 -- NB: equivalent of GHC.IO.unsafeDupablePerformIO, see notes there
 runS :: S RealWorld a -> a
-runS m = case runRW# m of (# _, a #) -> a
+runS m = case runRW# (\s -> m s) of (# _, a #) -> a
 
 -- stupid hack
 fail :: [Char] -> S s a



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/354d630c8258e4f81058962a5399036014a5cbd7

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/354d630c8258e4f81058962a5399036014a5cbd7
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20200415/b322ff66/attachment-0001.html>


More information about the ghc-commits mailing list