[Git][ghc/ghc][wip/T14620] WIP: Fix #14620 by introducing WW to detect more join points
Sebastian Graf
gitlab at gitlab.haskell.org
Tue Sep 22 15:52:34 UTC 2020
Sebastian Graf pushed to branch wip/T14620 at Glasgow Haskell Compiler / GHC
Commits:
ee89bb7b by Sebastian Graf at 2020-09-22T17:52:26+02:00
WIP: Fix #14620 by introducing WW to detect more join points
- - - - -
7 changed files:
- compiler/GHC/Core/Opt/OccurAnal.hs
- compiler/GHC/Core/Opt/Simplify.hs
- compiler/GHC/Core/SimpleOpt.hs
- compiler/GHC/Core/TyCo/Subst.hs
- compiler/GHC/Core/Type.hs
- compiler/GHC/Core/Utils.hs
- compiler/GHC/Driver/Ppr.hs
Changes:
=====================================
compiler/GHC/Core/Opt/OccurAnal.hs
=====================================
@@ -2984,9 +2984,6 @@ decideJoinPointHood NotTopLevel usage bndrs
-- Invariant 2a: stable unfoldings
-- See Note [Join points and INLINE pragmas]
, ok_unfolding arity (realIdUnfolding bndr)
-
- -- Invariant 4: Satisfies polymorphism rule
- , isValidJoinPointType arity (idType bndr)
= True
| otherwise
=====================================
compiler/GHC/Core/Opt/Simplify.hs
=====================================
@@ -53,7 +53,7 @@ import GHC.Core.Utils
import GHC.Core.Opt.Arity ( ArityType(..), arityTypeArity, isBotArityType
, pushCoTyArg, pushCoValArg
, idArityType, etaExpandAT )
-import GHC.Core.SimpleOpt ( joinPointBinding_maybe, joinPointBindings_maybe )
+import GHC.Core.SimpleOpt ( tryJoinPointWW, tryJoinPointWWs )
import GHC.Core.FVs ( mkRuleInfo )
import GHC.Core.Rules ( lookupRule, getRules, initRuleOpts )
import GHC.Types.Basic
@@ -1052,8 +1052,9 @@ simplExprF1 env (Case scrut bndr _ alts) cont
, sc_env = env, sc_cont = cont })
simplExprF1 env (Let (Rec pairs) body) cont
- | Just pairs' <- joinPointBindings_maybe pairs
- = {-#SCC "simplRecJoinPoin" #-} simplRecJoinPoint env pairs' body cont
+ | Just (pairs', wrappers) <- tryJoinPointWWs (getInScope env) (exprType body) pairs
+ -- , null wrappers || pprTrace "simple join Rec" (ppr pairs'<+> ppr (exprType body)) True
+ = {-#SCC "simplRecJoinPoin" #-} simplRecJoinPoint env pairs' wrappers body cont
| otherwise
= {-#SCC "simplRecE" #-} simplRecE env pairs body cont
@@ -1065,8 +1066,9 @@ simplExprF1 env (Let (NonRec bndr rhs) body) cont
do { ty' <- simplType env ty
; simplExprF (extendTvSubst env bndr ty') body cont }
- | Just (bndr', rhs') <- joinPointBinding_maybe bndr rhs
- = {-#SCC "simplNonRecJoinPoint" #-} simplNonRecJoinPoint env bndr' rhs' body cont
+ | Just (bndr', rhs', wrappers) <- tryJoinPointWW (getInScope env) (exprType body) bndr rhs
+ -- , null wrappers || pprTrace "simple join NonRec" (ppr bndr' $$ ppr (idType bndr') $$ ppr (exprType body) $$ ppr (isJoinId bndr) $$ ppr wrappers) True
+ = {-#SCC "simplNonRecJoinPoint" #-} simplNonRecJoinPoint env bndr' rhs' wrappers body cont
| otherwise
= {-#SCC "simplNonRecE" #-} simplNonRecE env bndr (rhs, env) ([], body) cont
@@ -1684,33 +1686,41 @@ type MaybeJoinCont = Maybe SimplCont
-- Just k => This is a join binding with continuation k
-- See Note [Rules and unfolding for join points]
-simplNonRecJoinPoint :: SimplEnv -> InId -> InExpr
- -> InExpr -> SimplCont
- -> SimplM (SimplFloats, OutExpr)
-simplNonRecJoinPoint env bndr rhs body cont
+preInlineJoinWrappers :: SimplEnv -> [(InId, InExpr)] -> SimplEnv
+preInlineJoinWrappers env binds
+ = foldl' (\env (b,r) -> extendIdSubst env b (mkContEx env r)) env' binds
+ where
+ env' = addNewInScopeIds env (map fst binds)
+
+simplNonRecJoinPoint
+ :: SimplEnv -> InId -> InExpr -> [(InId, InExpr)] -> InExpr -> SimplCont
+ -> SimplM (SimplFloats, OutExpr)
+simplNonRecJoinPoint env bndr rhs wrappers body cont
| ASSERT( isJoinId bndr ) True
- , Just env' <- preInlineUnconditionally env NotTopLevel bndr rhs env
+ , ASSERT( wrappers `lengthAtMost` 1 ) True
+ , Just env1 <- preInlineUnconditionally env NotTopLevel bndr rhs env
= do { tick (PreInlineUnconditionally bndr)
- ; simplExprF env' body cont }
-
- | otherwise
- = wrapJoinCont env cont $ \ env cont ->
- do { -- We push join_cont into the join RHS and the body;
- -- and wrap wrap_cont around the whole thing
- ; let mult = contHoleScaling cont
- res_ty = contResultType cont
- ; (env1, bndr1) <- simplNonRecJoinBndr env bndr mult res_ty
- ; (env2, bndr2) <- addBndrRules env1 bndr bndr1 (Just cont)
- ; (floats1, env3) <- simplJoinBind env2 cont bndr bndr2 rhs env
- ; (floats2, body') <- simplExprF env3 body cont
- ; return (floats1 `addFloats` floats2, body') }
+ ; let env2 = preInlineJoinWrappers env1 wrappers
+ ; simplExprF env2 body cont }
+ | otherwise
+ = wrapJoinCont env cont $ \ env cont ->
+ do { -- We push join_cont into the join RHS and the body;
+ -- and wrap wrap_cont around the whole thing
+ ; let mult = contHoleScaling cont
+ res_ty = contResultType cont
+ ; (env1, bndr1) <- simplNonRecJoinBndr env bndr mult res_ty
+ ; (env2, bndr2) <- addBndrRules env1 bndr bndr1 (Just cont)
+ ; (floats1, env3) <- simplJoinBind env2 cont bndr bndr2 rhs env
+ ; let env4 = preInlineJoinWrappers env3 wrappers
+ ; (floats2, body') <- simplExprF env4 body cont
+ ; return (floats1 `addFloats` floats2, body') }
------------------
-simplRecJoinPoint :: SimplEnv -> [(InId, InExpr)]
- -> InExpr -> SimplCont
- -> SimplM (SimplFloats, OutExpr)
-simplRecJoinPoint env pairs body cont
+simplRecJoinPoint
+ :: SimplEnv -> [(InId, InExpr)] -> [(InId, InExpr)] -> InExpr -> SimplCont
+ -> SimplM (SimplFloats, OutExpr)
+simplRecJoinPoint env pairs wrappers body cont
= wrapJoinCont env cont $ \ env cont ->
do { let bndrs = map fst pairs
mult = contHoleScaling cont
@@ -1718,8 +1728,9 @@ simplRecJoinPoint env pairs body cont
; env1 <- simplRecJoinBndrs env bndrs mult res_ty
-- NB: bndrs' don't have unfoldings or rules
-- We add them as we go down
- ; (floats1, env2) <- simplRecBind env1 NotTopLevel (Just cont) pairs
- ; (floats2, body') <- simplExprF env2 body cont
+ ; let env2 = preInlineJoinWrappers env1 wrappers
+ ; (floats1, env3) <- simplRecBind env2 NotTopLevel (Just cont) pairs
+ ; (floats2, body') <- simplExprF env3 body cont
; return (floats1 `addFloats` floats2, body') }
--------------------
=====================================
compiler/GHC/Core/SimpleOpt.hs
=====================================
@@ -5,6 +5,7 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE ViewPatterns #-}
module GHC.Core.SimpleOpt (
SimpleOpts (..), defaultSimpleOpts,
@@ -13,7 +14,7 @@ module GHC.Core.SimpleOpt (
simpleOptPgm, simpleOptExpr, simpleOptExprWith,
-- ** Join points
- joinPointBinding_maybe, joinPointBindings_maybe,
+ tryJoinPointWW, tryJoinPointWWs,
-- ** Predicates on expressions
exprIsConApp_maybe, exprIsLiteral_maybe, exprIsLambda_maybe,
@@ -36,7 +37,7 @@ import GHC.Core.Opt.OccurAnal( occurAnalyseExpr, occurAnalysePgm )
import GHC.Types.Literal
import GHC.Types.Id
import GHC.Types.Id.Info ( unfoldingInfo, setUnfoldingInfo, setRuleInfo, IdInfo (..) )
-import GHC.Types.Var ( isNonCoVarId )
+import GHC.Types.Var ( isNonCoVarId, setVarType, VarBndr (..) )
import GHC.Types.Var.Set
import GHC.Types.Var.Env
import GHC.Core.DataCon
@@ -44,7 +45,11 @@ import GHC.Types.Demand( etaConvertStrictSig )
import GHC.Core.Coercion.Opt ( optCoercion, OptCoercionOpts (..) )
import GHC.Core.Type hiding ( substTy, extendTvSubst, extendCvSubst, extendTvSubstList
, isInScope, substTyVarBndr, cloneTyVarBndr )
+import qualified GHC.Core.Type as Type
import GHC.Core.Coercion hiding ( substCo, substCoVarBndr )
+import GHC.Core.TyCo.Rep ( TyCoBinder (..) )
+import GHC.Core.Multiplicity
+import GHC.Core.Unify ( tcMatchTy )
import GHC.Builtin.Types
import GHC.Builtin.Names
import GHC.Types.Basic
@@ -52,8 +57,10 @@ import GHC.Unit.Module ( Module )
import GHC.Utils.Outputable
import GHC.Utils.Panic
import GHC.Utils.Misc
-import GHC.Data.Maybe ( orElse )
+import GHC.Utils.Monad ( mapAccumLM )
+import GHC.Data.Maybe
import GHC.Data.FastString
+import Data.Bifunctor ( first )
import Data.List
import qualified Data.ByteString as BS
@@ -172,14 +179,14 @@ simpleOptPgm opts this_mod binds rules =
-- hence paying just a substitution
do_one (env, binds') bind
- = case simple_opt_bind env bind TopLevel of
+ = case simple_opt_bind env bind of
(env', Nothing) -> (env', binds')
(env', Just bind') -> (env', bind':binds')
-- In these functions the substitution maps InVar -> OutExpr
----------------------
-type SimpleClo = (SimpleOptEnv, InExpr)
+type SimpleClo = (SimpleOptEnv, InExpr) -- Like SimplSR's ContEx
data SimpleOptEnv
= SOE { soe_co_opt_opts :: !OptCoercionOpts
@@ -191,7 +198,8 @@ data SimpleOptEnv
, soe_inl :: IdEnv SimpleClo
-- ^ Deals with preInlineUnconditionally; things
-- that occur exactly once and are inlined
- -- without having first been simplified
+ -- without having first been simplified or
+ -- substituted, thus the domain is InBndrs
, soe_subst :: Subst
-- ^ Deals with cloning; includes the InScopeSet
@@ -247,9 +255,10 @@ simple_opt_expr env expr
go (Lit lit) = Lit lit
go (Tick tickish e) = mkTick (substTickish subst tickish) (go e)
go (Cast e co) = mk_cast (go e) (go_co co)
- go (Let bind body) = case simple_opt_bind env bind NotTopLevel of
- (env', Nothing) -> simple_opt_expr env' body
- (env', Just bind) -> Let bind (simple_opt_expr env' body)
+ go (Let bind body) =
+ case simple_opt_local_bind env (exprType body) bind of
+ (env', Nothing) -> simple_opt_expr env' body
+ (env', Just bind) -> Let bind (simple_opt_expr env' body)
go lam@(Lam {}) = go_lam env [] lam
go (Case e b ty as)
@@ -351,7 +360,7 @@ simple_app env (Tick t e) as
-- However, do /not/ do this transformation for join points
-- See Note [simple_app and join points]
simple_app env (Let bind body) args
- = case simple_opt_bind env bind NotTopLevel of
+ = case simple_opt_local_bind env (exprType body) bind of
(env', Nothing) -> simple_app env' body args
(env', Just bind')
| isJoinBind bind' -> finish_app env expr' args
@@ -369,29 +378,87 @@ finish_app env fun (arg:args)
= finish_app env (App fun (simple_opt_clo env arg)) args
----------------------
-simple_opt_bind :: SimpleOptEnv -> InBind -> TopLevelFlag
- -> (SimpleOptEnv, Maybe OutBind)
-simple_opt_bind env (NonRec b r) top_level
- = (env', case mb_pr of
- Nothing -> Nothing
- Just (b,r) -> Just (NonRec b r))
+extendInlEnv :: SimpleOptEnv -> InBndr -> SimpleClo -> SimpleOptEnv
+-- Like GHC.Core.Opt.Simplify.Env.extendIdSubst
+extendInlEnv env@(SOE { soe_inl = inl_env }) bndr clo
+ = ASSERT2( isId bndr && not (isCoVar bndr), ppr bndr )
+ env { soe_inl = extendVarEnv inl_env bndr clo }
+
+extendInScopeEnv :: SimpleOptEnv -> [InBndr] -> SimpleOptEnv
+extendInScopeEnv env@(SOE { soe_subst = Subst in_scope ids tvs cos }) bndrs
+ = env { soe_subst = Subst (extendInScopeSetList in_scope bndrs) ids tvs cos }
+
+tryJoinPointWWs :: InScopeSet -> Type -> [(InBndr, InExpr)] -> Maybe ([(InBndr, InExpr)], [(InBndr, InExpr)])
+tryJoinPointWWs in_scope body_ty binds
+ = foldMap go <$> joinPointBindings_maybe in_scope body_ty binds
where
- (b', r') = joinPointBinding_maybe b r `orElse` (b, r)
- (env', mb_pr) = simple_bind_pair env b' Nothing (env,r') top_level
+ go jph = ([(join_bndr jph, join_rhs jph)], join_wrapper jph)
+ join_wrapper jph at JoinPointAfterMono{} -- Rare: A join point after we inline a wrapper
+ = [(join_wrapper_bndr jph, join_wrapper_body jph)]
+ join_wrapper DefinitelyJoinPoint{} -- Common: Regular join point. No wrapper
+ = []
+
+tryJoinPointWW :: InScopeSet -> Type -> InBndr -> InExpr -> Maybe (InBndr, InExpr, [(InBndr, InExpr)])
+tryJoinPointWW in_scope body_ty b r
+ | Just ([(b', r')], wrappers) <- tryJoinPointWWs in_scope body_ty [(b, r)]
+ = ASSERT( wrappers `lengthAtMost` 1 )
+ Just (b', r', wrappers)
+ | otherwise
+ = Nothing
-simple_opt_bind env (Rec prs) top_level
- = (env'', res_bind)
+pair_to_non_rec
+ :: (SimpleOptEnv, Maybe (OutBndr, OutExpr))
+ -> (SimpleOptEnv, Maybe OutBind)
+pair_to_non_rec (env, mb_pr) = (env, uncurry NonRec <$> mb_pr)
+
+simple_opt_local_bind
+ :: SimpleOptEnv -> Type -> InBind -> (SimpleOptEnv, Maybe OutBind)
+simple_opt_local_bind env body_ty (NonRec b r)
+ | (b', r', wrappers) <- tryJoinPointWW (substInScope (soe_subst env) `extendInScopeSet` b) body_ty b r `orElse` (b, r, [])
+ -- , null wrappers || pprTrace "simple_opt_local_bind:join" (ppr b <+> ppr (idType b) <+> ppr body_ty) True
+ = -- pprTraceWith "simple_opt_local_bind" (\(env', mb_bind) -> ppr b <+> (case mb_bind of Nothing -> text "inlined" $$ ppr env'; Just _ -> text "not inlined")) $
+ first (pre_inline_join_wrappers wrappers)
+ $ pair_to_non_rec
+ $ simple_bind_pair env b' Nothing (env,r') NotTopLevel
+
+simple_opt_local_bind env body_ty (Rec prs)
+ --- | null wrappers || pprTrace "simple_opt_local_bind:joinrec" (ppr prs <+> ppr body_ty) True
+ = (env3, res_bind)
where
- res_bind = Just (Rec (reverse rev_prs'))
- prs' = joinPointBindings_maybe prs `orElse` prs
- (env', bndrs') = subst_opt_bndrs env (map fst prs')
- (env'', rev_prs') = foldl' do_pr (env', []) (prs' `zip` bndrs')
- do_pr (env, prs) ((b,r), b')
+ res_bind = Just (Rec (reverse rev_prs'))
+ (prs', wrappers) = tryJoinPointWWs (substInScope (soe_subst env)) body_ty prs `orElse` (prs, [])
+ (env1, bndrs') = subst_opt_bndrs env (map fst prs')
+ env2 = pre_inline_join_wrappers wrappers env1
+ (env3, rev_prs') = foldl' simpl_pr (env2, []) (prs' `zip` bndrs')
+ simpl_pr (env, prs) ((b,r), b')
= (env', case mb_pr of
Just pr -> pr : prs
Nothing -> prs)
where
- (env', mb_pr) = simple_bind_pair env b (Just b') (env,r) top_level
+ (env', mb_pr) = simple_bind_pair env b (Just b') (env,r) NotTopLevel
+
+pre_inline_join_wrappers :: [(InBndr, InExpr)] -> SimpleOptEnv -> SimpleOptEnv
+pre_inline_join_wrappers binds env
+ = foldl' (\env (b,r) -> extendInlEnv env b (env, r)) env' binds
+ where
+ env' = extendInScopeEnv env (map fst binds)
+
+simple_opt_bind :: SimpleOptEnv -> InBind -> (SimpleOptEnv, Maybe OutBind)
+simple_opt_bind env (NonRec b r)
+ = pair_to_non_rec (simple_bind_pair env b Nothing (env,r) TopLevel)
+
+simple_opt_bind env (Rec prs)
+ = (env'', res_bind)
+ where
+ res_bind = Just (Rec (reverse rev_prs'))
+ (env', bndrs') = subst_opt_bndrs env (map fst prs)
+ (env'', rev_prs') = foldl' simpl_pr (env', []) (prs `zip` bndrs')
+ simpl_pr (env, prs) ((b,r), b')
+ = (env', case mb_pr of
+ Just pr -> pr : prs
+ Nothing -> prs)
+ where
+ (env', mb_pr) = simple_bind_pair env b (Just b') (env,r) TopLevel
----------------------
simple_bind_pair :: SimpleOptEnv
@@ -796,6 +863,27 @@ A more common case is when
and again its arity increases (#15517)
-}
+data JoinPointHood
+ = DefinitelyJoinPoint
+ { join_bndr :: !InBndr
+ , join_rhs :: !InExpr }
+ | JoinPointAfterMono
+ { join_bndr :: !InBndr
+ , join_rhs :: !InExpr
+ , join_wrapper_bndr :: !InBndr
+ , join_wrapper_body :: !InExpr }
+
+data Blub
+ = MonoTyArg !Type
+ | SubstBinder !TyCoBinder
+
+instance Outputable Blub where
+ ppr (MonoTyArg ty) = text "Mono" <+> ppr ty
+ ppr (SubstBinder bndr) = text "Subst" <+> ppr bndr
+
+isNotMonoTyArg :: Blub -> Bool
+isNotMonoTyArg MonoTyArg{} = False
+isNotMonoTyArg _ = True
-- | Returns Just (bndr,rhs) if the binding is a join point:
-- If it's a JoinId, just return it
@@ -806,29 +894,140 @@ and again its arity increases (#15517)
--
-- Precondition: the InBndr has been occurrence-analysed,
-- so its OccInfo is valid
-joinPointBinding_maybe :: InBndr -> InExpr -> Maybe (InBndr, InExpr)
-joinPointBinding_maybe bndr rhs
- | not (isId bndr)
- = Nothing
-
- | isJoinId bndr
- = Just (bndr, rhs)
-
- | AlwaysTailCalled join_arity <- tailCallInfo (idOccInfo bndr)
- , (bndrs, body) <- etaExpandToJoinPoint join_arity rhs
- , let str_sig = idStrictness bndr
- str_arity = count isId bndrs -- Strictness demands are for Ids only
- join_bndr = bndr `asJoinId` join_arity
- `setIdStrictness` etaConvertStrictSig str_arity str_sig
- = Just (join_bndr, mkLams bndrs body)
+joinPointBindings_maybe :: InScopeSet -> Type -> [(InBndr, InExpr)] -> Maybe [JoinPointHood]
+joinPointBindings_maybe in_scope body_type binds
+ = snd <$> mapAccumLM go (extendInScopeSetList in_scope (map fst binds)) binds
+ where
+ go :: InScopeSet -> (InBndr, InExpr) -> Maybe (InScopeSet, JoinPointHood)
+ go in_scope (bndr, rhs)
+ | not (isId bndr)
+ = Nothing
+
+ | isJoinId bndr
+ = Just (in_scope, DefinitelyJoinPoint bndr rhs)
+
+ | AlwaysTailCalled join_arity <- tailCallInfo (idOccInfo bndr)
+ , not (exprIsTrivial rhs)
+ , (lam_bndrs, rhs') <- etaExpandToJoinPoint join_arity rhs
+ , let eta_rhs' = mkLams lam_bndrs rhs'
+ , let inst_tys = matchJoinResTy join_arity (idType bndr) body_type
+ , let new_join_arity = count isNotMonoTyArg inst_tys -- all other
+ , let no_mono = new_join_arity == join_arity
+ , let worker_body = mk_worker_body lam_bndrs inst_tys eta_rhs'
+ -- we need an in-scope set as if the worker was defined inside the RHS of the wrapper (as is the case with SAT)
+ , let in_scope' = extendInScopeSetList in_scope lam_bndrs
+ , let new_bndr = uniqAway in_scope' bndr -- only used in else branch
+ `setIdType` exprType worker_body
+ , let wrapper_body = mk_wrapper_body new_bndr lam_bndrs inst_tys
+ , let wrapper_bndr = bndr
+ -- , no_mono || pprTrace "always tail called:" (vcat [ppr in_scope', ppr bndr, ppr (idType bndr), ppr body_type, ppr rhs, ppr new_bndr, ppr (exprType worker_body), ppr join_arity, ppr inst_tys, ppr new_bndr, ppr wrapper_body, ppr worker_body]) True
+ = Just $! if no_mono
+ then ( in_scope
+ , DefinitelyJoinPoint
+ { join_bndr = adjust_id_info bndr lam_bndrs join_arity
+ , join_rhs = eta_rhs' } )
+ else ( extendInScopeSet in_scope new_bndr
+ , JoinPointAfterMono
+ { join_bndr = adjust_id_info new_bndr lam_bndrs new_join_arity
+ , join_rhs = worker_body
+ , join_wrapper_bndr = wrapper_bndr
+ , join_wrapper_body = wrapper_body } )
- | otherwise
- = Nothing
+ | otherwise
+ = Nothing
+
+ adjust_id_info :: InBndr -> [InBndr] -> JoinArity -> InBndr
+ adjust_id_info bndr lam_bndrs join_arity = zapStableUnfolding $ -- TODO: Discuss! Type errors otherwise.
+ let str_sig = idStrictness bndr
+ str_arity = count isId lam_bndrs -- Strictness demands are for Ids only
+ in bndr `asJoinId` join_arity
+ `setIdStrictness` etaConvertStrictSig str_arity str_sig
+
+ -- WW example: result-type polymorphic join point
+ -- f :: forall a b. [a] -> forall c. b -> Maybe c -> [(a,c)]
+ -- f = <rhs>
+ -- We want to monomorphise for (a ~ Bool) and (c ~ Char) from a join body ty
+ -- of [(Bool, Char)]. Then, we want to get a WW split like
+ -- f = \@a @b (xs :: [a]) @c b (mc :: Maybe c) -> f' @b xs b mc
+ -- f' :: forall b. [Bool] -> b -> Maybe Char -> [(Bool, Char)]
+ -- f' = \@b (xs :: [Bool]) b mb -> <rhs> @Bool @b xs @Char b mb
+ -- the RHS of f is ill-typed... But after pre-inlining, we will be fine!
+ -- The inliner is carrying out the necessary transformation, so to speak,
+ -- it's not like a regular inlining decision.
+
+ mk_wrapper_body :: InBndr -> [InBndr] -> [Blub] -> InExpr
+ mk_wrapper_body new_bndr lam_bndrs inst_tys
+ = ASSERT( lam_bndrs `equalLength` inst_tys )
+ -- pprTraceWith "mk_wrapper_body" (\e -> ppr lam_bndrs $$ ppr inst_tys $$ ppr e) $
+ go (Var new_bndr) $ zipEqual "mk_wrapper_body" lam_bndrs inst_tys
+ where
+ go e [] = e
+ go e ((lb,SubstBinder{}):prs) -- non-instantiated parameter
+ | isId lb -- value paramater xs
+ = Lam lb (go (App e (Var lb)) prs)
+ | otherwise -- type paramater @b
+ = Lam lb (go (App e (Type (mkTyVarTy lb))) prs)
+ go e ((lb,MonoTyArg{}):prs) -- instantiated parameter, @a or @c
+ = ASSERT( isTyVar lb )
+ Lam lb (go e prs)
+
+ mk_worker_body :: [InBndr] -> [Blub] -> InExpr -> InExpr
+ mk_worker_body lam_bndrs inst_tys rhs
+ = -- pprTraceWith "mk_worker_body" (\e -> ppr e) $
+ go rhs $ zipEqual "mk_worker_body" lam_bndrs inst_tys
+ where
+ go e [] = e
+ go e ((lb,SubstBinder bndr):prs) -- non-instantiated parameter
+ | Anon _ (Scaled _ ty) <- bndr -- value paramater xs
+ , let lb' = lb `setIdType` ty
+ = Lam lb' (go (App e (Var lb')) prs)
+ | Named (binderVar -> tcv) <- bndr -- type paramater @b
+ = Lam tcv (go (App e (Type (mkTyVarTy tcv))) prs)
+ go e ((_ ,MonoTyArg ty):prs) -- instantiated paramater, @a or @c
+ = go (App e (Type ty)) prs
+
+-- | Figures out how to monomorphise the result type of a join point.
+--
+-- @matchJoinResTy ja join_ty body_ty@ computes the result type of @join_ty@ by
+-- skipping @ja@ binders and then matches it against @body_ty at .
+-- If a forall binder @a@ is mentioned in the resulting substitution @subst@,
+-- the corresponding entry in the returned list is @Just (subst a)@.
+--
+-- Postcondition: The returned list has length @ja at .
+matchJoinResTy
+ :: JoinArity -- ^ Number of binders to skip
+ -> Type -- ^ Type of the join point
+ -> Type -- ^ Type of the join body
+ -> [Blub] -- ^ An entry for each join binder, Just ty <=> instantiates
+ -- corresponding forall to ty
+matchJoinResTy orig_ar orig_ty body_ty = snd (go init_in_scope orig_ar orig_ty)
+ where
+ init_in_scope = mkInScopeSet $ tyCoVarsOfType body_ty `unionVarSet` tyCoVarsOfType orig_ty
-joinPointBindings_maybe :: [(InBndr, InExpr)] -> Maybe [(InBndr, InExpr)]
-joinPointBindings_maybe bndrs
- = mapM (uncurry joinPointBinding_maybe) bndrs
+ go :: InScopeSet -> Int -> Type -> (TCvSubst, [Blub])
+ go in_scope 0 res_ty = (TCvSubst in_scope tvs cvs, [])
+ where
+ TCvSubst _ tvs cvs = expectJust "matchJoinResTy" $ tcMatchTy res_ty body_ty
+
+ go in_scope n ty
+ | Just (arg_bndr, res_ty) <- splitPiTy_maybe ty
+ = case arg_bndr of
+ Anon f (Scaled m ty)
+ | (subst, inst_tys) <- go in_scope (n-1) res_ty
+ -> (subst, SubstBinder (Anon f (Scaled m (Type.substTy subst ty))):inst_tys)
+ Named (Bndr tcv vis)
+ | isTyVar tcv', Just ty <- lookupTyVar subst tcv'
+ -> (subst', MonoTyArg ty: inst_tys)
+ | otherwise
+ -> (subst', SubstBinder (Named (Bndr subst_tcv vis)) : inst_tys)
+ where
+ tcv' = uniqAway in_scope tcv
+ in_scope' = extendInScopeSet in_scope tcv'
+ (subst, inst_tys) = go in_scope' (n-1) res_ty
+ subst' = delTCvSubst subst tcv'
+ subst_tcv = tcv' `setVarType` Type.substTy subst' (varType tcv')
+ go _ _ _ = pprPanic "matchJoinResTy" (ppr orig_ar <+> ppr orig_ty)
{- *********************************************************************
* *
@@ -1342,5 +1541,3 @@ exprIsLambda_maybe (in_scope_set, id_unf) e
exprIsLambda_maybe _ _e
= -- pprTrace "exprIsLambda_maybe:Fail" (vcat [ppr _e])
Nothing
-
-
=====================================
compiler/GHC/Core/TyCo/Subst.hs
=====================================
@@ -19,7 +19,7 @@ module GHC.Core.TyCo.Subst
getTvSubstEnv,
getCvSubstEnv, getTCvInScope, getTCvSubstRangeFVs,
isInScope, notElemTCvSubst,
- setTvSubstEnv, setCvSubstEnv, zapTCvSubst,
+ setTvSubstEnv, setCvSubstEnv, zapTCvSubst, delTCvSubst,
extendTCvInScope, extendTCvInScopeList, extendTCvInScopeSet,
extendTCvSubst, extendTCvSubstWithClone,
extendCvSubst, extendCvSubstWithClone,
@@ -308,6 +308,23 @@ setCvSubstEnv (TCvSubst in_scope tenv _) cenv = TCvSubst in_scope tenv cenv
zapTCvSubst :: TCvSubst -> TCvSubst
zapTCvSubst (TCvSubst in_scope _ _) = TCvSubst in_scope emptyVarEnv emptyVarEnv
+delTCvSubst :: TCvSubst -> Var -> TCvSubst
+delTCvSubst subst v
+ | isTyVar v
+ = delTvSubst subst v
+ | isCoVar v
+ = delCvSubst subst v
+ | otherwise
+ = pprPanic "delTCvSubst" (ppr v)
+
+delTvSubst :: TCvSubst -> TyVar -> TCvSubst
+delTvSubst (TCvSubst in_scope tenv cenv) tv
+ = TCvSubst in_scope (delVarEnv tenv tv) cenv
+
+delCvSubst :: TCvSubst -> CoVar -> TCvSubst
+delCvSubst (TCvSubst in_scope tenv cenv) cv
+ = TCvSubst in_scope tenv (delVarEnv cenv cv)
+
extendTCvInScope :: TCvSubst -> Var -> TCvSubst
extendTCvInScope (TCvSubst in_scope tenv cenv) var
= TCvSubst (extendInScopeSet in_scope var) tenv cenv
=====================================
compiler/GHC/Core/Type.hs
=====================================
@@ -193,7 +193,7 @@ module GHC.Core.Type (
zipTCvSubst,
notElemTCvSubst,
getTvSubstEnv, setTvSubstEnv,
- zapTCvSubst, getTCvInScope, getTCvSubstRangeFVs,
+ zapTCvSubst, delTCvSubst, getTCvInScope, getTCvSubstRangeFVs,
extendTCvInScope, extendTCvInScopeList, extendTCvInScopeSet,
extendTCvSubst, extendCvSubst,
extendTvSubst, extendTvSubstBinderAndInScope,
=====================================
compiler/GHC/Core/Utils.hs
=====================================
@@ -276,7 +276,7 @@ applyTypeToArgs e op_ty args
go op_ty (Coercion co : args) = go_ty_args op_ty [mkCoercionTy co] args
go op_ty (_ : args) | Just (_, _, res_ty) <- splitFunTy_maybe op_ty
= go res_ty args
- go _ args = pprPanic "applyTypeToArgs" (panic_msg args)
+ go op_ty args = pprPanic "applyTypeToArgs" (panic_msg op_ty args)
-- go_ty_args: accumulate type arguments so we can
-- instantiate all at once with piResultTys
@@ -287,8 +287,9 @@ applyTypeToArgs e op_ty args
go_ty_args op_ty rev_tys args
= go (piResultTys op_ty (reverse rev_tys)) args
- panic_msg as = vcat [ text "Expression:" <+> pprCoreExpr e
+ panic_msg ot as = vcat [ text "Expression:" <+> pprCoreExpr e
, text "Type:" <+> ppr op_ty
+ , text "Type':" <+> ppr ot
, text "Args:" <+> ppr args
, text "Args':" <+> ppr as ]
@@ -2622,4 +2623,3 @@ isUnsafeEqualityProof e
= idName v == unsafeEqualityProofName
| otherwise
= False
-
=====================================
compiler/GHC/Driver/Ppr.hs
=====================================
@@ -15,6 +15,7 @@ module GHC.Driver.Ppr
, pprTraceWithFlags
, pprTraceM
, pprTraceDebug
+ , pprTraceWith
, pprTraceIt
, pprSTrace
, pprTraceException
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/ee89bb7b0a39644d9455af61204fd85eddce0e23
--
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/ee89bb7b0a39644d9455af61204fd85eddce0e23
You're receiving this email because of your account on gitlab.haskell.org.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20200922/e6f6aff9/attachment-0001.html>
More information about the ghc-commits
mailing list