24c20cd7 by Sebastian Graf at 2019-05-03T08:34:44Z
Pattern match complex expressions by GVN

By referential transparency, multiple syntactic occurrences of the same
expression evaluate to the same value. Global value numbering (GVN)
assigns each such expression the same unique number (a `Name` in our
case). Two expressions trivially have the same value if they are
assigned the same value number.

The term oracle `TmOracle` of the pattern match checker couldn't handle
any complex expression before this patch. It would just give up on
anything involving a function application whose head was not a
constructor, by falling back to `PmExprOther`. This means it could not
determine completeness of the following example:

  | True <- id True
  = 1
  | False <- id True
  = 2

This is simply because `TmOracle` couldn't figure out that `id True`
always evaluates to the same `Bool`.

In this patch, we desugar such `PmExprOther`s in pattern guards to
`CoreExpr`.  We do so in order to utilise `CoreMap Name` for a
light-weight GVN pass without concern for subexpressions.  `TmOracle`
only sees the representing variables, like so:

x = id True

  | True <- x
  = 1
  | False <- x
  = 2

So `TmOracle` still doesn't need to decide equality of complex
expressions, which allows it to stay dead simple.

- - - - -

1 changed file:

- compiler/deSugar/Check.hs


@@ -32,6 +32,7 @@ import TcHsSyn
 import Id
 import ConLike
 import Name
+import NameEnv
 import FamInstEnv
 import TysPrim (tYPETyCon)
 import TysWiredIn
@@ -42,9 +43,13 @@ import Outputable
 import FastString
 import DataCon
 import PatSyn
-import HscTypes (CompleteMatch(..))
+import HscTypes      (CompleteMatch(..))
+import CoreMap       (CoreMap, emptyCoreMap, lookupCoreMap, extendCoreMap)
+import CoreOpt       (simpleOptExpr)
+import CoreUtils     (exprType)
 import DsMonad
