Stricter WriterT (Part II)

Gabriel Gonzalez gabriel439 at
Sun Mar 17 17:18:13 CET 2013

I previously asked on this mailing list about getting a stricter WriterT 
added to transformers here:

To recap, neither of the two WriterT implementations in transformers 
keep the accumulator strictly evaluated.  In the previous request, I 
presented four implementations, only one of which runs in constant 
space, all of which are in this hpaste:

To summarize the four implementations:

Version #1: WriterT from Control.Monad.Trans.Writer.Strict
Version #2: Same as Version1, except the monad bind strictly evaluates 
the mappended result
Version #3: WriterT reimplemented as StateT, but no strictness annotations
Version #4: Same as Version3, except 'tell' strictly evaluates the 
mappended result

Only version #4 works and runs in constant space. Actually, not only 
does it run in constant space, but I failed to realize at the time that 
it also compiles to very efficient core when you add an `Int` type 
annotation to the summand.  For example, if you try to sum 1 billion 
`Int`s in the naive way using Version #4:

main :: IO ()
main = (print =<<) $ runWriterT4 $ replicateM_ 1000000000 $ tell4 $ Sum 
(1 :: Int)

... and compile it with -O2, it generates the following very nice core:

$wa1 =
   \ (w_s25b :: Int#)
     (ww_s25e :: Int#)
     (w1_s25g :: State# RealWorld) ->
     case <=# w_s25b 1 of _ {
       False ->
         $wa1 (-# w_s25b 1) (+# 1 ww_s25e) w1_s25g;
       True ->
         (# w1_s25g,
             (I# (+# 1 ww_s25e))
             ) #)

... and runs in 4.6 seconds on my netbook:

time ./writer
((),Sum {getSum = 1000000000})

real    0m4.580s
user    0m4.560s
sys     0m0.008s

... which is about 4.6 nanoseconds per element. This is quite impressive 
when you consider it is factoring everything through the 'IO' monad.  If 
you use `Identity` as the base monad:

main4 = print $ runIdentity $ runWriterT4 $ replicateM_ n $ tell4 $ Sum 
(1 :: Int)

... then it gets slightly faster:

real    0m3.678s
user    0m3.668s
sys     0m0.000s

... with an even nicer inner loop:

$wa1 =
   \ (w_s25v :: Int#) (ww_s25y :: Int#) ->
     case <=# w_s25v 1 of _ {
       False ->
         $wa1 (-# w_s25v 1) (+# 1 ww_s25y);
       True ->
         (# (),
            (I# (+# 1 ww_s25y))

The reason this stalled last time is that Edward and I agreed that I 
should first investigate if there is a "smaller" type that gives the 
same behavior.  Now I'm revisiting the issue because I can safely 
conclude that the answer is "no". The StateT implementation is the 
smallest type that gives the correct behavior.

To explain why, it helps to compare the definition of `(>>=)` for both 
WriterT and StateT:

m >>= k  = WriterT $ do
     (a, w)  <- runWriterT m
     (b, w') <- runWriterT (k a)
     return (b, w `mappend` w')

m >>= k  = StateT $ \s -> do
     (a, s') <- runStateT m s
     runStateT (k a) s'

The `WriterT` fails to run in constant space because of the pattern of 
binding the continuation before mappending the results.  This results in 
N nested binds before it can compute even the very first `mappend`.  
This not only leaks space, but also punishes the case where your base 
monad is a free monad, since it builds up a huge chain of 
left-associated binds.

The canonical solution to avoid this sort of nested bind is to use a 
continuation-passing-style transformation where you pass the second 
`runWriterT` a continuation saying what you want to do with its monoid 
result.  My first draft of such a solution looked like this:

newtype WriterT w m a = WriterT { unWriterT :: (w -> w) -> m (a, w) }

m >>= k  = WriterT $ \f -> do
     (a, w) <- runWriterT m f
     runWriterT (k a) (mappend w)

tell w = WriterT $ \f -> return ((), f w)

runWriterT m = unWriterT m id

... but then I realized that there is no need to pass a general 
function.  I only ever use mappend, so why not just pass in the monoid 
that I want to mappend and let `tell` just supply the `mappend`:

newtype WriterT w m a = WriterT { unWriterT :: w -> m (a, w) }

m >>= k  = WriterT $ \w -> do
     (a, w') <- runWriterT m f
     runWriterT (k a) w'

tell w' = WriterT $ \w -> return ((), mappend w w')

runWriterT m = unWriterT m mempty

Notice that this just reinvents the StateT monad transformer.  In other 
words, StateT *is* the continuation-passing-style transformation of 
WriterT, which is why you can't do any better than to reformulate 
WriterT as StateT internally.

So I propose that we add an additional stricter WriterT (under say, 
"Control.Monad.Trans.Writer.Stricter") which is internally implemented 
as StateT, but hide the constructor so we don't expose the implementation:

newtype WriterT w m a = WriterT { unWriterT :: w -> m (a, w) }

instance (Monad m, Monoid w) => Monad (WriterT w m) where
     return a = WriterT $ \w -> return (a, w)
     m >>= f  = WriterT $ \w -> do
         (a, w') <- unWriterT m w
         unWriterT (f a) w'

And define `tell` and `runWriterT` as follows:

tell :: (Monad m, Monoid w) => w -> WriterT w m ()
tell w = WriterT $ \w' ->
     let wt = w `mappend` w'
      in wt `seq` return ((), w `mappend` w')

runWriterT :: (Monoid w) => WriterT w m a -> m (a, w)
runWriterT m = unWriterT m mempty

If we do that, then WriterT becomes not only usable, but actually 
competitive with expertly tuned code.

More information about the Libraries mailing list