[Haskell-cafe] Suggestion on how to solve a type issue with Numeric.AD

Frederic Cogny frederic.cogny at gmail.com
Sat Apr 1 17:51:28 UTC 2017


Thanks a lot. This worked perfectly.
I saw the auto function but as the documentation says "embed a constant" I
thought (as opposed to a variable) that it will not derive against it.
Thanks again for the prompt and complete answer

On Sat, Apr 1, 2017, 03:44 Li-yao Xia <lysxia at gmail.com> wrote:

> Hi Frederic,
>
> As you have already noticed, it is best not to change the representation
> of the numerals in a function you pass to grad, as this destroys the
> metadata built up by the ad framework to compute derivatives. Revert
> ExprTreeToListFun. Having derived Functor for ExprTree, convert your
> tree and parameters using auto before applying exprTreeToListFun. This
> leaves grad free to pick the right type.
>
>      gradList = AD.grad (exprTreeToListFun (fmap auto tree) (fmap auto
> paramDict)) varVals
>
> Cheers,
> Li-yao
>
>
> On 03/31/2017 08:33 AM, Frederic Cogny wrote:
> > Hello Café
> >
> > *Issue: *
> > I'm having trouble using the AD package presumably due to a wrong use of
> > types (Sorry I know it's a bit vague)
> > Any help or pointer on how to solve this would be greatly appreciated.
> > Thanks in advance
> >
> > *Background: *
> > I've implemented a symbolic expression tree [full code here:
> > https://pastebin.com/FDpSFuRM]
> > -- | Expression Tree
> > data ExprTree a = Const a                             -- ^ number
> >                  | ParamNode ParamName                 -- ^ parameter
> >                  | VarNode VarName                     -- ^ variable
> >                  | UnaryNode MonoOp (ExprTree a)       -- ^ operator of
> > arity 1
> >                  | BinaryNode DualOp (ExprTree a) (ExprTree a) -- ^
> operator
> > of arity 2
> >                  | CondNode (Cond a) (ExprTree a) (ExprTree a)     -- ^
> > conditional node
> >                  deriving (Eq,Show, Generic)
> >
> > An evaluation function on it
> > -- |  evaluates an Expression Tree on its Context (Map ParamName a, Map
> > VarName a)
> > evaluate :: (Eq a, Ord a, Floating a) => ExprTree a -> Context a -> a
> >
> > And a few instances on it:
> > instance Num a => Default (ExprTree a) where ...
> > instance Num a => Num (ExprTree a) where ...
> > instance Fractional a => Fractional (ExprTree a) where ...
> > instance Floating a => Floating (ExprTree a) where ...
> > instance (Arbitrary a) => Arbitrary (ExprTree a) where ...
> >
> > This allows me to easily create (using derivation rules) the
> derivative(s)
> > of such a tree with respect to its variable(s):
> > diff :: (Eq a, Floating a) => ExprTree a -> VarName -> ExprTree a
> > grad :: (Eq a, Floating a) => ExprTree a -> Map VarName (ExprTree a)
> > hessian :: (Eq a, Floating a) => ExprTree a -> Map VarName (Map VarName
> > (ExprTree a))
> >
> > So far, so good ...
> >
> > Now, to gain assurance in my implementation I wanted to check the
> > derivatives against the Numeric.AD module
> > so I create the two following
> > -- | helper for AD usage
> > exprTreeToListFun :: (RealFloat a)
> >                    => ExprTree a      -- ^ tree
> >                    -> Map ParamName a -- ^ paramDict
> >                    -> ([a] -> a)      -- fun from var values to their
> > evaluation
> > exprTreeToListFun tree paramDict vals = res
> >    where
> >      res            = evaluate tree (paramDict, varDict)
> >      varDict        = Map.fromList $ zip (getVarNames tree) vals
> >
> >
> > gradThroughAD :: RealFloat a => ExprTree a -> Context a -> Map VarName a
> > gradThroughAD tree (paramDict, varDict) = res
> >    where
> >      varNames = Map.keys varDict
> >      varVals  = Map.elems varDict
> >      gradList = AD.grad (exprTreeToListFun tree paramDict) varVals
> >      res      = Map.fromList $ zip varNames gradList
> >
> >
> > it unfortunately does not type check with message:
> > • Couldn't match type ‘a’ with ‘Numeric.AD.Internal.Reverse.Reverse s a’
> ‘a’
> > is a rigid type variable bound by the type signature for: gradThroughAD
> ::
> > forall a. RealFloat a => ExprTree a -> Context a -> Map VarName a at src/
> > SymbolicExpression.hs:452:18 Expected type: [Numeric.AD.Internal.Reverse.
> > Reverse s a] -> Numeric.AD.Internal.Reverse.Reverse s a Actual type: [a]
> ->
> > a • In the first argument of ‘AD.grad’, namely ‘(exprTreeToListFun tree
> > paramDict)’ In the expression: AD.grad (exprTreeToListFun tree paramDict)
> > varVals In an equation for ‘gradList’: gradList = AD.grad
> (exprTreeToListFun
> > tree paramDict) varVals • Relevant bindings include gradList :: [a]
> (bound
> > at src/SymbolicExpression.hs:457:5) varVals :: [a] (bound at src/
> > SymbolicExpression.hs:456:5) varDict :: Map VarName a (bound at
> > src/SymbolicExpression.hs:453:32) paramDict :: Map ParamName a (bound at
> > src/SymbolicExpression.hs:453:21) tree :: ExprTree a (bound at src/
> > SymbolicExpression.hs:453:15) gradThroughAD :: ExprTree a -> Context a ->
> > Map VarName a (bound at src/SymbolicExpression.hs:453:1)
> >
> >
> > I was a bit surprised (I guess naively) by this, since to me, 'a' is
> > polymorphic with RealFloat as type constraint and that is "supposed" to
> > work with AD
> >
> > So anyway, I tried to modify as follow:
> > {-# LANGUAGE Rank2Types    #-}
> >
> > -- | helper for AD usage
> > exprTreeToListFun :: (RealFloat a)
> >                    => ExprTree a      -- ^ tree
> >                    -> Map ParamName a -- ^ paramDict
> >                    -> *(RealFloat b => *[b] -> b)      -- fun from var
> > values to their evaluation
> > exprTreeToListFun tree paramDict vals = res
> >    where
> >      res            = *realToFrac* $ evaluate tree (paramDict, varDict)
> >      varDict        = Map.fromList $ zip (getVarNames tree) *$ map
> > realToFrac *vals
> >
> > This now typechecks and runs but going through AD returns me *all
> > derivatives as zero* as if treating the variables (that passed through
> > *realToFrac*) as constants, so not that much helpful either.
> >
> > Any idea how this can be solved ?
> >
> > Apologies if this is a question that already has an answer somewhere but
> poking
> > around SO <http://stackoverflow.com/search?q=%5Bhaskell%5D+ad+package>
> It
> > looks like I'm not the only one having similar issues and unfortunately
> > none of the answers I found really helped me, if anything it confirms the
> > same issue of null derivatives:
> >
> http://stackoverflow.com/questions/36878083/ad-type-unification-error-with-constrained-type-vector
> >
> > Thanks again
> >
> >
> >
> > _______________________________________________
> > Haskell-Cafe mailing list
> > To (un)subscribe, modify options or view archives go to:
> > http://mail.haskell.org/cgi-bin/mailman/listinfo/haskell-cafe
> > Only members subscribed via the mailman list are allowed to post.
>
> --
Frederic Cogny
+33 7 83 12 61 69
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/haskell-cafe/attachments/20170401/8f3cfe2a/attachment.html>


More information about the Haskell-Cafe mailing list