[Git][ghc/ghc][wip/T14620] WIP: Fix #14620 by introducing WW to detect more join points

Sebastian Graf gitlab at gitlab.haskell.org
Wed Sep 23 11:41:07 UTC 2020



Sebastian Graf pushed to branch wip/T14620 at Glasgow Haskell Compiler / GHC


Commits:
e6ea5c5a by Sebastian Graf at 2020-09-23T13:40:57+02:00
WIP: Fix #14620 by introducing WW to detect more join points

- - - - -


8 changed files:

- compiler/GHC/Core.hs
- compiler/GHC/Core/Opt/OccurAnal.hs
- compiler/GHC/Core/Opt/Simplify.hs
- compiler/GHC/Core/SimpleOpt.hs
- compiler/GHC/Core/TyCo/Subst.hs
- compiler/GHC/Core/Type.hs
- compiler/GHC/Core/Utils.hs
- compiler/GHC/Driver/Ppr.hs


Changes:

=====================================
compiler/GHC/Core.hs
=====================================
@@ -831,10 +831,14 @@ Now we can move the case inward and we only have to change the jump:
 (Core Lint would still check that the body of the join point has the right type;
 that type would simply not be reflected in the join id.)
 
-Note [The polymorphism rule of join points]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-Invariant 4 of Note [Invariants on join points] forbids a join point to be
-polymorphic in its return type. That is, if its type is
+Historic Note [The polymorphism rule of join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Before we had Note [Join point worker/wrapper], we used to have as Invariant 4
+of Note [Invariants on join points]
+
+  4. The binding's type must not be polymorphic in its return type.
+
+That is, if its type is
 
   forall a1 ... ak. t1 -> ... -> tn -> r
 


=====================================
compiler/GHC/Core/Opt/OccurAnal.hs
=====================================
@@ -2984,9 +2984,6 @@ decideJoinPointHood NotTopLevel usage bndrs
         -- Invariant 2a: stable unfoldings
         -- See Note [Join points and INLINE pragmas]
       , ok_unfolding arity (realIdUnfolding bndr)
-
-        -- Invariant 4: Satisfies polymorphism rule
-      , isValidJoinPointType arity (idType bndr)
       = True
 
       | otherwise


=====================================
compiler/GHC/Core/Opt/Simplify.hs
=====================================
@@ -53,7 +53,7 @@ import GHC.Core.Utils
 import GHC.Core.Opt.Arity ( ArityType(..), arityTypeArity, isBotArityType
                           , pushCoTyArg, pushCoValArg
                           , idArityType, etaExpandAT )
-import GHC.Core.SimpleOpt ( joinPointBinding_maybe, joinPointBindings_maybe )
+import GHC.Core.SimpleOpt ( tryJoinPointWW, tryJoinPointWWs )
 import GHC.Core.FVs     ( mkRuleInfo )
 import GHC.Core.Rules   ( lookupRule, getRules, initRuleOpts )
 import GHC.Types.Basic
@@ -1052,8 +1052,9 @@ simplExprF1 env (Case scrut bndr _ alts) cont
                                  , sc_env = env, sc_cont = cont })
 
 simplExprF1 env (Let (Rec pairs) body) cont
