[Haskell-cafe] Known Unknowns

Chris Kuklewicz haskell at list.mightyreason.com
Fri Jan 27 05:53:07 EST 2006

Joel Koerwer wrote:
> On 1/26/06, *Donald Bruce Stewart* <dons at cse.unsw.edu.au
> <mailto:dons at cse.unsw.edu.au>> wrote: 
>     Ah, i just do: ghc A.hs -O2 -ddump-simpl | less
>     and then read the Core, keeping an eye on the functions I'm interested
>     in, and checking they're compiling to the kind of loops I'd write by
>     hand. This is particularly useful for the kinds of tight numeric loops
>     used in some of the shootout entries.
>     Cheers,
>       Don
> In that case could you describe the kind of loops you'd write by hand?

See below for the pseudo-code loop and the Haskell version.

> Seriously. And perhaps typical problems/fixes when the compiler doesn't
> produce what you want.

We don't have any fixes.

> Thanks,
> Joel

More discussion and code is at http://haskell.org/hawiki/NbodyEntry

The compiler produces code that runs 4 times slower than OCaml in our current
best attempt at programming against a 40 element (IOUArray Int Double).  The
final programs speed is very architecture dependent, but more frustrating is
that small referentially transparent changes to the source code produce up to
factor-of-two fluctuations in run time.

The small numeric functions in the shootout, where there is a recursive function
with 1 or 2 parameters (Double's), perform quite well.  But manipulating this
medium number of Double's to model the solar system has been too slow.

The main loop for the 5 planets looks quite simple in pseudo-c:

deltaTime = 0.01
for (i=0 ; i<5; ++i) {
  "get mass m, position (x,y,z), velocity (vx,vy,vz) of particle number i"

  for (j=(i+1); j<5; ++j) {
    "get mass, position, velocity of particle j"

    dxyx = "position of i" - "position of j"
    mag = deltaTime /(length of dxyz)^3

    "velocity of j" += "mass of i" * mag * dxyz
    "velocity of i" -= "mass of j" * mag * dxyz

  "position of i" += deltaTime * "velocity of i"

Note that the inner loop "for j" starts a "j=(i+1)".

The best performing Haskell code, for this loop, so far is:

-- Offsets for each field
x = 0; y = 1; z = 2; vx= 3; vy= 4; vz= 5; m = 6
-- This is the main code. Essentially all the time is spent here
advance n = when (n > 0) $ updateVel 0 >> advance (pred n)

  where updateVel i = when (i <= nbodies) $ do
            let i' = (.|. shift i 3)
            im  <- unsafeRead b (i' m)
            ix  <- unsafeRead b (i' x)
            iy  <- unsafeRead b (i' y)
            iz  <- unsafeRead b (i' z)
            ivx <- unsafeRead b (i' vx)
            ivy <- unsafeRead b (i' vy)
            ivz <- unsafeRead b (i' vz)

            let updateVel' ivx ivy ivz j =  ivx `seq` ivy `seq` ivz `seq`
                  if j > nbodies then do
                    unsafeWrite b (i' vx) ivx
                    unsafeWrite b (i' vy) ivy
                    unsafeWrite b (i' vz) ivz
                  else do
                    let j' = (.|. shiftL j 3)
                    jm <- unsafeRead b (j' m)
                    dx <- liftM (ix-) (unsafeRead b (j' x))
                    dy <- liftM (iy-) (unsafeRead b (j' y))
                    dz <- liftM (iz-) (unsafeRead b (j' z))
                    let distance = sqrt (dx*dx+dy*dy+dz*dz)
                        mag = 0.01 / (distance * distance * distance)
                    addScaled3 (3 .|. (shiftL j 3)) ( im*mag) dx dy dz
                    let a = -jm*mag
                        ivx' = ivx+a*dx
                        ivy' = ivy+a*dy
                        ivz' = ivz+a*dz
                    updateVel' ivx' ivy' ivz' $! (j+1)

            updateVel' ivx ivy ivz $! (i+1)
            addScaled (shiftL i 3) 0.01 (3 .|. (shiftL i 3))
            updateVel (i+1)

-- Helper functions

addScaled i a j | i `seq` a `seq` j `seq` False = undefined -- stricitfy
addScaled i a j = do set i1 =<< liftM2 scale (unsafeRead b i1) (unsafeRead b j1)
                     set i2 =<< liftM2 scale (unsafeRead b i2) (unsafeRead b j2)
                     set i3 =<< liftM2 scale (unsafeRead b i3) (unsafeRead b j3)
    where scale old new = old + a * new
          i1 = i; i2 = succ i1; i3 = succ i2;
          j1 = j; j2 = succ j1; j3 = succ j2;

addScaled3 i a jx jy jz | i `seq` a `seq` jx `seq` jy `seq` jz `seq` False =
addScaled3 i a jx jy jz = do set i1 =<< liftM (scale jx) (unsafeRead b i1)
                             set i2 =<< liftM (scale jy) (unsafeRead b i2)
                             set i3 =<< liftM (scale jz) (unsafeRead b i3)
    where scale new old = a * new + old
          i1 = i; i2 = succ i1; i3 = succ i2;

More information about the Haskell-Cafe mailing list