[Git][ghc/ghc][wip/T18120] Fix specialisation for DFuns

Simon Peyton Jones gitlab at gitlab.haskell.org
Fri May 1 14:29:06 UTC 2020



Simon Peyton Jones pushed to branch wip/T18120 at Glasgow Haskell Compiler / GHC


Commits:
7e77aca9 by Simon Peyton Jones at 2020-05-01T15:28:40+01:00
Fix specialisation for DFuns

When specialising a DFun we must take care to saturate the
unfolding.  See Note [Specialising DFuns] in Specialise.

Fixes #18120

- - - - -


6 changed files:

- compiler/GHC/Core/Opt/Specialise.hs
- compiler/GHC/Core/Unfold.hs
- compiler/GHC/HsToCore/Binds.hs
- compiler/GHC/Tc/Gen/Sig.hs
- + testsuite/tests/simplCore/should_compile/T18120.hs
- testsuite/tests/simplCore/should_compile/all.T


Changes:

=====================================
compiler/GHC/Core/Opt/Specialise.hs
=====================================
@@ -1362,6 +1362,7 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
     inl_prag  = idInlinePragma fn
     inl_act   = inlinePragmaActivation inl_prag
     is_local  = isLocalId fn
+    is_dfun   = isDFunId fn
 
         -- Figure out whether the function has an INLINE pragma
         -- See Note [Inline specialisations]
@@ -1384,22 +1385,34 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
     spec_call :: SpecInfo                         -- Accumulating parameter
               -> CallInfo                         -- Call instance
               -> SpecM SpecInfo
-    spec_call spec_acc@(rules_acc, pairs_acc, uds_acc) (CI { ci_key = call_args })
+    spec_call spec_acc@(rules_acc, pairs_acc, uds_acc) _ci@(CI { ci_key = call_args })
       = -- See Note [Specialising Calls]