-  | Just pairs' <- joinPointBindings_maybe pairs
-  = {-#SCC "simplRecJoinPoin" #-} simplRecJoinPoint env pairs' body cont
+  | Just (pairs', wrappers) <- tryJoinPointWWs (getInScope env) (exprType body) pairs
+  -- , null wrappers ||  pprTrace "simple join Rec" (ppr pairs'<+> ppr (exprType body)) True
+  = {-#SCC "simplRecJoinPoin" #-} simplRecJoinPoint env pairs' wrappers body cont
 
   | otherwise
   = {-#SCC "simplRecE" #-} simplRecE env pairs body cont
@@ -1065,8 +1066,9 @@ simplExprF1 env (Let (NonRec bndr rhs) body) cont
     do { ty' <- simplType env ty
        ; simplExprF (extendTvSubst env bndr ty') body cont }
 
-  | Just (bndr', rhs') <- joinPointBinding_maybe bndr rhs
-  = {-#SCC "simplNonRecJoinPoint" #-} simplNonRecJoinPoint env bndr' rhs' body cont
+  | Just (bndr', rhs', wrappers) <- tryJoinPointWW (getInScope env) (exprType body) bndr rhs
+  -- , null wrappers || pprTrace "simple join NonRec" (ppr bndr' $$ ppr (idType bndr') $$ ppr (exprType body) $$ ppr (isJoinId bndr) $$ ppr wrappers) True
+  = {-#SCC "simplNonRecJoinPoint" #-} simplNonRecJoinPoint env bndr' rhs' wrappers body cont
 
   | otherwise
   = {-#SCC "simplNonRecE" #-} simplNonRecE env bndr (rhs, env) ([], body) cont
@@ -1684,33 +1686,41 @@ type MaybeJoinCont = Maybe SimplCont
   -- Just k  => This is a join binding with continuation k
   -- See Note [Rules and unfolding for join points]
 
-simplNonRecJoinPoint :: SimplEnv -> InId -> InExpr
-                     -> InExpr -> SimplCont
-                     -> SimplM (SimplFloats, OutExpr)
-simplNonRecJoinPoint env bndr rhs body cont
+preInlineJoinWrappers :: SimplEnv -> [(InId, InExpr)] -> SimplEnv
+preInlineJoinWrappers env binds
+  = foldl' (\env (b,r) -> extendIdSubst env b (mkContEx env r)) env' binds
+  where
+    env' = addNewInScopeIds env (map fst binds)
+
+simplNonRecJoinPoint
+  :: SimplEnv -> InId -> InExpr -> [(InId, InExpr)] -> InExpr -> SimplCont
+  -> SimplM (SimplFloats, OutExpr)
+simplNonRecJoinPoint env bndr rhs wrappers body cont
   | ASSERT( isJoinId bndr ) True
-  , Just env' <- preInlineUnconditionally env NotTopLevel bndr rhs env
+  , ASSERT( wrappers `lengthAtMost` 1 ) True
+  , Just env1 <- preInlineUnconditionally env NotTopLevel bndr rhs env
   = do { tick (PreInlineUnconditionally bndr)
-       ; simplExprF env' body cont }
-
-   | otherwise
-   = wrapJoinCont env cont $ \ env cont ->
-     do { -- We push join_cont into the join RHS and the body;
-          -- and wrap wrap_cont around the whole thing
-        ; let mult   = contHoleScaling cont
-              res_ty = contResultType cont
-        ; (env1, bndr1)    <- simplNonRecJoinBndr env bndr mult res_ty
-        ; (env2, bndr2)    <- addBndrRules env1 bndr bndr1 (Just cont)
-        ; (floats1, env3)  <- simplJoinBind env2 cont bndr bndr2 rhs env
-        ; (floats2, body') <- simplExprF env3 body cont
-        ; return (floats1 `addFloats` floats2, body') }
+       ; let env2 = preInlineJoinWrappers env1 wrappers
+       ; simplExprF env2 body cont }
 
+  | otherwise
+  = wrapJoinCont env cont $ \ env cont ->
+    do { -- We push join_cont into the join RHS and the body;
+         -- and wrap wrap_cont around the whole thing
+       ; let mult   = contHoleScaling cont
+             res_ty = contResultType cont
+       ; (env1, bndr1)    <- simplNonRecJoinBndr env bndr mult res_ty
+       ; (env2, bndr2)    <- addBndrRules env1 bndr bndr1 (Just cont)
+       ; (floats1, env3)  <- simplJoinBind env2 cont bndr bndr2 rhs env
+       ; let env4   = preInlineJoinWrappers env3 wrappers
+       ; (floats2, body') <- simplExprF env4 body cont
+       ; return (floats1 `addFloats` floats2, body') }
 
 ------------------
-simplRecJoinPoint :: SimplEnv -> [(InId, InExpr)]
-                  -> InExpr -> SimplCont
-                  -> SimplM (SimplFloats, OutExpr)
-simplRecJoinPoint env pairs body cont
+simplRecJoinPoint
+  :: SimplEnv -> [(InId, InExpr)] -> [(InId, InExpr)] -> InExpr -> SimplCont
+  -> SimplM (SimplFloats, OutExpr)
+simplRecJoinPoint env pairs wrappers body cont
   = wrapJoinCont env cont $ \ env cont ->
     do { let bndrs  = map fst pairs
              mult   = contHoleScaling cont
@@ -1718,8 +1728,9 @@ simplRecJoinPoint env pairs body cont
        ; env1 <- simplRecJoinBndrs env bndrs mult res_ty
                -- NB: bndrs' don't have unfoldings or rules
                -- We add them as we go down
-       ; (floats1, env2)  <- simplRecBind env1 NotTopLevel (Just cont) pairs
-       ; (floats2, body') <- simplExprF env2 body cont
+       ; let env2   = preInlineJoinWrappers env1 wrappers
+       ; (floats1, env3)  <- simplRecBind env2 NotTopLevel (Just cont) pairs
+       ; (floats2, body') <- simplExprF env3 body cont
        ; return (floats1 `addFloats` floats2, body') }
 
 --------------------


=====================================
compiler/GHC/Core/SimpleOpt.hs
=====================================
@@ -5,6 +5,7 @@
 
 {-# LANGUAGE CPP #-}
 {-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE ViewPatterns #-}
 
 module GHC.Core.SimpleOpt (
         SimpleOpts (..), defaultSimpleOpts,
@@ -13,7 +14,7 @@ module GHC.Core.SimpleOpt (
         simpleOptPgm, simpleOptExpr, simpleOptExprWith,
 
         -- ** Join points
-        joinPointBinding_maybe, joinPointBindings_maybe,
+        tryJoinPointWW, tryJoinPointWWs,
 
         -- ** Predicates on expressions
         exprIsConApp_maybe, exprIsLiteral_maybe, exprIsLambda_maybe,
@@ -36,7 +37,7 @@ import GHC.Core.Opt.OccurAnal( occurAnalyseExpr, occurAnalysePgm )
 import GHC.Types.Literal
 import GHC.Types.Id
 import GHC.Types.Id.Info  ( unfoldingInfo, setUnfoldingInfo, setRuleInfo, IdInfo (..) )
-import GHC.Types.Var      ( isNonCoVarId )
+import GHC.Types.Var      ( isNonCoVarId, setVarType, VarBndr (..) )
 import GHC.Types.Var.Set
 import GHC.Types.Var.Env
 import GHC.Core.DataCon
@@ -44,7 +45,11 @@ import GHC.Types.Demand( etaConvertStrictSig )
 import GHC.Core.Coercion.Opt ( optCoercion, OptCoercionOpts (..) )
 import GHC.Core.Type hiding ( substTy, extendTvSubst, extendCvSubst, extendTvSubstList
                             , isInScope, substTyVarBndr, cloneTyVarBndr )
+import qualified GHC.Core.Type as Type
 import GHC.Core.Coercion hiding ( substCo, substCoVarBndr )
+import GHC.Core.TyCo.Rep ( TyCoBinder (..) )
+import GHC.Core.Multiplicity
+import GHC.Core.Unify ( tcMatchTy )
 import GHC.Builtin.Types
 import GHC.Builtin.Names
 import GHC.Types.Basic
@@ -52,8 +57,10 @@ import GHC.Unit.Module ( Module )
 import GHC.Utils.Outputable
 import GHC.Utils.Panic
 import GHC.Utils.Misc
-import GHC.Data.Maybe       ( orElse )
+import GHC.Utils.Monad      ( mapAccumLM )
+import GHC.Data.Maybe
 import GHC.Data.FastString
+import Data.Bifunctor ( first )
 import Data.List
 import qualified Data.ByteString as BS
 
@@ -172,14 +179,14 @@ simpleOptPgm opts this_mod binds rules =
              -- hence paying just a substitution
 
     do_one (env, binds') bind
-      = case simple_opt_bind env bind TopLevel of
+      = case simple_opt_bind env bind of
           (env', Nothing)    -> (env', binds')
           (env', Just bind') -> (env', bind':binds')
 
 -- In these functions the substitution maps InVar -> OutExpr
 
 ----------------------
-type SimpleClo = (SimpleOptEnv, InExpr)
+type SimpleClo = (SimpleOptEnv, InExpr) -- Like SimplSR's ContEx
 
 data SimpleOptEnv
   = SOE { soe_co_opt_opts :: !OptCoercionOpts
@@ -191,7 +198,8 @@ data SimpleOptEnv
         , soe_inl   :: IdEnv SimpleClo
              -- ^ Deals with preInlineUnconditionally; things
              -- that occur exactly once and are inlined
-             -- without having first been simplified
+             -- without having first been simplified or
+             -- substituted, thus the domain is InBndrs
 
         , soe_subst :: Subst
              -- ^ Deals with cloning; includes the InScopeSet
@@ -247,9 +255,10 @@ simple_opt_expr env expr
     go (Lit lit)        = Lit lit
     go (Tick tickish e) = mkTick (substTickish subst tickish) (go e)
     go (Cast e co)      = mk_cast (go e) (go_co co)
-    go (Let bind body)  = case simple_opt_bind env bind NotTopLevel of
-                             (env', Nothing)   -> simple_opt_expr env' body
-                             (env', Just bind) -> Let bind (simple_opt_expr env' body)
+    go (Let bind body)  =
+      case simple_opt_local_bind env (exprType body) bind of
+        (env', Nothing)   -> simple_opt_expr env' body
+        (env', Just bind) -> Let bind (simple_opt_expr env' body)
 
     go lam@(Lam {})     = go_lam env [] lam
     go (Case e b ty as)
@@ -351,7 +360,7 @@ simple_app env (Tick t e) as
 -- However, do /not/ do this transformation for join points
 --    See Note [simple_app and join points]
 simple_app env (Let bind body) args
-  = case simple_opt_bind env bind NotTopLevel of
+  = case simple_opt_local_bind env (exprType body) bind of
       (env', Nothing)   -> simple_app env' body args
       (env', Just bind')
         | isJoinBind bind' -> finish_app env expr' args
@@ -369,29 +378,87 @@ finish_app env fun (arg:args)
   = finish_app env (App fun (simple_opt_clo env arg)) args
 
 ----------------------
-simple_opt_bind :: SimpleOptEnv -> InBind -> TopLevelFlag
-                -> (SimpleOptEnv, Maybe OutBind)
-simple_opt_bind env (NonRec b r) top_level
-  = (env', case mb_pr of
-            Nothing    -> Nothing
-            Just (b,r) -> Just (NonRec b r))
+extendInlEnv :: SimpleOptEnv -> InBndr -> SimpleClo -> SimpleOptEnv
+-- Like GHC.Core.Opt.Simplify.Env.extendIdSubst
+extendInlEnv env@(SOE { soe_inl = inl_env }) bndr clo
+  = ASSERT2( isId bndr && not (isCoVar bndr), ppr bndr )
+    env { soe_inl = extendVarEnv inl_env bndr clo }
+
+extendInScopeEnv :: SimpleOptEnv -> [InBndr] -> SimpleOptEnv
+extendInScopeEnv env@(SOE { soe_subst = Subst in_scope ids tvs cos }) bndrs
+  = env { soe_subst = Subst (extendInScopeSetList in_scope bndrs) ids tvs cos }
+
+tryJoinPointWWs :: InScopeSet -> Type -> [(InBndr, InExpr)] -> Maybe ([(InBndr, InExpr)], [(InBndr, InExpr)])
+tryJoinPointWWs in_scope body_ty binds
+  = foldMap go <$> joinPointBindings_maybe in_scope body_ty binds
   where
-    (b', r') = joinPointBinding_maybe b r `orElse` (b, r)
-    (env', mb_pr) = simple_bind_pair env b' Nothing (env,r') top_level
+    go jph = ([(join_bndr jph, join_rhs jph)], join_wrapper jph)
+    join_wrapper jph at JoinPointAfterMono{} -- Rare:   A join point after we inline a wrapper
+      = [(join_wrapper_bndr jph, join_wrapper_body jph)]
+    join_wrapper DefinitelyJoinPoint{}    -- Common: Regular join point. No wrapper
+      = []
+
+tryJoinPointWW :: InScopeSet -> Type -> InBndr -> InExpr -> Maybe (InBndr, InExpr, [(InBndr, InExpr)])
+tryJoinPointWW in_scope body_ty b r
+  | Just ([(b', r')], wrappers) <- tryJoinPointWWs in_scope body_ty [(b, r)]
+  = ASSERT( wrappers `lengthAtMost` 1 )
+    Just (b', r', wrappers)
+  | otherwise
+  = Nothing
 
-simple_opt_bind env (Rec prs) top_level
-  = (env'', res_bind)
+pair_to_non_rec
+  :: (SimpleOptEnv, Maybe (OutBndr, OutExpr))
+  -> (SimpleOptEnv, Maybe OutBind)
+pair_to_non_rec (env, mb_pr) = (env, uncurry NonRec <$> mb_pr)
+
+simple_opt_local_bind
+  :: SimpleOptEnv -> Type -> InBind -> (SimpleOptEnv, Maybe OutBind)
+simple_opt_local_bind env body_ty (NonRec b r)
+  | (b', r', wrappers) <- tryJoinPointWW (substInScope (soe_subst env) `extendInScopeSet` b) body_ty b r `orElse` (b, r, [])
+  -- , null wrappers || pprTrace "simple_opt_local_bind:join" (ppr b <+> ppr (idType b) <+> ppr body_ty) True
+  = -- pprTraceWith "simple_opt_local_bind" (\(env', mb_bind) -> ppr b <+> (case mb_bind of Nothing -> text "inlined" $$ ppr env'; Just _ -> text "not inlined")) $
+    first (pre_inline_join_wrappers wrappers)
+  $ pair_to_non_rec
+  $ simple_bind_pair env b' Nothing (env,r') NotTopLevel
+
+simple_opt_local_bind env body_ty (Rec prs)
+  --- | null wrappers || pprTrace "simple_opt_local_bind:joinrec" (ppr prs <+> ppr body_ty) True
+  = (env3, res_bind)
   where
-    res_bind          = Just (Rec (reverse rev_prs'))
-    prs'              = joinPointBindings_maybe prs `orElse` prs
-    (env', bndrs')    = subst_opt_bndrs env (map fst prs')
-    (env'', rev_prs') = foldl' do_pr (env', []) (prs' `zip` bndrs')
-    do_pr (env, prs) ((b,r), b')
+    res_bind         = Just (Rec (reverse rev_prs'))
+    (prs', wrappers) = tryJoinPointWWs (substInScope (soe_subst env)) body_ty prs `orElse` (prs, [])
+    (env1, bndrs')   = subst_opt_bndrs env (map fst prs')
+    env2             = pre_inline_join_wrappers wrappers env1
+    (env3, rev_prs') = foldl' simpl_pr (env2, []) (prs' `zip` bndrs')
+    simpl_pr (env, prs) ((b,r), b')
        = (env', case mb_pr of
                   Just pr -> pr : prs
                   Nothing -> prs)
        where
-         (env', mb_pr) = simple_bind_pair env b (Just b') (env,r) top_level
+         (env', mb_pr) = simple_bind_pair env b (Just b') (env,r) NotTopLevel
+
+pre_inline_join_wrappers :: [(InBndr, InExpr)] -> SimpleOptEnv -> SimpleOptEnv
+pre_inline_join_wrappers binds env
+  = foldl' (\env (b,r) -> extendInlEnv env b (env, r)) env' binds
+  where
+    env' = extendInScopeEnv env (map fst binds)
+
+simple_opt_bind :: SimpleOptEnv -> InBind -> (SimpleOptEnv, Maybe OutBind)
+simple_opt_bind env (NonRec b r)
+  = pair_to_non_rec (simple_bind_pair env b Nothing (env,r) TopLevel)
+
+simple_opt_bind env (Rec prs)
+  = (env'', res_bind)
+  where
+    res_bind          = Just (Rec (reverse rev_prs'))
+    (env', bndrs')    = subst_opt_bndrs env (map fst prs)
+    (env'', rev_prs') = foldl' simpl_pr (env', []) (prs `zip` bndrs')
+    simpl_pr (env, prs) ((b,r), b')
+      = (env', case mb_pr of
+                 Just pr -> pr : prs
+                 Nothing -> prs)
+      where
+        (env', mb_pr) = simple_bind_pair env b (Just b') (env,r) TopLevel
 
 ----------------------
 simple_bind_pair :: SimpleOptEnv
@@ -796,39 +863,261 @@ A more common case is when
 and again its arity increases (#15517)
 -}
 
-
--- | Returns Just (bndr,rhs) if the binding is a join point:
--- If it's a JoinId, just return it
+-- | Indicates that a binding can be transformed into a join point.
+data JoinPointHood
+  = DefinitelyJoinPoint -- ^ A join point by nature
+    { join_bndr :: !InBndr
+    , join_rhs :: !InExpr }
+  | JoinPointAfterMono
+  -- ^ A join point after we have instantiated the forall binders occuring in
+  -- the result type. See Note [Join point worker/wrapper].
+    { join_bndr :: !InBndr
+    , join_rhs :: !InExpr
+    , join_wrapper_bndr :: !InBndr
+    -- ^ the bndr of a wrapper that needs to be inlined unconditionally
+    , join_wrapper_body :: !InExpr }
+
+-- | An element of the result list of 'matchJoinResTy'.
+-- Corresponds to a join binder of what is going to be the new join point.
+-- See Note [Join point worker/wrapper].
+data JoinWorkerBinder
+  = InstBinder !Type
+  -- ^ A binder that could be instantiated to the given type by matching against
+  -- the res ty. The corresponding binder will be dropped for the new join point
+  | SubstBinder !TyCoBinder
+  -- ^ A join binder that was not instantiated by matching against the res ty.
+  -- But since other join binders might have been instantiated, the binder's
+  -- type might have changed.
+
+instance Outputable JoinWorkerBinder where
+  ppr (InstBinder ty)    = text "Inst" <+> ppr ty
+  ppr (SubstBinder bndr) = text "Subst" <+> ppr bndr
+
+isSubstBinder :: JoinWorkerBinder -> Bool
+isSubstBinder SubstBinder{} = True
+isSubstBinder _             = False
+
+-- | Returns Just jph if the binding is a join point:
+-- If it's a JoinId, just return @DefinitelyJoinPoint bndr rhs at .
 -- If it's not yet a JoinId but is always tail-called,
 --    make it into a JoinId and return it.
 -- In the latter case, eta-expand the RHS if necessary, to make the
--- lambdas explicit, as is required for join points
+-- lambdas explicit, as is required for join points.
+-- If the join point is not result type polymorphic, return
+-- @DefinitelyJoinPoint bndr rhs at .
+-- If the join point is result type polymorphic, monomorphise it first,
+-- returning @JoinPointAfterMono bndr rhs worker_bndr worker_rhs at .
+-- Call sites then have to unconditionally inline the @bndr@/@rhs at .
+-- See Note [Join point worker/wrapper].
 --
 -- Precondition: the InBndr has been occurrence-analysed,
 --               so its OccInfo is valid
-joinPointBinding_maybe :: InBndr -> InExpr -> Maybe (InBndr, InExpr)
-joinPointBinding_maybe bndr rhs
-  | not (isId bndr)
-  = Nothing
-
-  | isJoinId bndr
-  = Just (bndr, rhs)
+joinPointBindings_maybe :: InScopeSet -> Type -> [(InBndr, InExpr)] -> Maybe [JoinPointHood]
+-- See Note [Join point worker/wrapper].
+joinPointBindings_maybe in_scope body_type binds
+  = snd <$> mapAccumLM go (extendInScopeSetList in_scope (map fst binds)) binds
+  where
+    go :: InScopeSet -> (InBndr, InExpr) -> Maybe (InScopeSet, JoinPointHood)
+    go in_scope (bndr, rhs)
+      | not (isId bndr)
+      = Nothing
+
+      | isJoinId bndr
+      = Just (in_scope, DefinitelyJoinPoint bndr rhs)
+
+      | AlwaysTailCalled join_arity <- tailCallInfo (idOccInfo bndr)
+      , not (exprIsTrivial rhs)
+      , (lam_bndrs, rhs') <- etaExpandToJoinPoint join_arity rhs
+      , let eta_rhs'       = mkLams lam_bndrs rhs'
+      , let inst_tys       = matchJoinResTy join_arity (idType bndr) body_type
+      , let new_join_arity = count isSubstBinder inst_tys
+      , let no_mono        = new_join_arity == join_arity
+      , let worker_body    = mk_worker_body lam_bndrs inst_tys eta_rhs'
+      -- we need an in-scope set as if the worker was defined inside the RHS of the wrapper (as is the case with SAT)
+      , let in_scope'      = extendInScopeSetList in_scope lam_bndrs
+      , let new_bndr       = uniqAway in_scope' bndr -- only used in else branch
+                                `setIdType` exprType worker_body
+      , let wrapper_body   = mk_wrapper_body new_bndr lam_bndrs inst_tys
+      , let wrapper_bndr   = bndr
+      -- , no_mono || pprTrace "always tail called:" (vcat [ppr in_scope', ppr bndr, ppr (idType bndr), ppr body_type, ppr rhs, ppr new_bndr, ppr (exprType worker_body), ppr join_arity, ppr inst_tys, ppr new_bndr, ppr wrapper_body, ppr worker_body]) True
+      = Just $! if no_mono
+          then ( in_scope
+               , DefinitelyJoinPoint
+                   { join_bndr = adjust_id_info bndr lam_bndrs join_arity
+                   , join_rhs = eta_rhs' } )
+          else ( extendInScopeSet in_scope new_bndr
+               , JoinPointAfterMono
+                   { join_bndr = adjust_id_info new_bndr lam_bndrs new_join_arity
+                   , join_rhs = worker_body
+                   , join_wrapper_bndr = wrapper_bndr
+                   , join_wrapper_body = wrapper_body } )
 
-  | AlwaysTailCalled join_arity <- tailCallInfo (idOccInfo bndr)
-  , (bndrs, body) <- etaExpandToJoinPoint join_arity rhs
-  , let str_sig   = idStrictness bndr
-        str_arity = count isId bndrs  -- Strictness demands are for Ids only
-        join_bndr = bndr `asJoinId`        join_arity
-                         `setIdStrictness` etaConvertStrictSig str_arity str_sig
-  = Just (join_bndr, mkLams bndrs body)
+      | otherwise
+      = Nothing
+
+    adjust_id_info :: InBndr -> [InBndr] -> JoinArity -> InBndr
+    adjust_id_info bndr lam_bndrs join_arity = zapStableUnfolding $ -- TODO: Discuss! Type errors otherwise.
+      let str_sig   = idStrictness bndr
+          str_arity = count isId lam_bndrs  -- Strictness demands are for Ids only
+      in bndr `asJoinId`        join_arity
+              `setIdStrictness` etaConvertStrictSig str_arity str_sig
+
+    mk_wrapper_body :: InBndr -> [InBndr] -> [JoinWorkerBinder] -> InExpr
+    -- See Note [Join point worker/wrapper].
+    mk_wrapper_body new_bndr lam_bndrs inst_tys
+      = ASSERT( lam_bndrs `equalLength` inst_tys )
+        -- pprTraceWith "mk_wrapper_body" (\e -> ppr lam_bndrs $$ ppr inst_tys $$ ppr e) $
+        go (Var new_bndr) $ zipEqual "mk_wrapper_body" lam_bndrs inst_tys
+      where
+        go e [] = e
+        go e ((lb,SubstBinder{}):prs) -- non-instantiated parameter
+          | isId lb                     --            value paramater xs
+          = Lam lb (go (App e (Var lb)) prs)
+          | otherwise                   --            type  paramater @b
+          = Lam lb (go (App e (Type (mkTyVarTy lb))) prs)
+        go e ((lb,InstBinder{}):prs)     --     instantiated parameter, @a or @c
+          = ASSERT( isTyVar lb )
+            Lam lb (go e prs)
+
+    mk_worker_body :: [InBndr] -> [JoinWorkerBinder] -> InExpr -> InExpr
+    -- See Note [Join point worker/wrapper].
+    mk_worker_body lam_bndrs inst_tys rhs
+      = -- pprTraceWith "mk_worker_body" (\e -> ppr e) $
+        go rhs $ zipEqual "mk_worker_body" lam_bndrs inst_tys
+      where
+        go e [] = e
+        go e ((lb,SubstBinder bndr):prs)     -- non-instantiated parameter
+          | Anon _ (Scaled _ ty) <- bndr     --            value paramater xs
+          , let lb' = lb `setIdType` ty
+          = Lam lb' (go (App e (Var lb')) prs)
+          | Named (binderVar -> tcv) <- bndr --            type  paramater @b
+          = Lam tcv (go (App e (Type (mkTyVarTy tcv))) prs)
+        go e ((_ ,InstBinder ty):prs)         --     instantiated paramater, @a or @c
+          = go (App e (Type ty)) prs
+
+-- | Figures out how to monomorphise the result type of a join point.
+--
+-- @matchJoinResTy ja join_ty body_ty@ computes the result type of @join_ty@ by
+-- skipping @ja@ binders and then matches it against @body_ty at .
+-- If a forall binder @a@ is mentioned in the resulting substitution @subst@,
+-- the corresponding entry in the returned list is @InstBinder (subst a)@.
+-- See Note [Join point worker/wrapper].
+--
+-- Postcondition: The returned list has length @ja at .
+matchJoinResTy
+  :: JoinArity          -- ^ Number of binders to skip
+  -> Type               -- ^ Type of the join point
+  -> Type               -- ^ Type of the join body
+  -> [JoinWorkerBinder] -- ^ An entry for each join binder,
+                        -- InstBinder ty <=> instantiates corresponding forall to ty
+matchJoinResTy orig_ar orig_ty body_ty = snd (go init_in_scope orig_ar orig_ty)
+  where
+    init_in_scope = mkInScopeSet $ tyCoVarsOfType body_ty `unionVarSet` tyCoVarsOfType orig_ty
 
-  | otherwise
-  = Nothing
+    go :: InScopeSet -> Int -> Type -> (TCvSubst, [JoinWorkerBinder])
+    go in_scope 0 res_ty = (TCvSubst in_scope tvs cvs, [])
+      where
+        TCvSubst _ tvs cvs = expectJust "matchJoinResTy" $ tcMatchTy res_ty body_ty
+
+    go in_scope n ty
+      | Just (arg_bndr, res_ty) <- splitPiTy_maybe ty
+      = case arg_bndr of
+          Anon f (Scaled m ty)
+            | (subst, inst_tys) <- go in_scope (n-1) res_ty
+            -> (subst, SubstBinder (Anon f (Scaled m (Type.substTy subst ty))):inst_tys)
+          Named (Bndr tcv vis)
+            | isTyVar tcv', Just ty <- lookupTyVar subst tcv'
+            -> (subst', InstBinder ty: inst_tys)
+            | otherwise
+            -> (subst', SubstBinder (Named (Bndr subst_tcv vis)) : inst_tys)
+            where
+              tcv'      = uniqAway in_scope tcv
+              in_scope' = extendInScopeSet in_scope tcv'
+              (subst, inst_tys) = go in_scope' (n-1) res_ty
+              subst'    = delTCvSubst subst tcv'
+              subst_tcv = tcv' `setVarType` Type.substTy subst' (varType tcv')
 
-joinPointBindings_maybe :: [(InBndr, InExpr)] -> Maybe [(InBndr, InExpr)]
-joinPointBindings_maybe bndrs
-  = mapM (uncurry joinPointBinding_maybe) bndrs
+    go _ _ _ = pprPanic "matchJoinResTy" (ppr orig_ar <+> ppr orig_ty)
 
+{- Note [Join point worker/wrapper]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Some let bindings that are 'AlwaysTailCalled' still need a bit of work to become
+a join point. Here's an example:
+
+  let f  :: forall a b. [a] -> forall c. b -> Maybe c -> [(a,c)]
+      f  = <rhs>
+  in (<body> :: [(Bool, Char)])
+
+Suppose @f@ is 'AlwaysTailCalled' in <body>. Its result type is polymorphic,
+which is exactly the situation as in
+Historic Note [The polymorphism rule of join points].
+That Note also explains why we can't just turn @f@ into join point untouched.
+
+So we need a transformation that monomorphises @f@ for its result type. Since we
+have available the type of (the soon to be join-) <body>, we can match @[(a,c)]@
+against @[(Bool, Char)]@ to get a substitution @[a ↦ Bool, c ↦ Char]@. We then
+apply this substitution to the tyco binders in the type of @f@, as we ascend
+from the result type:
+
+  * @Maybe c@ is substituted to @Maybe Char@ ('SubstBinder')
+  * @b@ is substituted to @b@ (unchanged, likewise 'SubstBinder')
+  * @forall c@ is in the domain of the substitution and thus will be
+    instantiated ('InstBinder')
+  * @[a]@ is substituted to @[Bool]@ ('SubstBinder')
+  * @forall b@ is substituted (note that in general its kind might mention @a@)
+    to @forall b@ ('SubstBinder')
+  * @forall a@ is in the domain of the substitution and thus will be
+    instantiated ('InstBinder')
+
+Figuring out this list of 'JoinWorkerBinder's (which is the
+'SubstBinder'/'InstBinder') is the job of 'matchJoinResTy'.
+In the simple, non-polymorphic case, it returns a list of 'SubstBinder's, one
+for each join binder. Otherwise, there is at least one 'InstBinder' that
+indicates monorphisation of a polymorphic join result type.
+
+The challenge is in rewriting all call sites of @f@ to match its new type,
+dropping the instantiated type arguments. A typical use case for the
+worker/wrapper transformation. Thus, we make @f@ a wrapper that rewrites to the
+new worker join point f':
+
+  let f  :: forall a b. [a] -> forall c. b -> Maybe c -> [(a,c)]
+      f  = \@a @b (xs :: [a]) @c b (mc :: Maybe c) -> f' @b xs b mc
+      f' :: forall b. [Bool] -> b -> Maybe Char -> [(Bool, Char)]
+      f' = \@b (xs :: [Bool]) b mb -> <rhs> @Bool @b xs @Char b mb
+  in (<body> :: [(Bool, Char)])
+
+Take note that @f@'s type did not change, but its new RHS is now actually
+ill-typed. This doesn't matter as long as we manage to inline the wrapper
+unconditionally at its call sites in <body>, where the arguments for @a@ and @c@
+will always be @Bool@ and @Char at .
+
+The join point worker @f'@ similarly instantiates @a@ and @c@ to @Bool@ and
+ at Char@. Its result type is monomorphic and it can be made into a join point.
+
+The worker/wrapper split is carried out by 'joinPointBinding_maybe', but only if
+there are any 'InstBinder's at all (In which case it returns the result
+ at JoinPointAfterMono@).
+Equipped with the 'matchJoinResTy' result (InstBinder = I, SubstBinder = S)
+
+  [I Bool, S (b::*), S (_::[Bool]), I Char, S (_::b), S (_::Maybe Char)]
+
+It builds the wrapper body of @f@ by applying the new worker binder @f'@ to
+
+  * Nothing if the corresponding 'JoinWorkerBinder' is @I _@
+  * @b@ if the corresponding 'JoinWorkerBinder' is @S _@ and @b@ is the old
+    lambda binder
+
+It builds the worker body of @f'@ by applying the <rhs> to
+
+  * @ty@ if the corresponding 'JoinWorkerBinder' is @I ty@
+  * @tv@ if the corresponding 'JoinWorkerBinder' is @S (tv::..)@ (Named binder)
+  * @b@  if the corresponding 'JoinWorkerBinder' is @S (_::ty)@  (Anon  binder)
+    and @b@ is the old lambda binder with its type updated to @ty at .
+
+The result of @joinPointBinding_maybe@ is ultimately exported via @tryJoinWW@
+and is used in the simple optimiser as well as the Simplifier, which both
+inline the join point wrapper unconditionally (if present).
+-}
 
 {- *********************************************************************
 *                                                                      *
@@ -1342,5 +1631,3 @@ exprIsLambda_maybe (in_scope_set, id_unf) e
 exprIsLambda_maybe _ _e
     = -- pprTrace "exprIsLambda_maybe:Fail" (vcat [ppr _e])
       Nothing
-
-


=====================================
compiler/GHC/Core/TyCo/Subst.hs
=====================================
@@ -19,7 +19,7 @@ module GHC.Core.TyCo.Subst
         getTvSubstEnv,
         getCvSubstEnv, getTCvInScope, getTCvSubstRangeFVs,
         isInScope, notElemTCvSubst,
-        setTvSubstEnv, setCvSubstEnv, zapTCvSubst,
+        setTvSubstEnv, setCvSubstEnv, zapTCvSubst, delTCvSubst,
         extendTCvInScope, extendTCvInScopeList, extendTCvInScopeSet,
         extendTCvSubst, extendTCvSubstWithClone,
         extendCvSubst, extendCvSubstWithClone,
@@ -308,6 +308,23 @@ setCvSubstEnv (TCvSubst in_scope tenv _) cenv = TCvSubst in_scope tenv cenv
 zapTCvSubst :: TCvSubst -> TCvSubst
 zapTCvSubst (TCvSubst in_scope _ _) = TCvSubst in_scope emptyVarEnv emptyVarEnv
 
+delTCvSubst :: TCvSubst -> Var -> TCvSubst
+delTCvSubst subst v
+  | isTyVar v
+  = delTvSubst subst v
+  | isCoVar v
+  = delCvSubst subst v
+  | otherwise
+  = pprPanic "delTCvSubst" (ppr v)
+
+delTvSubst :: TCvSubst -> TyVar -> TCvSubst
+delTvSubst (TCvSubst in_scope tenv cenv) tv
+  = TCvSubst in_scope (delVarEnv tenv tv) cenv
+
+delCvSubst :: TCvSubst -> CoVar -> TCvSubst
+delCvSubst (TCvSubst in_scope tenv cenv) cv
+  = TCvSubst in_scope tenv (delVarEnv cenv cv)
+
 extendTCvInScope :: TCvSubst -> Var -> TCvSubst
 extendTCvInScope (TCvSubst in_scope tenv cenv) var
   = TCvSubst (extendInScopeSet in_scope var) tenv cenv


=====================================
compiler/GHC/Core/Type.hs
=====================================
@@ -193,7 +193,7 @@ module GHC.Core.Type (
         zipTCvSubst,
         notElemTCvSubst,
         getTvSubstEnv, setTvSubstEnv,
-        zapTCvSubst, getTCvInScope, getTCvSubstRangeFVs,
+        zapTCvSubst, delTCvSubst, getTCvInScope, getTCvSubstRangeFVs,
         extendTCvInScope, extendTCvInScopeList, extendTCvInScopeSet,
         extendTCvSubst, extendCvSubst,
         extendTvSubst, extendTvSubstBinderAndInScope,


=====================================
compiler/GHC/Core/Utils.hs
=====================================
@@ -276,7 +276,7 @@ applyTypeToArgs e op_ty args
     go op_ty (Coercion co : args) = go_ty_args op_ty [mkCoercionTy co] args
     go op_ty (_ : args)           | Just (_, _, res_ty) <- splitFunTy_maybe op_ty
                                   = go res_ty args
-    go _ args = pprPanic "applyTypeToArgs" (panic_msg args)
+    go op_ty args = pprPanic "applyTypeToArgs" (panic_msg op_ty args)
 
     -- go_ty_args: accumulate type arguments so we can
     -- instantiate all at once with piResultTys
@@ -287,8 +287,9 @@ applyTypeToArgs e op_ty args
     go_ty_args op_ty rev_tys args
        = go (piResultTys op_ty (reverse rev_tys)) args
 
-    panic_msg as = vcat [ text "Expression:" <+> pprCoreExpr e
+    panic_msg ot as = vcat [ text "Expression:" <+> pprCoreExpr e
                      , text "Type:" <+> ppr op_ty
+                     , text "Type':" <+> ppr ot
                      , text "Args:" <+> ppr args
                      , text "Args':" <+> ppr as ]
 
@@ -2622,4 +2623,3 @@ isUnsafeEqualityProof e
   = idName v == unsafeEqualityProofName
   | otherwise
   = False
-


=====================================
compiler/GHC/Driver/Ppr.hs
=====================================
@@ -15,6 +15,7 @@ module GHC.Driver.Ppr
    , pprTraceWithFlags
    , pprTraceM
    , pprTraceDebug
+   , pprTraceWith
    , pprTraceIt
    , pprSTrace
    , pprTraceException



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/e6ea5c5acff2c58d2b09fe18a4cc45a9f46ecd70

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/e6ea5c5acff2c58d2b09fe18a4cc45a9f46ecd70
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20200923/ffb66c08/attachment-0001.html>


More information about the ghc-commits mailing list