[Haskell-cafe] Monad Set via GADT

Roberto Zunino zunino at di.unipi.it
Wed Jan 3 18:38:59 EST 2007

To improve my understanding of GADT, I tried to define a Set datatype, 
with the usual operations, so that it can be made a member of the 
standard Monad class. Here I report on my experiments.

First, I recap the problem. Data.Set.Set can not be made a Monad because 
of the Ord constraint on its parameter, while e.g. (return :: a -> m a) 
allows any type inside the monad m. This problem can be solved using an 
alternative monad class (restricted monad) so that the Ord context is 
actually provided for the monad operations. Rather, I aimed for the 
standard Monad typeclass.

To avoid reinventing Data.Set.Set, my datatype is based on that. 
Basically, when we know of an Ord context, we use a Set. Otherwise, we 
use a simple list representation. Using a list is just enough to allow 
the implementation of return (i.e. (:[])) and (>>=) (i.e. map and (++)), 
so while not being very efficient, it is simple. Other non-monadic 
operators require an Ord context, so that we can turn lists into Set.

My first shot at this was:

data SetM a where
     L :: [a] -> SetM a
     SM :: Ord a => Set.Set a -> SetM a

However, this was not enough to convince the type checker to use the Ord 
context stored in SM. Specifically, performing union with

union (SM m1) (SM m2) = SM (m1 `Set.union` m2)

causes GHC to report "No instance for (Ord a)". After some experiments, 
I found the following, using a type equality witness:

data Teq a b where Teq :: Teq a a
data SetM a where
     L :: [a] -> SetM a
     SM :: Ord w => Teq a w -> Set.Set w -> SetM a

Now if I use

union (SM p1 m1) (SM p2 m2) =
    case (p1,p2) of
    (Teq,Teq) -> SM Teq (m1 `Set.union` m2)

it typechecks! This rises some questions I cannot answer:

1) Why the first version did not typececk?
2) Why the second one does?
3) If I replace (Teq a w) with (Teq w a), as in
     SM :: Ord w => Teq w a -> Set.Set w -> SetM a
then union above does not typecheck! Why? I guess the type variable 
unification deriving from matching Teq is not symmetric as I expect it 
to be...

Below, I attach the working version. Monad and MonadPlus instances are 
provided for SetM. Conversions from/to Set are also provided, requiring 
an Ord context. "Efficient" return and mzero are included, forcing the 
Set representation to be used, and requiring Ord (these could also be 
derived from fromSet/toSet, however).

Comments are very welcome, of course, as well as non-GADT related 
alternative approaches.

Roberto Zunino.

{-# OPTIONS_GHC -Wall -fglasgow-exts #-}
module SetMonad
     ( SetM()
     , toSet, fromSet
     , union, unions
     , return', mzero'
     ) where

import qualified Data.Set as S
import Data.List hiding (union)
import Control.Monad

-- Type equality witness
data Teq a b where Teq :: Teq a a

-- Either a list or a real Set
data SetM a where
     L :: [a] -> SetM a
     SM :: Ord w => Teq a w -> S.Set w -> SetM a

instance Monad SetM where
     return = L . (:[])
     m >>= f = case m of
               L l      -> unions (map f l)
               SM Teq s -> unions (map f (S.toList s))

instance MonadPlus SetM where
     mzero = L []
     mplus = union

-- Efficient variants for Ord types
return' :: Ord a => a -> SetM a
return' = SM Teq . S.singleton

mzero' :: Ord a => SetM a
mzero' = SM Teq S.empty

-- Set union: use the best representation
union :: SetM a -> SetM a -> SetM a
union (L l1)     (L l2)     = L (l1 ++ l2)
union (SM p1 m1) (SM p2 m2) = case (p1,p2) of
                               (Teq,Teq) -> SM Teq (m1 `S.union` m2)
union (L l1)     (SM p m2)  = case p of
                               Teq -> SM Teq (m2 `S.union` S.fromList l1)
union s1         s2         = union s2 s1

-- Try to put a SM first before folding, to improve performance
unions :: [SetM a] -> SetM a
unions = let isSM (SM _ _) = True
              isSM _        = False
          in foldl' union (L []) . uncurry (++) . break isSM

-- Conversion from/to Set requires Ord
toSet :: Ord a => SetM a -> S.Set a
toSet (L l)    = S.fromList l
toSet (SM p m) = case p of Teq -> m

fromSet :: Ord a => S.Set a -> SetM a
fromSet = SM Teq

-- Tests
test :: IO ()
test =
     do let l = [1..3] :: [Int]
            s = fromSet (S.fromList l)
            g x = return' x `mplus` return' (x+100)
        print $ S.toList $ toSet $
              do x <- s
                 y <- s
                 return' (x+y)
        -- [2,3,4,5,6]
        print $ S.toList $ toSet $
              do x <- s
                 g x
        -- [1,2,3,101,102,103]
        print $ S.toList $ toSet $
              do x <- s
                 y <- g x
                 return' y
        -- [1,2,3,101,102,103]
        print $ S.toList $ toSet $
              do x <- s
                 y <- g x
                 g y
        -- [1,2,3,101,102,103,201,202,203]
        print $ S.toList $ toSet $
              do x <- s
                 y <- return (const x) -- no Ord!
                 return' (y ())
        -- [1,2,3]

More information about the Haskell-Cafe mailing list