-        do { ( useful, rhs_env2, leftover_bndrs
+        do { let all_call_args | is_dfun   = call_args ++ repeat UnspecArg
+                               | otherwise = call_args
+                               -- See Note [Specialising DFuns]
+           ; ( useful, rhs_env2, leftover_bndrs
              , rule_bndrs, rule_lhs_args
-             , spec_bndrs, dx_binds, spec_args) <- specHeader env rhs_bndrs call_args
+             , spec_bndrs1, dx_binds, spec_args) <- specHeader env rhs_bndrs all_call_args
+
+--           ; pprTrace "spec_call" (vcat [ text "call info: " <+> ppr _ci
+--                                        , text "useful:    " <+> ppr useful
+--                                        , text "rule_bndrs:" <+> ppr rule_bndrs
+--                                        , text "lhs_args:  " <+> ppr rule_lhs_args
+--                                        , text "spec_bndrs:" <+> ppr spec_bndrs1
+--                                        , text "spec_args: " <+> ppr spec_args
+--                                        , text "dx_binds:  " <+> ppr dx_binds
+--                                        , text "rhs_env2:  " <+> ppr (se_subst rhs_env2)
+--                                        , ppr dx_binds ]) $
+--             return ()
 
            ; dflags <- getDynFlags
            ; if not useful  -- No useful specialisation
                 || already_covered dflags rules_acc rule_lhs_args
              then return spec_acc
-             else -- pprTrace "spec_call" (vcat [ ppr _call_info, ppr fn, ppr rhs_dict_ids
-                  --                           , text "rhs_env2" <+> ppr (se_subst rhs_env2)
-                  --                           , ppr dx_binds ]) $
+             else
         do { -- Run the specialiser on the specialised RHS
              -- The "1" suffix is before we maybe add the void arg
-           ; (spec_rhs1, rhs_uds) <- specLam rhs_env2 (spec_bndrs ++ leftover_bndrs) rhs_body
+           ; (spec_rhs1, rhs_uds) <- specLam rhs_env2 (spec_bndrs1 ++ leftover_bndrs) rhs_body
            ; let spec_fn_ty1 = exprType spec_rhs1
 
                  -- Maybe add a void arg to the specialised function,
@@ -1407,14 +1420,13 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
                  -- See Note [Specialisations Must Be Lifted]
                  -- C.f. GHC.Core.Opt.WorkWrap.Utils.mkWorkerArgs
                  add_void_arg = isUnliftedType spec_fn_ty1 && not (isJoinId fn)
-                 (spec_rhs, spec_fn_ty, rule_rhs_args)
-                   | add_void_arg = ( Lam        voidArgId  spec_rhs1
-                                    , mkVisFunTy voidPrimTy spec_fn_ty1
-                                    , voidPrimId : spec_bndrs)
-                   | otherwise   = (spec_rhs1, spec_fn_ty1, spec_bndrs)
-
-                 arity_decr      = count isValArg rule_lhs_args - count isId rule_rhs_args
-                 join_arity_decr = length rule_lhs_args - length rule_rhs_args
+                 (spec_bndrs, spec_rhs, spec_fn_ty)
+                   | add_void_arg = ( voidPrimId : spec_bndrs1
+                                    , Lam        voidArgId  spec_rhs1
+                                    , mkVisFunTy voidPrimTy spec_fn_ty1)
+                   | otherwise   = (spec_bndrs1, spec_rhs1, spec_fn_ty1)
+
+                 join_arity_decr = length rule_lhs_args - length spec_bndrs
                  spec_join_arity | Just orig_join_arity <- isJoinId_maybe fn
                                  = Just (orig_join_arity - join_arity_decr)
                                  | otherwise
@@ -1449,7 +1461,7 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
                                   (idName fn)
                                   rule_bndrs
                                   rule_lhs_args
-                                  (mkVarApps (Var spec_fn) rule_rhs_args)
+                                  (mkVarApps (Var spec_fn) spec_bndrs)
 
                 spec_rule
                   = case isJoinId_maybe fn of
@@ -1472,15 +1484,15 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
                   = (inl_prag { inl_inline = NoUserInline }, noUnfolding)
 
                   | otherwise
-                  = (inl_prag, specUnfolding dflags fn spec_bndrs spec_app arity_decr fn_unf)
-
-                spec_app e = e `mkApps` spec_args
+                  = (inl_prag, specUnfolding dflags spec_bndrs (`mkApps` spec_args)
+                                             rule_lhs_args fn_unf)
 
                 --------------------------------------
                 -- Adding arity information just propagates it a bit faster
                 --      See Note [Arity decrease] in GHC.Core.Opt.Simplify
                 -- Copy InlinePragma information from the parent Id.
                 -- So if f has INLINE[1] so does spec_fn
+                arity_decr     = count isValArg rule_lhs_args - count isId spec_bndrs
                 spec_f_w_arity = spec_fn `setIdArity`      max 0 (fn_arity - arity_decr)
                                          `setInlinePragma` spec_inl_prag
                                          `setIdUnfolding`  spec_unf
@@ -1498,8 +1510,19 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
                     , spec_uds           `plusUDs` uds_acc
                     ) } }
 
