[Git][ghc/ghc][master] Account for local rules in specImports

Marge Bot (@marge-bot) gitlab at gitlab.haskell.org
Tue Feb 28 16:10:47 UTC 2023



Marge Bot pushed to branch master at Glasgow Haskell Compiler / GHC


Commits:
0c200ab7 by Simon Peyton Jones at 2023-02-28T11:10:31-05:00
Account for local rules in specImports

As #23024 showed, in GHC.Core.Opt.Specialise.specImports, we were
generating specialisations (a locally-define function) for imported
functions; and then generating specialisations for those
locally-defined functions.  The RULE for the latter should be
attached to the local Id, not put in the rules-for-imported-ids
set.

Fix is easy; similar to what happens in GHC.HsToCore.addExportFlagsAndRules

- - - - -


6 changed files:

- compiler/GHC/Core/Opt/Specialise.hs
- compiler/GHC/Core/Rules.hs
- compiler/GHC/HsToCore.hs
- + testsuite/tests/simplCore/should_compile/T23024.hs
- + testsuite/tests/simplCore/should_compile/T23024a.hs
- testsuite/tests/simplCore/should_compile/all.T


Changes:

=====================================
compiler/GHC/Core/Opt/Specialise.hs
=====================================
@@ -64,6 +64,7 @@ import GHC.Unit.Module( Module )
 import GHC.Unit.Module.ModGuts
 import GHC.Core.Unfold
 
