[Git][ghc/ghc][wip/romes/eqsat-pmc] Core Equality module pretty much good to go
Rodrigo Mesquita (@alt-romes)
gitlab at gitlab.haskell.org
Sun Jun 25 20:15:27 UTC 2023
Rodrigo Mesquita pushed to branch wip/romes/eqsat-pmc at Glasgow Haskell Compiler / GHC
Commits:
4d72410f by Rodrigo Mesquita at 2023-06-25T21:15:15+01:00
Core Equality module pretty much good to go
IPW
- - - - -
3 changed files:
- compiler/GHC/Core/Functor.hs
- compiler/GHC/Core/Map/Type.hs
- compiler/GHC/HsToCore/Pmc/Solver/Types.hs
Changes:
=====================================
compiler/GHC/Core/Functor.hs
=====================================
@@ -1,3 +1,4 @@
+{-# LANGUAGE MagicHash #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ViewPatterns #-}
@@ -7,6 +8,7 @@
-- ROMES:TODO: Rename to Core.Equality or something
module GHC.Core.Functor where
+import GHC.Exts (dataToTag#, tagToEnum#, (>#), (<#))
import GHC.Prelude
import GHC.Core
@@ -24,8 +26,7 @@ import Data.Equality.Analysis
import qualified Data.Equality.Graph.Monad as EGM
import Data.Equality.Utils (Fix(..))
-import GHC.Utils.Misc (all2, equalLength)
-import Data.Functor.Identity (Identity(..))
+import GHC.Utils.Misc (all2)
-- Important to note the binders are also represented by $a$
-- This is because in the e-graph we will represent binders with the
@@ -38,12 +39,12 @@ import Data.Functor.Identity (Identity(..))
data AltF b a
= AltF AltCon [b] a
- deriving (Functor, Foldable, Traversable, Eq, Ord)
+ deriving (Functor, Foldable, Traversable)
data BindF b a
= NonRecF b a
| RecF [(b, a)]
- deriving (Functor, Foldable, Traversable, Eq, Ord)
+ deriving (Functor, Foldable, Traversable)
data ExprF b a
= VarF Id
@@ -61,16 +62,20 @@ data ExprF b a
type CoreExprF
= ExprF CoreBndr
+type CoreAltF
+ = AltF CoreBndr
+type CoreBindF
+ = BindF CoreBndr
-instance Eq a => Eq (DeBruijnF CoreExprF a) where
- (==) = eqDeBruijnExprF
+newtype DeBruijnF f a = DF (DeBruijn (f a))
+ deriving (Functor, Foldable, Traversable)
-- ROMES:TODO: This instance is plain wrong. This DeBruijn scheme won't
-- particularly work for our situation, we'll likely have to have ints instead
-- of Id binders. Now, ignoring DeBruijn indices, we'll simply get this compile
-- to get a rougher estimate of performance?
-eqDeBruijnExprF :: forall a. Eq a => DeBruijnF CoreExprF a -> DeBruijnF CoreExprF a -> Bool
-eqDeBruijnExprF (DF (D env1 e1)) (DF (D env2 e2)) = go e1 e2 where
+eqDeBruijnExprF :: forall a. Eq a => DeBruijn (CoreExprF a) -> DeBruijn (CoreExprF a) -> Bool
+eqDeBruijnExprF (D env1 e1) (D env2 e2) = go e1 e2 where
go :: CoreExprF a -> CoreExprF a -> Bool
go (VarF v1) (VarF v2) = eqDeBruijnVar (D env1 v1) (D env2 v2)
go (LitF lit1) (LitF lit2) = lit1 == lit2
@@ -88,37 +93,30 @@ eqDeBruijnExprF (DF (D env1 e1)) (DF (D env2 e2)) = go e1 e2 where
&& D env1 (varMultMaybe b1) == D env2 (varMultMaybe b2)
&& e1 == e2
- go (LetF (NonRecF v1 r1) e1) (LetF (NonRecF v2 r2) e2)
- = r1 == r2 -- See Note [Alpha-equality for let-bindings]
- && e1 == e2
-
- go (LetF (RecF ps1) e1) (LetF (RecF ps2) e2)
- =
- -- See Note [Alpha-equality for let-bindings]
- all2 (\b1 b2 -> eqDeBruijnType (D env1 (varType b1))
- (D env2 (varType b2)))
- bs1 bs2
- && rs1 == rs2
+ go (LetF abs e1) (LetF bbs e2)
+ = D env1 abs == D env2 bbs
&& e1 == e2
- where
- (bs1,rs1) = unzip ps1
- (bs2,rs2) = unzip ps2
- go (CaseF e1 b1 t1 a1) (CaseF e2 b2 t2 a2)
+ go (CaseF e1 _b1 t1 a1) (CaseF e2 _b2 t2 a2)
| null a1 -- See Note [Empty case alternatives]
= null a2 && e1 == e2 && D env1 t1 == D env2 t2
| otherwise
- = e1 == e2 && a1 == a2
+ = e1 == e2 && D env1 a1 == D env2 a2
go _ _ = False
-instance Ord a => Ord (DeBruijnF CoreExprF a) where
- compare a b = if a == b then EQ else LT
--- deriving instance Ord a => Ord (DeBruijnF CoreExprF a)
-
-deriving instance Functor (DeBruijnF CoreExprF)
-deriving instance Foldable (DeBruijnF CoreExprF)
-deriving instance Traversable (DeBruijnF CoreExprF)
+-- ROMES:TODO: This one can be derived automatically, but perhaps it's better
+-- to be explicit here? We don't even really require the DeBruijn context here
+eqDeBruijnAltF :: forall a. Eq a => DeBruijn (CoreAltF a) -> DeBruijn (CoreAltF a) -> Bool
+eqDeBruijnAltF (D _env1 a1) (D _env2 a2) = go a1 a2 where
+ go (AltF DEFAULT _ rhs1) (AltF DEFAULT _ rhs2)
+ = rhs1 == rhs2
+ go (AltF (LitAlt lit1) _ rhs1) (AltF (LitAlt lit2) _ rhs2)
+ = lit1 == lit2 && rhs1 == rhs2
+ go (AltF (DataAlt dc1) _bs1 rhs1) (AltF (DataAlt dc2) _bs2 rhs2)
+ = dc1 == dc2 &&
+ rhs1 == rhs2 -- the CM environments were extended on representation (see 'representDBAltExpr')
+ go _ _ = False
-- | 'unsafeCoerce' mostly because I'm too lazy to write the boilerplate.
fromCoreExpr :: CoreExpr -> Fix CoreExprF
@@ -128,8 +126,9 @@ toCoreExpr :: CoreExpr -> Fix CoreExprF
toCoreExpr = unsafeCoerce
-- | Represents a DeBruijn CoreExpr being careful to correctly debruijnizie the expression as it is represented
--- TODO: Use `Compose DeBruijn CoreExprF` instead
+--
-- Always represent Ids, at least for now. We're seemingly using inexistent ids
+-- ROMES:TODO: do this all inside EGraphM instead
representDBCoreExpr :: Analysis a (DeBruijnF CoreExprF)
=> DeBruijn CoreExpr
-> EGraph a (DeBruijnF CoreExprF)
@@ -154,28 +153,205 @@ representDBCoreExpr (D cmenv expr) eg0 = case expr of
Let (Rec (unzip -> (bs,rs))) e ->
let cmenv' = extendCMEs cmenv bs
(bsids, eg1) = EGM.runEGraphM eg0 $
- traverse (\r -> state $ representDBCoreExpr (D cmenv' r)) rs
+ traverse (state . representDBCoreExpr . D cmenv') rs
(eid, eg2) = representDBCoreExpr (D cmenv' e) eg1
in add (Node $ DF (D cmenv (LetF (RecF (zip bs bsids)) eid))) eg2
Case e b t as -> let (eid, eg1) = representDBCoreExpr (D cmenv e) eg0
(as', eg2) = EGM.runEGraphM eg1 $
- traverse (\(Alt cons bs a) -> state $ \s -> let (aid, g) = representDBCoreExpr (D (extendCME cmenv b) a) s in (AltF cons bs aid, g)) as
+ traverse (state . representDBAltExpr . D (extendCME cmenv b)) as
in add (Node $ DF (D cmenv (CaseF eid b t as'))) eg2
+representDBAltExpr :: Analysis a (DeBruijnF CoreExprF)
+ => DeBruijn CoreAlt
+ -> EGraph a (DeBruijnF CoreExprF)
+ -> (CoreAltF ClassId, EGraph a (DeBruijnF CoreExprF))
+representDBAltExpr (D cm (Alt cons bs a)) eg0 =
+ let (ai, eg1) = representDBCoreExpr (D (extendCMEs cm bs) a) eg0
+ in (AltF cons bs ai, eg1)
--- ROMES:TODO: Instead of DeBruijnF CoreExprF we should have (ExprF (Int,Id))
--- * A represent function that makes up the debruijn indices as it is representing the expressions
--- * An Eq and Ord instance which ignore the Id and rather look at the DeBruijn index.
---
--- TODO
--- * For types, can we use eqDeBruijnType ? I think not, because Lambdas and Lets can bind type variables
---
--- TODO: The Best Alternative:
---
--- Each expression keeps its DeBruijnF CmEnv environment, but the represent
--- function must correctly prepare the debruijn indices, so that each e-node
--- has the debruijn indice it would have in its recursive descent in the Eq instance?
---
--- TODO: We could even probably have Compose DeBruijn CoreExprF in that case!
---
+instance Eq a => Eq (DeBruijn (CoreAltF a)) where
+ (==) = eqDeBruijnAltF
+
+instance Eq a => Eq (DeBruijn (CoreExprF a)) where
+ (==) = eqDeBruijnExprF
+
+instance Eq a => Eq (DeBruijnF CoreExprF a) where
+ (==) (DF a) (DF b) = eqDeBruijnExprF a b
+
+instance Eq a => Eq (DeBruijnF CoreAltF a) where
+ (==) (DF a) (DF b) = eqDeBruijnAltF a b
+
+deriving instance Ord a => Ord (DeBruijnF CoreExprF a)
+
+instance Ord a => Ord (DeBruijn (CoreExprF a)) where
+ -- We must assume that if `a` is DeBruijn expression, it is already correctly "extended" because 'representDBCoreExpr' ensures that.
+ -- RM:TODO: We don't yet compare the CmEnv at any point. Should we?
+ -- RM: I don't think so, the CmEnv is used to determine whether bound variables are equal, but they don't otherwise influence the result.
+ -- Or rather, if the subexpression with variables is equal, then the CmEnv is necessarily equal too?
+ -- So I think that just works...
+ -- Wait, in that sense, couldn't we find a way to derive ord? the important part being that to compare Types and Vars we must use the DeBruijn Env ...
+ compare a b
+ = case a of
+ D cma (VarF va)
+ -> case b of
+ D cmb (VarF vb) -> cmpDeBruijnVar (D cma va) (D cmb vb)
+ _ -> LT
+ D _ (LitF la)
+ -> case b of
+ D _ VarF{} -> GT
+ D _ (LitF lb) -> la `compare` lb
+ _ -> LT
+ D _ (AppF af aarg)
+ -> case dataToTag# b of
+ bt
+ -> if tagToEnum# (bt ># 2#) then
+ LT
+ else
+ case b of
+ D _ (AppF bf barg)
+ -> case compare af bf of
+ LT -> LT
+ EQ -> aarg `compare` barg -- e.g. here, if we had for children other expresssions debruijnized, they would have the *correct* environments, so we needn't worry.
+ -- the issue to automatically deriving is only really the 'Var' and 'Type' parameters ...
+ GT -> GT
+ _ -> GT
+ D _ (LamF _abind abody)
+ -> case dataToTag# b of
+ bt
+ -> if tagToEnum# (bt ># 3#) then
+ LT
+ else
+ case b of
+ D _ (LamF _bbind bbody) -- we can ignore the binder since the represented DB expression has the correct DB environments by construction (see 'representDBCoreExpr')
+ -> compare abody bbody
+ _ -> GT
+ D cma (LetF as abody)
+ -> case dataToTag# b of
+ bt
+ -> if tagToEnum# (bt ># 4#) then
+ LT
+ else
+ case b of
+ D cmb (LetF bs bbody)
+ -> case compare (D cma as) (D cmb bs) of
+ LT -> LT
+ EQ -> compare abody bbody
+ GT -> GT
+ _ -> GT
+ D cma (CaseF cax _cabind catype caalt)
+ -> case dataToTag# b of
+ bt
+ -> if tagToEnum# (bt <# 5#) then
+ GT
+ else
+ case b of
+ D cmb (CaseF cbx _cbbind cbtype cbalt)
+ -> case compare cax cbx of
+ LT -> LT
+ -- ROMES:TODO: Consider changing order of comparisons to a more efficient one
+ EQ -> case cmpDeBruijnType (D cma catype) (D cmb cbtype) of
+ LT -> LT
+ EQ -> D cma caalt `compare` D cmb cbalt
+ GT -> GT
+ GT -> GT
+ _ -> LT
+ D cma (CastF cax caco)
+ -> case dataToTag# b of
+ bt
+ -> if tagToEnum# (bt <# 6#) then
+ GT
+ else
+ case b of
+ D cmb (CastF cbx cbco)
+ -> case compare cax cbx of
+ LT -> LT
+ EQ -> cmpDeBruijnCoercion (D cma caco) (D cmb cbco)
+ GT -> GT
+ _ -> LT
+ D cma (TickF tatickish tax)
+ -> case dataToTag# b of
+ bt
+ -> if tagToEnum# (bt <# 7#) then
+ GT
+ else
+ case b of
+ D cmb (TickF tbtickish tbx)
+ -> case cmpDeBruijnTickish (D cma tatickish) (D cmb tbtickish) of
+ LT -> LT
+ EQ -> tax `compare` tbx
+ GT -> GT
+ _ -> LT
+ D cma (TypeF at)
+ -> case b of
+ D _ CoercionF{} -> LT
+ D cmb (TypeF bt) -> cmpDeBruijnType (D cma at) (D cmb bt)
+ _ -> GT
+ D cma (CoercionF aco)
+ -> case b of
+ D cmb (CoercionF bco) -> cmpDeBruijnCoercion (D cma aco) (D cmb bco)
+ _ -> GT
+
+instance Eq a => Eq (DeBruijn (CoreBindF a)) where
+ D cma a == D cmb b = go a b where
+ go (NonRecF _v1 r1) (NonRecF _v2 r2)
+ = r1 == r2 -- See Note [Alpha-equality for let-bindings]
+
+ go (RecF ps1) (RecF ps2)
+ =
+ -- See Note [Alpha-equality for let-bindings]
+ all2 (\b1 b2 -> eqDeBruijnType (D cma (varType b1))
+ (D cmb (varType b2)))
+ bs1 bs2
+ && rs1 == rs2
+ where
+ (bs1,rs1) = unzip ps1
+ (bs2,rs2) = unzip ps2
+
+ go _ _ = False
+
+
+instance Ord a => Ord (DeBruijn (CoreBindF a)) where
+ compare a b
+ = case a of
+ D _cma (NonRecF _ab ax)
+ -> case b of
+ D _cmb (NonRecF _bb bx) -- Again, we ignore the binders bc on representation they were accounted for correctly.
+ -> ax `compare` bx
+ _ -> LT
+ D _cma (RecF as)
+ -> case b of
+ D _cmb (RecF bs) -> compare (map snd as) (map snd bs)
+ _ -> GT
+
+
+instance Ord a => Ord (DeBruijn (CoreAltF a)) where
+ compare a b
+ = case a of
+ D _cma (AltF ac _abs arhs)
+ -> case b of
+ D _cmb (AltF bc _bbs brhs)
+ -> case compare ac bc of
+ LT -> LT
+ EQ -> -- Again, we don't look at binders bc we assume on representation they were accounted for correctly.
+ arhs `compare` brhs
+ GT -> GT
+
+cmpDeBruijnTickish :: DeBruijn CoreTickish -> DeBruijn CoreTickish -> Ordering
+cmpDeBruijnTickish (D env1 t1) (D env2 t2) = go t1 t2 where
+ go (Breakpoint lext lid lids) (Breakpoint rext rid rids)
+ = case compare lid rid of
+ LT -> LT
+ EQ -> case compare (D env1 lids) (D env2 rids) of
+ LT -> LT
+ EQ -> compare lext rext
+ GT -> GT
+ GT -> GT
+ go l r = compare l r
+
+-- ROMES:TODO: DEBRUIJN ORDERING ON TYPES!!!
+cmpDeBruijnType :: DeBruijn Type -> DeBruijn Type -> Ordering
+cmpDeBruijnType _ _ = EQ
+-- ROMES:TODO: DEBRUIJN ORDERING ON COERCIONS!!!
+cmpDeBruijnCoercion :: DeBruijn Coercion -> DeBruijn Coercion -> Ordering
+cmpDeBruijnCoercion _ _ = EQ
=====================================
compiler/GHC/Core/Map/Type.hs
=====================================
@@ -22,7 +22,8 @@ module GHC.Core.Map.Type (
-- * Utilities for use by friends only
TypeMapG, CoercionMapG,
- DeBruijn(..), DeBruijnF(..), deBruijnize, deBruijnizeF, eqDeBruijnType, eqDeBruijnVar,
+ DeBruijn(..), deBruijnize, eqDeBruijnType, eqDeBruijnVar,
+ cmpDeBruijnVar,
BndrMap, xtBndr, lkBndr,
VarMap, xtVar, lkVar, lkDFreeVar, xtDFreeVar,
@@ -283,6 +284,9 @@ eqDeBruijnType env_t1@(D env1 t1) env_t2@(D env2 t2) =
instance Eq (DeBruijn Var) where
(==) = eqDeBruijnVar
+instance Ord (DeBruijn Var) where
+ compare = cmpDeBruijnVar
+
eqDeBruijnVar :: DeBruijn Var -> DeBruijn Var -> Bool
eqDeBruijnVar (D env1 v1) (D env2 v2) =
case (lookupCME env1 v1, lookupCME env2 v2) of
@@ -290,6 +294,13 @@ eqDeBruijnVar (D env1 v1) (D env2 v2) =
(Nothing, Nothing) -> v1 == v2
_ -> False
+cmpDeBruijnVar :: DeBruijn Var -> DeBruijn Var -> Ordering
+cmpDeBruijnVar (D env1 v1) (D env2 v2) =
+ case (lookupCME env1 v1, lookupCME env2 v2) of
+ (Just b1, Just b2) -> compare b1 b2
+ (Nothing, Nothing) -> compare v1 v2
+ (z,w) -> compare z w -- Compare Maybes on whether they're Just or Nothing
+
instance {-# OVERLAPPING #-}
Outputable a => Outputable (TypeMapG a) where
ppr m = text "TypeMap elts" <+> ppr (foldTM (:) m [])
@@ -513,9 +524,7 @@ lookupCME (CME { cme_env = env }) v = lookupVarEnv env v
-- export the constructor. Make a helper function if you find yourself
-- needing it.
data DeBruijn a = D CmEnv a
- deriving (Functor,Foldable,Traversable) -- romes:TODO: For internal use only!
-
-newtype DeBruijnF f a = DF (DeBruijn (f a))
+ deriving (Functor, Foldable, Traversable) -- romes:TODO: for internal use only!
-- | Synthesizes a @DeBruijn a@ from an @a@, by assuming that there are no
-- bound binders (an empty 'CmEnv'). This is usually what you want if there
@@ -523,16 +532,21 @@ newtype DeBruijnF f a = DF (DeBruijn (f a))
deBruijnize :: a -> DeBruijn a
deBruijnize = D emptyCME
--- | Like 'deBruijnize' but synthesizes a @DeBruijnF f a@ from an @f a@
-deBruijnizeF :: Functor f => f a -> DeBruijnF f a
-deBruijnizeF = DF . deBruijnize
-
instance Eq (DeBruijn a) => Eq (DeBruijn [a]) where
D _ [] == D _ [] = True
D env (x:xs) == D env' (x':xs') = D env x == D env' x' &&
D env xs == D env' xs'
_ == _ = False
+instance Ord (DeBruijn a) => Ord (DeBruijn [a]) where
+ D _ [] `compare` D _ [] = EQ
+ D env (x:xs) `compare` D env' (x':xs') = case D env x `compare` D env' x' of
+ LT -> LT
+ EQ -> D env xs `compare` D env' xs'
+ GT -> GT
+ D _ [] `compare` D _ (_:_) = LT
+ D _ (_:_) `compare` D _ [] = GT
+
instance Eq (DeBruijn a) => Eq (DeBruijn (Maybe a)) where
D _ Nothing == D _ Nothing = True
D env (Just x) == D env' (Just x') = D env x == D env' x'
=====================================
compiler/GHC/HsToCore/Pmc/Solver/Types.hs
=====================================
@@ -46,7 +46,6 @@ import GHC.Prelude
import GHC.Data.Bag
import GHC.Data.FastString
import GHC.Types.Id
-import GHC.Types.Var.Set
import GHC.Types.Unique.DSet
import GHC.Types.Name
import GHC.Core.Functor
@@ -62,7 +61,6 @@ import GHC.Core.TyCon
import GHC.Types.Literal
import GHC.Core
import GHC.Core.TyCo.Compare( eqType )
-import GHC.Core.Map.Expr
import GHC.Core.Map.Type
import GHC.Core.Utils (exprType)
import GHC.Builtin.Names
@@ -80,12 +78,9 @@ import Data.Ratio
import GHC.Real (Ratio(..))
import qualified Data.Semigroup as Semi
-import Data.Functor.Const
import Data.Functor.Compose
-import Data.Function ((&))
import Data.Equality.Analysis (Analysis(..))
import Data.Equality.Graph (EGraph, ClassId)
-import Data.Equality.Utils (Fix(..))
import Data.Equality.Graph.Lens
import qualified Data.Equality.Graph as EG
import Data.IntSet (IntSet)
@@ -843,4 +838,4 @@ instance Outputable PmEquality where
representId :: Id -> TmEGraph -> (ClassId, TmEGraph)
-- ROMES:TODO: bit of a hack to represent binders with `Var`, which is likely wrong (lambda bound vars might get equivalent to global ones?). Will need to justify this well
-representId x = EG.add (EG.Node (deBruijnizeF (VarF x))) -- debruijn things are compared correctly wrt binders, but we can still have a debruijn var w name with no prob
+representId x = EG.add (EG.Node (DF (deBruijnize (VarF x)))) -- debruijn things are compared correctly wrt binders, but we can still have a debruijn var w name with no prob
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/4d72410fe0931fd1a94c8efcd3c8c8d52af2f396
--
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/4d72410fe0931fd1a94c8efcd3c8c8d52af2f396
You're receiving this email because of your account on gitlab.haskell.org.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20230625/c87ab9a2/attachment-0001.html>
More information about the ghc-commits
mailing list