[Haskell-cafe] Set monad
oleg at okmij.org
oleg at okmij.org
Fri Apr 12 12:49:43 CEST 2013
> One problem with such monad implementations is efficiency. Let's define
>
> step :: (MonadPlus m) => Int -> m Int
> step i = choose [i, i + 1]
>
> -- repeated application of step on 0:
> stepN :: (Monad m) => Int -> m (S.Set Int)
> stepN = runSet . f
> where
> f 0 = return 0
> f n = f (n-1) >>= step
>
> Then `stepN`'s time complexity is exponential in its argument. This is
> because `ContT` reorders the chain of computations to right-associative,
> which is correct, but changes the time complexity in this unfortunate way.
> If we used Set directly, constructing a left-associative chain, it produces
> the result immediately:
The example is excellent. And yet, the efficient genuine Set monad is
possible.
BTW, a simpler example to see the problem with the original CPS monad is to
repeat
choose [1,1] >> choose [1,1] >>choose [1,1] >> return 1
and observe exponential behavior. But your example is much more
subtle.
Enclosed is the efficient genuine Set monad. I wrote it in direct
style (it seems to be faster, anyway). The key is to use the optimized
choose function when we can.
{-# LANGUAGE GADTs, TypeSynonymInstances, FlexibleInstances #-}
module SetMonadOpt where
import qualified Data.Set as S
import Control.Monad
data SetMonad a where
SMOrd :: Ord a => S.Set a -> SetMonad a
SMAny :: [a] -> SetMonad a
instance Monad SetMonad where
return x = SMAny [x]
m >>= f = collect . map f $ toList m
toList :: SetMonad a -> [a]
toList (SMOrd x) = S.toList x
toList (SMAny x) = x
collect :: [SetMonad a] -> SetMonad a
collect [] = SMAny []
collect [x] = x
collect ((SMOrd x):t) = case collect t of
SMOrd y -> SMOrd (S.union x y)
SMAny y -> SMOrd (S.union x (S.fromList y))
collect ((SMAny x):t) = case collect t of
SMOrd y -> SMOrd (S.union y (S.fromList x))
SMAny y -> SMAny (x ++ y)
runSet :: Ord a => SetMonad a -> S.Set a
runSet (SMOrd x) = x
runSet (SMAny x) = S.fromList x
instance MonadPlus SetMonad where
mzero = SMAny []
mplus (SMAny x) (SMAny y) = SMAny (x ++ y)
mplus (SMAny x) (SMOrd y) = SMOrd (S.union y (S.fromList x))
mplus (SMOrd x) (SMAny y) = SMOrd (S.union x (S.fromList y))
mplus (SMOrd x) (SMOrd y) = SMOrd (S.union x y)
choose :: MonadPlus m => [a] -> m a
choose = msum . map return
test1 = runSet (do
n1 <- choose [1..5]
n2 <- choose [1..5]
let n = n1 + n2
guard $ n < 7
return n)
-- fromList [2,3,4,5,6]
-- Values to choose from might be higher-order or actions
test1' = runSet (do
n1 <- choose . map return $ [1..5]
n2 <- choose . map return $ [1..5]
n <- liftM2 (+) n1 n2
guard $ n < 7
return n)
-- fromList [2,3,4,5,6]
test2 = runSet (do
i <- choose [1..10]
j <- choose [1..10]
k <- choose [1..10]
guard $ i*i + j*j == k * k
return (i,j,k))
-- fromList [(3,4,5),(4,3,5),(6,8,10),(8,6,10)]
test3 = runSet (do
i <- choose [1..10]
j <- choose [1..10]
k <- choose [1..10]
guard $ i*i + j*j == k * k
return k)
-- fromList [5,10]
-- Test by Petr Pudlak
-- First, general, unoptimal case
step :: (MonadPlus m) => Int -> m Int
step i = choose [i, i + 1]
-- repeated application of step on 0:
stepN :: Int -> S.Set Int
stepN = runSet . f
where
f 0 = return 0
f n = f (n-1) >>= step
-- it works, but clearly exponential
{-
*SetMonad> stepN 14
fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]
(0.09 secs, 31465384 bytes)
*SetMonad> stepN 15
fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
(0.18 secs, 62421208 bytes)
*SetMonad> stepN 16
fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]
(0.35 secs, 124876704 bytes)
-}
-- And now the optimization
chooseOrd :: Ord a => [a] -> SetMonad a
chooseOrd x = SMOrd (S.fromList x)
stepOpt :: Int -> SetMonad Int
stepOpt i = chooseOrd [i, i + 1]
-- repeated application of step on 0:
stepNOpt :: Int -> S.Set Int
stepNOpt = runSet . f
where
f 0 = return 0
f n = f (n-1) >>= stepOpt
{-
stepNOpt 14
fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]
(0.00 secs, 515792 bytes)
stepNOpt 15
fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
(0.00 secs, 515680 bytes)
stepNOpt 16
fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]
(0.00 secs, 515656 bytes)
stepNOpt 30
fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
(0.00 secs, 1068856 bytes)
-}
More information about the Haskell-Cafe
mailing list