+import Data.List( partition )
 import Data.List.NonEmpty ( NonEmpty (..) )
 
 {-
@@ -726,6 +727,33 @@ specialisation (see canSpecImport):
     Specialise even INLINE things; it hasn't inlined yet, so perhaps
     it never will.  Moreover it may have calls inside it that we want
     to specialise
+
+Wrinkle (W1): If we specialise an imported Id M.foo, we make a /local/
+binding $sfoo.  But specImports may further specialise $sfoo. So we end up
+with RULES for both M.foo (imported) and $sfoo (local).  Rules for local
+Ids should be attached to the Ids themselves (see GHC.HsToCore
+Note [Attach rules to local ids]); so we must partition the rules and
+attach the local rules.  That is done in specImports, via addRulesToId.
+
+Note [Glom the bindings if imported functions are specialised]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Suppose we have an imported, *recursive*, INLINABLE function
+   f :: Eq a => a -> a
+   f = /\a \d x. ...(f a d)...
+In the module being compiled we have
+   g x = f (x::Int)
+Now we'll make a specialised function
+   f_spec :: Int -> Int
+   f_spec = \x -> ...(f Int dInt)...
+   {-# RULE  f Int _ = f_spec #-}
+   g = \x. f Int dInt x
+Note that f_spec doesn't look recursive
+After rewriting with the RULE, we get
+   f_spec = \x -> ...(f_spec)...
+BUT since f_spec was non-recursive before it'll *stay* non-recursive.
+The occurrence analyser never turns a NonRec into a Rec.  So we must
+make sure that f_spec is recursive.  Easiest thing is to make all
+the specialisations for imported bindings recursive.
 -}
 
 specImports :: SpecEnv
@@ -740,16 +768,24 @@ specImports top_env (MkUD { ud_binds = dict_binds, ud_calls = calls })
   = do { let env_w_dict_bndrs = top_env `bringFloatedDictsIntoScope` dict_binds
        ; (_env, spec_rules, spec_binds) <- spec_imports env_w_dict_bndrs [] dict_binds calls
 
-             -- Don't forget to wrap the specialized bindings with
-             -- bindings for the needed dictionaries.
-             -- See Note [Wrap bindings returned by specImports]
-             -- and Note [Glom the bindings if imported functions are specialised]
-       ; let final_binds
+             -- Make a Rec: see Note [Glom the bindings if imported functions are specialised]
+             --
+             -- wrapDictBinds: don't forget to wrap the specialized bindings with
+             --   bindings for the needed dictionaries.
+             --   See Note [Wrap bindings returned by specImports]
+             --
+             -- addRulesToId: see Wrinkle (W1) in Note [Specialising imported functions]
+             --               c.f. GHC.HsToCore.addExportFlagsAndRules
+       ; let (rules_for_locals, rules_for_imps) = partition isLocalRule spec_rules
+             local_rule_base = extendRuleBaseList emptyRuleBase rules_for_locals
+             final_binds
                | null spec_binds = wrapDictBinds dict_binds []
-               | otherwise       = [Rec $ flattenBinds $
-                                    wrapDictBinds dict_binds spec_binds]
+               | otherwise       = [Rec $ mapFst (addRulesToId local_rule_base) $
+                                          flattenBinds                          $
+                                          wrapDictBinds dict_binds              $
+                                          spec_binds]
 
-       ; return (spec_rules, final_binds)
+       ; return (rules_for_imps, final_binds)
     }
 
 -- | Specialise a set of calls to imported bindings
@@ -1111,27 +1147,6 @@ And if the call is to the same type, one specialisation is enough.
 Avoiding this recursive specialisation loop is one reason for the
 'callers' stack passed to specImports and specImport.
 
-Note [Glom the bindings if imported functions are specialised]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-Suppose we have an imported, *recursive*, INLINABLE function
-   f :: Eq a => a -> a
-   f = /\a \d x. ...(f a d)...
-In the module being compiled we have
-   g x = f (x::Int)
-Now we'll make a specialised function
-   f_spec :: Int -> Int
-   f_spec = \x -> ...(f Int dInt)...
-   {-# RULE  f Int _ = f_spec #-}
-   g = \x. f Int dInt x
-Note that f_spec doesn't look recursive
-After rewriting with the RULE, we get
-   f_spec = \x -> ...(f_spec)...
-BUT since f_spec was non-recursive before it'll *stay* non-recursive.
-The occurrence analyser never turns a NonRec into a Rec.  So we must
-make sure that f_spec is recursive.  Easiest thing is to make all
-the specialisations for imported bindings recursive.
-
-
 
 ************************************************************************
 *                                                                      *


=====================================
compiler/GHC/Core/Rules.hs
=====================================
@@ -22,7 +22,7 @@ module GHC.Core.Rules (
 
         -- ** Manipulating 'RuleInfo' rules
         extendRuleInfo, addRuleInfo,
-        addIdSpecialisations,
+        addIdSpecialisations, addRulesToId,
 
         -- ** RuleBase and RuleEnv
 
@@ -349,6 +349,14 @@ addIdSpecialisations id rules
   = setIdSpecialisation id $
     extendRuleInfo (idSpecialisation id) rules
 
+addRulesToId :: RuleBase -> Id -> Id
+-- Add rules in the RuleBase to the rules in the Id
+addRulesToId rule_base bndr
+  | Just rules <- lookupNameEnv rule_base (idName bndr)
+  = bndr `addIdSpecialisations` rules
+  | otherwise
+  = bndr
+
 -- | Gather all the rules for locally bound identifiers from the supplied bindings
 rulesOfBinds :: [CoreBind] -> [CoreRule]
 rulesOfBinds binds = concatMap (concatMap idCoreRules . bindersOf) binds


=====================================
compiler/GHC/HsToCore.hs
=====================================
@@ -362,32 +362,28 @@ deSugarExpr hsc_env tc_expr = do
 addExportFlagsAndRules
     :: Backend -> NameSet -> NameSet -> [CoreRule]
     -> [(Id, t)] -> [(Id, t)]
-addExportFlagsAndRules bcknd exports keep_alive rules = mapFst add_one
+addExportFlagsAndRules bcknd exports keep_alive rules
+  = mapFst (addRulesToId rule_base . add_export_flag)
+        -- addRulesToId: see Note [Attach rules to local ids]
+        -- NB: the binder might have some existing rules,
+        -- arising from specialisation pragmas
+
   where
-    add_one bndr = add_rules name (add_export name bndr)
-       where
-         name = idName bndr
 
     ---------- Rules --------
-        -- See Note [Attach rules to local ids]
-        -- NB: the binder might have some existing rules,
-        -- arising from specialisation pragmas
-    add_rules name bndr
-        | Just rules <- lookupNameEnv rule_base name
-        = bndr `addIdSpecialisations` rules
-        | otherwise
-        = bndr
     rule_base = extendRuleBaseList emptyRuleBase rules
 
     ---------- Export flag --------
     -- See Note [Adding export flags]
-    add_export name bndr
-        | dont_discard name = setIdExported bndr
+    add_export_flag bndr
+        | dont_discard bndr = setIdExported bndr
         | otherwise         = bndr
 
-    dont_discard :: Name -> Bool
-    dont_discard name = is_exported name
+    dont_discard :: Id -> Bool
+    dont_discard bndr = is_exported name
                      || name `elemNameSet` keep_alive
+       where
+         name = idName bndr
 
         -- In interactive mode, we don't want to discard any top-level
         -- entities at all (eg. do not inline them away during


=====================================
testsuite/tests/simplCore/should_compile/T23024.hs
=====================================
@@ -0,0 +1,8 @@
+{-# OPTIONS_GHC -fspecialize-aggressively -fexpose-all-unfoldings  #-}
+{-# LANGUAGE RankNTypes #-}
+module T23024 (testPolyn) where
+
+import T23024a
+
+testPolyn :: (forall r. Tensor r => r) -> Vector Double
+testPolyn f = gradientFromDelta f


=====================================
testsuite/tests/simplCore/should_compile/T23024a.hs
=====================================
@@ -0,0 +1,82 @@
+{-# OPTIONS_GHC -fspecialize-aggressively -fexpose-all-unfoldings -Wno-missing-methods #-}
+{-# LANGUAGE FlexibleInstances, FlexibleContexts, UndecidableInstances,
+             DataKinds, MultiParamTypeClasses, RankNTypes, MonoLocalBinds #-}
+module T23024a where
+
+import System.IO.Unsafe
+import Control.Monad.ST ( ST, runST )
+import Foreign.ForeignPtr
+import Foreign.Storable
+import GHC.ForeignPtr ( unsafeWithForeignPtr )
+
+class MyNum a where
+  fi :: a
+
+class (MyNum a, Eq a) => MyReal a
+
+class (MyReal a) => MyRealFrac a  where
+  fun :: a -> ()
+
+class (MyRealFrac a, MyNum a) => MyRealFloat a
+
+instance MyNum Double
+instance MyReal Double
+instance MyRealFloat Double
+instance MyRealFrac Double
+
+newtype Vector a = Vector (ForeignPtr a)
+
+class GVector v a where
+instance Storable a => GVector Vector a
+
+vunstream :: () -> ST s (v a)
+vunstream () = vunstream ()
+
+empty :: GVector v a => v a
+empty = runST (vunstream ())
+{-# NOINLINE empty #-}
+
+instance (Storable a, Eq a) => Eq (Vector a) where
+  xs == ys = idx xs == idx ys
+
+{-# NOINLINE idx #-}
+idx (Vector fp) = unsafePerformIO
+                        $ unsafeWithForeignPtr fp $ \p ->
+                          peekElemOff p 0
+
+instance MyNum (Vector Double)
+instance (MyNum (Vector a), Storable a, Eq a) => MyReal (Vector a)
+instance (MyNum (Vector a), Storable a, Eq a) => MyRealFrac (Vector a)
+instance (MyNum (Vector a), Storable a, MyRealFloat a) => MyRealFloat (Vector a)
+
+newtype ORArray a = A a
+
+instance (Eq a) => Eq (ORArray a) where
+  A x == A y = x == y
+
+instance (MyNum (Vector a)) => MyNum (ORArray a)
+instance (MyNum (Vector a), Storable a, Eq a) => MyReal (ORArray a)
+instance (MyRealFrac (Vector a), Storable a, Eq a) => MyRealFrac (ORArray a)
+instance (MyRealFloat (Vector a), Storable a, Eq a) => MyRealFloat (ORArray a)
+
+newtype Ast r = AstConst (ORArray r)
+
+instance Eq (Ast a) where
+  (==) = undefined
+
+instance MyNum (ORArray a) => MyNum (Ast a) where
+  fi = AstConst fi
+
+instance MyNum (ORArray a) => MyReal (Ast a)
+instance MyRealFrac (ORArray a) => MyRealFrac (Ast a) where
+  {-# INLINE fun #-}
+  fun x = ()
+  
+instance MyRealFloat (ORArray a) => MyRealFloat (Ast a)
+
+class (MyRealFloat a) => Tensor a
+instance (MyRealFloat a, MyNum (Vector a), Storable a) => Tensor (Ast a)
+
+gradientFromDelta :: Storable a => Ast a -> Vector a
+gradientFromDelta _ = empty
+{-# NOINLINE gradientFromDelta #-}


=====================================
testsuite/tests/simplCore/should_compile/all.T
=====================================
@@ -475,3 +475,4 @@ test('T22761', normal, multimod_compile, ['T22761', '-O2 -v0'])
 test('T23012', normal, compile, ['-O'])
 
 test('RewriteHigherOrderPatterns', normal, compile, ['-O -ddump-rule-rewrites -dsuppress-all -dsuppress-uniques'])
+test('T23024', normal, multimod_compile, ['T23024', '-O -v0'])



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/0c200ab78c814cb5d1efaf426f0d3d91ceab9f4d

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/0c200ab78c814cb5d1efaf426f0d3d91ceab9f4d
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20230228/5320d92c/attachment-0001.html>


More information about the ghc-commits mailing list