+import {-# SOURCE #-} DsExpr        (dsExpr)
 import TcSimplify    (tcCheckSatisfiability)
 import TcType        (isStringTy)
 import Bag
@@ -60,13 +65,15 @@ import qualified GHC.LanguageExtensions as LangExt
 import Data.List     (find)
 import Data.Maybe    (catMaybes, isJust, fromMaybe)
 import Control.Monad (forM, when, forM_, zipWithM, filterM)
+import Control.Monad.Trans.State.Strict (StateT (..), evalStateT)
+import Control.Monad.Trans.Class
 import Coercion
 import TcEvidence
 import TcSimplify    (tcNormalise)
 import IOEnv
 import qualified Data.Semigroup as Semi
-import ListT (ListT(..), fold, select)
+import ListT         (ListT(..), fold, select)
 This module checks pattern matches for:
@@ -140,6 +147,34 @@ getResult ls
     go (Just (PmResult _ _ (TypeOfUncovered _) _)) _new
       = panic "getResult: No inhabitation candidates"
+data TranslateEnv
+  = TE { te_rep_env :: !(CoreMap Id)
+       -- ^ Representatives for PmExprOther as Core expressions
+       , te_orig_exprs :: NameEnv (HsExpr GhcTc)
+       -- ^ Maps representatives to their represented expression
+       }
+initialTE :: TranslateEnv
+initialTE = TE emptyCoreMap emptyNameEnv
+-- | Monad in which we translate pattern matches
+type TlM a = StateT TranslateEnv DsM a
+representPmExprOther :: PmExpr -> TlM PmExpr
+representPmExprOther (PmExprOther e) = do
+  dflags <- lift getDynFlags
+  core_expr <- simpleOptExpr dflags <$> lift (dsExpr e)
+  StateT $ \env at TE{te_rep_env = cm, te_orig_exprs = origs } -> do
+    (name, env') <-
+      case lookupCoreMap cm core_expr of
+        Just y  -> pure (idName y, env)
+        Nothing -> do
+          y <- mkPmId (exprType core_expr)
+          pure (idName y, env { te_rep_env = extendCoreMap cm core_expr y })
+    tracePmD "representPmExprOther" (ppr name <+> text "->" <+> ppr (e, core_expr))
+    pure (PmExprVar name, env' { te_orig_exprs = extendNameEnv origs name e })
+representPmExprOther e = pure e
 data PatTy = PAT | VA -- Used only as a kind, to index PmPat
 -- The *arity* of a PatVec [p1,..,pn] is
@@ -350,9 +385,9 @@ checkSingle dflags ctxt@(DsMatchContext _ locn) var p = do
 checkSingle' :: SrcSpan -> Id -> Pat GhcTc -> PmM PmResult
 checkSingle' locn var p = do
   liftD resetPmIterDs -- set the iter-no to zero
-  fam_insts <- liftD dsGetFamInstEnvs
-  clause    <- liftD $ translatePat fam_insts p
-  missing   <- mkInitialUncovered [var]
+  fam_insts    <- liftD dsGetFamInstEnvs
+  (clause, te) <- liftD $ runStateT (translatePat fam_insts p) initialTE
+  missing      <- mkInitialUncovered [var]
   tracePm "checkSingle': missing" (vcat (map pprValVecDebug missing))
                                   -- no guards
   PartialResult prov cs us ds <- runMany (pmcheckI clause []) missing
@@ -422,8 +457,8 @@ checkMatches' vars matches
     go []     missing = return (mempty, [], missing, [])
     go (m:ms) missing = do
       tracePm "checkMatches': go" (ppr m $$ ppr missing)
-      fam_insts          <- liftD dsGetFamInstEnvs
-      (clause, guards)   <- liftD $ translateMatch fam_insts m
+      fam_insts              <- liftD dsGetFamInstEnvs
+      ((clause, guards), te) <- liftD $ runStateT (translateMatch fam_insts m) initialTE
       r@(PartialResult prov cs missing' ds)
         <- runMany (pmcheckI clause guards) missing
       tracePm "checkMatches': go: res" (ppr r)
@@ -966,12 +1001,12 @@ mkLitPattern lit = PmLit { pm_lit_lit = PmSLit lit }
 -- -----------------------------------------------------------------------
 -- * Transform (Pat Id) into of (PmPat Id)
-translatePat :: FamInstEnvs -> Pat GhcTc -> DsM PatVec
+translatePat :: FamInstEnvs -> Pat GhcTc -> TlM PatVec
 translatePat fam_insts pat = case pat of
-  WildPat  ty  -> mkPmVars [ty]
+  WildPat  ty  -> lift $ mkPmVars [ty]
   VarPat _ id  -> return [PmVar (unLoc id)]
   ParPat _ p   -> translatePat fam_insts (unLoc p)
-  LazyPat _ _  -> mkPmVars [hsPatType pat] -- like a variable
+  LazyPat _ _  -> lift $ mkPmVars [hsPatType pat] -- like a variable
   -- ignore strictness annotations for now
   BangPat _ p  -> translatePat fam_insts (unLoc p)
@@ -991,24 +1026,24 @@ translatePat fam_insts pat = case pat of
     | WpCast co <-  wrapper, isReflexiveCo co -> translatePat fam_insts p
     | otherwise -> do
         ps      <- translatePat fam_insts p
-        (xp,xe) <- mkPmId2Forms ty
+        (xp,xe) <- lift $ mkPmId2Forms ty
         g <- mkGuard ps (mkHsWrap wrapper (unLoc xe))
         return [xp,g]
   -- (n + k)  ===>   x (True <- x >= k) (n <- x-k)
-  NPlusKPat ty (dL->L _ _n) _k1 _k2 _ge _minus -> mkCanFailPmPat ty
+  NPlusKPat ty (dL->L _ _n) _k1 _k2 _ge _minus -> lift $ mkCanFailPmPat ty
   -- (fun -> pat)   ===>   x (pat <- fun x)
   ViewPat arg_ty lexpr lpat -> do
     ps <- translatePat fam_insts (unLoc lpat)
     -- See Note [Guards and Approximation]
-    res <- allM cantFailPattern ps
+    res <- lift $ allM cantFailPattern ps
     case res of
       True  -> do
-        (xp,xe) <- mkPmId2Forms arg_ty
+        (xp,xe) <- lift $ mkPmId2Forms arg_ty
         g <- mkGuard ps (HsApp noExt lexpr xe)
         return [xp,g]
-      False -> mkCanFailPmPat arg_ty
+      False -> lift $ mkCanFailPmPat arg_ty
   -- list
   ListPat (ListPatTc ty Nothing) ps -> do
@@ -1017,13 +1052,13 @@ translatePat fam_insts pat = case pat of
   -- overloaded list
   ListPat (ListPatTc _elem_ty (Just (pat_ty, _to_list))) lpats -> do
-    dflags <- getDynFlags
+    dflags <- lift $ getDynFlags
     if xopt LangExt.RebindableSyntax dflags
-       then mkCanFailPmPat pat_ty
+       then lift $ mkCanFailPmPat pat_ty
        else case splitListTyConApp_maybe pat_ty of
               Just e_ty -> translatePat fam_insts
                                         (ListPat (ListPatTc e_ty Nothing) lpats)
-              Nothing   -> mkCanFailPmPat pat_ty
+              Nothing   -> lift $ mkCanFailPmPat pat_ty
     -- (a) In the presence of RebindableSyntax, we don't know anything about
     --     `toList`, we should treat `ListPat` as any other view pattern.
@@ -1047,9 +1082,9 @@ translatePat fam_insts pat = case pat of
             , pat_tvs     = ex_tvs
             , pat_dicts   = dicts
             , pat_args    = ps } -> do
-    groups <- allCompleteMatches con arg_tys
+    groups <- lift $ allCompleteMatches con arg_tys
     case groups of
-      [] -> mkCanFailPmPat (conLikeResTy con arg_tys)
+      [] -> lift $ mkCanFailPmPat (conLikeResTy con arg_tys)
       _  -> do
         args <- translateConPatVec fam_insts arg_tys ex_tvs con ps
         return [PmCon { pm_con_con     = con
@@ -1178,23 +1213,23 @@ from translation in pattern matcher.
 -- | Translate a list of patterns (Note: each pattern is translated
 -- to a pattern vector but we do not concatenate the results).
-translatePatVec :: FamInstEnvs -> [Pat GhcTc] -> DsM [PatVec]
+translatePatVec :: FamInstEnvs -> [Pat GhcTc] -> TlM [PatVec]
 translatePatVec fam_insts pats = mapM (translatePat fam_insts) pats
 -- | Translate a constructor pattern
 translateConPatVec :: FamInstEnvs -> [Type] -> [TyVar]
-                   -> ConLike -> HsConPatDetails GhcTc -> DsM PatVec
+                   -> ConLike -> HsConPatDetails GhcTc -> TlM PatVec
 translateConPatVec fam_insts _univ_tys _ex_tvs _ (PrefixCon ps)
   = concat <$> translatePatVec fam_insts (map unLoc ps)
 translateConPatVec fam_insts _univ_tys _ex_tvs _ (InfixCon p1 p2)
   = concat <$> translatePatVec fam_insts (map unLoc [p1,p2])
 translateConPatVec fam_insts  univ_tys  ex_tvs c (RecCon (HsRecFields fs _))
     -- Nothing matched. Make up some fresh term variables
-  | null fs        = mkPmVars arg_tys
+  | null fs        = lift $ mkPmVars arg_tys
     -- The data constructor was not defined using record syntax. For the
     -- pattern to be in record syntax it should be empty (e.g. Just {}).
     -- So just like the previous case.
-  | null orig_lbls = ASSERT(null matched_lbls) mkPmVars arg_tys
+  | null orig_lbls = ASSERT(null matched_lbls) lift $ mkPmVars arg_tys
     -- Some of the fields appear, in the original order (there may be holes).
     -- Generate a simple constructor pattern and make up fresh variables for
     -- the rest of the fields
@@ -1202,13 +1237,13 @@ translateConPatVec fam_insts  univ_tys  ex_tvs c (RecCon (HsRecFields fs _))
   = ASSERT(orig_lbls `equalLength` arg_tys)
       let translateOne (lbl, ty) = case lookup lbl matched_pats of
             Just p  -> translatePat fam_insts p
-            Nothing -> mkPmVars [ty]
+            Nothing -> lift $ mkPmVars [ty]
       in  concatMapM translateOne (zip orig_lbls arg_tys)
     -- The fields that appear are not in the correct order. Make up fresh
     -- variables for all fields and add guards after matching, to force the
     -- evaluation in the correct order.
   | otherwise = do
-      arg_var_pats    <- mkPmVars arg_tys
+      arg_var_pats    <- lift $ mkPmVars arg_tys
       translated_pats <- forM matched_pats $ \(x,pat) -> do
         pvec <- translatePat fam_insts pat
         return (x, pvec)
@@ -1239,7 +1274,7 @@ translateConPatVec fam_insts  univ_tys  ex_tvs c (RecCon (HsRecFields fs _))
 -- Translate a single match
 translateMatch :: FamInstEnvs -> LMatch GhcTc (LHsExpr GhcTc)
-               -> DsM (PatVec,[PatVec])
+               -> TlM (PatVec,[PatVec])
 translateMatch fam_insts (dL->L _ (Match { m_pats = lpats, m_grhss = grhss })) =
   pats'   <- concat <$> translatePatVec fam_insts pats
@@ -1258,7 +1293,7 @@ translateMatch _ _ = panic "translateMatch"
 -- * Transform source guards (GuardStmt Id) to PmPats (Pattern)
 -- | Translate a list of guard statements to a pattern vector
-translateGuards :: FamInstEnvs -> [GuardStmt GhcTc] -> DsM PatVec
+translateGuards :: FamInstEnvs -> [GuardStmt GhcTc] -> TlM PatVec
 translateGuards fam_insts guards = do
   all_guards <- concat <$> mapM (translateGuard fam_insts) guards
@@ -1273,7 +1308,7 @@ translateGuards fam_insts guards = do
       | otherwise          = allM shouldKeep pv
     shouldKeep _other_pat  = pure False -- let the rest..
-  all_handled <- allM shouldKeep all_guards
+  all_handled <- lift $ allM shouldKeep all_guards
   -- It should have been @pure all_guards@ but it is too expressive.
   -- Since the term oracle does not handle all constraints we generate,
   -- we (hackily) replace all constraints the oracle cannot handle with a
@@ -1283,7 +1318,7 @@ translateGuards fam_insts guards = do
   if all_handled
     then pure all_guards
     else do
-      kept <- filterM shouldKeep all_guards
+      kept <- lift $ filterM shouldKeep all_guards
       pure (PmFake : kept)
 -- | Check whether a pattern can fail to match
@@ -1295,7 +1330,7 @@ cantFailPattern (PmGrd pv _e) = allM cantFailPattern pv
 cantFailPattern _             = pure False
 -- | Translate a guard statement to Pattern
-translateGuard :: FamInstEnvs -> GuardStmt GhcTc -> DsM PatVec
+translateGuard :: FamInstEnvs -> GuardStmt GhcTc -> TlM PatVec
 translateGuard fam_insts guard = case guard of
   BodyStmt _   e _ _ -> translateBoolGuard e
   LetStmt  _   binds -> translateLet (unLoc binds)
@@ -1308,18 +1343,18 @@ translateGuard fam_insts guard = case guard of
   XStmtLR         {} -> panic "translateGuard RecStmt"
 -- | Translate let-bindings
-translateLet :: HsLocalBinds GhcTc -> DsM PatVec
+translateLet :: HsLocalBinds GhcTc -> TlM PatVec
 translateLet _binds = return []
 -- | Translate a pattern guard
-translateBind :: FamInstEnvs -> LPat GhcTc -> LHsExpr GhcTc -> DsM PatVec
+translateBind :: FamInstEnvs -> LPat GhcTc -> LHsExpr GhcTc -> TlM PatVec
 translateBind fam_insts (dL->L _ p) e = do
   ps <- translatePat fam_insts p
   g <- mkGuard ps (unLoc e)
   return [g]
 -- | Translate a boolean guard
-translateBoolGuard :: LHsExpr GhcTc -> DsM PatVec
+translateBoolGuard :: LHsExpr GhcTc -> TlM PatVec
 translateBoolGuard e
   | isJust (isTrueLHsExpr e) = return []
     -- The formal thing to do would be to generate (True <- True)
@@ -1663,14 +1698,13 @@ mkOneConFull x con = do
 -- * More smart constructors and fresh variable generation
 -- | Create a guard pattern
-mkGuard :: PatVec -> HsExpr GhcTc -> DsM Pattern
+mkGuard :: PatVec -> HsExpr GhcTc -> TlM Pattern
 mkGuard pv e = do
-  res <- allM cantFailPattern pv
-  let expr = hsExprToPmExpr e
-  tracePmD "mkGuard" (vcat [ppr pv, ppr e, ppr res, ppr expr])
-  if | res                    -> pure (PmGrd pv expr)
-     | PmExprOther {} <- expr -> pure PmFake
-     | otherwise              -> pure (PmGrd pv expr)
+  res <- lift $ allM cantFailPattern pv
+  let expr  = hsExprToPmExpr e
+  expr' <- representPmExprOther expr
+  traceTl "mkGuard" (vcat [ppr pv, ppr e, ppr res, ppr expr, ppr expr'])
+  pure (PmGrd pv expr')
 -- | Create a term equality of the form: `(False ~ (x ~ lit))`
 mkNegEq :: Id -> PmLit -> ComplexEq
@@ -2403,8 +2437,8 @@ genCaseTmCs2 :: Maybe (LHsExpr GhcTc) -- Scrutinee
              -> [Id]                  -- MatchVars (should have length 1)
              -> DsM (Bag SimpleEq)
 genCaseTmCs2 Nothing _ _ = return emptyBag
-genCaseTmCs2 (Just scr) [p] [var] = do
-  fam_insts <- dsGetFamInstEnvs
+genCaseTmCs2 (Just scr) [p] [var] = flip evalStateT initialTE $ do
+  fam_insts <- lift $ dsGetFamInstEnvs
   [e] <- map vaToPmExpr . coercePatVec <$> translatePat fam_insts p
   let scr_e = lhsExprToPmExpr scr
   return $ listToBag [(var, e), (var, scr_e)]
@@ -2719,6 +2753,8 @@ involved.
 tracePm :: String -> SDoc -> PmM ()
 tracePm herald doc = liftD $ tracePmD herald doc
+traceTl :: String -> SDoc -> TlM ()
+traceTl herald doc = lift $ tracePmD herald doc
 tracePmD :: String -> SDoc -> DsM ()
 tracePmD herald doc = do

View it on GitLab: https://gitlab.haskell.org/ghc/ghc/commit/24c20cd71fbdd0de588a7ea0e06cbc520fd3a97c
