[Haskell-cafe] Safe forward-mode AD in Haskell?

Björn Buckwalter bjorn.buckwalter at gmail.com
Tue May 8 18:06:56 EDT 2007


Dear all,





= Introduction =

In [1] Siskind and Pearlmutter expose the danger of "perturbation
confusion" in forward-mode automatic differentiation. In particular
they state:

    "We discuss a potential problem with forward-mode AD common to
    many AD systems, including all attempts to integrate a forward-mode
    AD operator into Haskell."

This literate Haskell message shows how, using a type system extension
of the Glasgow Haskell Compiler[2] (GHC), we can statically guarantee
that perturbation confusion does not occur.


= Code =

The below code closely follows the Haskell code in the appendices
of [1]. It relies on GHC's arbitrary-rank polymorphism(?) extension
to Haskell 98.

> {-# OPTIONS_GHC -fglasgow-exts #-}

In our definition of the 'Bundle' data type we add the phantom type
's', which will be the key to disambiguating between different
application of the 'd' operator.

> data Bundle s a = Bundle a a

> instance Num a => Show (Bundle s a) where
>   showsPrec p (Bundle x x') = showsPrec p [x,x']

> instance Num a => Eq (Bundle s a) where
>   (Bundle x x') == (Bundle y y') = (x == y)

> lift z = Bundle z 0

> instance Num a => Num (Bundle s a) where
>   (Bundle x x') + (Bundle y y') = Bundle (x + y) (x' + y')
>   (Bundle x x') * (Bundle y y') = Bundle (x * y) (x * y' + x' * y)
>   fromInteger z = lift (fromInteger z)

> instance Fractional a => Fractional (Bundle s a) where
>   fromRational z = lift (fromRational z)

We provide a type signature for 'd' where we existentially quantify
the phantom type 's' to prevent mixing of bundles from different
'd' operators.

> d :: Num a => (forall s. Bundle s a -> Bundle s a) -> a -> a
> d f x = let (Bundle y y') = f (Bundle x 1) in y'

The extential quantification makes the definition

] constant_one' x = d (\y -> x + y) 1

impossible since 'x' originates externally to the 'd' operator. GHC
rejects the definition with a "Inferred type is less polymorphic
than expected". In order to pass the compiler the definition must
be changed to the corrected[1] version.

> constant_one x = d (\y -> (lift x) + y) 1
> should_be_one_a = d (\x -> x * (constant_one x)) 1
> should_be_one_b = d (\x -> x * 1 ) 1

> violation_of_referential_transparency = should_be_one_a /= should_be_one_b


= References =

[1] http://www.bcl.hamilton.ie/~qobi/nesting/papers/ifl2005.pdf
[2] http://haskell.org/ghc/


More information about the Haskell-Cafe mailing list