Beginner's question: Memo functions in Haskell

Matthias Heiler heiler@rumms.uni-mannheim.de
Wed, 5 Jun 2002 12:09:19 +0200 (MET DST)


Hello,

In a machine learning application I am currently playing with string
kernels which are recursively defined functions operating on strings.
In Haskell the implementation of these functions is very pleasing as
it is a one-to-one translation of the mathematical definition (see
code below [2] and reference [1], p.424).

However, the runtime performance is less pleasing as certain
subexpressions are computed over and over again (profiling with ghc
showed that the function k' (see code below) is called 1425291 times
in a toy example).

In a non-functional implementation I would now set up an auxillary
data structure (e.g. a hash table) for caching/memorizing some
intermediate results.  How would this be done (elegantly, efficiently,
by a Haskell-beginner) in Haskell?

So far, I have seen code using lists to speed up fib(n).  In my case
the arguments of k' are Int -> String -> String, and I don't expect a
simple list of tuples (Int, String, String, RESULT) to be efficient.

Thank you very much for you help,

  Matthias

[1] Huma Lodhi, Craig Saunders, John Shawe-Taylor, Nello Cristianini,
Chris Watkins: "Text Classification using String Kernels", Journal of
Machine Learning Research, 2(Feb):419-444, 2002.  Available online at
http://www.ai.mit.edu/projects/jmlr/papers/volume2.html

[2] My code (actually the first 'real' piece of code I wrote in
Haskell) is the following:

------------------------------------------------------------
module SKernel where

k' :: Double -> Int -> String -> String -> Double
k' lambda 0 s t = 1
k' lambda i s t = if min (length s) (length t) < i 
           then 0 
	   else (lambda * (k' lambda i s' t)) + 
	        sum [ lambda^((length t) - j + 2) * (k' lambda (i-1) s' t') | 
		      j <- [1..length t], 
		      t!!(j-1) == last s, 
		      t' <- [take (j-1) t] ]
	        where s' = take ((length s) - 1) s

k :: Double -> Int -> String -> String -> Double
k lambda i s t = if min (length s) (length t) < i
          then 0
	  else k lambda i s' t + 
 	       sum [ lambda^2 * (k' lambda (i-1) s' t') | 
	             j <- [1..length t], 
		     t!!(j-1) == last s, 
		     t' <- [take (j-1) t] ]
	       where s' = take ((length s) - 1) s

nk :: Double -> Int -> String -> String -> Double
nk lambda n s t = (k lambda n s t) / sqrt ((k lambda n s s) * (k lambda n t t))

-- a toy example would be the call
nk 0.5 5 "This is a string." "Here we have another string."
------------------------------------------------------------