[Haskell] Typeful symbolic differentiation of compiled functions

oleg at pobox.com oleg at pobox.com
Wed Nov 24 04:05:12 EST 2004


Jacques Carette wrote on LtU on Wed, 11/24/2004

] One quick (cryptic) example: the same difficulties in being able to
] express partial evaluation in a typed setting occurs in a CAS
] [computer algebra system]. Of course I mean to have a partial
] evaluator written in a language X for language X, and have the partial
] evaluator 'never go wrong'.  Cheating by encoding language X as an
] algebraic datastructure in X is counter-productive as it entails huge
] amounts of useless reflection/ reification. One really wants to be
] able to deal with object-level terms simply and directly. But of
] course, that way lies the land of paradoxes (in set theory, type
] theory, logic)
]
] And while I am at it: consider symbolic differentiation. If I call
] that 'function' diff, and have things like diff(sin(x),x) == cos(x),
] what is the type of diff? More interestingly, what if I have D(\x ->
] sin(x) ) == \x -> cos(x) What is the type of D ? Is it implementable
] in Ocaml or Haskell?  [Answer: as far as I know, it is not.  But that
] is because as far as I can tell, D can't even exist in System F.  You
] can't have something like D operating on opaque lambda terms.]. But
] both Maple and Mathematica can. And I can write that in LISP or Scheme
] too.

In this message, we develop the `symbolic' differentiator for a subset
of Haskell functions (which covers arithmetics and a bit of
trigonometry). We can write

	test1f x = x * x + fromInteger 1
	test1 = test1f (2.0::Float)
	test2f = diff_fn test1f
	test2 = test2f (3.0::Float)

We can evaluate our functions _numerically_ -- and differentiate them
_symbolically_. Partial derivatives are supported as well. To answer
Jacques Carette's question: the type of the derivative operator (which
is just a regular function) is

  diff_fn :: (Num b, D b) => (forall a. D a => a -> a) -> b -> b

where the class D includes Floats. One can add exact reals and other
similar things. The key insight is that Haskell98 supports a sort of a
reflection -- or, to be precise, type-directed partial evaluation and
hence term reconstructions. The very types that are assumed of great
hindrance to computer algebra and reflective systems turn out
indispensable in being able to operate on even *compiled* terms
_symbolically_.

We must point out that we specifically do _not_ represent our terms as
algebraic datatypes. Our terms are regular Haskell terms, and can be
compiled! That is in stark contrast with Scheme, for example: although
Scheme may permit term reconstruction under notable restrictions, that
ability is not present in the compiled code. In general, we cannot
take a _compiled_ function Float->Float and compute its derivative
symbolically, yielding another Float->Float function. Incidentally,
R5RS does not guarantee the success of type-directed partial
evaluation even in the interpreted code.

Jacques Carette has mentioned `useless reflection/reification'. The
paper `Tag Elimination and Jones-Optimality' by Walid Taha, Henning
Makholm and John Hughes has introduced a novel tag elimination
analysis as a way to remove all interpretative overhead. In this
message, we do _not_ use that technique. We exploit a different idea,
whose roots can be traced back to Forth. It is remarkable how Haskell
allows that technique.

Other features of our approach are: an extensible differentiation rule
database; emulation of GADT with type classes.

This message is the complete code.

> {-# OPTIONS -fglasgow-exts #-}
> -- We only need existentials. In the rest, it is Haskell98!
> -- Tested with GHC 6.2.1 and 6.3.20041106-snapshot
>
> module Diff where
> import Prelude hiding ((+), (-), (*), (/), (^), sin, cos, fromInteger)
> import qualified Prelude

First we declare the domain of `differentiable' (by us) functions 

> class D a where
>     (+):: a -> a -> a
>     (*):: a -> a -> a
>     (-):: a -> a -> a
>     (/):: a -> a -> a
>     (^):: a -> Int -> a
>     sin:: a -> a
>     cos:: a -> a
>     fromInteger:: Integer -> a

and inject floats into that domain

