[Haskell-cafe] Suggestion on how to solve a type issue with Numeric.AD
Frederic Cogny
frederic.cogny at gmail.com
Fri Mar 31 12:33:16 UTC 2017
Hello Café
*Issue: *
I'm having trouble using the AD package presumably due to a wrong use of
types (Sorry I know it's a bit vague)
Any help or pointer on how to solve this would be greatly appreciated.
Thanks in advance
*Background: *
I've implemented a symbolic expression tree [full code here:
https://pastebin.com/FDpSFuRM]
-- | Expression Tree
data ExprTree a = Const a -- ^ number
| ParamNode ParamName -- ^ parameter
| VarNode VarName -- ^ variable
| UnaryNode MonoOp (ExprTree a) -- ^ operator of
arity 1
| BinaryNode DualOp (ExprTree a) (ExprTree a) -- ^ operator
of arity 2
| CondNode (Cond a) (ExprTree a) (ExprTree a) -- ^
conditional node
deriving (Eq,Show, Generic)
An evaluation function on it
-- | evaluates an Expression Tree on its Context (Map ParamName a, Map
VarName a)
evaluate :: (Eq a, Ord a, Floating a) => ExprTree a -> Context a -> a
And a few instances on it:
instance Num a => Default (ExprTree a) where ...
instance Num a => Num (ExprTree a) where ...
instance Fractional a => Fractional (ExprTree a) where ...
instance Floating a => Floating (ExprTree a) where ...
instance (Arbitrary a) => Arbitrary (ExprTree a) where ...
This allows me to easily create (using derivation rules) the derivative(s)
of such a tree with respect to its variable(s):
diff :: (Eq a, Floating a) => ExprTree a -> VarName -> ExprTree a
grad :: (Eq a, Floating a) => ExprTree a -> Map VarName (ExprTree a)
hessian :: (Eq a, Floating a) => ExprTree a -> Map VarName (Map VarName
(ExprTree a))
So far, so good ...
Now, to gain assurance in my implementation I wanted to check the
derivatives against the Numeric.AD module
so I create the two following
-- | helper for AD usage
exprTreeToListFun :: (RealFloat a)
=> ExprTree a -- ^ tree
-> Map ParamName a -- ^ paramDict
-> ([a] -> a) -- fun from var values to their
evaluation
exprTreeToListFun tree paramDict vals = res
where
res = evaluate tree (paramDict, varDict)
varDict = Map.fromList $ zip (getVarNames tree) vals
gradThroughAD :: RealFloat a => ExprTree a -> Context a -> Map VarName a
gradThroughAD tree (paramDict, varDict) = res
where
varNames = Map.keys varDict
varVals = Map.elems varDict
gradList = AD.grad (exprTreeToListFun tree paramDict) varVals
res = Map.fromList $ zip varNames gradList
it unfortunately does not type check with message:
• Couldn't match type ‘a’ with ‘Numeric.AD.Internal.Reverse.Reverse s a’ ‘a’
is a rigid type variable bound by the type signature for: gradThroughAD ::
forall a. RealFloat a => ExprTree a -> Context a -> Map VarName a at src/
SymbolicExpression.hs:452:18 Expected type: [Numeric.AD.Internal.Reverse.
Reverse s a] -> Numeric.AD.Internal.Reverse.Reverse s a Actual type: [a] ->
a • In the first argument of ‘AD.grad’, namely ‘(exprTreeToListFun tree
paramDict)’ In the expression: AD.grad (exprTreeToListFun tree paramDict)
varVals In an equation for ‘gradList’: gradList = AD.grad (exprTreeToListFun
tree paramDict) varVals • Relevant bindings include gradList :: [a] (bound
at src/SymbolicExpression.hs:457:5) varVals :: [a] (bound at src/
SymbolicExpression.hs:456:5) varDict :: Map VarName a (bound at
src/SymbolicExpression.hs:453:32) paramDict :: Map ParamName a (bound at
src/SymbolicExpression.hs:453:21) tree :: ExprTree a (bound at src/
SymbolicExpression.hs:453:15) gradThroughAD :: ExprTree a -> Context a ->
Map VarName a (bound at src/SymbolicExpression.hs:453:1)
I was a bit surprised (I guess naively) by this, since to me, 'a' is
polymorphic with RealFloat as type constraint and that is "supposed" to
work with AD
So anyway, I tried to modify as follow:
{-# LANGUAGE Rank2Types #-}
-- | helper for AD usage
exprTreeToListFun :: (RealFloat a)
=> ExprTree a -- ^ tree
-> Map ParamName a -- ^ paramDict
-> *(RealFloat b => *[b] -> b) -- fun from var
values to their evaluation
exprTreeToListFun tree paramDict vals = res
where
res = *realToFrac* $ evaluate tree (paramDict, varDict)
varDict = Map.fromList $ zip (getVarNames tree) *$ map
realToFrac *vals
This now typechecks and runs but going through AD returns me *all
derivatives as zero* as if treating the variables (that passed through
*realToFrac*) as constants, so not that much helpful either.
Any idea how this can be solved ?
Apologies if this is a question that already has an answer somewhere but poking
around SO <http://stackoverflow.com/search?q=%5Bhaskell%5D+ad+package> It
looks like I'm not the only one having similar issues and unfortunately
none of the answers I found really helped me, if anything it confirms the
same issue of null derivatives:
http://stackoverflow.com/questions/36878083/ad-type-unification-error-with-constrained-type-vector
Thanks again
--
Frederic Cogny
+33 7 83 12 61 69
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/haskell-cafe/attachments/20170331/7f65490a/attachment.html>
More information about the Haskell-Cafe
mailing list