[Haskell] Jones-optimal, typed,
symbolic differentiation of (compiled) functions
oleg at pobox.com
oleg at pobox.com
Sun Dec 3 20:56:27 EST 2006
We show symbolic differentiation of a wide class of numeric functions
without any interpretative overhead. The functions to symbolically
differentiate can be given to us in a compiled form (in .hi files);
their source code is not needed. We produce a (compiled, if needed)
function that is an exact, algebraically simplified analytic
derivative of the given function. Our approach is reifying code into
its `dictionary view', intensional analysis of typed code expressions,
and the use of staging to evaluate under lambda. This message is the
improvement over the message posted on this list in November
2004. Rather than introduce our own expression data types, we use the
one built in the Template Haskell (TH). TH lets us avoid the overhead
of interpreting code expressions and achieve Jones-optimality. The
computed derivative can be compiled down to the machine code and so it
runs at full speed, as if it were written by hand to start with.
In the process, we develop a simple type system for a subset of TH
code expressions (TH is, sadly, completely untyped) -- so that
accidental errors can be detected early. We introduce a few
combinators for the intensional analysis of such typed code
expressions. We also show how to reify an identifier like (+) to a
TH.Name to match on -- by applying TH to itself. Effectively we obtain
more than one stage of computation.
Conversion from a TH code expression to code is a common operation
built into TH: splicing. The spliced code becomes the part of the host
program and can be compiled along with it. The reverse transformation,
from the compiled code to the TH code expression, is quite less
common. Once we recover the source code view of a (compiled) function,
we can convert it back to the compiled code with TH splicing. We can
also simplify the code, eliminate common subexpressions by let
insertions, partially evaluate the code -- or, the subject of this
message, differentiate the code and apply a few algebraic
simplifications.
It should be stressed that we take a function of the type Floating a =>
a -> a and return another (derived) function of that type. The
straightforward approach would be to transform
\x -> body
into
\x -> derive body
We see the obvious drawback: we have to evaluate 'derive body' every
time we apply the derivative function. Our approach is different: to
put it a bit informally, we compute
let body' = derive body in \x -> body'
That approach requires evaluating under lambda, which is the main
benefit of staging. We should make another point: symbolic
differentiation presupposes that the function is represented by a data
type such as data Exp = Var | Add Exp Exp ... Therefore, the result of
symbolic differentiation would be
\x -> interpret (derive body)
where 'interpret' is a function of the type 'Exp -> Float'. That
approach results in the overhead of interpreting Exp on each
invocation of the function. Staging again lets us eliminate such an
overhead. We really operate on the code, in the same form that
compiler sees it.
Let us take a simple example. Given the function
> test1f x = let y = x * x in y + 1
(which can be in a separately compiled file), we can reify it into a
TH code expression and print it:
> test1c = new'diffVar >>= \ (v::Var Float) -> return $ (test1f (var'exp v),v)
> test1r = test1c >>= \ (c,v) -> reflectDF v c
> test1cp = showQC test1r
*Diff> test1cp
\dx_0 -> GHC.Num.+ (GHC.Num.* dx_0 dx_0) 1
The output is produced by TH's pretty-printer -- which, I suspect, is
the same function used by GHC to print expressions in error
messages. Predictably, the let expression in the original function is
`inlined'; what we obtain is a `dictionary' view of the
function. We can splice the obtained code back into program; we
can also differentiate it symbolically:
> test1d = test1c >>= \ (c,v) -> reflectDF v $ diffC v c
> test1dp = showQC test1d
*Diff> test1dp
\dx_0 -> GHC.Num.+ (GHC.Num.+ (GHC.Num.* 1 dx_0) (GHC.Num.* dx_0 1)) 0
That is not too nice, so we do a bit of algebraic simplifications
> test1ds = test1c >>= \ (c,v) -> reflectDF v $ simpleC v $ diffC v c
> test1dsp = showQC test1ds
*Diff> test1dsp
\dx_0 -> GHC.Num.+ dx_0 dx_0
which is quite better. We can splice that code in a Haskell program, and
apply the result as a regular Haskell function:
> test1ds' = $(reflectQC test1ds) 2.0
The full code is available from the directory:
http://pobox.com/~oleg/ftp/Haskell/differentiation/
The file differentiation.lhs in that directory is the old message,
posted on the Haskell mailing list in November 2004. The other files
are new. Sorry about too many files: TH demands splitting the code
across modules. If you're using GHC 6.6, please edit the file
TypedCodeAux.hs and comment one definition and uncomment the other
one, as described in the code. In GHC 6.6, TH API has changed a bit.
The file TypedCode.hs introduces the type "Code a" of typed TH code
expressions. The (phantom) type parameter is the expression's
type. The file also defines combinators for building and analyzing
these typed expressions.
The main file is Diff.hs. It gives the reification of code into TH
code, differentiation rules and algebraic simplification rules, all
via the intentional analysis of the typed code. The differentiation
function is as follows:
> diff_fn :: Floating b => (forall a. Floating a => a -> a) -> QCode (b -> b)
> diff_fn f =
> do
> v <- new'diffVar
> let body = f (var'exp v) -- reified body of the function
> reflectDF v . simpleC v . diffC v $ body -- differentiate and simplify
which closely follows the informal outline described above.
Here's a more involved example,
> test2f x = foldl (\z c -> x*z + c) 0 [1,2,3]
> test2n = test2f (4::Float) -- 27.0
> test2s = show_fn test2f
If we just reflect this function (so we can print its code, test2s),
we obtain
*Diff> test2s
\dx_0 -> GHC.Num.+ (GHC.Num.* dx_0 (GHC.Num.+
(GHC.Num.* dx_0 (GHC.Num.+ (GHC.Num.* dx_0 0) 1)) 2)) 3
Its differentiation gives us
> test2ds = showQC (diff_fn test2f)
*Diff> test2ds
\dx_0 -> GHC.Num.+ (GHC.Num.+ dx_0 2) dx_0
which is not too bad. Again, the result can be spliced in and used as
a regular Haskell function. The result obviously has no interpretative
overhead.
> test2dn = $(reflectQC (diff_fn test2f)) (4::Float)
> -- 10.0
The file Diff.hs has more complex examples, such as differentiating
> test5f x = sin (5*x + pi/2) + cos(1 / x)
The file also demonstrates partial derivatives.
To finish this message, we show a sample of differentiation rules
> diffC :: (Floating a, Floating b) => Var b -> Code a -> Code a
> -- differentiating x/y
> diffC v c | Just (x,y) <- on'2opC op'div c =
> ((diffC v x) * y - x * (diffC v y)) / (y*y)
> ...
The intensional analysis of the code should be obvious. Here's a
sample simplification rule
> simpleCL :: Floating a => Var b -> Code a -> Maybe (Code a)
> simpleCL v c | Just (x,y) <- on'2opC op'add c =
> simple'recur op'add sadd v x y
> where
> sadd x y | Just 0 <- on'litRationalC x = Just y
> sadd x y | Just 0 <- on'litRationalC y = Just x
> -- constant folding
> sadd x y | (Just x, Just y) <- (on'litRationalC x, on'litRationalC y)
> = Just (fromRational $ x + y)
> sadd x y = Nothing
The function simpleCL is repeatedly applied to a code expression until
it reports that no simplifications have been made.
More information about the Haskell
mailing list