# [Haskell-cafe] <*> for Data.Sequence: mission accomplished

David Feuer david.feuer at gmail.com
Fri Dec 19 21:02:46 UTC 2014

All right! It's taken a long time, but I finally managed to write a
nice implementation of `<*>` for `Data.Sequence`! It looks nothing
like my original concept, but it accomplishes my incremental
performance goals while also being a little better than the current
implementation when forcing the whole result. Anyone interested can
see the code in the `ap` branch of treeowl/containers on GitHub. There
are still some ugly special cases for sufficiently small
arguments--any ideas for changing that would be most welcome.

Many thanks to Joachim Breitner for helping implement an earlier
approach, to Ross Paterson for coming up with a solution to a previous
formulation of the problem (thus helping me understand that I had
asked the wrong question), and to Carter Schonwald for putting up with
my stream of consciousness rambling.

On Sun, Nov 23, 2014 at 12:07 AM, David Feuer <david.feuer at gmail.com> wrote:
> OK, sorry for the flood of posts, but I think I've found a way to make that
> work. Specifically, I think I can write a three-Seq append that takes the
> total size and uses it to be as lazy as possible in the second of the three
> Seqs. I'm still working out the details, but I think it will work. It does
> the (possibly avoidable) rebuilding, but I *think* it's at least
> asymptotically optimal. Of course, if Ross Paterson can find something more
> efficient, that'd be even better.
>
> On Sat, Nov 22, 2014 at 10:10 PM, David Feuer <david.feuer at gmail.com> wrote:
>>
>> I want is close to this, but it won't quite work this way:
>>
>> fs <*> xs = equalJoin \$ fmap (<\$> xs) fs
>>
>> equalJoin :: Int -> Seq (Seq a) -> Seq a
>> equalJoin n s
>>   | length s <= 2*n = simpleJoin s
>>   | otherwise       = simpleJoin pref ><
>>                       equalJoin (2*n) mid ><
>>                       simpleJoin suff
>>   where (pref, s')  = splitAt n s
>>         (mid, suff) = splitAt (length s - 2*n) s'
>>
>> simpleJoin :: Seq (Seq a) -> Seq a
>> simpleJoin s
>>   | null s = empty
>>   | length s == 1 = index s 0
>>   | otherwise = simpleJoin front >< simpleJoin back
>>   where
>>     (front,back) = splitAt (length s `quot` 2) s
>>
>> I think the reason this doesn't work is that >< is too strict. I believe
>> the only potential way around this is to dig into the FingerTree
>> representation and build the thing top-down. I still don't understand how
>> (if at all) this can be done.
>>
>>
>> On Sat, Nov 22, 2014 at 12:57 PM, David Feuer <david.feuer at gmail.com>
>> wrote:
>>>
>>> The ideal goal, which has taken me forever to identify and which may well
>>> be unattainable, is to get O(log(min{i,mn-i})) access to each element of the
>>> result, while maintaining O(mn) time to force it entirely. Each of these is
>>> possible separately, of course. To get them both, if it's possible, we need
>>> to give up on the list-like approach and start splitting Seqs instead of
>>> lists. As we descend, we want to pass a single thunk to each element of each
>>> Digit to give it just enough to do its thing. Representing the splits
>>> efficiently and/or memoizing them could be a bit of a challenge.
>>>
>>> On Fri, Nov 21, 2014 at 02:00:16PM -0500, David Feuer wrote:
>>> > To be precise, I *think* using the fromList approach for <*> makes us
>>> > create O
>>> > (n) thunks in order to extract the last element of the result. If we
>>> > build the
>>> > result inward, I *think* we can avoid this, getting the last element of
>>> > the
>>> > result in O(1) time and space. But my understanding of this data
>>> > structure
>>> > remains primitive.
>>>
>>> This modification of the previous should do that.
>>>
>>> mult :: Seq (a -> b) -> Seq a -> Seq b
>>> mult sfs sxs = fromTwoLists (length sfs * length sxs) ys rev_ys
>>>   where
>>>     fs = toList sfs
>>>     rev_fs = toRevList sfs
>>>     xs = toList sxs
>>>     rev_xs = toRevList sxs
>>>     ys = [f x | f <- fs, x <- xs]
>>>     rev_ys = [f x | f <- rev_fs, x <- rev_xs]
>>>
>>> -- toRevList xs = toList (reverse xs)
>>> toRevList :: Seq a -> [a]
>>> toRevList = foldl (flip (:)) []
>>>
>>> -- Build a tree lazy in the middle, from a list and its reverse.
>>> --
>>> -- fromTwoLists (length xs) xs (reverse xs) = fromList xs
>>> --
>>> -- Getting the kth element from either end involves forcing the lists
>>> -- to length k.
>>> fromTwoLists :: Int -> [a] -> [a] -> Seq a
>>> fromTwoLists len_xs xs rev_xs =
>>>     Seq \$ mkTree2 len_xs 1 (map Elem xs) (map Elem rev_xs)
>>>
>>> -- Construct a fingertree from the first n elements of xs.
>>> -- The arguments must satisfy n <= length xs && rev_xs = reverse xs.
>>> -- Each element of xs has the same size, provided as an argument.
>>> mkTree2 :: Int -> Int -> [a] -> [a] -> FingerTree a
>>> mkTree2 n size xs rev_xs
>>>   | n == 0 = Empty
>>>   | n == 1 = let [x1] = xs in Single x1
>>>   | n <  6 = let
>>>             nl = n `div` 2
>>>             l = Data.List.take nl xs
>>>             r = Data.List.take (n - nl) rev_xs
>>>         in Deep totalSize (mkDigit l) Empty (mkRevDigit r)
>>>   | otherwise = let
>>>             size' = 3*size
>>>             n' = (n-4) `div` 3
>>>             digits = n - n'*3
>>>             nl = digits `div` 2
>>>             (l, xs') = Data.List.splitAt nl xs
>>>             (r, rev_xs') = Data.List.splitAt (digits - nl) rev_xs
>>>             nodes = mkNodes size' xs'
>>>             rev_nodes = mkRevNodes size' rev_xs'
>>>         in Deep totalSize (mkDigit l) (mkTree2 n' size' nodes rev_nodes)
>>> (mkRevDigit r)
>>>   where
>>>     totalSize = n*size
>>>
>>> mkDigit :: [a] -> Digit a
>>> mkDigit [x1] = One x1
>>> mkDigit [x1, x2] = Two x1 x2
>>> mkDigit [x1, x2, x3] = Three x1 x2 x3
>>> mkDigit [x1, x2, x3, x4] = Four x1 x2 x3 x4
>>>
>>> -- length xs <= 4 => mkRevDigit xs = mkDigit (reverse xs)
>>> mkRevDigit :: [a] -> Digit a
>>> mkRevDigit [x1] = One x1
>>> mkRevDigit [x2, x1] = Two x1 x2
>>> mkRevDigit [x3, x2, x1] = Three x1 x2 x3
>>> mkRevDigit [x4, x3, x2, x1] = Four x1 x2 x3 x4
>>>
>>> mkNodes :: Int -> [a] -> [Node a]
>>> mkNodes size (x1:x2:x3:xs) = Node3 size x1 x2 x3:mkNodes size xs
>>>
>>> -- length xs `mod` 3 == 0 =>
>>> --    mkRevNodes size xs = reverse (mkNodes size (reverse xs))
>>> mkRevNodes :: Int -> [a] -> [Node a]
>>> mkRevNodes size (x3:x2:x1:xs) = Node3 size x1 x2 x3:mkRevNodes size xs
>>> _______________________________________________