[Haskell-cafe] GHC predictability

Don Stewart dons at galois.com
Tue May 13 16:26:07 EDT 2008

> jeff.polakow:
> >    Hello,
> > 
> >    > For example, the natural and naive way to write Andrew's "mean" function
> >    > doesn't involve tuples at all: simply tail recurse with two accumulator
> >    > parameters, and compute the mean at the end.  GHC's strictness analyser
> >    > does the right thing with this, so there's no need for seq, $!, or the
> >    > like.  It's about 3 lines of code.
> >    >
> >    Is this the code you mean?
> > 
> >        meanNat = go 0 0 where
> >            go s n [] = s / n
> >            go s n (x:xs) = go (s+x) (n+1) xs
> >    If so, bang patterns are still required bang patterns in ghc-6.8.2 to run
> >    in constant memory:
> > 
> >        meanNat = go 0 0 where
> >            go s n [] = s / n
> >            go !s !n (x:xs) = go (s+x) (n+1) xs
> > 
> >    Is there some other way to write it so that ghc will essentially insert
> >    the bangs for me?
> Yes, give a type annotation, constraining 'n' to Int.
>     meanNat :: [Double] -> Double
>     meanNat = go 0 0
>       where
>        go :: Double -> Int -> [Double] -> Double
>        go s n []     = s / fromIntegral n
>        go s n (x:xs) = go (s+x) (n+1) xs
> And you get this loop:
>     M.$wgo :: Double#
>               -> Int#
>               -> [Double]
>               -> Double#
>     M.$wgo =
>       \ (ww_smN :: Double#)
>         (ww1_smR :: Int#)
>         (w_smT :: [Double]) ->
>         case w_smT of wild_B1 {
>           [] -> /## ww_smN (int2Double# ww1_smR);
>           : x_a9k xs_a9l ->
>             case x_a9k of wild1_am7 { D# y_am9 ->
>             M.$wgo (+## ww_smN y_am9) (+# ww1_smR 1) xs_a9l
>             }
>         }

Note this is pretty much identical to the code you get with a foldl' (though
without the unboxed pair return):

    import Data.List
    import Text.Printf
    import Data.Array.Vector

    mean :: [Double] -> Double
    mean arr = b / fromIntegral a
        k (n :*: s) a = (n+1 :*: s+a)
        (a :*: b) = foldl' k (0 :*: 0) arr :: (Int :*: Double)

    main = printf "%f\n" . mean $ [1 .. 1e9]

Note I'm using strict pairs for the accumulator, instead of banging lazy

    $s$wlgo :: [Double]
                    -> Double#
                    -> Int#
                    -> (# Int, Double #)

    $s$wlgo =
      \ (xs1_aMQ :: [Double])
        (sc_sYK :: Double#)
        (sc1_sYL :: Int#) ->
        case xs1_aMQ of wild_aML {
          [] -> (# I# sc1_sYL, D# sc_sYK #);
          : x_aMP xs11_XMX ->
            case x_aMP of wild1_aOg { D# y_aOi ->
            $s$wlgo xs11_XMX (+## sc_sYK y_aOi) (+# sc1_sYL 1)

-- Don

