[Haskell-cafe] Monad instance for Data.Set

Ryan Ingram ryani.spam at gmail.com
Tue Mar 25 17:13:45 EDT 2008


I was experimenting with Prompt today and found that you can get a
"restricted monad" style of behavior out of a regular monad using Prompt:

> {-# LANGUAGE GADTs #-}
> module SetTest where
> import qualified Data.Set as S

Prompt is available from
http://hackage.haskell.org/cgi-bin/hackage-scripts/package/MonadPrompt-1.0.0.1

> import Control.Monad.Prompt

"OrdP" is a prompt that implements MonadPlus for orderable types:

> data OrdP m a where
>    PZero :: OrdP m a
>    PRestrict :: Ord a => m a -> OrdP m a
>    PPlus :: Ord a => m a -> m a -> OrdP m a

> type SetM = RecPrompt OrdP

We can't make this an instance of MonadPlus; mplus would need an Ord
constraint.  But as long as we don't import it, we can overload the name.

> mzero :: SetM a
> mzero = prompt PZero
> mplus :: Ord a => SetM a -> SetM a -> SetM a
> mplus x y = prompt (PPlus x y)

"mrestrict" can be inserted at various points in a computation to optimize
it; it forces the passed in computation to complete and uses a Set to
eliminate duplicate outputs.  We could also implement mrestrict without an
additional element in our prompt datatype, at the cost of some performance:
  mrestrict m = mplus mzero m

> mrestrict :: Ord a => SetM a -> SetM a
> mrestrict x = prompt (PRestrict x)

Finally we need an interpretation function to run the monad and extract a
set from it:

> runSetM :: Ord r => SetM r -> S.Set r
> runSetM = runPromptC ret prm . unRecPrompt where
>    -- ret :: r -> S.Set r
>    ret = S.singleton
>    -- prm :: forall a. OrdP SetM a -> (a -> S.Set r) -> S.Set r
>    prm PZero _ = S.empty
>    prm (PRestrict m) k = unionMap k (runSetM m)
>    prm (PPlus m1 m2) k = unionMap k (runSetM m1 `S.union` runSetM m2)

unionMap is the equivalent of concatMap for lists.

> unionMap :: Ord b => (a -> S.Set b) -> S.Set a -> S.Set b
> unionMap f = S.fold (\a r -> f a `S.union` r) S.empty

Oleg's test now works without modification:

> test1s_do () = do
>   x <- return "a"
>   return $ "b" ++ x

> olegtest :: S.Set String
> olegtest = runSetM $ test1s_do ()
> -- fromList ["ba"]

> settest :: S.Set Int
> settest = runSetM $ do
>    x <- mplus (mplus mzero (return 2)) (mplus (return 2) (return 3))
>    return (x+3)
> -- fromList [5,6]

What this does under the hood is treat the computation on each element of the
set separately, except at programmer-specified synchronization points where
the computation result is required to be a member of the Ord typeclass.

Synchronization points happen at every "mplus" & "mrestrict"; these correspond
to a gathering of the computation results up to that point into a Set and then
dispatching the remainder of the computation from that Set.

  -- ryan


More information about the Haskell-Cafe mailing list