[commit: packages/containers] develop-0.6, develop-0.6-questionable, master, zip-devel: Make zipWith faster (31e1234)

git at git.haskell.org git at git.haskell.org
Fri Jan 23 22:40:29 UTC 2015


Repository : ssh://git@git.haskell.org/containers

On branches: develop-0.6,develop-0.6-questionable,master,zip-devel
Link       : http://git.haskell.org/packages/containers.git/commitdiff/31e1234435ae734bbf3d33a79e9cce89d06ac738

>---------------------------------------------------------------

commit 31e1234435ae734bbf3d33a79e9cce89d06ac738
Author: David Feuer <David.Feuer at gmail.com>
Date:   Tue Dec 2 17:09:49 2014 -0500

    Make zipWith faster
    
    Make `zipWith` build its result with the structure of its first
    argument, splitting up its second argument as it goes. This allows
    fast random access to the elements of the results immediately,
    without having to build large portions of the structure. It also
    seems to be slightly faster than the old implementation when the
    entire result is used, presumably by avoiding rebalancing costs.
    I believe most of this code will also help implement a fast
    `(<*>)`.
    
    Use the same approach to implement `zipWith3` and `zipWith4`.
    
    Clean up a couple warnings.
    
    Many thanks to Carter Schonwald for suggesting that I use the
    structure of the first sequence to structure the result, and for
    helping me come up with the splitTraverse approach.
    
    Benchmarks:
    
    Zipping two 100000 element lists and extracting the 50000th element
    takes about 11.4ms with the new implementation, as opposed to 88ms with
    the old. Zipping two 10000 element sequences and forcing the result to
    normal form takes 4.0ms now rather than 19.7ms. The indexing gains show
    up for even very short sequences, but the new implementation really
    starts to look good once the size gets to around 1000--presumably it
    handles cache effects better than the old one. Note that the naive
    approach of converting sequences to lists, zipping them, and then
    converting back, actually works very well for forcing short sequences to
    normal form, even better than the new implementation. But it starts to
    lose a lot of ground by the time the size gets to around 10000, and its
    performance on the indexing tests is bad.


>---------------------------------------------------------------

31e1234435ae734bbf3d33a79e9cce89d06ac738
 Data/Sequence.hs | 106 +++++++++++++++++++++++++++++++++++++++++++++++--------
 1 file changed, 92 insertions(+), 14 deletions(-)

diff --git a/Data/Sequence.hs b/Data/Sequence.hs
index b54f1e6..10d3a92 100644
--- a/Data/Sequence.hs
+++ b/Data/Sequence.hs
@@ -676,10 +676,10 @@ replicateM n x
 
 -- | @'replicateSeq' n xs@ concatenates @n@ copies of @xs at .
 replicateSeq :: Int -> Seq a -> Seq a
-replicateSeq n xs
+replicateSeq n s
   | n < 0     = error "replicateSeq takes a nonnegative integer argument"
   | n == 0    = empty
-  | otherwise = go n xs
+  | otherwise = go n s
   where
     -- Invariant: k >= 1
     go 1 xs = xs
@@ -1703,6 +1703,75 @@ reverseNode f (Node2 s a b) = Node2 s (f b) (f a)
 reverseNode f (Node3 s a b c) = Node3 s (f c) (f b) (f a)
 
 ------------------------------------------------------------------------
