[Haskell-cafe] Suggestion on how to solve a type issue with Numeric.AD

Frederic Cogny frederic.cogny at gmail.com
Fri Mar 31 12:33:16 UTC 2017


Hello Café

*Issue: *
I'm having trouble using the AD package presumably due to a wrong use of
types (Sorry I know it's a bit vague)
Any help or pointer on how to solve this would be greatly appreciated.
Thanks in advance

*Background: *
I've implemented a symbolic expression tree [full code here:
https://pastebin.com/FDpSFuRM]
-- | Expression Tree
data ExprTree a = Const a                             -- ^ number
                | ParamNode ParamName                 -- ^ parameter
                | VarNode VarName                     -- ^ variable
                | UnaryNode MonoOp (ExprTree a)       -- ^ operator of
arity 1
                | BinaryNode DualOp (ExprTree a) (ExprTree a) -- ^ operator
of arity 2
                | CondNode (Cond a) (ExprTree a) (ExprTree a)     -- ^
conditional node
                deriving (Eq,Show, Generic)

An evaluation function on it
-- |  evaluates an Expression Tree on its Context (Map ParamName a, Map
VarName a)
evaluate :: (Eq a, Ord a, Floating a) => ExprTree a -> Context a -> a

And a few instances on it:
instance Num a => Default (ExprTree a) where ...
instance Num a => Num (ExprTree a) where ...
instance Fractional a => Fractional (ExprTree a) where ...
instance Floating a => Floating (ExprTree a) where ...
instance (Arbitrary a) => Arbitrary (ExprTree a) where ...

This allows me to easily create (using derivation rules) the derivative(s)
of such a tree with respect to its variable(s):
diff :: (Eq a, Floating a) => ExprTree a -> VarName -> ExprTree a
grad :: (Eq a, Floating a) => ExprTree a -> Map VarName (ExprTree a)
hessian :: (Eq a, Floating a) => ExprTree a -> Map VarName (Map VarName
(ExprTree a))

So far, so good ...

Now, to gain assurance in my implementation I wanted to check the
derivatives against the Numeric.AD module
so I create the two following
-- | helper for AD usage
exprTreeToListFun :: (RealFloat a)
                  => ExprTree a      -- ^ tree
                  -> Map ParamName a -- ^ paramDict
                  -> ([a] -> a)      -- fun from var values to their
evaluation
exprTreeToListFun tree paramDict vals = res
  where
    res            = evaluate tree (paramDict, varDict)
    varDict        = Map.fromList $ zip (getVarNames tree) vals


gradThroughAD :: RealFloat a => ExprTree a -> Context a -> Map VarName a
gradThroughAD tree (paramDict, varDict) = res
  where
    varNames = Map.keys varDict
    varVals  = Map.elems varDict
    gradList = AD.grad (exprTreeToListFun tree paramDict) varVals
    res      = Map.fromList $ zip varNames gradList


it unfortunately does not type check with message:
• Couldn't match type ‘a’ with ‘Numeric.AD.Internal.Reverse.Reverse s a’ ‘a’
is a rigid type variable bound by the type signature for: gradThroughAD ::
forall a. RealFloat a => ExprTree a -> Context a -> Map VarName a at src/
SymbolicExpression.hs:452:18 Expected type: [Numeric.AD.Internal.Reverse.
Reverse s a] -> Numeric.AD.Internal.Reverse.Reverse s a Actual type: [a] ->
a • In the first argument of ‘AD.grad’, namely ‘(exprTreeToListFun tree
paramDict)’ In the expression: AD.grad (exprTreeToListFun tree paramDict)
varVals In an equation for ‘gradList’: gradList = AD.grad (exprTreeToListFun
tree paramDict) varVals • Relevant bindings include gradList :: [a] (bound
at src/SymbolicExpression.hs:457:5) varVals :: [a] (bound at src/
SymbolicExpression.hs:456:5) varDict :: Map VarName a (bound at
src/SymbolicExpression.hs:453:32) paramDict :: Map ParamName a (bound at
src/SymbolicExpression.hs:453:21) tree :: ExprTree a (bound at src/
SymbolicExpression.hs:453:15) gradThroughAD :: ExprTree a -> Context a ->
Map VarName a (bound at src/SymbolicExpression.hs:453:1)


I was a bit surprised (I guess naively) by this, since to me, 'a' is
polymorphic with RealFloat as type constraint and that is "supposed" to
work with AD

So anyway, I tried to modify as follow:
{-# LANGUAGE Rank2Types    #-}

-- | helper for AD usage
exprTreeToListFun :: (RealFloat a)
                  => ExprTree a      -- ^ tree
                  -> Map ParamName a -- ^ paramDict
                  -> *(RealFloat b => *[b] -> b)      -- fun from var
values to their evaluation
exprTreeToListFun tree paramDict vals = res
  where
    res            = *realToFrac* $ evaluate tree (paramDict, varDict)
    varDict        = Map.fromList $ zip (getVarNames tree) *$ map
realToFrac *vals

This now typechecks and runs but going through AD returns me *all
derivatives as zero* as if treating the variables (that passed through
*realToFrac*) as constants, so not that much helpful either.

Any idea how this can be solved ?

Apologies if this is a question that already has an answer somewhere but poking
around SO <http://stackoverflow.com/search?q=%5Bhaskell%5D+ad+package> It
looks like I'm not the only one having similar issues and unfortunately
none of the answers I found really helped me, if anything it confirms the
same issue of null derivatives:
http://stackoverflow.com/questions/36878083/ad-type-unification-error-with-constrained-type-vector

Thanks again

-- 
Frederic Cogny
+33 7 83 12 61 69
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/haskell-cafe/attachments/20170331/7f65490a/attachment.html>


More information about the Haskell-Cafe mailing list