[Haskell-cafe] How to improve speed? (MersenneTwister is several times slower than C version)

isto isto.aho at dnainternet.net
Wed Nov 1 17:16:55 EST 2006


Hi all,

On HaWiki was an announcement of MersenneTwister made by Lennart
Augustsson.  On a typical run to find out 10000000th rnd num the output
is (code shown below):

$ time ./testMTla
Testing Mersenne Twister.
Result is [3063349438]

real    0m4.925s
user    0m4.856s


I was exercising with the very same algorithm and tried to make it
efficient (by using IOUArray): now a typical run looks like (code shown
below):

$ time ./testMT
Testing Mersenne Twister.
3063349438

real    0m3.032s
user    0m3.004s


The original C-version (modified so that only the last number is
shown) gives typically

$ time ./mt19937ar
outputs of genrand_int32()
3063349438

real    0m0.624s
user    0m0.616s

Results are similar with 64 bit IOUArray against 64 bit C variant.
C seems to work about 5 to 10 times faster in this case.

I have tried to do different things but now I'm stuck.  unsafeRead
and unsafeWrite improved a bit the lazy (STUArray-version) and
IOUArray-versions but not very much.  I took a look of Core file but
then, I'm not sure where the boxed values are ok. E.g. should  IOUArray
Int Word64  be replaced with something else?

Any hints and comments on how to improve the efficiency and make
everything better will be appreciated a lot!  

br, Isto

----------------------------- testMTla.hs (MersenneTwister, see HaWiki)
module Main where

-- ghc -O3 -optc-O3 -optc-ffast-math -fexcess-precision --make testMTla

import MersenneTwister

main = do
	putStrLn "Testing Mersenne Twister."
	let 	mt = mersenneTwister 100
		w = take 1 (drop 9999999 mt)
		-- w = take 1 (drop 99 mt)
	putStrLn $ "Result is " ++ (show w)
-----------------------------

----------------------------- testMT.hs
module Main where

-- Compile eg with
--   ghc -O3 -optc-O3 -optc-ffast-math -fexcess-precision --make testMT

import Mersenne

