[Git][ghc/ghc][wip/gvn-pmcheck] Pattern match complex expressions by GVN
Sebastian Graf
gitlab at gitlab.haskell.org
Fri May 3 08:58:43 UTC 2019
Sebastian Graf pushed to branch wip/gvn-pmcheck at Glasgow Haskell Compiler / GHC
Commits:
24c20cd7 by Sebastian Graf at 2019-05-03T08:34:44Z
Pattern match complex expressions by GVN
By referential transparency, multiple syntactic occurrences of the same
expression evaluate to the same value. Global value numbering (GVN)
assigns each such expression the same unique number (a `Name` in our
case). Two expressions trivially have the same value if they are
assigned the same value number.
The term oracle `TmOracle` of the pattern match checker couldn't handle
any complex expression before this patch. It would just give up on
anything involving a function application whose head was not a
constructor, by falling back to `PmExprOther`. This means it could not
determine completeness of the following example:
```haskell
foo
| True <- id True
= 1
| False <- id True
= 2
```
This is simply because `TmOracle` couldn't figure out that `id True`
always evaluates to the same `Bool`.
In this patch, we desugar such `PmExprOther`s in pattern guards to
`CoreExpr`. We do so in order to utilise `CoreMap Name` for a
light-weight GVN pass without concern for subexpressions. `TmOracle`
only sees the representing variables, like so:
```haskell
x = id True
foo
| True <- x
= 1
| False <- x
= 2
```
So `TmOracle` still doesn't need to decide equality of complex
expressions, which allows it to stay dead simple.
- - - - -
1 changed file:
- compiler/deSugar/Check.hs
Changes:
=====================================
compiler/deSugar/Check.hs
=====================================
@@ -32,6 +32,7 @@ import TcHsSyn
import Id
import ConLike
import Name
+import NameEnv
import FamInstEnv
import TysPrim (tYPETyCon)
import TysWiredIn
@@ -42,9 +43,13 @@ import Outputable
import FastString
import DataCon
import PatSyn
-import HscTypes (CompleteMatch(..))
+import HscTypes (CompleteMatch(..))
+import CoreMap (CoreMap, emptyCoreMap, lookupCoreMap, extendCoreMap)
+import CoreOpt (simpleOptExpr)
+import CoreUtils (exprType)
import DsMonad
+import {-# SOURCE #-} DsExpr (dsExpr)
import TcSimplify (tcCheckSatisfiability)
import TcType (isStringTy)
import Bag
@@ -60,13 +65,15 @@ import qualified GHC.LanguageExtensions as LangExt
import Data.List (find)
import Data.Maybe (catMaybes, isJust, fromMaybe)
import Control.Monad (forM, when, forM_, zipWithM, filterM)
+import Control.Monad.Trans.State.Strict (StateT (..), evalStateT)
+import Control.Monad.Trans.Class
import Coercion
import TcEvidence
import TcSimplify (tcNormalise)
import IOEnv
import qualified Data.Semigroup as Semi
-import ListT (ListT(..), fold, select)
+import ListT (ListT(..), fold, select)
{-
This module checks pattern matches for:
@@ -140,6 +147,34 @@ getResult ls
go (Just (PmResult _ _ (TypeOfUncovered _) _)) _new
= panic "getResult: No inhabitation candidates"
+data TranslateEnv
+ = TE { te_rep_env :: !(CoreMap Id)
+ -- ^ Representatives for PmExprOther as Core expressions
+ , te_orig_exprs :: NameEnv (HsExpr GhcTc)
+ -- ^ Maps representatives to their represented expression
+ }
+
+initialTE :: TranslateEnv
+initialTE = TE emptyCoreMap emptyNameEnv
+
+-- | Monad in which we translate pattern matches
+type TlM a = StateT TranslateEnv DsM a
+
+representPmExprOther :: PmExpr -> TlM PmExpr
+representPmExprOther (PmExprOther e) = do
+ dflags <- lift getDynFlags
+ core_expr <- simpleOptExpr dflags <$> lift (dsExpr e)
+ StateT $ \env at TE{te_rep_env = cm, te_orig_exprs = origs } -> do
+ (name, env') <-
+ case lookupCoreMap cm core_expr of
+ Just y -> pure (idName y, env)
+ Nothing -> do
+ y <- mkPmId (exprType core_expr)
+ pure (idName y, env { te_rep_env = extendCoreMap cm core_expr y })
+ tracePmD "representPmExprOther" (ppr name <+> text "->" <+> ppr (e, core_expr))
+ pure (PmExprVar name, env' { te_orig_exprs = extendNameEnv origs name e })
+representPmExprOther e = pure e
+
data PatTy = PAT | VA -- Used only as a kind, to index PmPat
-- The *arity* of a PatVec [p1,..,pn] is
@@ -350,9 +385,9 @@ checkSingle dflags ctxt@(DsMatchContext _ locn) var p = do
checkSingle' :: SrcSpan -> Id -> Pat GhcTc -> PmM PmResult
checkSingle' locn var p = do
liftD resetPmIterDs -- set the iter-no to zero
- fam_insts <- liftD dsGetFamInstEnvs
- clause <- liftD $ translatePat fam_insts p
- missing <- mkInitialUncovered [var]
+ fam_insts <- liftD dsGetFamInstEnvs
+ (clause, te) <- liftD $ runStateT (translatePat fam_insts p) initialTE
+ missing <- mkInitialUncovered [var]
tracePm "checkSingle': missing" (vcat (map pprValVecDebug missing))
-- no guards
PartialResult prov cs us ds <- runMany (pmcheckI clause []) missing
@@ -422,8 +457,8 @@ checkMatches' vars matches
go [] missing = return (mempty, [], missing, [])
go (m:ms) missing = do
tracePm "checkMatches': go" (ppr m $$ ppr missing)
- fam_insts <- liftD dsGetFamInstEnvs
- (clause, guards) <- liftD $ translateMatch fam_insts m
+ fam_insts <- liftD dsGetFamInstEnvs
+ ((clause, guards), te) <- liftD $ runStateT (translateMatch fam_insts m) initialTE
r@(PartialResult prov cs missing' ds)
<- runMany (pmcheckI clause guards) missing
tracePm "checkMatches': go: res" (ppr r)
@@ -966,12 +1001,12 @@ mkLitPattern lit = PmLit { pm_lit_lit = PmSLit lit }
-- -----------------------------------------------------------------------
-- * Transform (Pat Id) into of (PmPat Id)
-translatePat :: FamInstEnvs -> Pat GhcTc -> DsM PatVec
+translatePat :: FamInstEnvs -> Pat GhcTc -> TlM PatVec
translatePat fam_insts pat = case pat of
- WildPat ty -> mkPmVars [ty]
+ WildPat ty -> lift $ mkPmVars [ty]
VarPat _ id -> return [PmVar (unLoc id)]
ParPat _ p -> translatePat fam_insts (unLoc p)
- LazyPat _ _ -> mkPmVars [hsPatType pat] -- like a variable
+ LazyPat _ _ -> lift $ mkPmVars [hsPatType pat] -- like a variable
-- ignore strictness annotations for now
BangPat _ p -> translatePat fam_insts (unLoc p)
@@ -991,24 +1026,24 @@ translatePat fam_insts pat = case pat of
| WpCast co <- wrapper, isReflexiveCo co -> translatePat fam_insts p
| otherwise -> do
ps <- translatePat fam_insts p
- (xp,xe) <- mkPmId2Forms ty
+ (xp,xe) <- lift $ mkPmId2Forms ty
g <- mkGuard ps (mkHsWrap wrapper (unLoc xe))
return [xp,g]
-- (n + k) ===> x (True <- x >= k) (n <- x-k)
- NPlusKPat ty (dL->L _ _n) _k1 _k2 _ge _minus -> mkCanFailPmPat ty
+ NPlusKPat ty (dL->L _ _n) _k1 _k2 _ge _minus -> lift $ mkCanFailPmPat ty
-- (fun -> pat) ===> x (pat <- fun x)
ViewPat arg_ty lexpr lpat -> do
ps <- translatePat fam_insts (unLoc lpat)
-- See Note [Guards and Approximation]
- res <- allM cantFailPattern ps
+ res <- lift $ allM cantFailPattern ps
case res of
True -> do
- (xp,xe) <- mkPmId2Forms arg_ty
+ (xp,xe) <- lift $ mkPmId2Forms arg_ty
g <- mkGuard ps (HsApp noExt lexpr xe)
return [xp,g]
- False -> mkCanFailPmPat arg_ty
+ False -> lift $ mkCanFailPmPat arg_ty
-- list
ListPat (ListPatTc ty Nothing) ps -> do
@@ -1017,13 +1052,13 @@ translatePat fam_insts pat = case pat of
-- overloaded list
ListPat (ListPatTc _elem_ty (Just (pat_ty, _to_list))) lpats -> do
- dflags <- getDynFlags
+ dflags <- lift $ getDynFlags
if xopt LangExt.RebindableSyntax dflags
- then mkCanFailPmPat pat_ty
+ then lift $ mkCanFailPmPat pat_ty
else case splitListTyConApp_maybe pat_ty of
Just e_ty -> translatePat fam_insts
(ListPat (ListPatTc e_ty Nothing) lpats)
- Nothing -> mkCanFailPmPat pat_ty
+ Nothing -> lift $ mkCanFailPmPat pat_ty
-- (a) In the presence of RebindableSyntax, we don't know anything about
-- `toList`, we should treat `ListPat` as any other view pattern.
--
@@ -1047,9 +1082,9 @@ translatePat fam_insts pat = case pat of
, pat_tvs = ex_tvs
, pat_dicts = dicts
, pat_args = ps } -> do
- groups <- allCompleteMatches con arg_tys
+ groups <- lift $ allCompleteMatches con arg_tys
case groups of
- [] -> mkCanFailPmPat (conLikeResTy con arg_tys)
+ [] -> lift $ mkCanFailPmPat (conLikeResTy con arg_tys)
_ -> do
args <- translateConPatVec fam_insts arg_tys ex_tvs con ps
return [PmCon { pm_con_con = con
@@ -1178,23 +1213,23 @@ from translation in pattern matcher.
-- | Translate a list of patterns (Note: each pattern is translated
-- to a pattern vector but we do not concatenate the results).
-translatePatVec :: FamInstEnvs -> [Pat GhcTc] -> DsM [PatVec]
+translatePatVec :: FamInstEnvs -> [Pat GhcTc] -> TlM [PatVec]
translatePatVec fam_insts pats = mapM (translatePat fam_insts) pats
-- | Translate a constructor pattern
translateConPatVec :: FamInstEnvs -> [Type] -> [TyVar]
- -> ConLike -> HsConPatDetails GhcTc -> DsM PatVec
+ -> ConLike -> HsConPatDetails GhcTc -> TlM PatVec
translateConPatVec fam_insts _univ_tys _ex_tvs _ (PrefixCon ps)
= concat <$> translatePatVec fam_insts (map unLoc ps)
translateConPatVec fam_insts _univ_tys _ex_tvs _ (InfixCon p1 p2)
= concat <$> translatePatVec fam_insts (map unLoc [p1,p2])
translateConPatVec fam_insts univ_tys ex_tvs c (RecCon (HsRecFields fs _))
-- Nothing matched. Make up some fresh term variables
- | null fs = mkPmVars arg_tys
+ | null fs = lift $ mkPmVars arg_tys
-- The data constructor was not defined using record syntax. For the
-- pattern to be in record syntax it should be empty (e.g. Just {}).
-- So just like the previous case.
- | null orig_lbls = ASSERT(null matched_lbls) mkPmVars arg_tys
+ | null orig_lbls = ASSERT(null matched_lbls) lift $ mkPmVars arg_tys
-- Some of the fields appear, in the original order (there may be holes).
-- Generate a simple constructor pattern and make up fresh variables for
-- the rest of the fields
@@ -1202,13 +1237,13 @@ translateConPatVec fam_insts univ_tys ex_tvs c (RecCon (HsRecFields fs _))
= ASSERT(orig_lbls `equalLength` arg_tys)
let translateOne (lbl, ty) = case lookup lbl matched_pats of
Just p -> translatePat fam_insts p
- Nothing -> mkPmVars [ty]
+ Nothing -> lift $ mkPmVars [ty]
in concatMapM translateOne (zip orig_lbls arg_tys)
-- The fields that appear are not in the correct order. Make up fresh
-- variables for all fields and add guards after matching, to force the
-- evaluation in the correct order.
| otherwise = do
- arg_var_pats <- mkPmVars arg_tys
+ arg_var_pats <- lift $ mkPmVars arg_tys
translated_pats <- forM matched_pats $ \(x,pat) -> do
pvec <- translatePat fam_insts pat
return (x, pvec)
@@ -1239,7 +1274,7 @@ translateConPatVec fam_insts univ_tys ex_tvs c (RecCon (HsRecFields fs _))
-- Translate a single match
translateMatch :: FamInstEnvs -> LMatch GhcTc (LHsExpr GhcTc)
- -> DsM (PatVec,[PatVec])
+ -> TlM (PatVec,[PatVec])
translateMatch fam_insts (dL->L _ (Match { m_pats = lpats, m_grhss = grhss })) =
do
pats' <- concat <$> translatePatVec fam_insts pats
@@ -1258,7 +1293,7 @@ translateMatch _ _ = panic "translateMatch"
-- * Transform source guards (GuardStmt Id) to PmPats (Pattern)
-- | Translate a list of guard statements to a pattern vector
-translateGuards :: FamInstEnvs -> [GuardStmt GhcTc] -> DsM PatVec
+translateGuards :: FamInstEnvs -> [GuardStmt GhcTc] -> TlM PatVec
translateGuards fam_insts guards = do
all_guards <- concat <$> mapM (translateGuard fam_insts) guards
let
@@ -1273,7 +1308,7 @@ translateGuards fam_insts guards = do
| otherwise = allM shouldKeep pv
shouldKeep _other_pat = pure False -- let the rest..
- all_handled <- allM shouldKeep all_guards
+ all_handled <- lift $ allM shouldKeep all_guards
-- It should have been @pure all_guards@ but it is too expressive.
-- Since the term oracle does not handle all constraints we generate,
-- we (hackily) replace all constraints the oracle cannot handle with a
@@ -1283,7 +1318,7 @@ translateGuards fam_insts guards = do
if all_handled
then pure all_guards
else do
- kept <- filterM shouldKeep all_guards
+ kept <- lift $ filterM shouldKeep all_guards
pure (PmFake : kept)
-- | Check whether a pattern can fail to match
@@ -1295,7 +1330,7 @@ cantFailPattern (PmGrd pv _e) = allM cantFailPattern pv
cantFailPattern _ = pure False
-- | Translate a guard statement to Pattern
-translateGuard :: FamInstEnvs -> GuardStmt GhcTc -> DsM PatVec
+translateGuard :: FamInstEnvs -> GuardStmt GhcTc -> TlM PatVec
translateGuard fam_insts guard = case guard of
BodyStmt _ e _ _ -> translateBoolGuard e
LetStmt _ binds -> translateLet (unLoc binds)
@@ -1308,18 +1343,18 @@ translateGuard fam_insts guard = case guard of
XStmtLR {} -> panic "translateGuard RecStmt"
-- | Translate let-bindings
-translateLet :: HsLocalBinds GhcTc -> DsM PatVec
+translateLet :: HsLocalBinds GhcTc -> TlM PatVec
translateLet _binds = return []
-- | Translate a pattern guard
-translateBind :: FamInstEnvs -> LPat GhcTc -> LHsExpr GhcTc -> DsM PatVec
+translateBind :: FamInstEnvs -> LPat GhcTc -> LHsExpr GhcTc -> TlM PatVec
translateBind fam_insts (dL->L _ p) e = do
ps <- translatePat fam_insts p
g <- mkGuard ps (unLoc e)
return [g]
-- | Translate a boolean guard
-translateBoolGuard :: LHsExpr GhcTc -> DsM PatVec
+translateBoolGuard :: LHsExpr GhcTc -> TlM PatVec
translateBoolGuard e
| isJust (isTrueLHsExpr e) = return []
-- The formal thing to do would be to generate (True <- True)
@@ -1663,14 +1698,13 @@ mkOneConFull x con = do
-- * More smart constructors and fresh variable generation
-- | Create a guard pattern
-mkGuard :: PatVec -> HsExpr GhcTc -> DsM Pattern
+mkGuard :: PatVec -> HsExpr GhcTc -> TlM Pattern
mkGuard pv e = do
- res <- allM cantFailPattern pv
- let expr = hsExprToPmExpr e
- tracePmD "mkGuard" (vcat [ppr pv, ppr e, ppr res, ppr expr])
- if | res -> pure (PmGrd pv expr)
- | PmExprOther {} <- expr -> pure PmFake
- | otherwise -> pure (PmGrd pv expr)
+ res <- lift $ allM cantFailPattern pv
+ let expr = hsExprToPmExpr e
+ expr' <- representPmExprOther expr
+ traceTl "mkGuard" (vcat [ppr pv, ppr e, ppr res, ppr expr, ppr expr'])
+ pure (PmGrd pv expr')
-- | Create a term equality of the form: `(False ~ (x ~ lit))`
mkNegEq :: Id -> PmLit -> ComplexEq
@@ -2403,8 +2437,8 @@ genCaseTmCs2 :: Maybe (LHsExpr GhcTc) -- Scrutinee
-> [Id] -- MatchVars (should have length 1)
-> DsM (Bag SimpleEq)
genCaseTmCs2 Nothing _ _ = return emptyBag
-genCaseTmCs2 (Just scr) [p] [var] = do
- fam_insts <- dsGetFamInstEnvs
+genCaseTmCs2 (Just scr) [p] [var] = flip evalStateT initialTE $ do
+ fam_insts <- lift $ dsGetFamInstEnvs
[e] <- map vaToPmExpr . coercePatVec <$> translatePat fam_insts p
let scr_e = lhsExprToPmExpr scr
return $ listToBag [(var, e), (var, scr_e)]
@@ -2719,6 +2753,8 @@ involved.
tracePm :: String -> SDoc -> PmM ()
tracePm herald doc = liftD $ tracePmD herald doc
+traceTl :: String -> SDoc -> TlM ()
+traceTl herald doc = lift $ tracePmD herald doc
tracePmD :: String -> SDoc -> DsM ()
tracePmD herald doc = do
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/commit/24c20cd71fbdd0de588a7ea0e06cbc520fd3a97c
--
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/commit/24c20cd71fbdd0de588a7ea0e06cbc520fd3a97c
You're receiving this email because of your account on gitlab.haskell.org.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20190503/ec19d6a4/attachment-0001.html>
More information about the ghc-commits
mailing list