[Haskell-cafe] Monad instance for Data.Set
Ryan Ingram
ryani.spam at gmail.com
Tue Mar 25 17:13:45 EDT 2008
I was experimenting with Prompt today and found that you can get a
"restricted monad" style of behavior out of a regular monad using Prompt:
> {-# LANGUAGE GADTs #-}
> module SetTest where
> import qualified Data.Set as S
Prompt is available from
http://hackage.haskell.org/cgi-bin/hackage-scripts/package/MonadPrompt-1.0.0.1
> import Control.Monad.Prompt
"OrdP" is a prompt that implements MonadPlus for orderable types:
> data OrdP m a where
> PZero :: OrdP m a
> PRestrict :: Ord a => m a -> OrdP m a
> PPlus :: Ord a => m a -> m a -> OrdP m a
> type SetM = RecPrompt OrdP
We can't make this an instance of MonadPlus; mplus would need an Ord
constraint. But as long as we don't import it, we can overload the name.
> mzero :: SetM a
> mzero = prompt PZero
> mplus :: Ord a => SetM a -> SetM a -> SetM a
> mplus x y = prompt (PPlus x y)
"mrestrict" can be inserted at various points in a computation to optimize
it; it forces the passed in computation to complete and uses a Set to
eliminate duplicate outputs. We could also implement mrestrict without an
additional element in our prompt datatype, at the cost of some performance:
mrestrict m = mplus mzero m
> mrestrict :: Ord a => SetM a -> SetM a
> mrestrict x = prompt (PRestrict x)
Finally we need an interpretation function to run the monad and extract a
set from it:
> runSetM :: Ord r => SetM r -> S.Set r
> runSetM = runPromptC ret prm . unRecPrompt where
> -- ret :: r -> S.Set r
> ret = S.singleton
> -- prm :: forall a. OrdP SetM a -> (a -> S.Set r) -> S.Set r
> prm PZero _ = S.empty
> prm (PRestrict m) k = unionMap k (runSetM m)
> prm (PPlus m1 m2) k = unionMap k (runSetM m1 `S.union` runSetM m2)
unionMap is the equivalent of concatMap for lists.
> unionMap :: Ord b => (a -> S.Set b) -> S.Set a -> S.Set b
> unionMap f = S.fold (\a r -> f a `S.union` r) S.empty
Oleg's test now works without modification:
> test1s_do () = do
> x <- return "a"
> return $ "b" ++ x
> olegtest :: S.Set String
> olegtest = runSetM $ test1s_do ()
> -- fromList ["ba"]
> settest :: S.Set Int
> settest = runSetM $ do
> x <- mplus (mplus mzero (return 2)) (mplus (return 2) (return 3))
> return (x+3)
> -- fromList [5,6]
What this does under the hood is treat the computation on each element of the
set separately, except at programmer-specified synchronization points where
the computation result is required to be a member of the Ord typeclass.
Synchronization points happen at every "mplus" & "mrestrict"; these correspond
to a gathering of the computation results up to that point into a Set and then
dispatching the remainder of the computation from that Set.
-- ryan
More information about the Haskell-Cafe
mailing list