[commit: ghc] wip/gadtpm: Fixed Record Pattern Translation (62d6edd)

git at git.haskell.org git at git.haskell.org
Tue Jul 7 22:30:40 UTC 2015


Repository : ssh://git@git.haskell.org/ghc

On branch  : wip/gadtpm
Link       : http://ghc.haskell.org/trac/ghc/changeset/62d6edd41cd71f259a1e4b092c51a0499b21078d/ghc

>---------------------------------------------------------------

commit 62d6edd41cd71f259a1e4b092c51a0499b21078d
Author: George Karachalias <george.karachalias at gmail.com>
Date:   Wed Jul 8 00:31:16 2015 +0200

    Fixed Record Pattern Translation


>---------------------------------------------------------------

62d6edd41cd71f259a1e4b092c51a0499b21078d
 compiler/deSugar/Check.hs | 63 +++++++++++++++++++++++++++++++++--------------
 1 file changed, 45 insertions(+), 18 deletions(-)

diff --git a/compiler/deSugar/Check.hs b/compiler/deSugar/Check.hs
index 8e790d3..5150188 100644
--- a/compiler/deSugar/Check.hs
+++ b/compiler/deSugar/Check.hs
@@ -42,7 +42,7 @@ import UniqSupply -- ( UniqSupply
                   -- , splitUniqSupply      -- :: UniqSupply -> (UniqSupply, UniqSupply)
                   -- , listSplitUniqSupply  -- :: UniqSupply -> [UniqSupply]
                   -- , uniqFromSupply )     -- :: UniqSupply -> Unique
-import Control.Monad (liftM3)
+import Control.Monad (liftM3, forM)
 import Data.Maybe (isNothing, fromJust)
 import DsGRHSs (isTrueLHsExpr)
 
@@ -280,7 +280,7 @@ translatePat pat = case pat of
             , pat_tvs     = ex_tvs
             , pat_dicts   = dicts
             , pat_args    = ps } -> do
-    args <- translateConPatVec con ps
+    args <- translateConPatVec arg_tys ex_tvs con ps
     return [mkPmConPat con arg_tys ex_tvs dicts args]
 
   NPat lit mb_neg eq -> do
@@ -319,22 +319,43 @@ translatePat pat = case pat of
 translatePatVec :: [Pat Id] -> UniqSM [PatVec] -- Do not concatenate them (sometimes we need them separately)
 translatePatVec pats = mapM translatePat pats
 
