[Git][ghc/ghc][wip/romes/eqsat-pmc] Core Equality module pretty much good to go

Rodrigo Mesquita (@alt-romes) gitlab at gitlab.haskell.org
Sun Jun 25 20:15:27 UTC 2023



Rodrigo Mesquita pushed to branch wip/romes/eqsat-pmc at Glasgow Haskell Compiler / GHC


Commits:
4d72410f by Rodrigo Mesquita at 2023-06-25T21:15:15+01:00
Core Equality module pretty much good to go

IPW

- - - - -


3 changed files:

- compiler/GHC/Core/Functor.hs
- compiler/GHC/Core/Map/Type.hs
- compiler/GHC/HsToCore/Pmc/Solver/Types.hs


Changes:

=====================================
compiler/GHC/Core/Functor.hs
=====================================
@@ -1,3 +1,4 @@
+{-# LANGUAGE MagicHash #-}
 {-# LANGUAGE DeriveTraversable #-}
 {-# LANGUAGE FlexibleInstances #-}
 {-# LANGUAGE ViewPatterns #-}
@@ -7,6 +8,7 @@
 -- ROMES:TODO: Rename to Core.Equality or something
 module GHC.Core.Functor where
 
+import GHC.Exts (dataToTag#, tagToEnum#, (>#), (<#))
 import GHC.Prelude
 
 import GHC.Core
@@ -24,8 +26,7 @@ import Data.Equality.Analysis
 import qualified Data.Equality.Graph.Monad as EGM
 import Data.Equality.Utils (Fix(..))
 
-import GHC.Utils.Misc (all2, equalLength)
-import Data.Functor.Identity (Identity(..))
+import GHC.Utils.Misc (all2)
 
 -- Important to note the binders are also represented by $a$
 -- This is because in the e-graph we will represent binders with the
@@ -38,12 +39,12 @@ import Data.Functor.Identity (Identity(..))
 
 data AltF b a
     = AltF AltCon [b] a
-    deriving (Functor, Foldable, Traversable, Eq, Ord)
+    deriving (Functor, Foldable, Traversable)
 
 data BindF b a
   = NonRecF b a
   | RecF [(b, a)]
-  deriving (Functor, Foldable, Traversable, Eq, Ord)
+  deriving (Functor, Foldable, Traversable)
 
 data ExprF b a
   = VarF Id
@@ -61,16 +62,20 @@ data ExprF b a
 
 type CoreExprF
   = ExprF CoreBndr
+type CoreAltF
+  = AltF CoreBndr
+type CoreBindF
+  = BindF CoreBndr
 
-instance Eq a => Eq (DeBruijnF CoreExprF a) where
-  (==) = eqDeBruijnExprF
+newtype DeBruijnF f a = DF (DeBruijn (f a))
+  deriving (Functor, Foldable, Traversable)
 
 -- ROMES:TODO: This instance is plain wrong. This DeBruijn scheme won't
 -- particularly work for our situation, we'll likely have to have ints instead
 -- of Id binders. Now, ignoring DeBruijn indices, we'll simply get this compile
 -- to get a rougher estimate of performance?
-eqDeBruijnExprF :: forall a. Eq a => DeBruijnF CoreExprF a -> DeBruijnF CoreExprF a -> Bool
-eqDeBruijnExprF (DF (D env1 e1)) (DF (D env2 e2)) = go e1 e2 where
+eqDeBruijnExprF :: forall a. Eq a => DeBruijn (CoreExprF a) -> DeBruijn (CoreExprF a) -> Bool
+eqDeBruijnExprF (D env1 e1) (D env2 e2) = go e1 e2 where
     go :: CoreExprF a -> CoreExprF a -> Bool
     go (VarF v1) (VarF v2)             = eqDeBruijnVar (D env1 v1) (D env2 v2)
     go (LitF lit1)    (LitF lit2)      = lit1 == lit2
@@ -88,37 +93,30 @@ eqDeBruijnExprF (DF (D env1 e1)) (DF (D env2 e2)) = go e1 e2 where
       && D env1 (varMultMaybe b1) == D env2 (varMultMaybe b2)
       && e1 == e2
 
-    go (LetF (NonRecF v1 r1) e1) (LetF (NonRecF v2 r2) e2)
-      = r1 == r2 -- See Note [Alpha-equality for let-bindings]
-      && e1 == e2
-
-    go (LetF (RecF ps1) e1) (LetF (RecF ps2) e2)
-      =
-      -- See Note [Alpha-equality for let-bindings]
-      all2 (\b1 b2 -> eqDeBruijnType (D env1 (varType b1))
-                                     (D env2 (varType b2)))
-           bs1 bs2
-      && rs1 == rs2
+    go (LetF abs e1) (LetF bbs e2)
+      = D env1 abs == D env2 bbs
       && e1 == e2
-      where
-        (bs1,rs1) = unzip ps1
-        (bs2,rs2) = unzip ps2
 
-    go (CaseF e1 b1 t1 a1) (CaseF e2 b2 t2 a2)
+    go (CaseF e1 _b1 t1 a1) (CaseF e2 _b2 t2 a2)
       | null a1   -- See Note [Empty case alternatives]
       = null a2 && e1 == e2 && D env1 t1 == D env2 t2
       | otherwise
-      = e1 == e2 && a1 == a2
+      = e1 == e2 && D env1 a1 == D env2 a2
 
     go _ _ = False
 
-instance Ord a => Ord (DeBruijnF CoreExprF a) where
-  compare a b = if a == b then EQ else LT
--- deriving instance Ord a => Ord (DeBruijnF CoreExprF a)
-
-deriving instance Functor (DeBruijnF CoreExprF)
-deriving instance Foldable (DeBruijnF CoreExprF)
-deriving instance Traversable (DeBruijnF CoreExprF)
+-- ROMES:TODO: This one can be derived automatically, but perhaps it's better
+-- to be explicit here? We don't even really require the DeBruijn context here
+eqDeBruijnAltF :: forall a. Eq a => DeBruijn (CoreAltF a) -> DeBruijn (CoreAltF a) -> Bool
+eqDeBruijnAltF (D _env1 a1) (D _env2 a2) = go a1 a2 where
+    go (AltF DEFAULT _ rhs1) (AltF DEFAULT _ rhs2)
+        = rhs1 == rhs2
+    go (AltF (LitAlt lit1) _ rhs1) (AltF (LitAlt lit2) _ rhs2)
+        = lit1 == lit2 && rhs1 == rhs2
+    go (AltF (DataAlt dc1) _bs1 rhs1) (AltF (DataAlt dc2) _bs2 rhs2)
+        = dc1 == dc2 &&
+          rhs1 == rhs2 -- the CM environments were extended on representation (see 'representDBAltExpr')
+    go _ _ = False
 
 -- | 'unsafeCoerce' mostly because I'm too lazy to write the boilerplate.
 fromCoreExpr :: CoreExpr -> Fix CoreExprF
@@ -128,8 +126,9 @@ toCoreExpr :: CoreExpr -> Fix CoreExprF
 toCoreExpr = unsafeCoerce
 
 -- | Represents a DeBruijn CoreExpr being careful to correctly debruijnizie the expression as it is represented
--- TODO: Use `Compose DeBruijn CoreExprF` instead
+--
 -- Always represent Ids, at least for now. We're seemingly using inexistent ids
+-- ROMES:TODO: do this all inside EGraphM instead
 representDBCoreExpr :: Analysis a (DeBruijnF CoreExprF)
                     => DeBruijn CoreExpr
                     -> EGraph a (DeBruijnF CoreExprF)
@@ -154,28 +153,205 @@ representDBCoreExpr (D cmenv expr) eg0 = case expr of
   Let (Rec (unzip -> (bs,rs))) e ->
     let cmenv' = extendCMEs cmenv bs
         (bsids, eg1) = EGM.runEGraphM eg0 $
-                         traverse (\r -> state $ representDBCoreExpr (D cmenv' r)) rs
+                         traverse (state . representDBCoreExpr . D cmenv') rs
         (eid, eg2)  = representDBCoreExpr (D cmenv' e) eg1
      in add (Node $ DF (D cmenv (LetF (RecF (zip bs bsids)) eid))) eg2
   Case e b t as -> let (eid, eg1)  = representDBCoreExpr (D cmenv e) eg0
                        (as', eg2) = EGM.runEGraphM eg1 $
-                         traverse (\(Alt cons bs a) -> state $ \s -> let (aid, g) = representDBCoreExpr (D (extendCME cmenv b) a) s in (AltF cons bs aid, g)) as
+                         traverse (state . representDBAltExpr . D (extendCME cmenv b)) as
                     in add (Node $ DF (D cmenv (CaseF eid b t as'))) eg2
 
+representDBAltExpr :: Analysis a (DeBruijnF CoreExprF)
+                   => DeBruijn CoreAlt
+                   -> EGraph a (DeBruijnF CoreExprF)
+                   -> (CoreAltF ClassId, EGraph a (DeBruijnF CoreExprF))
+representDBAltExpr (D cm (Alt cons bs a)) eg0 =
+  let (ai, eg1) = representDBCoreExpr (D (extendCMEs cm bs) a) eg0
+   in (AltF cons bs ai, eg1)
 
--- ROMES:TODO: Instead of DeBruijnF CoreExprF we should have (ExprF (Int,Id))
--- * A represent function that makes up the debruijn indices as it is representing the expressions
--- * An Eq and Ord instance which ignore the Id and rather look at the DeBruijn index.
---
--- TODO
--- * For types, can we use eqDeBruijnType ? I think not, because Lambdas and Lets can bind type variables
---
--- TODO: The Best Alternative:
---
--- Each expression keeps its DeBruijnF CmEnv environment, but the represent
--- function must correctly prepare the debruijn indices, so that each e-node
--- has the debruijn indice it would have in its recursive descent in the Eq instance?
---
--- TODO: We could even probably have Compose DeBruijn CoreExprF in that case!
---
+instance Eq a => Eq (DeBruijn (CoreAltF a)) where
+  (==) = eqDeBruijnAltF
+
+instance Eq a => Eq (DeBruijn (CoreExprF a)) where
+  (==) = eqDeBruijnExprF
+
+instance Eq a => Eq (DeBruijnF CoreExprF a) where
+  (==) (DF a) (DF b) = eqDeBruijnExprF a b
+
+instance Eq a => Eq (DeBruijnF CoreAltF a) where
+  (==) (DF a) (DF b) = eqDeBruijnAltF a b
+
+deriving instance Ord a => Ord (DeBruijnF CoreExprF a)
+
+instance Ord a => Ord (DeBruijn (CoreExprF a)) where
+  -- We must assume that if `a` is DeBruijn expression, it is already correctly "extended" because 'representDBCoreExpr' ensures that.
+  -- RM:TODO: We don't yet compare the CmEnv at any point. Should we?
+  -- RM: I don't think so, the CmEnv is used to determine whether bound variables are equal, but they don't otherwise influence the result.
+  -- Or rather, if the subexpression with variables is equal, then the CmEnv is necessarily equal too?
+  -- So I think that just works...
+  -- Wait, in that sense, couldn't we find a way to derive ord? the important part being that to compare Types and Vars we must use the DeBruijn Env ...
+  compare a b
+    = case a of
+        D cma (VarF va)
+          -> case b of
+               D cmb (VarF vb) -> cmpDeBruijnVar (D cma va) (D cmb vb)
+               _ -> LT
+        D _ (LitF la)
+          -> case b of
+               D _ VarF{} -> GT
+               D _ (LitF lb) -> la `compare` lb
+               _ -> LT
+        D _ (AppF af aarg)
+          -> case dataToTag# b of
+               bt
+                 -> if tagToEnum# (bt ># 2#) then
+                        LT
+                    else
+                        case b of
+                          D _ (AppF bf barg)
+                            -> case compare af bf of
+                                 LT -> LT
+                                 EQ -> aarg `compare` barg -- e.g. here, if we had for children other expresssions debruijnized, they would have the *correct* environments, so we needn't worry.
+                                                           -- the issue to automatically deriving is only really the 'Var' and 'Type' parameters ...
+                                 GT -> GT
+                          _ -> GT
+        D _ (LamF _abind abody)
+          -> case dataToTag# b of
+               bt
+                 -> if tagToEnum# (bt ># 3#) then
+                        LT
+                    else
+                        case b of
+                          D _ (LamF _bbind bbody) -- we can ignore the binder since the represented DB expression has the correct DB environments by construction (see 'representDBCoreExpr')
+                            -> compare abody bbody
+                          _ -> GT
+        D cma (LetF as abody)
+          -> case dataToTag# b of
+               bt
+                 -> if tagToEnum# (bt ># 4#) then
+                        LT
+                    else
+                        case b of
+                          D cmb (LetF bs bbody)
+                            -> case compare (D cma as) (D cmb bs) of
+                                 LT -> LT
+                                 EQ -> compare abody bbody
+                                 GT -> GT
+                          _ -> GT
+        D cma (CaseF cax _cabind catype caalt)
+          -> case dataToTag# b of
+               bt
+                 -> if tagToEnum# (bt <# 5#) then
+                        GT
+                    else
+                        case b of
+                          D cmb (CaseF cbx _cbbind cbtype cbalt)
+                            -> case compare cax cbx of
+                                 LT -> LT
+                                 -- ROMES:TODO: Consider changing order of comparisons to a more efficient one
+                                 EQ -> case cmpDeBruijnType (D cma catype) (D cmb cbtype) of
+                                        LT -> LT
+                                        EQ -> D cma caalt `compare` D cmb cbalt
+                                        GT -> GT
+                                 GT -> GT
+                          _ -> LT
+        D cma (CastF cax caco)
+          -> case dataToTag# b of
+               bt
+                 -> if tagToEnum# (bt <# 6#) then
+                        GT
+                    else
+                        case b of
+                          D cmb (CastF cbx cbco)
+                            -> case compare cax cbx of
+                                 LT -> LT
+                                 EQ -> cmpDeBruijnCoercion (D cma caco) (D cmb cbco)
+                                 GT -> GT
+                          _ -> LT
+        D cma (TickF tatickish tax)
+          -> case dataToTag# b of
+               bt
+                 -> if tagToEnum# (bt <# 7#) then
+                        GT
+                    else
+                        case b of
+                          D cmb (TickF tbtickish tbx)
+                            -> case cmpDeBruijnTickish (D cma tatickish) (D cmb tbtickish) of
+                                 LT -> LT
+                                 EQ -> tax `compare` tbx
+                                 GT -> GT
+                          _ -> LT
+        D cma (TypeF at)
+          -> case b of
+               D _    CoercionF{} -> LT
+               D cmb (TypeF bt) -> cmpDeBruijnType (D cma at) (D cmb bt)
+               _ -> GT
+        D cma (CoercionF aco)
+          -> case b of
+               D cmb (CoercionF bco) -> cmpDeBruijnCoercion (D cma aco) (D cmb bco)
+               _ -> GT
+
+instance Eq a => Eq (DeBruijn (CoreBindF a)) where
+  D cma a == D cmb b = go a b where
+    go (NonRecF _v1 r1) (NonRecF _v2 r2)
+      = r1 == r2 -- See Note [Alpha-equality for let-bindings]
+
+    go (RecF ps1) (RecF ps2)
+      =
+      -- See Note [Alpha-equality for let-bindings]
+      all2 (\b1 b2 -> eqDeBruijnType (D cma (varType b1))
+                                     (D cmb (varType b2)))
+           bs1 bs2
+      && rs1 == rs2
+      where
+        (bs1,rs1) = unzip ps1
+        (bs2,rs2) = unzip ps2
+
+    go _ _ = False
+
+
+instance Ord a => Ord (DeBruijn (CoreBindF a)) where
+  compare a b
+    = case a of
+        D _cma (NonRecF _ab ax)
+          -> case b of
+               D _cmb (NonRecF _bb bx) -- Again, we ignore the binders bc on representation they were accounted for correctly.
+                 -> ax `compare` bx
+               _ -> LT
+        D _cma (RecF as)
+          -> case b of
+               D _cmb (RecF bs) -> compare (map snd as) (map snd bs)
+               _ -> GT
+
+
+instance Ord a => Ord (DeBruijn (CoreAltF a)) where
+  compare a b
+    = case a of
+        D _cma (AltF ac _abs arhs)
+          -> case b of
+               D _cmb (AltF bc _bbs brhs)
+                 -> case compare ac bc of
+                      LT -> LT
+                      EQ -> -- Again, we don't look at binders bc we assume on representation they were accounted for correctly.
+                        arhs `compare` brhs
+                      GT -> GT
+
+cmpDeBruijnTickish :: DeBruijn CoreTickish -> DeBruijn CoreTickish -> Ordering
+cmpDeBruijnTickish (D env1 t1) (D env2 t2) = go t1 t2 where
+    go (Breakpoint lext lid lids) (Breakpoint rext rid rids)
+        = case compare lid rid of
+            LT -> LT
+            EQ -> case compare (D env1 lids) (D env2 rids) of
+                    LT -> LT
+                    EQ -> compare lext rext
+                    GT -> GT
+            GT -> GT
+    go l r = compare l r
+
+-- ROMES:TODO: DEBRUIJN ORDERING ON TYPES!!!
+cmpDeBruijnType :: DeBruijn Type -> DeBruijn Type -> Ordering
+cmpDeBruijnType _ _ = EQ
 
+-- ROMES:TODO: DEBRUIJN ORDERING ON COERCIONS!!!
+cmpDeBruijnCoercion :: DeBruijn Coercion -> DeBruijn Coercion -> Ordering
+cmpDeBruijnCoercion _ _ = EQ


=====================================
compiler/GHC/Core/Map/Type.hs
=====================================
@@ -22,7 +22,8 @@ module GHC.Core.Map.Type (
    -- * Utilities for use by friends only
    TypeMapG, CoercionMapG,
 
-   DeBruijn(..), DeBruijnF(..), deBruijnize, deBruijnizeF, eqDeBruijnType, eqDeBruijnVar,
+   DeBruijn(..), deBruijnize, eqDeBruijnType, eqDeBruijnVar,
+   cmpDeBruijnVar,
 
    BndrMap, xtBndr, lkBndr,
    VarMap, xtVar, lkVar, lkDFreeVar, xtDFreeVar,
@@ -283,6 +284,9 @@ eqDeBruijnType env_t1@(D env1 t1) env_t2@(D env2 t2) =
 instance Eq (DeBruijn Var) where
   (==) = eqDeBruijnVar
 
+instance Ord (DeBruijn Var) where
+  compare = cmpDeBruijnVar
+
 eqDeBruijnVar :: DeBruijn Var -> DeBruijn Var -> Bool
 eqDeBruijnVar (D env1 v1) (D env2 v2) =
     case (lookupCME env1 v1, lookupCME env2 v2) of
@@ -290,6 +294,13 @@ eqDeBruijnVar (D env1 v1) (D env2 v2) =
         (Nothing, Nothing) -> v1 == v2
         _ -> False
 
+cmpDeBruijnVar :: DeBruijn Var -> DeBruijn Var -> Ordering
+cmpDeBruijnVar (D env1 v1) (D env2 v2) =
+    case (lookupCME env1 v1, lookupCME env2 v2) of
+        (Just b1, Just b2) -> compare b1 b2
+        (Nothing, Nothing) -> compare v1 v2
+        (z,w) -> compare z w -- Compare Maybes on whether they're Just or Nothing
+
 instance {-# OVERLAPPING #-}
          Outputable a => Outputable (TypeMapG a) where
   ppr m = text "TypeMap elts" <+> ppr (foldTM (:) m [])
@@ -513,9 +524,7 @@ lookupCME (CME { cme_env = env }) v = lookupVarEnv env v
 -- export the constructor.  Make a helper function if you find yourself
 -- needing it.
 data DeBruijn a = D CmEnv a
-  deriving (Functor,Foldable,Traversable) -- romes:TODO: For internal use only!
-
-newtype DeBruijnF f a = DF (DeBruijn (f a))
+  deriving (Functor, Foldable, Traversable) -- romes:TODO: for internal use only!
 
 -- | Synthesizes a @DeBruijn a@ from an @a@, by assuming that there are no
 -- bound binders (an empty 'CmEnv').  This is usually what you want if there
@@ -523,16 +532,21 @@ newtype DeBruijnF f a = DF (DeBruijn (f a))
 deBruijnize :: a -> DeBruijn a
 deBruijnize = D emptyCME
 
--- | Like 'deBruijnize' but synthesizes a @DeBruijnF f a@ from an @f a@
-deBruijnizeF :: Functor f => f a -> DeBruijnF f a
-deBruijnizeF = DF . deBruijnize
-
 instance Eq (DeBruijn a) => Eq (DeBruijn [a]) where
     D _   []     == D _    []       = True
     D env (x:xs) == D env' (x':xs') = D env x  == D env' x' &&
                                       D env xs == D env' xs'
     _            == _               = False
 
+instance Ord (DeBruijn a) => Ord (DeBruijn [a]) where
+    D _   []     `compare` D _    []       = EQ
+    D env (x:xs) `compare` D env' (x':xs') = case D env x `compare` D env' x' of
+                                               LT -> LT
+                                               EQ -> D env xs `compare` D env' xs'
+                                               GT -> GT
+    D _   []     `compare` D _    (_:_)    = LT
+    D _   (_:_)  `compare` D _    []       = GT
+
 instance Eq (DeBruijn a) => Eq (DeBruijn (Maybe a)) where
     D _   Nothing  == D _    Nothing   = True
     D env (Just x) == D env' (Just x') = D env x  == D env' x'


=====================================
compiler/GHC/HsToCore/Pmc/Solver/Types.hs
=====================================
@@ -46,7 +46,6 @@ import GHC.Prelude
 import GHC.Data.Bag
 import GHC.Data.FastString
 import GHC.Types.Id
-import GHC.Types.Var.Set
 import GHC.Types.Unique.DSet
 import GHC.Types.Name
 import GHC.Core.Functor
@@ -62,7 +61,6 @@ import GHC.Core.TyCon
 import GHC.Types.Literal
 import GHC.Core
 import GHC.Core.TyCo.Compare( eqType )
-import GHC.Core.Map.Expr
 import GHC.Core.Map.Type
 import GHC.Core.Utils (exprType)
 import GHC.Builtin.Names
@@ -80,12 +78,9 @@ import Data.Ratio
 import GHC.Real (Ratio(..))
 import qualified Data.Semigroup as Semi
 
-import Data.Functor.Const
 import Data.Functor.Compose
-import Data.Function ((&))
 import Data.Equality.Analysis (Analysis(..))
 import Data.Equality.Graph (EGraph, ClassId)
-import Data.Equality.Utils (Fix(..))
 import Data.Equality.Graph.Lens
 import qualified Data.Equality.Graph as EG
 import Data.IntSet (IntSet)
@@ -843,4 +838,4 @@ instance Outputable PmEquality where
 
 representId :: Id -> TmEGraph -> (ClassId, TmEGraph)
 -- ROMES:TODO: bit of a hack to represent binders with `Var`, which is likely wrong (lambda bound vars might get equivalent to global ones?). Will need to justify this well
-representId x = EG.add (EG.Node (deBruijnizeF (VarF x))) -- debruijn things are compared correctly wrt binders, but we can still have a debruijn var w name with no prob
+representId x = EG.add (EG.Node (DF (deBruijnize (VarF x)))) -- debruijn things are compared correctly wrt binders, but we can still have a debruijn var w name with no prob



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/4d72410fe0931fd1a94c8efcd3c8c8d52af2f396

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/4d72410fe0931fd1a94c8efcd3c8c8d52af2f396
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20230625/c87ab9a2/attachment-0001.html>


More information about the ghc-commits mailing list