# [Haskell-cafe] Suggestion on how to solve a type issue with Numeric.AD

Frederic Cogny frederic.cogny at gmail.com
Fri Mar 31 12:33:16 UTC 2017

```Hello Café

*Issue: *
I'm having trouble using the AD package presumably due to a wrong use of
types (Sorry I know it's a bit vague)
Any help or pointer on how to solve this would be greatly appreciated.

*Background: *
I've implemented a symbolic expression tree [full code here:
https://pastebin.com/FDpSFuRM]
-- | Expression Tree
data ExprTree a = Const a                             -- ^ number
| ParamNode ParamName                 -- ^ parameter
| VarNode VarName                     -- ^ variable
| UnaryNode MonoOp (ExprTree a)       -- ^ operator of
arity 1
| BinaryNode DualOp (ExprTree a) (ExprTree a) -- ^ operator
of arity 2
| CondNode (Cond a) (ExprTree a) (ExprTree a)     -- ^
conditional node
deriving (Eq,Show, Generic)

An evaluation function on it
-- |  evaluates an Expression Tree on its Context (Map ParamName a, Map
VarName a)
evaluate :: (Eq a, Ord a, Floating a) => ExprTree a -> Context a -> a

And a few instances on it:
instance Num a => Default (ExprTree a) where ...
instance Num a => Num (ExprTree a) where ...
instance Fractional a => Fractional (ExprTree a) where ...
instance Floating a => Floating (ExprTree a) where ...
instance (Arbitrary a) => Arbitrary (ExprTree a) where ...

This allows me to easily create (using derivation rules) the derivative(s)
of such a tree with respect to its variable(s):
diff :: (Eq a, Floating a) => ExprTree a -> VarName -> ExprTree a
grad :: (Eq a, Floating a) => ExprTree a -> Map VarName (ExprTree a)
hessian :: (Eq a, Floating a) => ExprTree a -> Map VarName (Map VarName
(ExprTree a))

So far, so good ...

Now, to gain assurance in my implementation I wanted to check the
so I create the two following
-- | helper for AD usage
exprTreeToListFun :: (RealFloat a)
=> ExprTree a      -- ^ tree
-> Map ParamName a -- ^ paramDict
-> ([a] -> a)      -- fun from var values to their
evaluation
exprTreeToListFun tree paramDict vals = res
where
res            = evaluate tree (paramDict, varDict)
varDict        = Map.fromList \$ zip (getVarNames tree) vals

gradThroughAD :: RealFloat a => ExprTree a -> Context a -> Map VarName a
where
varNames = Map.keys varDict
varVals  = Map.elems varDict
res      = Map.fromList \$ zip varNames gradList

it unfortunately does not type check with message:
• Couldn't match type ‘a’ with ‘Numeric.AD.Internal.Reverse.Reverse s a’ ‘a’
is a rigid type variable bound by the type signature for: gradThroughAD ::
forall a. RealFloat a => ExprTree a -> Context a -> Map VarName a at src/
Reverse s a] -> Numeric.AD.Internal.Reverse.Reverse s a Actual type: [a] ->
tree paramDict) varVals • Relevant bindings include gradList :: [a] (bound
at src/SymbolicExpression.hs:457:5) varVals :: [a] (bound at src/
SymbolicExpression.hs:456:5) varDict :: Map VarName a (bound at
src/SymbolicExpression.hs:453:32) paramDict :: Map ParamName a (bound at
src/SymbolicExpression.hs:453:21) tree :: ExprTree a (bound at src/
Map VarName a (bound at src/SymbolicExpression.hs:453:1)

I was a bit surprised (I guess naively) by this, since to me, 'a' is
polymorphic with RealFloat as type constraint and that is "supposed" to

So anyway, I tried to modify as follow:
{-# LANGUAGE Rank2Types    #-}

-- | helper for AD usage
exprTreeToListFun :: (RealFloat a)
=> ExprTree a      -- ^ tree
-> Map ParamName a -- ^ paramDict
-> *(RealFloat b => *[b] -> b)      -- fun from var
values to their evaluation
exprTreeToListFun tree paramDict vals = res
where
res            = *realToFrac* \$ evaluate tree (paramDict, varDict)
varDict        = Map.fromList \$ zip (getVarNames tree) *\$ map
realToFrac *vals

This now typechecks and runs but going through AD returns me *all
derivatives as zero* as if treating the variables (that passed through
*realToFrac*) as constants, so not that much helpful either.

Any idea how this can be solved ?

Apologies if this is a question that already has an answer somewhere but poking
looks like I'm not the only one having similar issues and unfortunately
none of the answers I found really helped me, if anything it confirms the
same issue of null derivatives: