[Haskell-cafe] Monad Set via GADT
Roberto Zunino
zunino at di.unipi.it
Wed Jan 3 18:38:59 EST 2007
To improve my understanding of GADT, I tried to define a Set datatype,
with the usual operations, so that it can be made a member of the
standard Monad class. Here I report on my experiments.
First, I recap the problem. Data.Set.Set can not be made a Monad because
of the Ord constraint on its parameter, while e.g. (return :: a -> m a)
allows any type inside the monad m. This problem can be solved using an
alternative monad class (restricted monad) so that the Ord context is
actually provided for the monad operations. Rather, I aimed for the
standard Monad typeclass.
To avoid reinventing Data.Set.Set, my datatype is based on that.
Basically, when we know of an Ord context, we use a Set. Otherwise, we
use a simple list representation. Using a list is just enough to allow
the implementation of return (i.e. (:[])) and (>>=) (i.e. map and (++)),
so while not being very efficient, it is simple. Other non-monadic
operators require an Ord context, so that we can turn lists into Set.
My first shot at this was:
data SetM a where
L :: [a] -> SetM a
SM :: Ord a => Set.Set a -> SetM a
However, this was not enough to convince the type checker to use the Ord
context stored in SM. Specifically, performing union with
union (SM m1) (SM m2) = SM (m1 `Set.union` m2)
causes GHC to report "No instance for (Ord a)". After some experiments,
I found the following, using a type equality witness:
data Teq a b where Teq :: Teq a a
data SetM a where
L :: [a] -> SetM a
SM :: Ord w => Teq a w -> Set.Set w -> SetM a
Now if I use
union (SM p1 m1) (SM p2 m2) =
case (p1,p2) of
(Teq,Teq) -> SM Teq (m1 `Set.union` m2)
it typechecks! This rises some questions I cannot answer:
1) Why the first version did not typececk?
2) Why the second one does?
3) If I replace (Teq a w) with (Teq w a), as in
SM :: Ord w => Teq w a -> Set.Set w -> SetM a
then union above does not typecheck! Why? I guess the type variable
unification deriving from matching Teq is not symmetric as I expect it
to be...
Below, I attach the working version. Monad and MonadPlus instances are
provided for SetM. Conversions from/to Set are also provided, requiring
an Ord context. "Efficient" return and mzero are included, forcing the
Set representation to be used, and requiring Ord (these could also be
derived from fromSet/toSet, however).
Comments are very welcome, of course, as well as non-GADT related
alternative approaches.
Regards,
Roberto Zunino.
============================================================
\begin{code}
{-# OPTIONS_GHC -Wall -fglasgow-exts #-}
module SetMonad
( SetM()
, toSet, fromSet
, union, unions
, return', mzero'
) where
import qualified Data.Set as S
import Data.List hiding (union)
import Control.Monad
-- Type equality witness
data Teq a b where Teq :: Teq a a
-- Either a list or a real Set
data SetM a where
L :: [a] -> SetM a
SM :: Ord w => Teq a w -> S.Set w -> SetM a
instance Monad SetM where
return = L . (:[])
m >>= f = case m of
L l -> unions (map f l)
SM Teq s -> unions (map f (S.toList s))
instance MonadPlus SetM where
mzero = L []
mplus = union
-- Efficient variants for Ord types
return' :: Ord a => a -> SetM a
return' = SM Teq . S.singleton
mzero' :: Ord a => SetM a
mzero' = SM Teq S.empty
-- Set union: use the best representation
union :: SetM a -> SetM a -> SetM a
union (L l1) (L l2) = L (l1 ++ l2)
union (SM p1 m1) (SM p2 m2) = case (p1,p2) of
(Teq,Teq) -> SM Teq (m1 `S.union` m2)
union (L l1) (SM p m2) = case p of
Teq -> SM Teq (m2 `S.union` S.fromList l1)
union s1 s2 = union s2 s1
-- Try to put a SM first before folding, to improve performance
unions :: [SetM a] -> SetM a
unions = let isSM (SM _ _) = True
isSM _ = False
in foldl' union (L []) . uncurry (++) . break isSM
-- Conversion from/to Set requires Ord
toSet :: Ord a => SetM a -> S.Set a
toSet (L l) = S.fromList l
toSet (SM p m) = case p of Teq -> m
fromSet :: Ord a => S.Set a -> SetM a
fromSet = SM Teq
-- Tests
test :: IO ()
test =
do let l = [1..3] :: [Int]
s = fromSet (S.fromList l)
g x = return' x `mplus` return' (x+100)
print $ S.toList $ toSet $
do x <- s
y <- s
return' (x+y)
-- [2,3,4,5,6]
print $ S.toList $ toSet $
do x <- s
g x
-- [1,2,3,101,102,103]
print $ S.toList $ toSet $
do x <- s
y <- g x
return' y
-- [1,2,3,101,102,103]
print $ S.toList $ toSet $
do x <- s
y <- g x
g y
-- [1,2,3,101,102,103,201,202,203]
print $ S.toList $ toSet $
do x <- s
y <- return (const x) -- no Ord!
return' (y ())
-- [1,2,3]
\end{code}
More information about the Haskell-Cafe
mailing list