[Git][ghc/ghc][wip/T17151] 2 commits: Fix an tricky specialiser loop

Simon Peyton Jones gitlab at gitlab.haskell.org
Fri Apr 3 12:14:08 UTC 2020



Simon Peyton Jones pushed to branch wip/T17151 at Glasgow Haskell Compiler / GHC


Commits:
c8a02a02 by Simon Peyton Jones at 2020-04-03T13:13:57+01:00
Fix an tricky specialiser loop

Issue #17151 was a very tricky example of a bug in which the
specialiser accidentally constructs a recurive dictionary,
so that everything turns into bottom.

I have fixed variants of this bug at least twice before:
see Note [Avoiding loops].  It was a bit of a struggle
to isolate the problem, greatly aided by the work that
Alexey Kuleshevich did in distilling a test case.

Once I'd understood the problem, it was not difficult to fix,
though it did lead me a bit of refactoring in specImports.

- - - - -
fd12d84e by Simon Peyton Jones at 2020-04-03T13:13:57+01:00
Refactoring only

This refactors DictBinds into a data type rather than a pair.
No change in behaviour, just better code

- - - - -


5 changed files:

- compiler/GHC/Core/Op/Specialise.hs
- + testsuite/tests/simplCore/should_run/T17151.hs
- + testsuite/tests/simplCore/should_run/T17151.stdout
- + testsuite/tests/simplCore/should_run/T17151a.hs
- testsuite/tests/simplCore/should_run/all.T


Changes:

=====================================
compiler/GHC/Core/Op/Specialise.hs
=====================================
@@ -589,19 +589,11 @@ specProgram guts@(ModGuts { mg_module = this_mod
              -- Specialise the bindings of this module
        ; (binds', uds) <- runSpecM dflags this_mod (go binds)
 
-             -- Specialise imported functions
-       ; hpt_rules <- getRuleBase
-       ; let rule_base = extendRuleBaseList hpt_rules local_rules
-       ; (new_rules, spec_binds) <- specImports dflags this_mod top_env emptyVarSet
-                                                [] rule_base uds
-
-       ; let final_binds
-               | null spec_binds = binds'
-               | otherwise       = Rec (flattenBinds spec_binds) : binds'
-                   -- Note [Glom the bindings if imported functions are specialised]
+       ; (spec_rules, spec_binds) <- specImports dflags this_mod top_env
+                                                 local_rules uds
 
-       ; return (guts { mg_binds = final_binds
-                      , mg_rules = new_rules ++ local_rules }) }
+       ; return (guts { mg_binds = spec_binds ++ binds'
+                      , mg_rules = spec_rules ++ local_rules }) }
   where
         -- We need to start with a Subst that knows all the things
         -- that are in scope, so that the substitution engine doesn't
@@ -645,72 +637,93 @@ See #10491
 *                                                                      *
 ********************************************************************* -}
 
--- | Specialise a set of calls to imported bindings
-specImports :: DynFlags
-            -> Module
-            -> SpecEnv          -- Passed in so that all top-level Ids are in scope
-            -> VarSet           -- Don't specialise these ones
-                                -- See Note [Avoiding recursive specialisation]
-            -> [Id]             -- Stack of imported functions being specialised
-            -> RuleBase         -- Rules from this module and the home package
-                                -- (but not external packages, which can change)
-            -> UsageDetails     -- Calls for imported things, and floating bindings
-            -> CoreM ( [CoreRule]   -- New rules
-                     , [CoreBind] ) -- Specialised bindings
-                                    -- See Note [Wrapping bindings returned by specImports]
-specImports dflags this_mod top_env done callers rule_base
+specImports :: DynFlags -> Module -> SpecEnv
+            -> [CoreRule]
+            -> UsageDetails
+            -> CoreM ([CoreRule], [CoreBind])
+specImports dflags this_mod top_env local_rules
             (MkUD { ud_binds = dict_binds, ud_calls = calls })
-  -- See Note [Disabling cross-module specialisation]
   | not $ gopt Opt_CrossModuleSpecialise dflags
-  = return ([], [])
+    -- See Note [Disabling cross-module specialisation]
+  = return ([], wrapDictBinds dict_binds [])
 
   | otherwise