+-- Traversing with splittable "state"
+------------------------------------------------------------------------
+
+-- For zipping, and probably also for (<*>), it is useful to build a result by
+-- traversing a sequence while splitting up something else.  For zipping, we
+-- traverse the first sequence while splitting up the second [and third [and
+-- fourth]]. For fs <*> xs, we expect soon to traverse
+--
+-- > replicate (length fs * length xs) ()
+--
+-- while splitting something essentially equivalent to
+--
+-- > fmap (\f -> fmap f xs) fs
+--
+-- David Feuer, with excellent guidance from Carter Schonwald, December 2014
+
+class Splittable s where
+    splitState :: Int -> s -> (s,s)
+
+instance Splittable (Seq a) where
+    splitState = splitAt
+
+instance (Splittable a, Splittable b) => Splittable (a, b) where
+    splitState i (a, b) = ((al, bl), (ar, br))
+      where
+        (al, ar) = splitState i a
+        (bl, br) = splitState i b
+
+splitTraverseSeq :: (Splittable s) => (s -> a -> b) -> s -> Seq a -> Seq b
+splitTraverseSeq f s (Seq xs) = Seq $ splitTraverseTree (\s' (Elem a) -> Elem (f s' a)) s xs
+
+splitTraverseTree :: (Sized a, Splittable s) => (s -> a -> b) -> s -> FingerTree a -> FingerTree b
+splitTraverseTree _f _s Empty = Empty
+splitTraverseTree f s (Single xs) = Single $ f s xs
+splitTraverseTree f s (Deep n pr m sf) = Deep n (splitTraverseDigit f prs pr) (splitTraverseTree (splitTraverseNode f) ms m) (splitTraverseDigit f sfs sf)
+  where
+    (prs, r) = splitState (size pr) s
+    (ms, sfs) = splitState (n - size pr - size sf) r
+
+splitTraverseDigit :: (Sized a, Splittable s) => (s -> a -> b) -> s -> Digit a -> Digit b
+splitTraverseDigit f s (One a) = One (f s a)
+splitTraverseDigit f s (Two a b) = Two (f first a) (f second b)
+  where
+    (first, second) = splitState (size a) s
+splitTraverseDigit f s (Three a b c) = Three (f first a) (f second b) (f third c)
+  where
+    (first, r) = splitState (size a) s
+    (second, third) = splitState (size b) r
+splitTraverseDigit f s (Four a b c d) = Four (f first a) (f second b) (f third c) (f fourth d)
+  where
+    (first, s') = splitState (size a) s
+    (middle, fourth) = splitState (size b + size c) s'
+    (second, third) = splitState (size b) middle
+
+splitTraverseNode :: (Sized a, Splittable s) => (s -> a -> b) -> s -> Node a -> Node b
+splitTraverseNode f s (Node2 ns a b) = Node2 ns (f first a) (f second b)
+  where
+    (first, second) = splitState (size a) s
+splitTraverseNode f s (Node3 ns a b c) = Node3 ns (f first a) (f second b) (f third c)
+  where
+    (first, r) = splitState (size a) s
+    (second, third) = splitState (size b) r
+
+getSingleton :: Seq a -> a
+getSingleton (Seq (Single (Elem a))) = a
+getSingleton (Seq Empty) = error "getSingleton: Empty"
+getSingleton _ = error "getSingleton: Not a singleton."
+
+------------------------------------------------------------------------
 -- Zipping
 ------------------------------------------------------------------------
 
@@ -1717,17 +1786,11 @@ zip = zipWith (,)
 -- For example, @zipWith (+)@ is applied to two sequences to take the
 -- sequence of corresponding sums.
 zipWith :: (a -> b -> c) -> Seq a -> Seq b -> Seq c
-zipWith f xs ys
-  | length xs <= length ys      = zipWith' f xs ys
-  | otherwise                   = zipWith' (flip f) ys xs
-
--- like 'zipWith', but assumes length xs <= length ys
-zipWith' :: (a -> b -> c) -> Seq a -> Seq b -> Seq c
-zipWith' f xs ys = snd (mapAccumL k ys xs)
+zipWith f s1 s2 = splitTraverseSeq (\s a -> f a (getSingleton s)) s2' s1'
   where
-    k kys x = case viewl kys of
-           (z :< zs) -> (zs, f x z)
-           EmptyL    -> error "zipWith': unexpected EmptyL"
+    minLen = min (length s1) (length s2)
+    s1' = take minLen s1
+    s2' = take minLen s2
 
 -- | /O(min(n1,n2,n3))/.  'zip3' takes three sequences and returns a
 -- sequence of triples, analogous to 'zip'.
@@ -1738,7 +1801,14 @@ zip3 = zipWith3 (,,)
 -- three elements, as well as three sequences and returns a sequence of
 -- their point-wise combinations, analogous to 'zipWith'.
 zipWith3 :: (a -> b -> c -> d) -> Seq a -> Seq b -> Seq c -> Seq d
-zipWith3 f s1 s2 s3 = zipWith ($) (zipWith f s1 s2) s3
+zipWith3 f s1 s2 s3 = splitTraverseSeq (\s a ->
+    case s of
+      (b, c) -> f a (getSingleton b) (getSingleton c)) (s2', s3') s1'
+  where
+    minLen = minimum [length s1, length s2, length s3]
+    s1' = take minLen s1
+    s2' = take minLen s2
+    s3' = take minLen s3
 
 -- | /O(min(n1,n2,n3,n4))/.  'zip4' takes four sequences and returns a
 -- sequence of quadruples, analogous to 'zip'.
@@ -1749,7 +1819,15 @@ zip4 = zipWith4 (,,,)
 -- four elements, as well as four sequences and returns a sequence of
 -- their point-wise combinations, analogous to 'zipWith'.
 zipWith4 :: (a -> b -> c -> d -> e) -> Seq a -> Seq b -> Seq c -> Seq d -> Seq e
-zipWith4 f s1 s2 s3 s4 = zipWith ($) (zipWith ($) (zipWith f s1 s2) s3) s4
+zipWith4 f s1 s2 s3 s4 = splitTraverseSeq (\s a ->
+    case s of
+      (b, (c, d)) -> f a (getSingleton b) (getSingleton c) (getSingleton d)) (s2', (s3', s4')) s1'
+  where
+    minLen = minimum [length s1, length s2, length s3, length s4]
+    s1' = take minLen s1
+    s2' = take minLen s2
+    s3' = take minLen s3
+    s4' = take minLen s4
 
 ------------------------------------------------------------------------
 -- Sorting



More information about the ghc-commits mailing list