[Git][ghc/ghc][master] 2 commits: Fix an tricky specialiser loop
Marge Bot
gitlab at gitlab.haskell.org
Mon Apr 6 17:16:55 UTC 2020
Marge Bot pushed to branch master at Glasgow Haskell Compiler / GHC
Commits:
cec2c71f by Simon Peyton Jones at 2020-04-06T13:16:44-04:00
Fix an tricky specialiser loop
Issue #17151 was a very tricky example of a bug in which the
specialiser accidentally constructs a recurive dictionary,
so that everything turns into bottom.
I have fixed variants of this bug at least twice before:
see Note [Avoiding loops]. It was a bit of a struggle
to isolate the problem, greatly aided by the work that
Alexey Kuleshevich did in distilling a test case.
Once I'd understood the problem, it was not difficult to fix,
though it did lead me a bit of refactoring in specImports.
- - - - -
e850d14f by Simon Peyton Jones at 2020-04-06T13:16:44-04:00
Refactoring only
This refactors DictBinds into a data type rather than a pair.
No change in behaviour, just better code
- - - - -
5 changed files:
- compiler/GHC/Core/Op/Specialise.hs
- + testsuite/tests/simplCore/should_run/T17151.hs
- + testsuite/tests/simplCore/should_run/T17151.stdout
- + testsuite/tests/simplCore/should_run/T17151a.hs
- testsuite/tests/simplCore/should_run/all.T
Changes:
=====================================
compiler/GHC/Core/Op/Specialise.hs
=====================================
@@ -589,19 +589,11 @@ specProgram guts@(ModGuts { mg_module = this_mod
-- Specialise the bindings of this module
; (binds', uds) <- runSpecM dflags this_mod (go binds)
- -- Specialise imported functions
- ; hpt_rules <- getRuleBase
- ; let rule_base = extendRuleBaseList hpt_rules local_rules
- ; (new_rules, spec_binds) <- specImports dflags this_mod top_env emptyVarSet
- [] rule_base uds
-
- ; let final_binds
- | null spec_binds = binds'
- | otherwise = Rec (flattenBinds spec_binds) : binds'
- -- Note [Glom the bindings if imported functions are specialised]
+ ; (spec_rules, spec_binds) <- specImports dflags this_mod top_env
+ local_rules uds
- ; return (guts { mg_binds = final_binds
- , mg_rules = new_rules ++ local_rules }) }
+ ; return (guts { mg_binds = spec_binds ++ binds'
+ , mg_rules = spec_rules ++ local_rules }) }
where
-- We need to start with a Subst that knows all the things
-- that are in scope, so that the substitution engine doesn't
@@ -645,72 +637,93 @@ See #10491
* *
********************************************************************* -}
--- | Specialise a set of calls to imported bindings
-specImports :: DynFlags
- -> Module
- -> SpecEnv -- Passed in so that all top-level Ids are in scope
- -> VarSet -- Don't specialise these ones
- -- See Note [Avoiding recursive specialisation]
- -> [Id] -- Stack of imported functions being specialised
- -> RuleBase -- Rules from this module and the home package
- -- (but not external packages, which can change)
- -> UsageDetails -- Calls for imported things, and floating bindings
- -> CoreM ( [CoreRule] -- New rules
- , [CoreBind] ) -- Specialised bindings
- -- See Note [Wrapping bindings returned by specImports]
-specImports dflags this_mod top_env done callers rule_base
+specImports :: DynFlags -> Module -> SpecEnv
+ -> [CoreRule]
+ -> UsageDetails
+ -> CoreM ([CoreRule], [CoreBind])
+specImports dflags this_mod top_env local_rules
(MkUD { ud_binds = dict_binds, ud_calls = calls })
- -- See Note [Disabling cross-module specialisation]
| not $ gopt Opt_CrossModuleSpecialise dflags
- = return ([], [])
+ -- See Note [Disabling cross-module specialisation]
+ = return ([], wrapDictBinds dict_binds [])
| otherwise
- = do { let import_calls = dVarEnvElts calls
- ; (rules, spec_binds) <- go rule_base import_calls
+ = do { hpt_rules <- getRuleBase
+ ; let rule_base = extendRuleBaseList hpt_rules local_rules
+
+ ; (spec_rules, spec_binds) <- spec_imports dflags this_mod top_env
+ [] rule_base
+ dict_binds calls
-- Don't forget to wrap the specialized bindings with
-- bindings for the needed dictionaries.
-- See Note [Wrap bindings returned by specImports]
- ; let spec_binds' = wrapDictBinds dict_binds spec_binds
+ -- and Note [Glom the bindings if imported functions are specialised]
+ ; let final_binds
+ | null spec_binds = wrapDictBinds dict_binds []
+ | otherwise = [Rec $ flattenBinds $
+ wrapDictBinds dict_binds spec_binds]
+
+ ; return (spec_rules, final_binds)
+ }
+
+-- | Specialise a set of calls to imported bindings
+spec_imports :: DynFlags
+ -> Module
+ -> SpecEnv -- Passed in so that all top-level Ids are in scope
+ -> [Id] -- Stack of imported functions being specialised
+ -- See Note [specImport call stack]
+ -> RuleBase -- Rules from this module and the home package
+ -- (but not external packages, which can change)
+ -> Bag DictBind -- Dict bindings, used /only/ for filterCalls
+ -- See Note [Avoiding loops in specImports]
+ -> CallDetails -- Calls for imported things
+ -> CoreM ( [CoreRule] -- New rules
+ , [CoreBind] ) -- Specialised bindings
+spec_imports dflags this_mod top_env
+ callers rule_base dict_binds calls
+ = do { let import_calls = dVarEnvElts calls
+ -- ; debugTraceMsg (text "specImports {" <+>
+ -- vcat [ text "calls:" <+> ppr import_calls
+ -- , text "dict_binds:" <+> ppr dict_binds ])
+ ; (rules, spec_binds) <- go rule_base import_calls
+ -- ; debugTraceMsg (text "End specImports }" <+> ppr import_calls)
- ; return (rules, spec_binds') }
+ ; return (rules, spec_binds) }
where
go :: RuleBase -> [CallInfoSet] -> CoreM ([CoreRule], [CoreBind])
go _ [] = return ([], [])
- go rb (cis@(CIS fn _) : other_calls)
- = do { let ok_calls = filterCalls cis dict_binds
- -- Drop calls that (directly or indirectly) refer to fn
- -- See Note [Avoiding loops]
--- ; debugTraceMsg (text "specImport" <+> vcat [ ppr fn
--- , text "calls" <+> ppr cis
--- , text "ud_binds =" <+> ppr dict_binds
--- , text "dump set =" <+> ppr dump_set
--- , text "filtered calls =" <+> ppr ok_calls ])
- ; (rules1, spec_binds1) <- specImport dflags this_mod top_env
- done callers rb fn ok_calls
+ go rb (cis : other_calls)
+ = do { -- debugTraceMsg (text "specImport {" <+> ppr cis)
+ ; (rules1, spec_binds1) <- spec_import dflags this_mod top_env
+ callers rb dict_binds cis
+ -- ; debugTraceMsg (text "specImport }" <+> ppr cis)
; (rules2, spec_binds2) <- go (extendRuleBaseList rb rules1) other_calls
; return (rules1 ++ rules2, spec_binds1 ++ spec_binds2) }
-specImport :: DynFlags
- -> Module
- -> SpecEnv -- Passed in so that all top-level Ids are in scope
- -> VarSet -- Don't specialise these
- -- See Note [Avoiding recursive specialisation]
- -> [Id] -- Stack of imported functions being specialised
- -> RuleBase -- Rules from this module
- -> Id -> [CallInfo] -- Imported function and calls for it
- -> CoreM ( [CoreRule] -- New rules
- , [CoreBind] ) -- Specialised bindings
-specImport dflags this_mod top_env done callers rb fn calls_for_fn
- | fn `elemVarSet` done
+spec_import :: DynFlags
+ -> Module
+ -> SpecEnv -- Passed in so that all top-level Ids are in scope
+ -> [Id] -- Stack of imported functions being specialised
+ -- See Note [specImport call stack]
+ -> RuleBase -- Rules from this module
+ -> Bag DictBind -- Dict bindings, used /only/ for filterCalls
+ -- See Note [Avoiding loops in specImports]
+ -> CallInfoSet -- Imported function and calls for it
+ -> CoreM ( [CoreRule] -- New rules
+ , [CoreBind] ) -- Specialised bindings
+spec_import dflags this_mod top_env callers
+ rb dict_binds cis@(CIS fn _)
+ | isIn "specImport" fn callers
= return ([], []) -- No warning. This actually happens all the time
-- when specialising a recursive function, because
-- the RHS of the specialised function contains a recursive
-- call to the original function
- | null calls_for_fn -- We filtered out all the calls in deleteCallsMentioning
- = return ([], [])
+ | null good_calls
+ = do { -- debugTraceMsg (text "specImport:no valid calls")
+ ; return ([], []) }
| wantSpecImport dflags unfolding
, Just rhs <- maybeUnfoldingTemplate unfolding
@@ -723,32 +736,37 @@ specImport dflags this_mod top_env done callers rb fn calls_for_fn
; let full_rb = unionRuleBase rb (eps_rule_base eps)
rules_for_fn = getRules (RuleEnv full_rb vis_orphs) fn
- ; (rules1, spec_pairs, uds)
- <- -- pprTrace "specImport1" (vcat [ppr fn, ppr calls_for_fn, ppr rhs]) $
- runSpecM dflags this_mod $
- specCalls (Just this_mod) top_env rules_for_fn calls_for_fn fn rhs
+ ; (rules1, spec_pairs, MkUD { ud_binds = dict_binds1, ud_calls = new_calls })
+ <- do { -- debugTraceMsg (text "specImport1" <+> vcat [ppr fn, ppr good_calls, ppr rhs])
+ ; runSpecM dflags this_mod $
+ specCalls (Just this_mod) top_env rules_for_fn good_calls fn rhs }
; let spec_binds1 = [NonRec b r | (b,r) <- spec_pairs]
-- After the rules kick in we may get recursion, but
-- we rely on a global GlomBinds to sort that out later
-- See Note [Glom the bindings if imported functions are specialised]
-- Now specialise any cascaded calls
- ; (rules2, spec_binds2) <- -- pprTrace "specImport 2" (ppr fn $$ ppr rules1 $$ ppr spec_binds1) $
- specImports dflags this_mod top_env
- (extendVarSet done fn)
- (fn:callers)
- (extendRuleBaseList rb rules1)
- uds
+ -- ; debugTraceMsg (text "specImport 2" <+> (ppr fn $$ ppr rules1 $$ ppr spec_binds1))
+ ; (rules2, spec_binds2) <- spec_imports dflags this_mod top_env
+ (fn:callers)
+ (extendRuleBaseList rb rules1)
+ (dict_binds `unionBags` dict_binds1)
+ new_calls
- ; let final_binds = spec_binds2 ++ spec_binds1
+ ; let final_binds = wrapDictBinds dict_binds1 $
+ spec_binds2 ++ spec_binds1
; return (rules2 ++ rules1, final_binds) }
- | otherwise = do { tryWarnMissingSpecs dflags callers fn calls_for_fn
- ; return ([], [])}
+ | otherwise
+ = do { tryWarnMissingSpecs dflags callers fn good_calls
+ ; return ([], [])}
where
unfolding = realIdUnfolding fn -- We want to see the unfolding even for loop breakers
+ good_calls = filterCalls cis dict_binds
+ -- SUPER IMPORTANT! Drop calls that (directly or indirectly) refer to fn
+ -- See Note [Avoiding loops in specImports]
-- | Returns whether or not to show a missed-spec warning.
-- If -Wall-missed-specializations is on, show the warning.
@@ -790,8 +808,114 @@ wantSpecImport dflags unf
-- inside it that we want to specialise
| otherwise -> False -- Stable, not INLINE, hence INLINABLE
-{- Note [Warning about missed specialisations]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+{- Note [Avoiding loops in specImports]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+We must take great care when specialising instance declarations
+(functions like $fOrdList) lest we accidentally build a recursive
+dictionary. See Note [Avoiding loops].
+
+The basic strategy of Note [Avoiding loops] is to use filterCalls
+to discard loopy specialisations. But to do that we must ensure
+that the in-scope dict-binds (passed to filterCalls) contains
+all the needed dictionary bindings. In particular, in the recursive
+call to spec_imorpts in spec_import, we must include the dict-binds
+from the parent. Lacking this caused #17151, a really nasty bug.
+
+Here is what happened.
+* Class struture:
+ Source is a superclass of Mut
+ Index is a superclass of Source
+
+* We started with these dict binds
+ dSource = $fSourcePix @Int $fIndexInt
+ dIndex = sc_sel dSource
+ dMut = $fMutPix @Int dIndex
+ and these calls to specialise
+ $fMutPix @Int dIndex
+ $fSourcePix @Int $fIndexInt
+
+* We specialised the call ($fMutPix @Int dIndex)
+ ==> new call ($fSourcePix @Int dIndex)
+ (because Source is a superclass of Mut)
+
+* We specialised ($fSourcePix @Int dIndex)
+ ==> produces specialised dict $s$fSourcePix,
+ a record with dIndex as a field
+ plus RULE forall d. ($fSourcePix @Int d) = $s$fSourcePix
+ *** This is the bogus step ***
+
+* Now we decide not to specialise the call
+ $fSourcePix @Int $fIndexInt
+ because we alredy have a RULE that matches it
+
+* Finally the simplifer rewrites
+ dSource = $fSourcePix @Int $fIndexInt
+ ==> dSource = $s$fSourcePix
+
+Disaster. Now we have
+
+Rewrite dSource's RHS to $s$fSourcePix Disaster
+ dSource = $s$fSourcePix
+ dIndex = sc_sel dSource
+ $s$fSourcePix = MkSource dIndex ...
+
+Solution: filterCalls should have stopped the bogus step,
+by seeing that dIndex transitively uses $fSourcePix. But
+it can only do that if it sees all the dict_binds. Wow.
+
+--------------
+Here's another example (#13429). Suppose we have
+ class Monoid v => C v a where ...
+
+We start with a call
+ f @ [Integer] @ Integer $fC[]Integer
+
+Specialising call to 'f' gives dict bindings
+ $dMonoid_1 :: Monoid [Integer]
+ $dMonoid_1 = M.$p1C @ [Integer] $fC[]Integer
+
+ $dC_1 :: C [Integer] (Node [Integer] Integer)
+ $dC_1 = M.$fCvNode @ [Integer] $dMonoid_1
+
+...plus a recursive call to
+ f @ [Integer] @ (Node [Integer] Integer) $dC_1
+
+Specialising that call gives
+ $dMonoid_2 :: Monoid [Integer]
+ $dMonoid_2 = M.$p1C @ [Integer] $dC_1
+
+ $dC_2 :: C [Integer] (Node [Integer] Integer)
+ $dC_2 = M.$fCvNode @ [Integer] $dMonoid_2
+
+Now we have two calls to the imported function
+ M.$fCvNode :: Monoid v => C v a
+ M.$fCvNode @v @a m = C m some_fun
+
+But we must /not/ use the call (M.$fCvNode @ [Integer] $dMonoid_2)
+for specialisation, else we get:
+
+ $dC_1 = M.$fCvNode @ [Integer] $dMonoid_1
+ $dMonoid_2 = M.$p1C @ [Integer] $dC_1
+ $s$fCvNode = C $dMonoid_2 ...
+ RULE M.$fCvNode [Integer] _ _ = $s$fCvNode
+
+Now use the rule to rewrite the call in the RHS of $dC_1
+and we get a loop!
+
+
+Note [specImport call stack]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+When specialising an imports function 'f', we may get new calls
+of an imported fuction 'g', which we want to specialise in turn,
+and similarly specialising 'g' might expose a new call to 'h'.
+
+We track the stack of enclosing functions. So when specialising 'h' we
+haev a specImport call stack of [g,f]. We do this for two reasons:
+* Note [Warning about missed specialisations]
+* Note [Avoiding recursive specialisation]
+
+Note [Warning about missed specialisations]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Suppose
* In module Lib, you carefully mark a function 'foo' INLINABLE
* Import Lib(foo) into another module M
@@ -807,6 +931,16 @@ is what Opt_WarnAllMissedSpecs does.
ToDo: warn about missed opportunities for local functions.
+Note [Avoiding recursive specialisation]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+When we specialise 'f' we may find new overloaded calls to 'g', 'h' in
+'f's RHS. So we want to specialise g,h. But we don't want to
+specialise f any more! It's possible that f's RHS might have a
+recursive yet-more-specialised call, so we'd diverge in that case.
+And if the call is to the same type, one specialisation is enough.
+Avoiding this recursive specialisation loop is one reason for the
+'callers' stack passed to specImports and specImport.
+
Note [Specialise imported INLINABLE things]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
What imported functions do we specialise? The basic set is
@@ -842,15 +976,6 @@ make sure that f_spec is recursive. Easiest thing is to make all
the specialisations for imported bindings recursive.
-Note [Avoiding recursive specialisation]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-When we specialise 'f' we may find new overloaded calls to 'g', 'h' in
-'f's RHS. So we want to specialise g,h. But we don't want to
-specialise f any more! It's possible that f's RHS might have a
-recursive yet-more-specialised call, so we'd diverge in that case.
-And if the call is to the same type, one specialisation is enough.
-Avoiding this recursive specialisation loop is the reason for the
-'done' VarSet passed to specImports and specImport.
************************************************************************
* *
@@ -992,7 +1117,8 @@ specCase env scrut' case_bndr [(con, args, rhs)]
; (rhs', rhs_uds) <- specExpr env_rhs' rhs
; let scrut_bind = mkDB (NonRec case_bndr_flt scrut')
case_bndr_set = unitVarSet case_bndr_flt
- sc_binds = [(NonRec sc_arg_flt sc_rhs, case_bndr_set)
+ sc_binds = [ DB { db_bind = NonRec sc_arg_flt sc_rhs
+ , db_fvs = case_bndr_set }
| (sc_arg_flt, sc_rhs) <- sc_args_flt `zip` sc_rhss ]
flt_binds = scrut_bind : sc_binds
(free_uds, dumped_dbs) = dumpUDs (case_bndr':args') rhs_uds
@@ -1115,7 +1241,7 @@ specBind rhs_env (NonRec fn rhs) body_uds
else
-- No call in final_uds mentions bound variables,
-- so we can just leave the binding here
- return (map fst final_binds, free_uds) }
+ return (map db_bind final_binds, free_uds) }
specBind rhs_env (Rec pairs) body_uds
@@ -1142,7 +1268,7 @@ specBind rhs_env (Rec pairs) body_uds
; if float_all then
return ([], final_uds `snocDictBind` final_bind)
else
- return ([fst final_bind], final_uds) }
+ return ([db_bind final_bind], final_uds) }
---------------------------
@@ -1621,8 +1747,10 @@ In general, we need only make this Rec if
Note [Avoiding loops]
~~~~~~~~~~~~~~~~~~~~~
When specialising /dictionary functions/ we must be very careful to
-avoid building loops. Here is an example that bit us badly: #3591
+avoid building loops. Here is an example that bit us badly, on
+several distinct occasions.
+Here is one: #3591
class Eq a => C a
instance Eq [a] => C [a]
@@ -1637,13 +1765,11 @@ This translates to
None of these definitions is recursive. What happened was that we
generated a specialisation:
-
RULE forall d. dfun T d = dT :: C [T]
dT = (MkD a d (meth d)) [T/a, d1/d]
= MkD T d1 (meth d1)
But now we use the RULE on the RHS of d2, to get
-
d2 = dT = MkD d1 (meth d1)
d1 = $p1 d2
@@ -1660,46 +1786,6 @@ Solution:
(directly or indirectly) on the dfun we are specialising.
This is done by 'filterCalls'
---------------
-Here's another example, this time for an imported dfun, so the call
-to filterCalls is in specImports (#13429). Suppose we have
- class Monoid v => C v a where ...
-
-We start with a call
- f @ [Integer] @ Integer $fC[]Integer
-
-Specialising call to 'f' gives dict bindings
- $dMonoid_1 :: Monoid [Integer]
- $dMonoid_1 = M.$p1C @ [Integer] $fC[]Integer
-
- $dC_1 :: C [Integer] (Node [Integer] Integer)
- $dC_1 = M.$fCvNode @ [Integer] $dMonoid_1
-
-...plus a recursive call to
- f @ [Integer] @ (Node [Integer] Integer) $dC_1
-
-Specialising that call gives
- $dMonoid_2 :: Monoid [Integer]
- $dMonoid_2 = M.$p1C @ [Integer] $dC_1
-
- $dC_2 :: C [Integer] (Node [Integer] Integer)
- $dC_2 = M.$fCvNode @ [Integer] $dMonoid_2
-
-Now we have two calls to the imported function
- M.$fCvNode :: Monoid v => C v a
- M.$fCvNode @v @a m = C m some_fun
-
-But we must /not/ use the call (M.$fCvNode @ [Integer] $dMonoid_2)
-for specialisation, else we get:
-
- $dC_1 = M.$fCvNode @ [Integer] $dMonoid_1
- $dMonoid_2 = M.$p1C @ [Integer] $dC_1
- $s$fCvNode = C $dMonoid_2 ...
- RULE M.$fCvNode [Integer] _ _ = $s$fCvNode
-
-Now use the rule to rewrite the call in the RHS of $dC_1
-and we get a loop!
-
--------------
Here's yet another example
@@ -2227,7 +2313,7 @@ data UsageDetails
-- | A 'DictBind' is a binding along with a cached set containing its free
-- variables (both type variables and dictionaries)
-type DictBind = (CoreBind, VarSet)
+data DictBind = DB { db_bind :: CoreBind, db_fvs :: VarSet }
{- Note [Floated dictionary bindings]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -2256,6 +2342,11 @@ So the DictBinds in (ud_binds :: Bag DictBind) may contain
non-dictionary bindings too.
-}
+instance Outputable DictBind where
+ ppr (DB { db_bind = bind, db_fvs = fvs })
+ = text "DB" <+> braces (sep [ text "bind:" <+> ppr bind
+ , text "fvs: " <+> ppr fvs ])
+
instance Outputable UsageDetails where
ppr (MkUD { ud_binds = dbs, ud_calls = calls })
= text "MkUD" <+> braces (sep (punctuate comma
@@ -2304,8 +2395,8 @@ ppr_call_key_ty (SpecDict _) = Nothing
ppr_call_key_ty UnspecArg = Nothing
instance Outputable CallInfo where
- ppr (CI { ci_key = key, ci_fvs = fvs })
- = text "CI" <> braces (hsep [ fsep (mapMaybe ppr_call_key_ty key), ppr fvs ])
+ ppr (CI { ci_key = key, ci_fvs = _fvs })
+ = text "CI" <> braces (sep (map ppr key))
unionCalls :: CallDetails -> CallDetails -> CallDetails
unionCalls c1 c2 = plusDVarEnv_C unionCallInfoSet c1 c2
@@ -2491,11 +2582,11 @@ plusUDs (MkUD {ud_binds = db1, ud_calls = calls1})
-----------------------------
_dictBindBndrs :: Bag DictBind -> [Id]
-_dictBindBndrs dbs = foldr ((++) . bindersOf . fst) [] dbs
+_dictBindBndrs dbs = foldr ((++) . bindersOf . db_bind) [] dbs
-- | Construct a 'DictBind' from a 'CoreBind'
mkDB :: CoreBind -> DictBind
-mkDB bind = (bind, bind_fvs bind)
+mkDB bind = DB { db_bind = bind, db_fvs = bind_fvs bind }
-- | Identify the free variables of a 'CoreBind'
bind_fvs :: CoreBind -> VarSet
@@ -2526,17 +2617,18 @@ pair_fvs (bndr, rhs) = exprSomeFreeVars interesting rhs
-- | Flatten a set of "dumped" 'DictBind's, and some other binding
-- pairs, into a single recursive binding.
-recWithDumpedDicts :: [(Id,CoreExpr)] -> Bag DictBind ->DictBind
+recWithDumpedDicts :: [(Id,CoreExpr)] -> Bag DictBind -> DictBind
recWithDumpedDicts pairs dbs
- = (Rec bindings, fvs)
+ = DB { db_bind = Rec bindings, db_fvs = fvs }
where
- (bindings, fvs) = foldr add
- ([], emptyVarSet)
- (dbs `snocBag` mkDB (Rec pairs))
- add (NonRec b r, fvs') (pairs, fvs) =
- ((b,r) : pairs, fvs `unionVarSet` fvs')
- add (Rec prs1, fvs') (pairs, fvs) =
- (prs1 ++ pairs, fvs `unionVarSet` fvs')
+ (bindings, fvs) = foldr add ([], emptyVarSet)
+ (dbs `snocBag` mkDB (Rec pairs))
+ add (DB { db_bind = bind, db_fvs = fvs }) (prs_acc, fvs_acc)
+ = case bind of
+ NonRec b r -> ((b,r) : prs_acc, fvs')
+ Rec prs1 -> (prs1 ++ prs_acc, fvs')
+ where
+ fvs' = fvs_acc `unionVarSet` fvs
snocDictBinds :: UsageDetails -> [DictBind] -> UsageDetails
-- Add ud_binds to the tail end of the bindings in uds
@@ -2556,13 +2648,13 @@ wrapDictBinds :: Bag DictBind -> [CoreBind] -> [CoreBind]
wrapDictBinds dbs binds
= foldr add binds dbs
where
- add (bind,_) binds = bind : binds
+ add (DB { db_bind = bind }) binds = bind : binds
wrapDictBindsE :: Bag DictBind -> CoreExpr -> CoreExpr
wrapDictBindsE dbs expr
= foldr add expr dbs
where
- add (bind,_) expr = Let bind expr
+ add (DB { db_bind = bind }) expr = Let bind expr
----------------------
dumpUDs :: [CoreBndr] -> UsageDetails -> (UsageDetails, Bag DictBind)
@@ -2624,9 +2716,10 @@ filterCalls (CIS fn call_bag) dbs
-- (_,_,dump_set) = splitDictBinds dbs {fn}
-- But this variant is shorter
- go so_far (db,fvs) | fvs `intersectsVarSet` so_far
- = extendVarSetList so_far (bindersOf db)
- | otherwise = so_far
+ go so_far (DB { db_bind = bind, db_fvs = fvs })
+ | fvs `intersectsVarSet` so_far
+ = extendVarSetList so_far (bindersOf bind)
+ | otherwise = so_far
ok_call (CI { ci_fvs = fvs }) = not (fvs `intersectsVarSet` dump_set)
@@ -2643,8 +2736,9 @@ splitDictBinds dbs bndr_set
-- Important that it's foldl' not foldr;
-- we're accumulating the set of dumped ids in dump_set
where
- split_db (free_dbs, dump_dbs, dump_idset) db@(bind, fvs)
- | dump_idset `intersectsVarSet` fvs -- Dump it
+ split_db (free_dbs, dump_dbs, dump_idset) db
+ | DB { db_bind = bind, db_fvs = fvs } <- db
+ , dump_idset `intersectsVarSet` fvs -- Dump it
= (free_dbs, dump_dbs `snocBag` db,
extendVarSetList dump_idset (bindersOf bind))
=====================================
testsuite/tests/simplCore/should_run/T17151.hs
=====================================
@@ -0,0 +1,18 @@
+{-# LANGUAGE MonoLocalBinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+module Main where
+
+import T17151a
+
+main :: IO ()
+main = do
+ let ys :: Array P Int Int
+ ys = computeS (makeArray D 1 (const 5))
+ applyStencil ::
+ (Source P ix Int, Load D ix Int)
+ => Stencil ix Int Int
+ -> Array P ix Int
+ -> Array P ix Int
+ applyStencil s = computeS . mapStencil s
+ print (applyStencil (makeConvolutionStencilFromKernel ys) ys `unsafeIndex` 0)
+ print (applyStencil (makeConvolutionStencilFromKernel ys) ys `unsafeIndex` 0)
=====================================
testsuite/tests/simplCore/should_run/T17151.stdout
=====================================
@@ -0,0 +1,2 @@
+55
+55
=====================================
testsuite/tests/simplCore/should_run/T17151a.hs
=====================================
@@ -0,0 +1,205 @@
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE EmptyDataDecls #-}
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE UnboxedTuples #-}
+module T17151a
+ ( computeS
+ , Stencil
+ , P(..)
+ , D(..)
+ , makeConvolutionStencilFromKernel
+ , mapStencil
+ , Array
+ , Construct(..)
+ , Source(..)
+ , Load(..)
+ , Mutable(..)
+ ) where
+
+import Control.Monad.ST
+import Data.Functor.Identity
+import GHC.STRef
+import GHC.ST
+import GHC.Exts
+import Unsafe.Coerce
+import Data.Kind
+
+---- Hacked up stuff to simulate primitive package
+class Prim e where
+ indexByteArray :: ByteArray -> Int -> e
+ sizeOf :: e ->Int
+instance Prim Int where
+ indexByteArray _ _ = 55
+ sizeOf _ = 99
+
+data ByteArray = BA
+type MutableByteArray s = STRef s Int
+
+class Monad m => PrimMonad m where
+ type PrimState m
+ primitive :: (State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
+instance PrimMonad (ST s) where
+ type PrimState (ST s) = s
+ primitive = ST
+
+unsafeFreezeByteArray :: PrimMonad m => MutableByteArray (PrimState m) -> m ByteArray
+unsafeFreezeByteArray a = return (unsafeCoerce a)
+
+newByteArray :: PrimMonad m => Int -> m (MutableByteArray (PrimState m))
+newByteArray (I# n#)
+ = primitive (\s# -> case newMutVar# 33 s# of
+ (# s'#, arr# #) -> (# s'#, STRef arr# #))
+
+writeByteArray :: PrimMonad m => MutableByteArray (PrimState m) -> Int -> e -> m ()
+writeByteArray _ _ _ = return ()
+
+----- End of hacked up stuff
+
+--------------
+newtype Stencil ix e a =
+ Stencil ((ix -> e) -> ix -> a)
+
+mapStencil :: Source r ix e => Stencil ix e a -> Array r ix e -> Array D ix a
+mapStencil (Stencil stencilF) arr = DArray (size arr) (stencilF (unsafeIndex arr))
+{-# INLINE mapStencil #-}
+
+makeConvolutionStencilFromKernel
+ :: (Source r ix e, Num e)
+ => Array r ix e
+ -> Stencil ix e e
+makeConvolutionStencilFromKernel arr = Stencil stencil
+ where
+ sz = size arr
+ sCenter = liftIndex (`quot` 2) sz
+ stencil getVal ix =
+ runIdentity $
+ loopM 0 (< totalElem sz) (+ 1) 0 $ \i a ->
+ pure $ accum a (fromLinearIndex sz i) (unsafeLinearIndex arr i)
+ where
+ ixOff = liftIndex2 (+) ix sCenter
+ accum acc kIx kVal = getVal (liftIndex2 (-) ixOff kIx) * kVal + acc
+ {-# INLINE accum #-}
+ {-# INLINE stencil #-}
+{-# INLINE makeConvolutionStencilFromKernel #-}
+
+
+computeS :: (Mutable r ix e, Load r' ix e) => Array r' ix e -> Array r ix e
+computeS arr = runST $ do
+ marr <- unsafeNew (size arr)
+ unsafeLoadIntoS marr arr
+ unsafeFreeze marr
+{-# INLINE computeS #-}
+
+
+data D = D deriving Show
+
+data instance Array D ix e = DArray{dSize :: ix,
+ dIndex :: ix -> e}
+
+instance Index ix => Construct D ix e where
+ makeArray _ = DArray
+ {-# INLINE makeArray #-}
+
+instance Index ix => Source D ix e where
+ unsafeIndex = dIndex
+ {-# INLINE unsafeIndex #-}
+
+instance Index ix => Load D ix e where
+ size = dSize
+ {-# INLINE size #-}
+ loadArrayM arr = splitLinearlyWith_ (totalElem (size arr)) (unsafeLinearIndex arr)
+ {-# INLINE loadArrayM #-}
+
+
+data P = P deriving Show
+
+data instance Array P ix e = PArray ix ByteArray
+
+instance (Prim e, Index ix) => Construct P ix e where
+ makeArray _ sz f = computeS (makeArray D sz f)
+ {-# INLINE makeArray #-}
+
+instance (Prim e, Index ix) => Source P ix e where
+ unsafeIndex (PArray sz a) = indexByteArray a . toLinearIndex sz
+ {-# INLINE unsafeIndex #-}
+
+instance (Prim e, Index ix) => Mutable P ix e where
+ data MArray s P ix e = MPArray ix (MutableByteArray s)
+ unsafeFreeze (MPArray sz a) = PArray sz <$> unsafeFreezeByteArray a
+ {-# INLINE unsafeFreeze #-}
+ unsafeNew sz = MPArray sz <$> newByteArray (totalElem sz * eSize)
+ where
+ eSize = sizeOf (undefined :: e)
+ {-# INLINE unsafeNew #-}
+ unsafeLinearWrite (MPArray _ ma) = writeByteArray ma
+ {-# INLINE unsafeLinearWrite #-}
+
+
+instance (Prim e, Index ix) => Load P ix e where
+ size (PArray sz _) = sz
+ {-# INLINE size #-}
+ loadArrayM arr = splitLinearlyWith_ (totalElem (size arr)) (unsafeLinearIndex arr)
+ {-# INLINE loadArrayM #-}
+
+
+unsafeLinearIndex :: Source r ix e => Array r ix e -> Int -> e
+unsafeLinearIndex arr = unsafeIndex arr . fromLinearIndex (size arr)
+{-# INLINE unsafeLinearIndex #-}
+
+
+loopM :: Monad m => Int -> (Int -> Bool) -> (Int -> Int) -> a -> (Int -> a -> m a) -> m a
+loopM init' condition increment initAcc f = go init' initAcc
+ where
+ go step acc
+ | condition step = f step acc >>= go (increment step)
+ | otherwise = return acc
+{-# INLINE loopM #-}
+
+splitLinearlyWith_ ::
+ Monad m => Int -> (Int -> b) -> (Int -> b -> m ()) -> m ()
+splitLinearlyWith_ totalLength index write =
+ loopM 0 (< totalLength) (+1) () $ \i () -> write i (index i)
+{-# INLINE splitLinearlyWith_ #-}
+
+
+data family Array r ix e :: Type
+
+class Index ix => Construct r ix e where
+ makeArray :: r -> ix -> (ix -> e) -> Array r ix e
+
+class Load r ix e => Source r ix e where
+ unsafeIndex :: Array r ix e -> ix -> e
+
+class Index ix => Load r ix e where
+ size :: Array r ix e -> ix
+ loadArrayM :: Monad m => Array r ix e -> (Int -> e -> m ()) -> m ()
+ unsafeLoadIntoS ::
+ (Mutable r' ix e, PrimMonad m) => MArray (PrimState m) r' ix e -> Array r ix e -> m ()
+ unsafeLoadIntoS marr arr = loadArrayM arr (unsafeLinearWrite marr)
+ {-# INLINE unsafeLoadIntoS #-}
+
+class (Construct r ix e, Source r ix e) => Mutable r ix e where
+ data MArray s r ix e :: Type
+ unsafeFreeze :: PrimMonad m => MArray (PrimState m) r ix e -> m (Array r ix e)
+ unsafeNew :: PrimMonad m => ix -> m (MArray (PrimState m) r ix e)
+ unsafeLinearWrite :: PrimMonad m => MArray (PrimState m) r ix e -> Int -> e -> m ()
+
+
+class (Eq ix, Ord ix, Show ix) =>
+ Index ix
+ where
+ totalElem :: ix -> Int
+ liftIndex2 :: (Int -> Int -> Int) -> ix -> ix -> ix
+ liftIndex :: (Int -> Int) -> ix -> ix
+ toLinearIndex :: ix -> ix -> Int
+ fromLinearIndex :: ix -> Int -> ix
+
+instance Index Int where
+ totalElem = id
+ toLinearIndex _ = id
+ fromLinearIndex _ = id
+ liftIndex f = f
+ liftIndex2 f = f
=====================================
testsuite/tests/simplCore/should_run/all.T
=====================================
@@ -92,3 +92,4 @@ test('T15840', normal, compile_and_run, [''])
test('T15840a', normal, compile_and_run, [''])
test('T16066', exit_code(1), compile_and_run, ['-O1'])
test('T17206', exit_code(1), compile_and_run, [''])
+test('T17151', [], multimod_compile_and_run, ['T17151', ''])
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/dcfe29c8520244764146c7a5f336be1f9700db6c...e850d14ffbeea39ad386b1e888cd97375758d6d6
--
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/dcfe29c8520244764146c7a5f336be1f9700db6c...e850d14ffbeea39ad386b1e888cd97375758d6d6
You're receiving this email because of your account on gitlab.haskell.org.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20200406/03391e2d/attachment-0001.html>
More information about the ghc-commits
mailing list