[Haskell-cafe] Suggestion on how to solve a type issue with Numeric.AD

Li-yao Xia lysxia at gmail.com
Sat Apr 1 20:10:14 UTC 2017


Hi Frederic,

For the purposes of the gradient computation, the values contained in 
the tree and the parameters are actually considered constants, hence the 
use of auto. You are differentiating only with respect to the variables, 
which are passed to grad separately.

Regards,
Li-yao


On 04/01/2017 01:51 PM, Frederic Cogny wrote:
> Thanks a lot. This worked perfectly.
> I saw the auto function but as the documentation says "embed a constant" I
> thought (as opposed to a variable) that it will not derive against it.
> Thanks again for the prompt and complete answer
>
> On Sat, Apr 1, 2017, 03:44 Li-yao Xia <lysxia at gmail.com> wrote:
>
>> Hi Frederic,
>>
>> As you have already noticed, it is best not to change the representation
>> of the numerals in a function you pass to grad, as this destroys the
>> metadata built up by the ad framework to compute derivatives. Revert
>> ExprTreeToListFun. Having derived Functor for ExprTree, convert your
>> tree and parameters using auto before applying exprTreeToListFun. This
>> leaves grad free to pick the right type.
>>
>>       gradList = AD.grad (exprTreeToListFun (fmap auto tree) (fmap auto
>> paramDict)) varVals
>>
>> Cheers,
>> Li-yao
>>
>>
>> On 03/31/2017 08:33 AM, Frederic Cogny wrote:
>>> Hello Café
>>>
>>> *Issue: *
>>> I'm having trouble using the AD package presumably due to a wrong use of
>>> types (Sorry I know it's a bit vague)
>>> Any help or pointer on how to solve this would be greatly appreciated.
>>> Thanks in advance
>>>
>>> *Background: *
>>> I've implemented a symbolic expression tree [full code here:
>>> https://pastebin.com/FDpSFuRM]
>>> -- | Expression Tree
>>> data ExprTree a = Const a                             -- ^ number
>>>                   | ParamNode ParamName                 -- ^ parameter
>>>                   | VarNode VarName                     -- ^ variable
>>>                   | UnaryNode MonoOp (ExprTree a)       -- ^ operator of
>>> arity 1
>>>                   | BinaryNode DualOp (ExprTree a) (ExprTree a) -- ^
>> operator
>>> of arity 2
>>>                   | CondNode (Cond a) (ExprTree a) (ExprTree a)     -- ^
>>> conditional node
>>>                   deriving (Eq,Show, Generic)
>>>
>>> An evaluation function on it
>>> -- |  evaluates an Expression Tree on its Context (Map ParamName a, Map
>>> VarName a)
>>> evaluate :: (Eq a, Ord a, Floating a) => ExprTree a -> Context a -> a
>>>
>>> And a few instances on it:
>>> instance Num a => Default (ExprTree a) where ...
>>> instance Num a => Num (ExprTree a) where ...
>>> instance Fractional a => Fractional (ExprTree a) where ...
>>> instance Floating a => Floating (ExprTree a) where ...
>>> instance (Arbitrary a) => Arbitrary (ExprTree a) where ...
>>>
>>> This allows me to easily create (using derivation rules) the
>> derivative(s)
>>> of such a tree with respect to its variable(s):
>>> diff :: (Eq a, Floating a) => ExprTree a -> VarName -> ExprTree a
>>> grad :: (Eq a, Floating a) => ExprTree a -> Map VarName (ExprTree a)
>>> hessian :: (Eq a, Floating a) => ExprTree a -> Map VarName (Map VarName
>>> (ExprTree a))
>>>
>>> So far, so good ...
>>>
>>> Now, to gain assurance in my implementation I wanted to check the
>>> derivatives against the Numeric.AD module
>>> so I create the two following
>>> -- | helper for AD usage
>>> exprTreeToListFun :: (RealFloat a)
>>>                     => ExprTree a      -- ^ tree
>>>                     -> Map ParamName a -- ^ paramDict
>>>                     -> ([a] -> a)      -- fun from var values to their
>>> evaluation
>>> exprTreeToListFun tree paramDict vals = res
>>>     where
>>>       res            = evaluate tree (paramDict, varDict)
>>>       varDict        = Map.fromList $ zip (getVarNames tree) vals
>>>
>>>
>>> gradThroughAD :: RealFloat a => ExprTree a -> Context a -> Map VarName a
>>> gradThroughAD tree (paramDict, varDict) = res
>>>     where
>>>       varNames = Map.keys varDict
>>>       varVals  = Map.elems varDict
>>>       gradList = AD.grad (exprTreeToListFun tree paramDict) varVals
>>>       res      = Map.fromList $ zip varNames gradList
>>>
>>>
>>> it unfortunately does not type check with message:
>>> • Couldn't match type ‘a’ with ‘Numeric.AD.Internal.Reverse.Reverse s a’
>> ‘a’
>>> is a rigid type variable bound by the type signature for: gradThroughAD
>> ::
>>> forall a. RealFloat a => ExprTree a -> Context a -> Map VarName a at src/
>>> SymbolicExpression.hs:452:18 Expected type: [Numeric.AD.Internal.Reverse.
>>> Reverse s a] -> Numeric.AD.Internal.Reverse.Reverse s a Actual type: [a]
>> ->
>>> a • In the first argument of ‘AD.grad’, namely ‘(exprTreeToListFun tree
>>> paramDict)’ In the expression: AD.grad (exprTreeToListFun tree paramDict)
>>> varVals In an equation for ‘gradList’: gradList = AD.grad
>> (exprTreeToListFun
>>> tree paramDict) varVals • Relevant bindings include gradList :: [a]
>> (bound
>>> at src/SymbolicExpression.hs:457:5) varVals :: [a] (bound at src/
>>> SymbolicExpression.hs:456:5) varDict :: Map VarName a (bound at
>>> src/SymbolicExpression.hs:453:32) paramDict :: Map ParamName a (bound at
>>> src/SymbolicExpression.hs:453:21) tree :: ExprTree a (bound at src/
>>> SymbolicExpression.hs:453:15) gradThroughAD :: ExprTree a -> Context a ->
>>> Map VarName a (bound at src/SymbolicExpression.hs:453:1)
>>>
>>>
>>> I was a bit surprised (I guess naively) by this, since to me, 'a' is
>>> polymorphic with RealFloat as type constraint and that is "supposed" to
>>> work with AD
>>>
>>> So anyway, I tried to modify as follow:
>>> {-# LANGUAGE Rank2Types    #-}
>>>
>>> -- | helper for AD usage
>>> exprTreeToListFun :: (RealFloat a)
>>>                     => ExprTree a      -- ^ tree
>>>                     -> Map ParamName a -- ^ paramDict
>>>                     -> *(RealFloat b => *[b] -> b)      -- fun from var
>>> values to their evaluation
>>> exprTreeToListFun tree paramDict vals = res
>>>     where
>>>       res            = *realToFrac* $ evaluate tree (paramDict, varDict)
>>>       varDict        = Map.fromList $ zip (getVarNames tree) *$ map
>>> realToFrac *vals
>>>
>>> This now typechecks and runs but going through AD returns me *all
>>> derivatives as zero* as if treating the variables (that passed through
>>> *realToFrac*) as constants, so not that much helpful either.
>>>
>>> Any idea how this can be solved ?
>>>
>>> Apologies if this is a question that already has an answer somewhere but
>> poking
>>> around SO <http://stackoverflow.com/search?q=%5Bhaskell%5D+ad+package>
>> It
>>> looks like I'm not the only one having similar issues and unfortunately
>>> none of the answers I found really helped me, if anything it confirms the
>>> same issue of null derivatives:
>>>
>> http://stackoverflow.com/questions/36878083/ad-type-unification-error-with-constrained-type-vector
>>> Thanks again
>>>
>>>
>>>
>>> _______________________________________________
>>> Haskell-Cafe mailing list
>>> To (un)subscribe, modify options or view archives go to:
>>> http://mail.haskell.org/cgi-bin/mailman/listinfo/haskell-cafe
>>> Only members subscribed via the mailman list are allowed to post.
>> --
> Frederic Cogny
> +33 7 83 12 61 69
>



More information about the Haskell-Cafe mailing list