-translateConPatVec :: DataCon -> HsConPatDetails Id -> UniqSM PatVec
-translateConPatVec _ (PrefixCon ps)   = concat <$> translatePatVec (map unLoc ps)
-translateConPatVec _ (InfixCon p1 p2) = concat <$> translatePatVec (map unLoc [p1,p2])
-translateConPatVec c (RecCon (HsRecFields fs _))
-  | null fs   = mapM mkPmVarSM $ dataConOrigArgTys c
-  | otherwise = concat <$> translatePatVec (map (unLoc . snd) all_pats)
+translateConPatVec :: [Type] -> [TyVar] -> DataCon -> HsConPatDetails Id -> UniqSM PatVec
+translateConPatVec _univ_tys _ex_tvs _ (PrefixCon ps)   = concat <$> translatePatVec (map unLoc ps)
+translateConPatVec _univ_tys _ex_tvs _ (InfixCon p1 p2) = concat <$> translatePatVec (map unLoc [p1,p2])
+translateConPatVec  univ_tys  ex_tvs c (RecCon (HsRecFields fs _))
+  | null fs        = mkPmVarsSM arg_tys                            -- Nothing matched. Make up some fresh variables
+  | null orig_lbls = ASSERT (null matched_lbls) mkPmVarsSM arg_tys -- If it is not a record but uses record syntax it can only be {}. So just like above
+-- It is an optimisation anyway, we can avoid doing it..
+--   | matched_lbls `subsetOf` orig_lbls = do -- Ordered: The easy case (no additional guards)
+--       arg_pats <- zip orig_lbls <$> mkPmVarsSM arg_tys
+--       {- WE'VE GOT WORK TO DO -}
+--       undefined
+-- subsetOf :: Eq a => [a] -> [a] -> Bool
+-- subsetOf []     _  = True
+-- subsetOf (_:_)  [] = False
+-- subsetOf (x:xs) (y:ys)
+--   | x == y         = subsetOf xs     ys
+--   | otherwise      = subsetOf (x:xs) ys
+  | otherwise = do                         -- Not Ordered: We match against all patterns and add (strict) guards to match in the right order
+      arg_var_pats <- mkPmVarsSM arg_tys -- the normal variable patterns -- no forcing yet
+
+      translated_pats <- forM matched_pats $ \(x,pat) -> do
+        pvec <- translatePat pat
+        return (idName x, pvec)
+
+      let zipped = zip orig_lbls [ x | VarAbs x <- arg_var_pats ] -- [(Name, Id)]
+          guards = map (\(name,pvec) -> case lookup name zipped of
+                            Just x -> GBindAbs pvec (PmExprVar x)) translated_pats
+
+      return (arg_var_pats ++ guards)
   where
-    -- TODO: The functions below are ugly and they do not care much about types too
-    field_pats = map (\lbl -> (lbl, noLoc (WildPat (dataConFieldType c lbl)))) (dataConFieldLabels c)
-    all_pats   = foldr (\(L _ (HsRecField id p _)) acc -> insertNm (getName (unLoc id)) p acc)
-                       field_pats fs
+    -- The actual argument types (instantiated)
+    arg_tys = dataConInstOrigArgTys c (univ_tys ++ mkTyVarTys ex_tvs)
 
-    insertNm nm p [] = [(nm,p)]
-    insertNm nm p (x@(n,_):xs)
-      | nm == n    = (nm,p):xs
-      | otherwise  = x : insertNm nm p xs
+    -- Some label information
+    orig_lbls    = dataConFieldLabels c
+    matched_lbls = [idName id | L _ (HsRecField (L _ id) _         _) <- fs]
+    matched_pats = [(id,pat)  | L _ (HsRecField (L _ id) (L _ pat) _) <- fs]
 
 translateMatch :: LMatch Id (LHsExpr Id) -> UniqSM (PatVec,[PatVec])
 translateMatch (L _ (Match lpats _ grhss)) = do
@@ -620,6 +641,9 @@ mkPmVar usupply ty = VarAbs (mkPmId usupply ty)
 mkPmVarSM :: Type -> UniqSM (PmPat abs)
 mkPmVarSM ty = flip mkPmVar ty <$> getUniqueSupplyM
 
+mkPmVarsSM :: [Type] -> UniqSM [PmPat abs]
+mkPmVarsSM tys = mapM mkPmVarSM tys
+
 mkPmId :: UniqSupply -> Type -> Id
 mkPmId usupply ty = mkLocalId name ty
   where
@@ -862,6 +886,11 @@ allTheSame :: Eq a => [a] -> Bool
 allTheSame []     = True
 allTheSame (x:xs) = all (==x) xs
 
+sameArity :: PatVec -> ValSetAbs -> Bool
+sameArity pv vsa = vsaArity pv_a vsa == pv_a
+  where pv_a = patVecArity pv
+
+
 {-
 %************************************************************************
 %*                                                                      *
@@ -1030,7 +1059,6 @@ patVectProc2 :: (PatVec, [PatVec]) -> ValSetAbs -> PmM (Bool, Bool, ValSetAbs) -
 patVectProc2 (vec,gvs) vsa = do
   us <- getUniqueSupplyM
   let (c_def, u_def, d_def) = process_guards us gvs -- default (the continuation)
-
   (usC, usU, usD) <- getUniqueSupplyM3
   mb_c <- anySatValSetAbs (covered2   usC c_def vec vsa)
   mb_d <- anySatValSetAbs (divergent2 usD d_def vec vsa)
@@ -1066,7 +1094,6 @@ checkMatches'2 [] missing = do
 
 checkMatches'2 (m:ms) missing = do
   patterns_n_guards <- liftUs (translateMatch m)
-  -- pprInTcRnIf (ptext (sLit "translated") <+> ppr patterns_n_guards)
   (c,  d,  us ) <- patVectProc2 patterns_n_guards missing
   (rs, is, us') <- checkMatches'2 ms us
   return $ case (c,d) of



More information about the ghc-commits mailing list