[Haskell-cafe] Suggestion on how to solve a type issue with Numeric.AD
Li-yao Xia
lysxia at gmail.com
Sat Apr 1 01:44:29 UTC 2017
Hi Frederic,
As you have already noticed, it is best not to change the representation
of the numerals in a function you pass to grad, as this destroys the
metadata built up by the ad framework to compute derivatives. Revert
ExprTreeToListFun. Having derived Functor for ExprTree, convert your
tree and parameters using auto before applying exprTreeToListFun. This
leaves grad free to pick the right type.
gradList = AD.grad (exprTreeToListFun (fmap auto tree) (fmap auto
paramDict)) varVals
Cheers,
Li-yao
On 03/31/2017 08:33 AM, Frederic Cogny wrote:
> Hello Café
>
> *Issue: *
> I'm having trouble using the AD package presumably due to a wrong use of
> types (Sorry I know it's a bit vague)
> Any help or pointer on how to solve this would be greatly appreciated.
> Thanks in advance
>
> *Background: *
> I've implemented a symbolic expression tree [full code here:
> https://pastebin.com/FDpSFuRM]
> -- | Expression Tree
> data ExprTree a = Const a -- ^ number
> | ParamNode ParamName -- ^ parameter
> | VarNode VarName -- ^ variable
> | UnaryNode MonoOp (ExprTree a) -- ^ operator of
> arity 1
> | BinaryNode DualOp (ExprTree a) (ExprTree a) -- ^ operator
> of arity 2
> | CondNode (Cond a) (ExprTree a) (ExprTree a) -- ^
> conditional node
> deriving (Eq,Show, Generic)
>
> An evaluation function on it
> -- | evaluates an Expression Tree on its Context (Map ParamName a, Map
> VarName a)
> evaluate :: (Eq a, Ord a, Floating a) => ExprTree a -> Context a -> a
>
> And a few instances on it:
> instance Num a => Default (ExprTree a) where ...
> instance Num a => Num (ExprTree a) where ...
> instance Fractional a => Fractional (ExprTree a) where ...
> instance Floating a => Floating (ExprTree a) where ...
> instance (Arbitrary a) => Arbitrary (ExprTree a) where ...
>
> This allows me to easily create (using derivation rules) the derivative(s)
> of such a tree with respect to its variable(s):
> diff :: (Eq a, Floating a) => ExprTree a -> VarName -> ExprTree a
> grad :: (Eq a, Floating a) => ExprTree a -> Map VarName (ExprTree a)
> hessian :: (Eq a, Floating a) => ExprTree a -> Map VarName (Map VarName
> (ExprTree a))
>
> So far, so good ...
>
> Now, to gain assurance in my implementation I wanted to check the
> derivatives against the Numeric.AD module
> so I create the two following
> -- | helper for AD usage
> exprTreeToListFun :: (RealFloat a)
> => ExprTree a -- ^ tree
> -> Map ParamName a -- ^ paramDict
> -> ([a] -> a) -- fun from var values to their
> evaluation
> exprTreeToListFun tree paramDict vals = res
> where
> res = evaluate tree (paramDict, varDict)
> varDict = Map.fromList $ zip (getVarNames tree) vals
>
>
> gradThroughAD :: RealFloat a => ExprTree a -> Context a -> Map VarName a
> gradThroughAD tree (paramDict, varDict) = res
> where
> varNames = Map.keys varDict
> varVals = Map.elems varDict
> gradList = AD.grad (exprTreeToListFun tree paramDict) varVals
> res = Map.fromList $ zip varNames gradList
>
>
> it unfortunately does not type check with message:
> • Couldn't match type ‘a’ with ‘Numeric.AD.Internal.Reverse.Reverse s a’ ‘a’
> is a rigid type variable bound by the type signature for: gradThroughAD ::
> forall a. RealFloat a => ExprTree a -> Context a -> Map VarName a at src/
> SymbolicExpression.hs:452:18 Expected type: [Numeric.AD.Internal.Reverse.
> Reverse s a] -> Numeric.AD.Internal.Reverse.Reverse s a Actual type: [a] ->
> a • In the first argument of ‘AD.grad’, namely ‘(exprTreeToListFun tree
> paramDict)’ In the expression: AD.grad (exprTreeToListFun tree paramDict)
> varVals In an equation for ‘gradList’: gradList = AD.grad (exprTreeToListFun
> tree paramDict) varVals • Relevant bindings include gradList :: [a] (bound
> at src/SymbolicExpression.hs:457:5) varVals :: [a] (bound at src/
> SymbolicExpression.hs:456:5) varDict :: Map VarName a (bound at
> src/SymbolicExpression.hs:453:32) paramDict :: Map ParamName a (bound at
> src/SymbolicExpression.hs:453:21) tree :: ExprTree a (bound at src/
> SymbolicExpression.hs:453:15) gradThroughAD :: ExprTree a -> Context a ->
> Map VarName a (bound at src/SymbolicExpression.hs:453:1)
>
>
> I was a bit surprised (I guess naively) by this, since to me, 'a' is
> polymorphic with RealFloat as type constraint and that is "supposed" to
> work with AD
>
> So anyway, I tried to modify as follow:
> {-# LANGUAGE Rank2Types #-}
>
> -- | helper for AD usage
> exprTreeToListFun :: (RealFloat a)
> => ExprTree a -- ^ tree
> -> Map ParamName a -- ^ paramDict
> -> *(RealFloat b => *[b] -> b) -- fun from var
> values to their evaluation
> exprTreeToListFun tree paramDict vals = res
> where
> res = *realToFrac* $ evaluate tree (paramDict, varDict)
> varDict = Map.fromList $ zip (getVarNames tree) *$ map
> realToFrac *vals
>
> This now typechecks and runs but going through AD returns me *all
> derivatives as zero* as if treating the variables (that passed through
> *realToFrac*) as constants, so not that much helpful either.
>
> Any idea how this can be solved ?
>
> Apologies if this is a question that already has an answer somewhere but poking
> around SO <http://stackoverflow.com/search?q=%5Bhaskell%5D+ad+package> It
> looks like I'm not the only one having similar issues and unfortunately
> none of the answers I found really helped me, if anything it confirms the
> same issue of null derivatives:
> http://stackoverflow.com/questions/36878083/ad-type-unification-error-with-constrained-type-vector
>
> Thanks again
>
>
>
> _______________________________________________
> Haskell-Cafe mailing list
> To (un)subscribe, modify options or view archives go to:
> http://mail.haskell.org/cgi-bin/mailman/listinfo/haskell-cafe
> Only members subscribed via the mailman list are allowed to post.
More information about the Haskell-Cafe
mailing list