-  = do { let import_calls = dVarEnvElts calls
-       ; (rules, spec_binds) <- go rule_base import_calls
+  = do { hpt_rules <- getRuleBase
+       ; let rule_base = extendRuleBaseList hpt_rules local_rules
+
+       ; (spec_rules, spec_binds) <- spec_imports dflags this_mod top_env
+                                                  [] rule_base
+                                                  dict_binds calls
 
              -- Don't forget to wrap the specialized bindings with
              -- bindings for the needed dictionaries.
              -- See Note [Wrap bindings returned by specImports]
-       ; let spec_binds' = wrapDictBinds dict_binds spec_binds
+             -- and Note [Glom the bindings if imported functions are specialised]
+       ; let final_binds
+               | null spec_binds = wrapDictBinds dict_binds []
+               | otherwise       = [Rec $ flattenBinds $
+                                    wrapDictBinds dict_binds spec_binds]
+
+       ; return (spec_rules, final_binds)
+    }
+
+-- | Specialise a set of calls to imported bindings
+spec_imports :: DynFlags
+             -> Module
+             -> SpecEnv          -- Passed in so that all top-level Ids are in scope
+             -> [Id]             -- Stack of imported functions being specialised
+                                 -- See Note [specImport call stack]
+             -> RuleBase         -- Rules from this module and the home package
+                                 -- (but not external packages, which can change)
+             -> Bag DictBind     -- Dict bindings, used /only/ for filterCalls
+                                 -- See Note [Avoiding loops in specImports]
+             -> CallDetails      -- Calls for imported things
+             -> CoreM ( [CoreRule]   -- New rules
+                      , [CoreBind] ) -- Specialised bindings
+spec_imports dflags this_mod top_env
+             callers rule_base dict_binds calls
+  = do { let import_calls = dVarEnvElts calls
+       -- ; debugTraceMsg (text "specImports {" <+>
+       --                  vcat [ text "calls:" <+> ppr import_calls
+       --                       , text "dict_binds:" <+> ppr dict_binds ])
+       ; (rules, spec_binds) <- go rule_base import_calls
+       -- ; debugTraceMsg (text "End specImports }" <+> ppr import_calls)
 
-       ; return (rules, spec_binds') }
+       ; return (rules, spec_binds) }
   where
     go :: RuleBase -> [CallInfoSet] -> CoreM ([CoreRule], [CoreBind])
     go _ [] = return ([], [])
-    go rb (cis@(CIS fn _) : other_calls)
-      = do { let ok_calls = filterCalls cis dict_binds
-                     -- Drop calls that (directly or indirectly) refer to fn
-                     -- See Note [Avoiding loops]
---           ; debugTraceMsg (text "specImport" <+> vcat [ ppr fn
---                                                       , text "calls" <+> ppr cis
---                                                       , text "ud_binds =" <+> ppr dict_binds
---                                                       , text "dump set =" <+> ppr dump_set
---                                                       , text "filtered calls =" <+> ppr ok_calls ])
-           ; (rules1, spec_binds1) <- specImport dflags this_mod top_env
-                                                 done callers rb fn ok_calls
+    go rb (cis : other_calls)
+      = do { -- debugTraceMsg (text "specImport {" <+> ppr cis)
+           ; (rules1, spec_binds1) <- spec_import dflags this_mod top_env
+                                                  callers rb dict_binds cis
+           -- ; debugTraceMsg (text "specImport }" <+> ppr cis)
 
            ; (rules2, spec_binds2) <- go (extendRuleBaseList rb rules1) other_calls
            ; return (rules1 ++ rules2, spec_binds1 ++ spec_binds2) }
 
-specImport :: DynFlags
-           -> Module
-           -> SpecEnv               -- Passed in so that all top-level Ids are in scope
-           -> VarSet                -- Don't specialise these
-                                    -- See Note [Avoiding recursive specialisation]
-           -> [Id]                  -- Stack of imported functions being specialised
-           -> RuleBase              -- Rules from this module
-           -> Id -> [CallInfo]      -- Imported function and calls for it
-           -> CoreM ( [CoreRule]    -- New rules
-                    , [CoreBind] )  -- Specialised bindings
-specImport dflags this_mod top_env done callers rb fn calls_for_fn
-  | fn `elemVarSet` done
+spec_import :: DynFlags
+            -> Module
+            -> SpecEnv               -- Passed in so that all top-level Ids are in scope
+            -> [Id]                  -- Stack of imported functions being specialised
+                                     -- See Note [specImport call stack]
+            -> RuleBase              -- Rules from this module
+            -> Bag DictBind          -- Dict bindings, used /only/ for filterCalls
+                                     -- See Note [Avoiding loops in specImports]
+            -> CallInfoSet           -- Imported function and calls for it
+            -> CoreM ( [CoreRule]    -- New rules
+                     , [CoreBind] )  -- Specialised bindings
+spec_import dflags this_mod top_env callers
+            rb dict_binds cis@(CIS fn _)
+  | isIn "specImport" fn callers
   = return ([], [])     -- No warning.  This actually happens all the time
                         -- when specialising a recursive function, because
                         -- the RHS of the specialised function contains a recursive
                         -- call to the original function
 
-  | null calls_for_fn   -- We filtered out all the calls in deleteCallsMentioning
-  = return ([], [])
+  | null good_calls
+  = do { -- debugTraceMsg (text "specImport:no valid calls")
+       ; return ([], []) }
 
   | wantSpecImport dflags unfolding
   , Just rhs <- maybeUnfoldingTemplate unfolding
@@ -723,32 +736,37 @@ specImport dflags this_mod top_env done callers rb fn calls_for_fn
        ; let full_rb = unionRuleBase rb (eps_rule_base eps)
              rules_for_fn = getRules (RuleEnv full_rb vis_orphs) fn
 
-       ; (rules1, spec_pairs, uds)
-             <- -- pprTrace "specImport1" (vcat [ppr fn, ppr calls_for_fn, ppr rhs]) $
-                runSpecM dflags this_mod $
-                specCalls (Just this_mod) top_env rules_for_fn calls_for_fn fn rhs
+       ; (rules1, spec_pairs, MkUD { ud_binds = dict_binds1, ud_calls = new_calls })
+             <- do { -- debugTraceMsg (text "specImport1" <+> vcat [ppr fn, ppr good_calls, ppr rhs])
+                   ; runSpecM dflags this_mod $
+                     specCalls (Just this_mod) top_env rules_for_fn good_calls fn rhs }
        ; let spec_binds1 = [NonRec b r | (b,r) <- spec_pairs]
              -- After the rules kick in we may get recursion, but
              -- we rely on a global GlomBinds to sort that out later
              -- See Note [Glom the bindings if imported functions are specialised]
 
               -- Now specialise any cascaded calls
-       ; (rules2, spec_binds2) <- -- pprTrace "specImport 2" (ppr fn $$ ppr rules1 $$ ppr spec_binds1) $
-                                  specImports dflags this_mod top_env
-                                              (extendVarSet done fn)
-                                              (fn:callers)
-                                              (extendRuleBaseList rb rules1)
-                                              uds
+       -- ; debugTraceMsg (text "specImport 2" <+> (ppr fn $$ ppr rules1 $$ ppr spec_binds1))
+       ; (rules2, spec_binds2) <- spec_imports dflags this_mod top_env
+                                               (fn:callers)
+                                               (extendRuleBaseList rb rules1)
+                                               (dict_binds `unionBags` dict_binds1)
+                                               new_calls
 
-       ; let final_binds = spec_binds2 ++ spec_binds1
+       ; let final_binds = wrapDictBinds dict_binds1 $
+                           spec_binds2 ++ spec_binds1
 
        ; return (rules2 ++ rules1, final_binds) }
 
-  | otherwise = do { tryWarnMissingSpecs dflags callers fn calls_for_fn
-                   ; return ([], [])}
+  | otherwise
+  = do { tryWarnMissingSpecs dflags callers fn good_calls
+       ; return ([], [])}
 
   where
     unfolding = realIdUnfolding fn   -- We want to see the unfolding even for loop breakers
+    good_calls = filterCalls cis dict_binds
+       -- SUPER IMPORTANT!  Drop calls that (directly or indirectly) refer to fn
+       -- See Note [Avoiding loops in specImports]
 
 -- | Returns whether or not to show a missed-spec warning.
 -- If -Wall-missed-specializations is on, show the warning.
@@ -790,8 +808,114 @@ wantSpecImport dflags unf
                -- inside it that we want to specialise
        | otherwise -> False    -- Stable, not INLINE, hence INLINABLE
 
-{- Note [Warning about missed specialisations]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+{- Note [Avoiding loops in specImports]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+We must take great care when specialising instance declarations
+(functions like $fOrdList) lest we accidentally build a recursive
+dictionary. See Note [Avoiding loops].
+
+The basic strategy of Note [Avoiding loops] is to use filterCalls
+to discard loopy specialisations.  But to do that we must ensure
+that the in-scope dict-binds (passed to filterCalls) contains
+all the needed dictionary bindings.  In particular, in the recursive
+call to spec_imorpts in spec_import, we must include the dict-binds
+from the parent.  Lacking this caused #17151, a really nasty bug.
+
+Here is what happened.
+* Class struture:
+    Source is a superclass of Mut
+    Index is a superclass of Source
+
+* We started with these dict binds
+    dSource = $fSourcePix @Int $fIndexInt
+    dIndex  = sc_sel dSource
+    dMut    = $fMutPix @Int dIndex
+  and these calls to specialise
+    $fMutPix @Int dIndex
+    $fSourcePix @Int $fIndexInt
+
+* We specialised the call ($fMutPix @Int dIndex)
+  ==> new call ($fSourcePix @Int dIndex)
+      (because Source is a superclass of Mut)
+
+* We specialised ($fSourcePix @Int dIndex)
+  ==> produces specialised dict $s$fSourcePix,
+      a record with dIndex as a field
+      plus RULE forall d. ($fSourcePix @Int d) = $s$fSourcePix
+  *** This is the bogus step ***
+
+* Now we decide not to specialise the call
+    $fSourcePix @Int $fIndexInt
+  because we alredy have a RULE that matches it
+
+* Finally the simplifer rewrites
+    dSource = $fSourcePix @Int $fIndexInt
+    ==>  dSource = $s$fSourcePix
+
+Disaster. Now we have
+
+Rewrite dSource's RHS to $s$fSourcePix   Disaster
+    dSource = $s$fSourcePix
+    dIndex  = sc_sel dSource
+    $s$fSourcePix = MkSource dIndex ...
+
+Solution: filterCalls should have stopped the bogus step,
+by seeing that dIndex transitively uses $fSourcePix. But
+it can only do that if it sees all the dict_binds.  Wow.
+
+--------------
+Here's another example (#13429).  Suppose we have
+  class Monoid v => C v a where ...
+
+We start with a call
+   f @ [Integer] @ Integer $fC[]Integer
+
+Specialising call to 'f' gives dict bindings
+   $dMonoid_1 :: Monoid [Integer]
+   $dMonoid_1 = M.$p1C @ [Integer] $fC[]Integer
+
+   $dC_1 :: C [Integer] (Node [Integer] Integer)
+   $dC_1 = M.$fCvNode @ [Integer] $dMonoid_1
+
+...plus a recursive call to
+   f @ [Integer] @ (Node [Integer] Integer) $dC_1
+
+Specialising that call gives
+   $dMonoid_2  :: Monoid [Integer]
+   $dMonoid_2  = M.$p1C @ [Integer] $dC_1
+
+   $dC_2 :: C [Integer] (Node [Integer] Integer)
+   $dC_2 = M.$fCvNode @ [Integer] $dMonoid_2
+
+Now we have two calls to the imported function
+  M.$fCvNode :: Monoid v => C v a
+  M.$fCvNode @v @a m = C m some_fun
+
+But we must /not/ use the call (M.$fCvNode @ [Integer] $dMonoid_2)
+for specialisation, else we get:
+
+  $dC_1 = M.$fCvNode @ [Integer] $dMonoid_1
+  $dMonoid_2 = M.$p1C @ [Integer] $dC_1
+  $s$fCvNode = C $dMonoid_2 ...
+    RULE M.$fCvNode [Integer] _ _ = $s$fCvNode
+
+Now use the rule to rewrite the call in the RHS of $dC_1
+and we get a loop!
+
+
+Note [specImport call stack]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+When specialising an imports function 'f', we may get new calls
+of an imported fuction 'g', which we want to specialise in turn,
+and similarly specialising 'g' might expose a new call to 'h'.
+
+We track the stack of enclosing functions. So when specialising 'h' we
+haev a specImport call stack of [g,f]. We do this for two reasons:
+* Note [Warning about missed specialisations]
+* Note [Avoiding recursive specialisation]
+
+Note [Warning about missed specialisations]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 Suppose
  * In module Lib, you carefully mark a function 'foo' INLINABLE
  * Import Lib(foo) into another module M
@@ -807,6 +931,16 @@ is what Opt_WarnAllMissedSpecs does.
 
 ToDo: warn about missed opportunities for local functions.
 
+Note [Avoiding recursive specialisation]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+When we specialise 'f' we may find new overloaded calls to 'g', 'h' in
+'f's RHS.  So we want to specialise g,h.  But we don't want to
+specialise f any more!  It's possible that f's RHS might have a
+recursive yet-more-specialised call, so we'd diverge in that case.
+And if the call is to the same type, one specialisation is enough.
+Avoiding this recursive specialisation loop is one reason for the
+'callers' stack passed to specImports and specImport.
+
 Note [Specialise imported INLINABLE things]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 What imported functions do we specialise?  The basic set is
@@ -842,15 +976,6 @@ make sure that f_spec is recursive.  Easiest thing is to make all
 the specialisations for imported bindings recursive.
 
 
-Note [Avoiding recursive specialisation]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-When we specialise 'f' we may find new overloaded calls to 'g', 'h' in
-'f's RHS.  So we want to specialise g,h.  But we don't want to
-specialise f any more!  It's possible that f's RHS might have a
-recursive yet-more-specialised call, so we'd diverge in that case.
-And if the call is to the same type, one specialisation is enough.
-Avoiding this recursive specialisation loop is the reason for the
-'done' VarSet passed to specImports and specImport.
 
 ************************************************************************
 *                                                                      *
@@ -992,7 +1117,8 @@ specCase env scrut' case_bndr [(con, args, rhs)]
        ; (rhs', rhs_uds)   <- specExpr env_rhs' rhs
        ; let scrut_bind    = mkDB (NonRec case_bndr_flt scrut')
              case_bndr_set = unitVarSet case_bndr_flt
-             sc_binds      = [(NonRec sc_arg_flt sc_rhs, case_bndr_set)
+             sc_binds      = [ DB { db_bind = NonRec sc_arg_flt sc_rhs
+                                  , db_fvs  = case_bndr_set }
                              | (sc_arg_flt, sc_rhs) <- sc_args_flt `zip` sc_rhss ]
              flt_binds     = scrut_bind : sc_binds
              (free_uds, dumped_dbs) = dumpUDs (case_bndr':args') rhs_uds
@@ -1115,7 +1241,7 @@ specBind rhs_env (NonRec fn rhs) body_uds
          else
              -- No call in final_uds mentions bound variables,
              -- so we can just leave the binding here
-              return (map fst final_binds, free_uds) }
+              return (map db_bind final_binds, free_uds) }
 
 
 specBind rhs_env (Rec pairs) body_uds
@@ -1142,7 +1268,7 @@ specBind rhs_env (Rec pairs) body_uds
        ; if float_all then
               return ([], final_uds `snocDictBind` final_bind)
          else
-              return ([fst final_bind], final_uds) }
+              return ([db_bind final_bind], final_uds) }
 
 
 ---------------------------
@@ -1621,8 +1747,10 @@ In general, we need only make this Rec if
 Note [Avoiding loops]
 ~~~~~~~~~~~~~~~~~~~~~
 When specialising /dictionary functions/ we must be very careful to
-avoid building loops. Here is an example that bit us badly: #3591
+avoid building loops. Here is an example that bit us badly, on
+several distinct occasions.
 
+Here is one: #3591
      class Eq a => C a
      instance Eq [a] => C [a]
 
@@ -1637,13 +1765,11 @@ This translates to
 
 None of these definitions is recursive. What happened was that we
 generated a specialisation:
-
      RULE forall d. dfun T d = dT  :: C [T]
      dT = (MkD a d (meth d)) [T/a, d1/d]
         = MkD T d1 (meth d1)
 
 But now we use the RULE on the RHS of d2, to get
-
     d2 = dT = MkD d1 (meth d1)
     d1 = $p1 d2
 
@@ -1660,46 +1786,6 @@ Solution:
   (directly or indirectly) on the dfun we are specialising.
   This is done by 'filterCalls'
 
---------------
-Here's another example, this time for an imported dfun, so the call
-to filterCalls is in specImports (#13429). Suppose we have
-  class Monoid v => C v a where ...
-
-We start with a call
-   f @ [Integer] @ Integer $fC[]Integer
-
-Specialising call to 'f' gives dict bindings
-   $dMonoid_1 :: Monoid [Integer]
-   $dMonoid_1 = M.$p1C @ [Integer] $fC[]Integer
-
-   $dC_1 :: C [Integer] (Node [Integer] Integer)
-   $dC_1 = M.$fCvNode @ [Integer] $dMonoid_1
-
-...plus a recursive call to
-   f @ [Integer] @ (Node [Integer] Integer) $dC_1
-
-Specialising that call gives
-   $dMonoid_2  :: Monoid [Integer]
-   $dMonoid_2  = M.$p1C @ [Integer] $dC_1
-
-   $dC_2 :: C [Integer] (Node [Integer] Integer)
-   $dC_2 = M.$fCvNode @ [Integer] $dMonoid_2
-
-Now we have two calls to the imported function
-  M.$fCvNode :: Monoid v => C v a
-  M.$fCvNode @v @a m = C m some_fun
-
-But we must /not/ use the call (M.$fCvNode @ [Integer] $dMonoid_2)
-for specialisation, else we get:
-
-  $dC_1 = M.$fCvNode @ [Integer] $dMonoid_1
-  $dMonoid_2 = M.$p1C @ [Integer] $dC_1
-  $s$fCvNode = C $dMonoid_2 ...
-    RULE M.$fCvNode [Integer] _ _ = $s$fCvNode
-
-Now use the rule to rewrite the call in the RHS of $dC_1
-and we get a loop!
-
 --------------
 Here's yet another example
 
@@ -2227,7 +2313,7 @@ data UsageDetails
 
 -- | A 'DictBind' is a binding along with a cached set containing its free
 -- variables (both type variables and dictionaries)
-type DictBind = (CoreBind, VarSet)
+data DictBind = DB { db_bind :: CoreBind, db_fvs :: VarSet }
 
 {- Note [Floated dictionary bindings]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -2256,6 +2342,11 @@ So the DictBinds in (ud_binds :: Bag DictBind) may contain
 non-dictionary bindings too.
 -}
 
+instance Outputable DictBind where
+  ppr (DB { db_bind = bind, db_fvs = fvs })
+    = text "DB" <+> braces (sep [ text "bind:" <+> ppr bind
+                                , text "fvs: " <+> ppr fvs ])
+
 instance Outputable UsageDetails where
   ppr (MkUD { ud_binds = dbs, ud_calls = calls })
         = text "MkUD" <+> braces (sep (punctuate comma
@@ -2304,8 +2395,8 @@ ppr_call_key_ty (SpecDict _)  = Nothing
 ppr_call_key_ty UnspecArg     = Nothing
 
 instance Outputable CallInfo where
-  ppr (CI { ci_key = key, ci_fvs = fvs })
-    = text "CI" <> braces (hsep [ fsep (mapMaybe ppr_call_key_ty key), ppr fvs ])
+  ppr (CI { ci_key = key, ci_fvs = _fvs })
+    = text "CI" <> braces (sep (map ppr key))
 
 unionCalls :: CallDetails -> CallDetails -> CallDetails
 unionCalls c1 c2 = plusDVarEnv_C unionCallInfoSet c1 c2
@@ -2491,11 +2582,11 @@ plusUDs (MkUD {ud_binds = db1, ud_calls = calls1})
 
 -----------------------------
 _dictBindBndrs :: Bag DictBind -> [Id]
-_dictBindBndrs dbs = foldr ((++) . bindersOf . fst) [] dbs
+_dictBindBndrs dbs = foldr ((++) . bindersOf . db_bind) [] dbs
 
 -- | Construct a 'DictBind' from a 'CoreBind'
 mkDB :: CoreBind -> DictBind
-mkDB bind = (bind, bind_fvs bind)
+mkDB bind = DB { db_bind = bind, db_fvs = bind_fvs bind }
 
 -- | Identify the free variables of a 'CoreBind'
 bind_fvs :: CoreBind -> VarSet
@@ -2526,17 +2617,18 @@ pair_fvs (bndr, rhs) = exprSomeFreeVars interesting rhs
 
 -- | Flatten a set of "dumped" 'DictBind's, and some other binding
 -- pairs, into a single recursive binding.
-recWithDumpedDicts :: [(Id,CoreExpr)] -> Bag DictBind ->DictBind
+recWithDumpedDicts :: [(Id,CoreExpr)] -> Bag DictBind -> DictBind
 recWithDumpedDicts pairs dbs
-  = (Rec bindings, fvs)
+  = DB { db_bind = Rec bindings, db_fvs = fvs }
   where
-    (bindings, fvs) = foldr add
-                               ([], emptyVarSet)
-                               (dbs `snocBag` mkDB (Rec pairs))
-    add (NonRec b r, fvs') (pairs, fvs) =
-      ((b,r) : pairs, fvs `unionVarSet` fvs')
-    add (Rec prs1,   fvs') (pairs, fvs) =
-      (prs1 ++ pairs, fvs `unionVarSet` fvs')
+    (bindings, fvs) = foldr add ([], emptyVarSet)
+                                (dbs `snocBag` mkDB (Rec pairs))
+    add (DB { db_bind = bind, db_fvs = fvs }) (prs_acc, fvs_acc)
+      = case bind of
+          NonRec b r -> ((b,r) : prs_acc, fvs')
+          Rec prs1   -> (prs1 ++ prs_acc, fvs')
+      where
+        fvs' = fvs_acc `unionVarSet` fvs
 
 snocDictBinds :: UsageDetails -> [DictBind] -> UsageDetails
 -- Add ud_binds to the tail end of the bindings in uds
@@ -2556,13 +2648,13 @@ wrapDictBinds :: Bag DictBind -> [CoreBind] -> [CoreBind]
 wrapDictBinds dbs binds
   = foldr add binds dbs
   where
-    add (bind,_) binds = bind : binds
+    add (DB { db_bind = bind }) binds = bind : binds
 
 wrapDictBindsE :: Bag DictBind -> CoreExpr -> CoreExpr
 wrapDictBindsE dbs expr
   = foldr add expr dbs
   where
-    add (bind,_) expr = Let bind expr
+    add (DB { db_bind = bind }) expr = Let bind expr
 
 ----------------------
 dumpUDs :: [CoreBndr] -> UsageDetails -> (UsageDetails, Bag DictBind)
@@ -2624,9 +2716,10 @@ filterCalls (CIS fn call_bag) dbs
       --   (_,_,dump_set) = splitDictBinds dbs {fn}
       -- But this variant is shorter
 
-    go so_far (db,fvs) | fvs `intersectsVarSet` so_far
-                       = extendVarSetList so_far (bindersOf db)
-                       | otherwise = so_far
+    go so_far (DB { db_bind = bind, db_fvs = fvs })
+       | fvs `intersectsVarSet` so_far
+       = extendVarSetList so_far (bindersOf bind)
+       | otherwise = so_far
 
     ok_call (CI { ci_fvs = fvs }) = not (fvs `intersectsVarSet` dump_set)
 
@@ -2643,8 +2736,9 @@ splitDictBinds dbs bndr_set
                 -- Important that it's foldl' not foldr;
                 -- we're accumulating the set of dumped ids in dump_set
    where
-    split_db (free_dbs, dump_dbs, dump_idset) db@(bind, fvs)
-        | dump_idset `intersectsVarSet` fvs     -- Dump it
+    split_db (free_dbs, dump_dbs, dump_idset) db
+        | DB { db_bind = bind, db_fvs = fvs } <- db
+        , dump_idset `intersectsVarSet` fvs     -- Dump it
         = (free_dbs, dump_dbs `snocBag` db,
            extendVarSetList dump_idset (bindersOf bind))
 


=====================================
testsuite/tests/simplCore/should_run/T17151.hs
=====================================
@@ -0,0 +1,18 @@
+{-# LANGUAGE MonoLocalBinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+module Main where
+
+import T17151a
+
+main :: IO ()
+main = do
+  let ys :: Array P Int Int
+      ys = computeS (makeArray D 1 (const 5))
+      applyStencil ::
+           (Source P ix Int, Load D ix Int)
+        => Stencil ix Int Int
+        -> Array P ix Int
+        -> Array P ix Int
+      applyStencil s = computeS . mapStencil s
+  print (applyStencil (makeConvolutionStencilFromKernel ys) ys `unsafeIndex` 0)
+  print (applyStencil (makeConvolutionStencilFromKernel ys) ys `unsafeIndex` 0)


=====================================
testsuite/tests/simplCore/should_run/T17151.stdout
=====================================
@@ -0,0 +1,2 @@
+55
+55


=====================================
testsuite/tests/simplCore/should_run/T17151a.hs
=====================================
@@ -0,0 +1,205 @@
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE EmptyDataDecls #-}
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE UnboxedTuples #-}
+module T17151a
+  ( computeS
+  , Stencil
+  , P(..)
+  , D(..)
+  , makeConvolutionStencilFromKernel
+  , mapStencil
+  , Array
+  , Construct(..)
+  , Source(..)
+  , Load(..)
+  , Mutable(..)
+  ) where
+
+import Control.Monad.ST
+import Data.Functor.Identity
+import GHC.STRef
+import GHC.ST
+import GHC.Exts
+import Unsafe.Coerce
+import Data.Kind
+
+----  Hacked up stuff to simulate primitive package
+class Prim e where
+  indexByteArray :: ByteArray -> Int -> e
+  sizeOf :: e ->Int
+instance Prim Int where
+  indexByteArray _ _ = 55
+  sizeOf _ = 99
+
+data ByteArray = BA
+type MutableByteArray s = STRef s Int
+
+class Monad m => PrimMonad m where
+  type PrimState m
+  primitive :: (State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
+instance PrimMonad (ST s) where
+  type PrimState (ST s) = s
+  primitive = ST
+
+unsafeFreezeByteArray :: PrimMonad m => MutableByteArray (PrimState m) -> m ByteArray
+unsafeFreezeByteArray a = return (unsafeCoerce a)
+
+newByteArray :: PrimMonad m => Int -> m (MutableByteArray (PrimState m))
+newByteArray (I# n#)
+  = primitive (\s# -> case newMutVar# 33 s# of
+                        (# s'#, arr# #) -> (# s'#, STRef arr# #))
+
+writeByteArray :: PrimMonad m => MutableByteArray (PrimState m) -> Int -> e -> m ()
+writeByteArray _ _ _ = return ()
+
+-----  End of hacked up stuff
+
+--------------
+newtype Stencil ix e a =
+  Stencil ((ix -> e) -> ix -> a)
+
+mapStencil :: Source r ix e => Stencil ix e a -> Array r ix e -> Array D ix a
+mapStencil (Stencil stencilF) arr = DArray (size arr) (stencilF (unsafeIndex arr))
+{-# INLINE mapStencil #-}
+
+makeConvolutionStencilFromKernel
+  :: (Source r ix e, Num e)
+  => Array r ix e
+  -> Stencil ix e e
+makeConvolutionStencilFromKernel arr = Stencil stencil
+  where
+    sz = size arr
+    sCenter = liftIndex (`quot` 2) sz
+    stencil getVal ix =
+      runIdentity $
+      loopM 0 (< totalElem sz) (+ 1) 0 $ \i a ->
+        pure $ accum a (fromLinearIndex sz i) (unsafeLinearIndex arr i)
+      where
+        ixOff = liftIndex2 (+) ix sCenter
+        accum acc kIx kVal = getVal (liftIndex2 (-) ixOff kIx) * kVal + acc
+        {-# INLINE accum #-}
+    {-# INLINE stencil #-}
+{-# INLINE makeConvolutionStencilFromKernel #-}
+
+
+computeS :: (Mutable r ix e, Load r' ix e) => Array r' ix e -> Array r ix e
+computeS arr = runST $ do
+  marr <- unsafeNew (size arr)
+  unsafeLoadIntoS marr arr
+  unsafeFreeze marr
+{-# INLINE computeS #-}
+
+
+data D = D deriving Show
+
+data instance  Array D ix e = DArray{dSize :: ix,
+                                     dIndex :: ix -> e}
+
+instance Index ix => Construct D ix e where
+  makeArray _ = DArray
+  {-# INLINE makeArray #-}
+
+instance Index ix => Source D ix e where
+  unsafeIndex = dIndex
+  {-# INLINE unsafeIndex #-}
+
+instance Index ix => Load D ix e where
+  size = dSize
+  {-# INLINE size #-}
+  loadArrayM arr = splitLinearlyWith_ (totalElem (size arr)) (unsafeLinearIndex arr)
+  {-# INLINE loadArrayM #-}
+
+
+data P = P deriving Show
+
+data instance Array P ix e = PArray ix ByteArray
+
+instance (Prim e, Index ix) => Construct P ix e where
+  makeArray _ sz f = computeS (makeArray D sz f)
+  {-# INLINE makeArray #-}
+
+instance (Prim e, Index ix) => Source P ix e where
+  unsafeIndex (PArray sz a) = indexByteArray a . toLinearIndex sz
+  {-# INLINE unsafeIndex #-}
+
+instance (Prim e, Index ix) => Mutable P ix e where
+  data MArray s P ix e = MPArray ix (MutableByteArray s)
+  unsafeFreeze (MPArray sz a) = PArray sz <$> unsafeFreezeByteArray a
+  {-# INLINE unsafeFreeze #-}
+  unsafeNew sz = MPArray sz <$> newByteArray (totalElem sz * eSize)
+    where
+      eSize = sizeOf (undefined :: e)
+  {-# INLINE unsafeNew #-}
+  unsafeLinearWrite (MPArray _ ma) = writeByteArray ma
+  {-# INLINE unsafeLinearWrite #-}
+
+
+instance (Prim e, Index ix) => Load P ix e where
+  size (PArray sz _) = sz
+  {-# INLINE size #-}
+  loadArrayM arr = splitLinearlyWith_ (totalElem (size arr)) (unsafeLinearIndex arr)
+  {-# INLINE loadArrayM #-}
+
+
+unsafeLinearIndex :: Source r ix e => Array r ix e -> Int -> e
+unsafeLinearIndex arr = unsafeIndex arr . fromLinearIndex (size arr)
+{-# INLINE unsafeLinearIndex #-}
+
+
+loopM :: Monad m => Int -> (Int -> Bool) -> (Int -> Int) -> a -> (Int -> a -> m a) -> m a
+loopM init' condition increment initAcc f = go init' initAcc
+  where
+    go step acc
+      | condition step = f step acc >>= go (increment step)
+      | otherwise = return acc
+{-# INLINE loopM #-}
+
+splitLinearlyWith_ ::
+     Monad m => Int -> (Int -> b) -> (Int -> b -> m ()) -> m ()
+splitLinearlyWith_ totalLength index write =
+  loopM 0 (< totalLength) (+1) () $ \i () -> write i (index i)
+{-# INLINE splitLinearlyWith_ #-}
+
+
+data family Array r ix e :: Type
+
+class Index ix => Construct r ix e where
+  makeArray :: r -> ix -> (ix -> e) -> Array r ix e
+
+class Load r ix e => Source r ix e where
+  unsafeIndex :: Array r ix e -> ix -> e
+
+class Index ix => Load r ix e where
+  size :: Array r ix e -> ix
+  loadArrayM :: Monad m => Array r ix e -> (Int -> e -> m ()) -> m ()
+  unsafeLoadIntoS ::
+       (Mutable r' ix e, PrimMonad m) => MArray (PrimState m) r' ix e -> Array r ix e -> m ()
+  unsafeLoadIntoS marr arr = loadArrayM arr (unsafeLinearWrite marr)
+  {-# INLINE unsafeLoadIntoS #-}
+
+class (Construct r ix e, Source r ix e) => Mutable r ix e where
+  data MArray s r ix e :: Type
+  unsafeFreeze :: PrimMonad m => MArray (PrimState m) r ix e -> m (Array r ix e)
+  unsafeNew :: PrimMonad m => ix -> m (MArray (PrimState m) r ix e)
+  unsafeLinearWrite :: PrimMonad m => MArray (PrimState m) r ix e -> Int -> e -> m ()
+
+
+class (Eq ix, Ord ix, Show ix) =>
+      Index ix
+  where
+  totalElem :: ix -> Int
+  liftIndex2 :: (Int -> Int -> Int) -> ix -> ix -> ix
+  liftIndex :: (Int -> Int) -> ix -> ix
+  toLinearIndex :: ix -> ix -> Int
+  fromLinearIndex :: ix -> Int -> ix
+
+instance Index Int where
+  totalElem = id
+  toLinearIndex _ = id
+  fromLinearIndex _ = id
+  liftIndex f = f
+  liftIndex2 f = f


=====================================
testsuite/tests/simplCore/should_run/all.T
=====================================
@@ -92,3 +92,4 @@ test('T15840', normal, compile_and_run, [''])
 test('T15840a', normal, compile_and_run, [''])
 test('T16066', exit_code(1), compile_and_run, ['-O1'])
 test('T17206', exit_code(1), compile_and_run, [''])
+test('T17151', [], multimod_compile_and_run, ['T17151', ''])



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/8927189a1abc3f107bc2c1e06df3dbe7d5e2bd98...fd12d84eaa2e16a96ccbe4bdcfc8309f1312d4be

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/8927189a1abc3f107bc2c1e06df3dbe7d5e2bd98...fd12d84eaa2e16a96ccbe4bdcfc8309f1312d4be
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20200403/18b4c925/attachment-0001.html>


More information about the ghc-commits mailing list