-{- Note [Specialisation Must Preserve Sharing]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+{- Note [Specialising DFuns]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+DFuns have a special sort of unfolding (DFunUnfolding), and these are
+hard to specialise a DFunUnfolding to give another DFunUnfolding
+unless the DFun is fully applied (#18120).  So, in the case of DFunIds
+we simply extend the CallKey with trailing UnspecArgs, so we'll
+generate a rule that completely saturates the DFun.
+
+There is an ASSERT that checks this, in the DFunUnfolding case of
+GHC.Core.Unfold.specUnfolding.
+
+Note [Specialisation Must Preserve Sharing]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 Consider a function:
 
     f :: forall a. Eq a => a -> blah
@@ -2089,7 +2112,7 @@ isSpecDict _             = False
 --      -- Specialised function helpers
 --    , [c, i, x]
 --    , [dShow1 = $dfShow dShowT2]
---    , [T1, T2, dEqT1, dShow1]
+--    , [T1, T2, c, i, dEqT1, dShow1]
 --    )
 specHeader
      :: SpecEnv
@@ -2106,12 +2129,13 @@ specHeader
 
                 -- RULE helpers
               , [OutBndr]    -- Binders for the RULE
-              , [CoreArg]    -- Args for the LHS of the rule
+              , [OutExpr]    -- Args for the LHS of the rule
 
                 -- Specialised function helpers
               , [OutBndr]    -- Binders for $sf
               , [DictBind]   -- Auxiliary dictionary bindings
               , [OutExpr]    -- Specialised arguments for unfolding
+                             -- Same length as "args for LHS of rule"
               )
 
 -- We want to specialise on type 'T1', and so we must construct a substitution


=====================================
compiler/GHC/Core/Unfold.hs
=====================================
@@ -173,47 +173,47 @@ mkInlinableUnfolding dflags expr
   where
     expr' = simpleOptExpr dflags expr
 
-specUnfolding :: DynFlags -> Id -> [Var] -> (CoreExpr -> CoreExpr) -> Arity
+specUnfolding :: DynFlags
+              -> [Var] -> (CoreExpr -> CoreExpr)
+              -> [CoreArg]   -- LHS arguments in the RULE
               -> Unfolding -> Unfolding
 -- See Note [Specialising unfoldings]
--- specUnfolding spec_bndrs spec_app arity_decrease unf
---   = \spec_bndrs. spec_app( unf )
+-- specUnfolding spec_bndrs spec_args unf
+--   = \spec_bndrs. unf spec_args
 --
-specUnfolding dflags fn spec_bndrs spec_app arity_decrease
+specUnfolding dflags spec_bndrs spec_app rule_lhs_args
               df@(DFunUnfolding { df_bndrs = old_bndrs, df_con = con, df_args = args })
-  = ASSERT2( arity_decrease == count isId old_bndrs - count isId spec_bndrs
-           , ppr df $$ ppr spec_bndrs $$ ppr (spec_app (Var fn)) $$ ppr arity_decrease )
+  = ASSERT2( rule_lhs_args `equalLength` old_bndrs
+           , ppr df $$ ppr rule_lhs_args )
+           -- For this ASSERT see Note [DFunUnfoldings] in GHC.Core.Opt.Specialise
     mkDFunUnfolding spec_bndrs con (map spec_arg args)
-      -- There is a hard-to-check assumption here that the spec_app has
-      -- enough applications to exactly saturate the old_bndrs
       -- For DFunUnfoldings we transform
-      --       \old_bndrs. MkD <op1> ... <opn>
+      --       \obs. MkD <op1> ... <opn>
       -- to
-      --       \new_bndrs. MkD (spec_app(\old_bndrs. <op1>)) ... ditto <opn>
-      -- The ASSERT checks the value part of that
+      --       \sbs. MkD ((\obs. <op1>) spec_args) ... ditto <opn>
   where
-    spec_arg arg = simpleOptExpr dflags (spec_app (mkLams old_bndrs arg))
+    spec_arg arg = simpleOptExpr dflags $
+                   spec_app (mkLams old_bndrs arg)
                    -- The beta-redexes created by spec_app will be
                    -- simplified away by simplOptExpr
 
-specUnfolding dflags _ spec_bndrs spec_app arity_decrease
+specUnfolding dflags spec_bndrs spec_app rule_lhs_args
               (CoreUnfolding { uf_src = src, uf_tmpl = tmpl
                              , uf_is_top = top_lvl
                              , uf_guidance = old_guidance })
  | isStableSource src  -- See Note [Specialising unfoldings]
- , UnfWhen { ug_arity     = old_arity
-           , ug_unsat_ok  = unsat_ok
-           , ug_boring_ok = boring_ok } <- old_guidance
- = let guidance = UnfWhen { ug_arity     = old_arity - arity_decrease
-                          , ug_unsat_ok  = unsat_ok
-                          , ug_boring_ok = boring_ok }
-       new_tmpl = simpleOptExpr dflags (mkLams spec_bndrs (spec_app tmpl))
-                   -- The beta-redexes created by spec_app will be
-                   -- simplified away by simplOptExpr
+ , UnfWhen { ug_arity     = old_arity } <- old_guidance
+ = mkCoreUnfolding src top_lvl new_tmpl
+                   (old_guidance { ug_arity = old_arity - arity_decrease })
+ where
+   new_tmpl = simpleOptExpr dflags $
+              mkLams spec_bndrs    $
+              spec_app tmpl  -- The beta-redexes created by spec_app
+                             -- will besimplified away by simplOptExpr
+   arity_decrease = count isValArg rule_lhs_args - count isId spec_bndrs
 
-   in mkCoreUnfolding src top_lvl new_tmpl guidance
 
-specUnfolding _ _ _ _ _ _ = noUnfolding
+specUnfolding _ _ _ _ _ = noUnfolding
 
 {- Note [Specialising unfoldings]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


=====================================
compiler/GHC/HsToCore/Binds.hs
=====================================
@@ -694,20 +694,19 @@ dsSpec mb_poly_rhs (L loc (SpecPrag poly_id spec_co spec_inl))
          dflags <- getDynFlags
        ; case decomposeRuleLhs dflags spec_bndrs ds_lhs of {
            Left msg -> do { warnDs NoReason msg; return Nothing } ;
-           Right (rule_bndrs, _fn, args) -> do
+           Right (rule_bndrs, _fn, rule_lhs_args) -> do
 
        { this_mod <- getModule
        ; let fn_unf    = realIdUnfolding poly_id
-             spec_unf  = specUnfolding dflags poly_id spec_bndrs core_app arity_decrease fn_unf
+             spec_unf  = specUnfolding dflags spec_bndrs core_app rule_lhs_args fn_unf
              spec_id   = mkLocalId spec_name spec_ty
                             `setInlinePragma` inl_prag
                             `setIdUnfolding`  spec_unf
-             arity_decrease = count isValArg args - count isId spec_bndrs
 
        ; rule <- dsMkUserRule this_mod is_local_id
                         (mkFastString ("SPEC " ++ showPpr dflags poly_name))
                         rule_act poly_name
-                        rule_bndrs args
+                        rule_bndrs rule_lhs_args
                         (mkVarApps (Var spec_id) spec_bndrs)
 
        ; let spec_rhs = mkLams spec_bndrs (core_app poly_rhs)


=====================================
compiler/GHC/Tc/Gen/Sig.hs
=====================================
@@ -634,7 +634,6 @@ to connect the two, something like
 This wrapper is put in the TcSpecPrag, in the ABExport record of
 the AbsBinds.
 
-
         f :: (Eq a, Ix b) => a -> b -> Bool
         {-# SPECIALISE f :: (Ix p, Ix q) => Int -> (p,q) -> Bool #-}
         f = <poly_rhs>
@@ -662,8 +661,6 @@ Note that
   * The RHS of f_spec, <poly_rhs> has a *copy* of 'binds', so that it
     can fully specialise it.
 
-
-
 From the TcSpecPrag, in GHC.HsToCore.Binds we generate a binding for f_spec and a RULE:
 
    f_spec :: Int -> b -> Int
@@ -702,14 +699,14 @@ Some wrinkles
 
   So we simply do this:
     - Generate a constraint to check that the specialised type (after
-      skolemiseation) is equal to the instantiated function type.
+      skolemisation) is equal to the instantiated function type.
     - But *discard* the evidence (coercion) for that constraint,
       so that we ultimately generate the simpler code
           f_spec :: Int -> F Int
           f_spec = <f rhs> Int dNumInt
 
           RULE: forall d. f Int d = f_spec
-      You can see this discarding happening in
+      You can see this discarding happening in tcSpecPrag
 
 3. Note that the HsWrapper can transform *any* function with the right
    type prefix


=====================================
testsuite/tests/simplCore/should_compile/T18120.hs
=====================================
@@ -0,0 +1,34 @@
+{-# LANGUAGE ConstraintKinds #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE UndecidableSuperClasses #-}
+module Bug where
+
+import Data.Kind
+
+type family
+  AllF (c :: k -> Constraint) (xs :: [k]) :: Constraint where
+  AllF _c '[]       = ()
+  AllF  c (x ': xs) = (c x, All c xs)
+
+class (AllF c xs, SListI xs) => All (c :: k -> Constraint) (xs :: [k]) where
+instance All c '[] where
+instance (c x, All c xs) => All c (x ': xs) where
+
+class Top x
+instance Top x
+
+type SListI = All Top
+
+class All SListI (Code a) => Generic (a :: Type) where
+  type Code a :: [[Type]]
+
+data T = MkT Int
+instance Generic T where
+  type Code T = '[ '[Int] ]


=====================================
testsuite/tests/simplCore/should_compile/all.T
=====================================
@@ -318,3 +318,4 @@ test('T17966',
 test('T17810', normal, multimod_compile, ['T17810', '-fspecialise-aggressively -dcore-lint -O -v0'])
 test('T18013', normal, multimod_compile, ['T18013', '-v0 -O'])
 test('T18098', normal, compile, ['-dcore-lint -O2'])
+test('T18120', normal, compile, ['-dcore-lint -O'])



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/7e77aca97127a65b574417a0fb25d18e8ddc4b1e

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/7e77aca97127a65b574417a0fb25d18e8ddc4b1e
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20200501/692df226/attachment-0001.html>


More information about the ghc-commits mailing list