[Haskell-cafe] Safe forward-mode AD in Haskell?
Björn Buckwalter
bjorn.buckwalter at gmail.com
Tue May 8 18:06:56 EDT 2007
Dear all,
= Introduction =
In [1] Siskind and Pearlmutter expose the danger of "perturbation
confusion" in forward-mode automatic differentiation. In particular
they state:
"We discuss a potential problem with forward-mode AD common to
many AD systems, including all attempts to integrate a forward-mode
AD operator into Haskell."
This literate Haskell message shows how, using a type system extension
of the Glasgow Haskell Compiler[2] (GHC), we can statically guarantee
that perturbation confusion does not occur.
= Code =
The below code closely follows the Haskell code in the appendices
of [1]. It relies on GHC's arbitrary-rank polymorphism(?) extension
to Haskell 98.
> {-# OPTIONS_GHC -fglasgow-exts #-}
In our definition of the 'Bundle' data type we add the phantom type
's', which will be the key to disambiguating between different
application of the 'd' operator.
> data Bundle s a = Bundle a a
> instance Num a => Show (Bundle s a) where
> showsPrec p (Bundle x x') = showsPrec p [x,x']
> instance Num a => Eq (Bundle s a) where
> (Bundle x x') == (Bundle y y') = (x == y)
> lift z = Bundle z 0
> instance Num a => Num (Bundle s a) where
> (Bundle x x') + (Bundle y y') = Bundle (x + y) (x' + y')
> (Bundle x x') * (Bundle y y') = Bundle (x * y) (x * y' + x' * y)
> fromInteger z = lift (fromInteger z)
> instance Fractional a => Fractional (Bundle s a) where
> fromRational z = lift (fromRational z)
We provide a type signature for 'd' where we existentially quantify
the phantom type 's' to prevent mixing of bundles from different
'd' operators.
> d :: Num a => (forall s. Bundle s a -> Bundle s a) -> a -> a
> d f x = let (Bundle y y') = f (Bundle x 1) in y'
The extential quantification makes the definition
] constant_one' x = d (\y -> x + y) 1
impossible since 'x' originates externally to the 'd' operator. GHC
rejects the definition with a "Inferred type is less polymorphic
than expected". In order to pass the compiler the definition must
be changed to the corrected[1] version.
> constant_one x = d (\y -> (lift x) + y) 1
> should_be_one_a = d (\x -> x * (constant_one x)) 1
> should_be_one_b = d (\x -> x * 1 ) 1
> violation_of_referential_transparency = should_be_one_a /= should_be_one_b
= References =
[1] http://www.bcl.hamilton.ie/~qobi/nesting/papers/ifl2005.pdf
[2] http://haskell.org/ghc/
More information about the Haskell-Cafe
mailing list