> instance D Float where
>     (+) = (Prelude.+)
>     (-) = (Prelude.-)
>     (*) = (Prelude.*)
>     (/) = (Prelude./)
>     (^) = (Prelude.^)
>     sin = Prelude.sin
>     cos = Prelude.cos
>     fromInteger = Prelude.fromInteger


For symbolic manipulation, we need a representation for
(reconstructed) terms

> -- Here, reflect is the tag eliminator -- or `compiler'
> class Term t a | t -> a where
>     reflect :: t -> a -> a

We should point out that the terms are fully typeful.

> newtype Const a = Const a deriving Show
> data Var a   = Var     deriving Show
> data Add x y = Add x y deriving Show
> data Sub x y = Sub x y deriving Show
> data Mul x y = Mul x y deriving Show
> data Div x y = Div x y deriving Show
> data Pow x   = Pow x Int deriving Show
> newtype Sin x   = Sin x   deriving Show
> newtype Cos x   = Cos x   deriving Show


We can now describe the grammar of our term representation in the
following straightforward way: 

> instance Term (Const a) a where reflect (Const a) = const a
>
> instance Term (Var a) a where reflect _ = id
>
> instance (D a, Term x a, Term y a) => Term (Add x y) a 
>     where
>     reflect (Add x y) = \a -> (reflect x a) + (reflect y a)
>
> instance (D a, Term x a) => Term (Sin x) a 
>     where
>     reflect (Sin x) = sin . reflect x 

The other instances are given in the Appendix. This is the straightforward
emulation of GADT. The function `reflect' removes the `tags' after the
symbolic differentiation. Actually, `Sin' is a newtype constructor, so
there is no run-time tag to eliminate in this case. 

We must stress that there is no `reify' function. One may say it is
built into Haskell already.

We only need to declare the datatype for the reified code

> data Code a = forall t. (Show t, Term t a, DiffRules t a) => Code t
> instance Show a => Show (Code a) where show (Code t) = show t
> reflect_code (Code c) = reflect c

inject the reified code in the D domain

> instance (Num a, D a) => D (Code a) where
>     Code x + Code y = Code $ Add x y
>     Code x - Code y = Code $ Sub x y
>     Code x * Code y = Code $ Mul x y
>     Code x / Code y = Code $ Div x y
>     (Code x) ^ n    = Code $ Pow x n
>     sin (Code x)    = Code $ Sin x
>     cos (Code x)    = Code $ Cos x
>     fromInteger n = Code $ Const (fromInteger n)

and we're done with the first part:

We can define a function

> test1f x = x * x + fromInteger 1
> test1 = test1f (2.0::Float)

we can even compile it. At any point, we can reify it

> test1c = test1f (Code Var :: Code Float)

and reflect it back:

> test1f' = reflect_code test1c
> test1' = test1f' (2.0::Float)

	*Diff> test1
	5.0
	*Diff> test1'
	5.0
	*Diff> test1c
	Add (Mul Var Var) (Const 1.0)

The differentiation part is quite straightforward. We declare a class
for differentiation rules

> class (Term t a,D a) => DiffRules t a | t -> a where 
>     diff :: t -> Code a

The rules are the instances of the class DiffRules

