[Git][ghc/ghc][master] Enforce invariant of `ListBag` constructor.

Marge Bot (@marge-bot) gitlab at gitlab.haskell.org
Wed Oct 19 14:47:32 UTC 2022



Marge Bot pushed to branch master at Glasgow Haskell Compiler / GHC


Commits:
c3732c62 by M Farkas-Dyck at 2022-10-19T10:47:13-04:00
Enforce invariant of `ListBag` constructor.

- - - - -


3 changed files:

- compiler/GHC/Data/Bag.hs
- compiler/GHC/Tc/Deriv.hs
- compiler/GHC/Utils/Monad.hs


Changes:

=====================================
compiler/GHC/Data/Bag.hs
=====================================
@@ -18,7 +18,7 @@ module GHC.Data.Bag (
         concatBag, catBagMaybes, foldBag,
         isEmptyBag, isSingletonBag, consBag, snocBag, anyBag, allBag,
         listToBag, nonEmptyToBag, bagToList, headMaybe, mapAccumBagL,
-        concatMapBag, concatMapBagPair, mapMaybeBag,
+        concatMapBag, concatMapBagPair, mapMaybeBag, unzipBag,
         mapBagM, mapBagM_,
         flatMapBagM, flatMapBagPairM,
         mapAndUnzipBagM, mapAccumBagLM,
@@ -33,9 +33,10 @@ import GHC.Utils.Misc
 import GHC.Utils.Monad
 import Control.Monad
 import Data.Data
-import Data.Maybe( mapMaybe, listToMaybe )
+import Data.Maybe( mapMaybe )
 import Data.List ( partition, mapAccumL )
 import Data.List.NonEmpty ( NonEmpty(..) )
+import qualified Data.List.NonEmpty as NE
 import qualified Data.Semigroup ( (<>) )
 
 infixr 3 `consBag`
@@ -45,7 +46,7 @@ data Bag a
   = EmptyBag
   | UnitBag a
   | TwoBags (Bag a) (Bag a) -- INVARIANT: neither branch is empty
-  | ListBag [a]             -- INVARIANT: the list is non-empty
+  | ListBag (NonEmpty a)
   deriving (Foldable, Functor, Traversable)
 
 emptyBag :: Bag a
@@ -90,7 +91,7 @@ isSingletonBag :: Bag a -> Bool
 isSingletonBag EmptyBag      = False
 isSingletonBag (UnitBag _)   = True
 isSingletonBag (TwoBags _ _) = False          -- Neither is empty
-isSingletonBag (ListBag xs)  = isSingleton xs
+isSingletonBag (ListBag (_:|xs)) = null xs
 
 filterBag :: (a -> Bool) -> Bag a -> Bag a
 filterBag _    EmptyBag = EmptyBag
@@ -98,7 +99,7 @@ filterBag pred b@(UnitBag val) = if pred val then b else EmptyBag
 filterBag pred (TwoBags b1 b2) = sat1 `unionBags` sat2
     where sat1 = filterBag pred b1
           sat2 = filterBag pred b2
-filterBag pred (ListBag vs)    = listToBag (filter pred vs)
+filterBag pred (ListBag vs)    = listToBag (filter pred (toList vs))
 
 filterBagM :: Monad m => (a -> m Bool) -> Bag a -> m (Bag a)
 filterBagM _    EmptyBag = return EmptyBag
@@ -111,7 +112,7 @@ filterBagM pred (TwoBags b1 b2) = do
   sat2 <- filterBagM pred b2
   return (sat1 `unionBags` sat2)
 filterBagM pred (ListBag vs) = do
-  sat <- filterM pred vs
+  sat <- filterM pred (toList vs)
   return (listToBag sat)
 
 allBag :: (a -> Bool) -> Bag a -> Bool
@@ -135,9 +136,7 @@ anyBagM p (TwoBags b1 b2) = do flag <- anyBagM p b1
 anyBagM p (ListBag xs)    = anyM p xs
 
 concatBag :: Bag (Bag a) -> Bag a
-concatBag bss = foldr add emptyBag bss
-  where
-    add bs rs = bs `unionBags` rs
+concatBag = foldr unionBags emptyBag
 
 catBagMaybes :: Bag (Maybe a) -> Bag a
 catBagMaybes bs = foldr add emptyBag bs
@@ -155,7 +154,7 @@ partitionBag pred (TwoBags b1 b2)
   where (sat1, fail1) = partitionBag pred b1
         (sat2, fail2) = partitionBag pred b2
 partitionBag pred (ListBag vs) = (listToBag sats, listToBag fails)
-  where (sats, fails) = partition pred vs
+  where (sats, fails) = partition pred (toList vs)
 
 
 partitionBagWith :: (a -> Either b c) -> Bag a
@@ -171,7 +170,7 @@ partitionBagWith pred (TwoBags b1 b2)
   where (sat1, fail1) = partitionBagWith pred b1
         (sat2, fail2) = partitionBagWith pred b2
 partitionBagWith pred (ListBag vs) = (listToBag sats, listToBag fails)
-  where (sats, fails) = partitionWith pred vs
+  where (sats, fails) = partitionWith pred (toList vs)
 
 foldBag :: (r -> r -> r) -- Replace TwoBags with this; should be associative
         -> (a -> r)      -- Replace UnitBag with this
@@ -220,7 +219,7 @@ mapMaybeBag f (UnitBag x)     = case f x of
                                   Nothing -> EmptyBag
                                   Just y  -> UnitBag y
 mapMaybeBag f (TwoBags b1 b2) = unionBags (mapMaybeBag f b1) (mapMaybeBag f b2)
-mapMaybeBag f (ListBag xs)    = ListBag (mapMaybe f xs)
+mapMaybeBag f (ListBag xs)    = listToBag $ mapMaybe f (toList xs)
 
 mapBagM :: Monad m => (a -> m b) -> Bag a -> m (Bag b)
 mapBagM _ EmptyBag        = return EmptyBag
@@ -267,7 +266,7 @@ mapAndUnzipBagM f (TwoBags b1 b2) = do (r1,s1) <- mapAndUnzipBagM f b1
                                        (r2,s2) <- mapAndUnzipBagM f b2
                                        return (TwoBags r1 r2, TwoBags s1 s2)
 mapAndUnzipBagM f (ListBag xs)    = do ts <- mapM f xs
-                                       let (rs,ss) = unzip ts
+                                       let (rs,ss) = NE.unzip ts
                                        return (ListBag rs, ListBag ss)
 
 mapAccumBagL ::(acc -> x -> (acc, y)) -- ^ combining function
@@ -298,20 +297,31 @@ mapAccumBagLM f s (ListBag xs)    = do { (s', xs') <- mapAccumLM f s xs
 listToBag :: [a] -> Bag a
 listToBag [] = EmptyBag
 listToBag [x] = UnitBag x
-listToBag vs = ListBag vs
+listToBag (x:xs) = ListBag (x:|xs)
 
 nonEmptyToBag :: NonEmpty a -> Bag a
 nonEmptyToBag (x :| []) = UnitBag x
-nonEmptyToBag (x :| xs) = ListBag (x : xs)
+nonEmptyToBag xs = ListBag xs
 
 bagToList :: Bag a -> [a]
 bagToList b = foldr (:) [] b
 
+unzipBag :: Bag (a, b) -> (Bag a, Bag b)
+unzipBag EmptyBag = (EmptyBag, EmptyBag)
+unzipBag (UnitBag (a, b)) = (UnitBag a, UnitBag b)
+unzipBag (TwoBags xs1 xs2) = (TwoBags as1 as2, TwoBags bs1 bs2)
+  where
+    (as1, bs1) = unzipBag xs1
+    (as2, bs2) = unzipBag xs2
+unzipBag (ListBag xs) = (ListBag as, ListBag bs)
+  where
+    (as, bs) = NE.unzip xs
+
 headMaybe :: Bag a -> Maybe a
 headMaybe EmptyBag = Nothing
 headMaybe (UnitBag v) = Just v
 headMaybe (TwoBags b1 _) = headMaybe b1
-headMaybe (ListBag l) = listToMaybe l
+headMaybe (ListBag (v:|_)) = Just v
 
 instance (Outputable a) => Outputable (Bag a) where
     ppr bag = braces (pprWithCommas ppr (bagToList bag))


=====================================
compiler/GHC/Tc/Deriv.hs
=====================================
@@ -289,8 +289,8 @@ renameDeriv inst_infos bagBinds
         -- Bring the extra deriving stuff into scope
         -- before renaming the instances themselves
         ; traceTc "rnd" (vcat (map (\i -> pprInstInfoDetails i $$ text "") inst_infos))
-        ; (aux_binds, aux_sigs) <- mapAndUnzipBagM return bagBinds
-        ; let aux_val_binds = ValBinds NoAnnSortKey aux_binds (bagToList aux_sigs)
+        ; let (aux_binds, aux_sigs) = unzipBag bagBinds
+              aux_val_binds = ValBinds NoAnnSortKey aux_binds (bagToList aux_sigs)
         -- Importantly, we use rnLocalValBindsLHS, not rnTopBindsLHS, to rename
         -- auxiliary bindings as if they were defined locally.
         -- See Note [Auxiliary binders] in GHC.Tc.Deriv.Generate.


=====================================
compiler/GHC/Utils/Monad.hs
=====================================
@@ -1,3 +1,5 @@
+{-# LANGUAGE MonadComprehensions #-}
+
 -- | Utilities related to Monad and Applicative classes
 --   Mostly for backwards compatibility.
 
@@ -28,8 +30,11 @@ import GHC.Prelude
 import Control.Monad
 import Control.Monad.Fix
 import Control.Monad.IO.Class
+import Control.Monad.Trans.State.Strict (StateT (..))
 import Data.Foldable (sequenceA_, foldlM, foldrM)
 import Data.List (unzip4, unzip5, zipWith4)
+import Data.List.NonEmpty (NonEmpty (..))
+import Data.Tuple (swap)
 
 -------------------------------------------------------------------------------
 -- Common functions
@@ -137,18 +142,28 @@ mapAndUnzip5M f xs =  unzip5 <$> traverse f xs
 -- variant and use it where appropriate.
 
 -- | Monadic version of mapAccumL
-mapAccumLM :: Monad m
+mapAccumLM :: (Monad m, Traversable t)
             => (acc -> x -> m (acc, y)) -- ^ combining function
             -> acc                      -- ^ initial state
-            -> [x]                      -- ^ inputs
-            -> m (acc, [y])             -- ^ final state, outputs
-{-# INLINE mapAccumLM #-}
+            -> t x                      -- ^ inputs
+            -> m (acc, t y)             -- ^ final state, outputs
+{-# INLINE [1] mapAccumLM #-}
 -- INLINE pragma.  mapAccumLM is called in inner loops.  Like 'map',
 -- we inline it so that we can take advantage of knowing 'f'.
 -- This makes a few percent difference (in compiler allocations)
 -- when compiling perf/compiler/T9675
-mapAccumLM f s xs =
-  go s xs
+mapAccumLM f s = fmap swap . flip runStateT s . traverse f'
+  where
+    f' = StateT . (fmap . fmap) swap . flip f
+{-# RULES "mapAccumLM/List" mapAccumLM = mapAccumLM_List #-}
+{-# RULES "mapAccumLM/NonEmpty" mapAccumLM = mapAccumLM_NonEmpty #-}
+
+mapAccumLM_List
+ :: Monad m
+ => (acc -> x -> m (acc, y))
+ -> acc -> [x] -> m (acc, [y])
+{-# INLINE mapAccumLM_List #-}
+mapAccumLM_List f s = go s
   where
     go s (x:xs) = do
       (s1, x')  <- f s x
@@ -156,6 +171,14 @@ mapAccumLM f s xs =
       return    (s2, x' : xs')
     go s [] = return (s, [])
 
+mapAccumLM_NonEmpty
+ :: Monad m
+ => (acc -> x -> m (acc, y))
+ -> acc -> NonEmpty x -> m (acc, NonEmpty y)
+{-# INLINE mapAccumLM_NonEmpty #-}
+mapAccumLM_NonEmpty f s (x:|xs) =
+  [(s2, x':|xs') | (s1, x') <- f s x, (s2, xs') <- mapAccumLM_List f s1 xs]
+
 -- | Monadic version of mapSnd
 mapSndM :: (Applicative m, Traversable f) => (b -> m c) -> f (a,b) -> m (f (a,c))
 mapSndM = traverse . traverse
@@ -174,25 +197,21 @@ mapMaybeM f = foldr g (pure [])
   where g a = liftA2 (maybe id (:)) (f a)
 
 -- | Monadic version of 'any', aborts the computation at the first @True@ value
-anyM :: Monad m => (a -> m Bool) -> [a] -> m Bool
-anyM f xs = go xs
-  where
-    go [] = return False
-    go (x:xs) = do b <- f x
-                   if b then return True
-                        else go xs
+anyM :: (Monad m, Foldable f) => (a -> m Bool) -> f a -> m Bool
+anyM f = foldr (orM . f) (pure False)
 
 -- | Monad version of 'all', aborts the computation at the first @False@ value
-allM :: Monad m => (a -> m Bool) -> [a] -> m Bool
-allM f bs = go bs
-  where
-    go []     = return True
-    go (b:bs) = (f b) >>= (\bv -> if bv then go bs else return False)
+allM :: (Monad m, Foldable f) => (a -> m Bool) -> f a -> m Bool
+allM f = foldr (andM . f) (pure True)
 
 -- | Monadic version of or
 orM :: Monad m => m Bool -> m Bool -> m Bool
 orM m1 m2 = m1 >>= \x -> if x then return True else m2
 
+-- | Monadic version of and
+andM :: Monad m => m Bool -> m Bool -> m Bool
+andM m1 m2 = m1 >>= \x -> if x then m2 else return False
+
 -- | Monadic version of foldl that discards its result
 foldlM_ :: (Monad m, Foldable t) => (a -> b -> m a) -> a -> t b -> m ()
 foldlM_ = foldM_



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/c3732c6210972a992e1153b0667cf8abf0351acd

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/c3732c6210972a992e1153b0667cf8abf0351acd
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20221019/d7457a01/attachment-0001.html>


More information about the ghc-commits mailing list