genRNums32 :: MT32 -> Int -> IO (MT32)
genRNums32 mt nCnt = gRN mt nCnt 
	where gRN :: MT32 -> Int -> IO (MT32)
	      gRN mt nCnt | mt `seq` nCnt `seq` False = undefined
	      gRN mt 1    = do 
			(r,mt') <- next32 mt
			putStrLn $ (show r)
			return mt'
	      gRN mt nCnt = do
			(r,mt') <- next32 mt
			gRN mt' $! (nCnt-1) 


main = do
	putStrLn "Testing Mersenne Twister."
	mt32 <- initialiseGenerator32 100
	genRNums32 mt32 10000000
-----------------------------

----------------------------- Mersenne.hs (sorry for linewraps)
module Mersenne where

import Data.Bits
import Data.Word
import Data.Array.Base
import Data.Array.MArray
import Data.Array.IO
-- import System.Random


data MT32 = MT32 (IOUArray Int Word32) Int
data MT64 = MT64 (IOUArray Int Word64) Int


last32bitsof :: Word32 -> Word32 
last32bitsof a = a .&. 0xffffffff -- == (2^32-1)  

lm32 = 0x7fffffff :: Word32
um32 = 0x80000000 :: Word32
mA32 = 0x9908b0df :: Word32 -- == 2567483615

-- Array of length 624.
initialiseGenerator32 :: Int -> IO MT32 
initialiseGenerator32 seed = do
	let s = last32bitsof (fromIntegral seed)::Word32
	mt <- newArray (0,623) (0::Word32)
	unsafeWrite mt 0 s
	iG mt s 1
	mt' <- generateNumbers32 mt
	return (MT32 mt' 0)
	where
		iG :: (IOUArray Int Word32) -> Word32 -> Int -> IO (IOUArray Int
Word32)
		iG mt lastNro n  
			| n == 624    = return mt
			| otherwise = do let n1 = lastNro `xor` (shiftR lastNro 30)
				             new = (1812433253 * n1 + (fromIntegral n)::Word32) 
					 unsafeWrite mt n new
				         iG mt new (n+1)


generateNumbers32 :: (IOUArray Int Word32) -> IO (IOUArray Int Word32)
generateNumbers32 mt = gLoop 0 mt
	where
		gLoop :: Int -> (IOUArray Int Word32) -> IO (IOUArray Int Word32)
		gLoop i mt 
			| i==623  = do 
				wL <- unsafeRead mt 623
				w0 <- unsafeRead mt 0
				w396 <- unsafeRead mt 396
				let y = (wL .&. um32) .|. (w0 .&. lm32) :: Word32
				if even y 
			 	   then unsafeWrite mt 623 (w396 `xor` (shiftR y 1))
				   else unsafeWrite mt 623 (w396 `xor` (shiftR y 1) `xor` mA32)
				return mt
			| otherwise = do
				wi  <- unsafeRead mt i
				wi1 <- unsafeRead mt (i+1) 
				w3  <- unsafeRead mt ((i+397) `mod` 624)
				let y = (wi .&. um32) .|. (wi1 .&. lm32)
				if even y 
				   then unsafeWrite mt i (w3 `xor` (shiftR y 1))
				   else unsafeWrite mt i (w3 `xor` (shiftR y 1) `xor` mA32)
				gLoop (i+1) mt


next32 :: MT32 -> IO (Word32, MT32)
next32 (MT32 mt i) 
	| i >= 624  = do mt' <- generateNumbers32 mt
			 let m = MT32 mt' (i `mod` 624)
			 (w,m')  <- next32 m
			 return (w,m')
	| otherwise = do 
		y <- unsafeRead mt i
		let y1 = y  `xor`  (shiftR y  11)
	    	    y2 = y1 `xor` ((shiftL y1 7 ) .&. 0x9d2c5680) -- == 2636928640
	    	    y3 = y2 `xor` ((shiftL y2 15) .&. 0xefc60000) -- == 4022730752
	            y4 = y3 `xor`  (shiftR y3 18) 
		return $ (y4, MT32 mt (i+1))


mA64 = 0xB5026F5AA96619E9 :: Word64
um64 = 0xFFFFFFFF80000000 :: Word64
lm64 = 0x7FFFFFFF :: Word64

initialiseGenerator64 :: Int -> IO (MT64)
initialiseGenerator64 seed = do
	let s = (fromIntegral seed)::Word64
	mt <- newArray (0,311) (0::Word64)
	unsafeWrite mt 0 s
	iG mt s 1
	generateNumbers64 mt
	return (MT64 mt 0)
	where
		iG :: (IOUArray Int Word64) -> Word64 -> Int -> IO (IOUArray Int
Word64)
		iG mt lN i | mt `seq` lN `seq` i `seq` False = undefined
		iG mt lastNro 312 = return mt  
		iG mt lastNro n   = do 
				let n1 = lastNro `xor` (shiftR lastNro 62)
			            new = (6364136223846793005 * n1 + (fromIntegral
n)::Word64) 
				unsafeWrite mt n new
				iG mt new $! (n+1)

generateNumbers64 :: (IOUArray Int Word64)  -> IO ()
generateNumbers64 mt = gLoop 0 
	where
		gLoop :: Int -> IO ()
		gLoop i | i `seq` False = undefined
		gLoop 311 = do 
				wL <- unsafeRead mt 311
				w0 <- unsafeRead mt 0
				w155 <- unsafeRead mt 155
				let y = (wL .&. um64) .|. (w0 .&. lm64) :: Word64
				if even y  
			 	   then unsafeWrite mt 311 (w155 `xor` (shiftR y 1))
				   else unsafeWrite mt 311 (w155 `xor` (shiftR y 1) `xor` mA64)
				return ()
		gLoop i  = do
				wi  <- unsafeRead mt i
				wi1 <- unsafeRead mt (i+1) 
				w3  <- unsafeRead mt ((i+156) `mod` 312)
				let y = (wi .&. um64) .|. (wi1 .&. lm64)
				if even y 
				   then unsafeWrite mt i (w3 `xor` (shiftR y 1))
				   else unsafeWrite mt i (w3 `xor` (shiftR y 1) `xor` mA64)
				gLoop $! (i+1) 


next64 :: MT64 -> IO (Word64, MT64)
next64 (MT64 mt 312) = do generateNumbers64 mt
			  let m = MT64 mt 0
			  (w,m') <- next64 m
			  return (w,m')
next64 (MT64 mt i) = do
		y <- unsafeRead mt i
		let y1 = y  `xor` ((shiftR y  29) .&. 0x5555555555555555)
	    	    y2 = y1 `xor` ((shiftL y1 17) .&. 0x71D67FFFEDA60000) 
	    	    y3 = y2 `xor` ((shiftL y2 37) .&. 0xFFF7EEE000000000)
	            y4 = y3 `xor`  (shiftR y3 43) 
		return $! (y4, MT64 mt (i+1))







More information about the Haskell-Cafe mailing list