> instance (Num a, D a) => DiffRules (Const a) a where
>     diff _ = Code $ Const 0
>
> instance (Num a, D a) => DiffRules (Var a) a where
>     diff _ = Code $ Const 1
>
> instance (Show x, Show y, DiffRules x a, DiffRules y a)
>     => DiffRules (Mul x y) a where
>     diff (Mul x y) = case (diff x,diff y) of
> 		       (Code x'::Code a,Code y') ->
> 			   Code $ Add (Mul (x::x) y') (Mul x' (y::y))
>
>
> instance (Num a, Show x, DiffRules x a)
>     => DiffRules (Sin x) a where
>     diff (Sin x) = case diff x of
> 		       (Code x'::Code a) ->
> 			   Code $ Mul x' (Cos x)

The other instances are in the Appendix.

The approach is scalable -- we may add more rules later, in other
modules.

And that's about it:

> diff_code (Code c) = diff c
>
> diff_fn :: (Num b, D b) => (forall a. D a => a -> a) -> b -> b
> diff_fn f = 
>     let code = f (Code Var)
>     in reflect_code $ diff_code code

the differentiation operator could not be any simpler.

We can try 

> test2f = diff_fn test1f
> test2 = test2f (3.0::Float)

we can even see the differentiation result, symbolically:

   *Diff> diff_code test1c
   Add (Add (Mul Var (Const 1.0)) (Mul (Const 1.0) Var)) (Const 0.0)

True, simplifications are direly needed. Well, the full computer
algebra system is a little bit too big to be developed over one
evening. Besides, I wanted to go home three hours ago.

Here's a slightly more complex example:

> test5f x = sin (fromInteger 5*x) + cos(fromInteger 1/x)
> test5c = test5f (Code Var :: Code Float)
>
> test5 = test5f (pi::Float)
> test5d = diff_code test5c
>
> test6 = diff_fn test5f (pi::Float)

One can evaluate the function test5f numerically, differentiate it
symbolically, check the result of differentiation -- and evaluate it
numerically right away.

We can even do partial derivatives:

> test3f x y = (x*y + ((fromInteger 5)*(x^2))) / y
>
> test3c1 = test3f (Code Var :: Code Float) (fromInteger 10)
>
> test4x y = diff_fn (\x -> test3f x (fromInteger y))
> test4y x = diff_fn (test3f (fromInteger x))

-- *Diff> test4x 1 (2::Float) -- partial derivative with respect to x
-- 21.0
-- *Diff> test4y 5 (5::Float) -- partial derivative with respect to y
-- -5.0


Appendix:


> instance (D a, Term x a, Term y a) => Term (Sub x y) a 
>     where
>     reflect (Sub x y) = \a -> (reflect x a) - (reflect y a)
>
> instance (D a, Term x a, Term y a) => Term (Mul x y) a 
>     where
>     reflect (Mul x y) = \a -> (reflect x a) * (reflect y a)
>
> instance (D a, Term x a, Term y a) => Term (Div x y) a 
>     where
>     reflect (Div x y) = \a -> (reflect x a) / (reflect y a)
>
> instance (D a, Term x a) => Term (Pow x) a 
>     where
>     reflect (Pow x n) = (^ n) . reflect x
>
> instance (D a, Term x a) => Term (Cos x) a 
>     where
>     reflect (Cos x) = cos . reflect x 


> instance (Show x, Show y, DiffRules x a, DiffRules y a)
>     => DiffRules (Add x y) a where
>     diff (Add x y) = case (diff x,diff y) of
> 		       (Code x'::Code a,Code y') ->
> 			   Code $ Add x' y'
>
> instance (Show x, Show y, DiffRules x a, DiffRules y a)
>     => DiffRules (Sub x y) a where
>     diff (Sub x y) = case (diff x,diff y) of
> 		       (Code x'::Code a,Code y') ->
> 			   Code $ Sub x' y'
>
> instance (Num a, Show x, Show y, DiffRules x a, DiffRules y a)
>     => DiffRules (Div x y) a where
>     diff (Div x y) = case (diff x,diff y) of
> 		       (Code x'::Code a,Code y') ->
> 			   Code $ 
> 				Div (Sub (Mul x' y) (Mul x y'))
> 				    (Pow y 2)
>
> instance (Num a, Show x, DiffRules x a)
>     => DiffRules (Pow x) a where
>     diff (Pow x n) = case diff x of
> 		       (Code x'::Code a) ->
> 			   Code $ Mul (Const (fromInteger $ toInteger n))
> 				      (Mul x' (Pow x (n Prelude.- 1)))
> instance (Num a, Show x, DiffRules x a)
>     => DiffRules (Cos x) a where
>     diff (Cos x) = case diff x of
> 		       (Code x'::Code a) ->
> 			   Code $ Mul x' (Sub (Const 0) (Sin x))


More information about the Haskell mailing list