<div dir="ltr">I am trying to understand how can I use <code>Numeric.AD</code> (automatic differentiation) in Haskell.<div class="gmail_quote"><div dir="ltr">
<p>I defined a simple matrix type and a scalar function taking an array
and two matrices as arguments. I want to use AD to get the gradient of
the scoring function with respect to both matrices, but I'm running into
compilation problems. Here is the code:</p><p>-------------------------------</p><pre><code><span>{-# LANGUAGE DeriveTraversable, DeriveFunctor, DeriveFoldable #-}</span><span>
</span><span>import</span><span> Numeric.AD.Mode.Reverse as R
</span><span>import</span><span> Data.Traversable as T
</span><span>import</span><span> Data.Foldable as F
</span><span>--- Non-linear function on "vectors"</span><span>
logistic x </span><span>=</span><span> </span><span>1.0</span><span> </span><span>/</span><span> </span><span>(</span><span>1.0</span><span> </span><span>+</span><span> exp</span><span>(-</span><span>x</span><span>)</span><span> </span><span>)</span><span>
phi v </span><span>=</span><span> map logistic v
phi' </span><span>(</span><span>x</span><span>:</span><span>xs</span><span>)</span><span> </span><span>=</span><span> x </span><span>:</span><span> </span><span>(</span><span>phi xs</span><span>)</span><span>
</span><span>--- dot product</span><span>
dot u v </span><span>=</span><span> foldr </span><span>(+)</span><span> </span><span>0</span><span> </span><span>$</span><span> zipWith </span><span>(*)</span><span> u v
</span><span>--- simple matrix type</span><span>
</span><span>data</span><span> Matrix a </span><span>=</span><span> M </span><span>[[</span><span>a</span><span>]]</span><span> </span><span>deriving</span><span> </span><span>(</span><span>Eq</span><span>,</span><span>Show</span><span>,</span><span>Functor</span><span>,</span><span>F.Foldable</span><span>,</span><span>T.Traversable</span><span>)</span><span>
</span><span>--- action of a matrix on a vector</span><span>
mv </span><span>_</span><span> </span><span>[]</span><span> </span><span>=</span><span> </span><span>[]</span><span>
mv </span><span>(</span><span>M </span><span>[])</span><span> </span><span>_</span><span> </span><span>=</span><span> </span><span>[]</span><span>
mv </span><span>(</span><span> M m </span><span>)</span><span> v </span><span>=</span><span> </span><span>(</span><span> dot </span><span>(</span><span>head m</span><span>)</span><span> v </span><span>)</span><span> </span><span>:</span><span> </span><span>(</span><span>mv </span><span>(</span><span>M </span><span>(</span><span>tail m</span><span>))</span><span> v </span><span>)</span><span>
</span><span>--- two matrices</span><span>
mbW1 </span><span>=</span><span> M </span><span>$</span><span> </span><span>[[</span><span>1</span><span>,</span><span>0</span><span>,</span><span>0</span><span>],[-</span><span>1</span><span>,</span><span>5</span><span>,</span><span>1</span><span>],[</span><span>1</span><span>,</span><span>2</span><span>,-</span><span>3</span><span>]]</span><span>
mbW2 </span><span>=</span><span> M </span><span>$</span><span> </span><span>[[</span><span>0</span><span>,</span><span>0</span><span>,</span><span>0</span><span>],[</span><span>1</span><span>,</span><span>3</span><span>,-</span><span>1</span><span>],[-</span><span>2</span><span>,</span><span>4</span><span>,</span><span>6</span><span>]]</span><span>
</span><span>--- two different scoring functions</span><span>
sc1 v m </span><span>=</span><span> foldr </span><span>(+)</span><span> </span><span>0</span><span> </span><span>$</span><span> </span><span>(</span><span>phi' </span><span>.</span><span> </span><span>(</span><span>mv m</span><span>)</span><span> </span><span>)</span><span> v
sc2 </span><span>::</span><span> Floating a </span><span>=></span><span> </span><span>[</span><span>a</span><span>]</span><span> </span><span>-></span><span> </span><span>[</span><span>Matrix a</span><span>]</span><span> </span><span>-></span><span> a
sc2 v </span><span>[</span><span>m1</span><span>,</span><span> m2</span><span>]</span><span> </span><span>=</span><span> foldr </span><span>(+)</span><span> </span><span>0</span><span> </span><span>$</span><span> </span><span>(</span><span>phi' </span><span>.</span><span> </span><span>(</span><span>mv m2</span><span>)</span><span> </span><span>.</span><span> phi' </span><span>.</span><span> </span><span>(</span><span>mv m1</span><span>)</span><span> </span><span>)</span><span> v
strToInt </span><span>=</span><span> read </span><span>::</span><span> String </span><span>-></span><span> Double
strLToIntL </span><span>=</span><span> map strToInt
</span><span>--- testing</span><span>
main </span><span>=</span><span> </span><span>do</span><span>
putStrLn </span><span>$</span><span> </span><span>"mbW1:"</span><span> </span><span>++</span><span> </span><span>(</span><span>show mbW1</span><span>)</span><span>
putStrLn </span><span>$</span><span> </span><span>"mbW2:"</span><span> </span><span>++</span><span> </span><span>(</span><span>show mbW2</span><span>)</span><span>
rawInput </span><span><-</span><span> readFile </span><span>"/dev/stdin"</span><span>
</span><span>let</span><span> xin</span><span>=</span><span> strLToIntL </span><span>$</span><span> lines rawInput
putStrLn </span><span>"sc xin mbW1"</span><span>
print </span><span>$</span><span> sc1 xin mbW1 </span><span>--- ok. = </span><span>
putStrLn </span><span>"grad (sc1 xin) mbW1"</span><span>
print </span><span>$</span><span> grad </span><span>(</span><span> sc1 xin</span><span>)</span><span> mbW1 </span><span>-- yields an error: expects xin [Reverse s Double] instead of [Double]</span><span>
putStrLn </span><span>"grad (sc1 [3,5,7]) mbW1"</span><span>
print </span><span>$</span><span> grad </span><span>(</span><span> sc1 </span><span>[</span><span>3</span><span>,</span><span>5</span><span>,</span><span>7</span><span>])</span><span> mbW1 </span><span>--- ok. =</span><span>
putStrLn </span><span>"sc2 xin [mbW1,mbW2]"</span><span>
print </span><span>$</span><span> sc2 xin </span><span>[</span><span>mbW1</span><span>,</span><span> mbW2</span><span>]</span><span>
putStrLn </span><span>"grad (sc2 [3,5,7) [mbW1,mbW2]"</span><span>
print </span><span>$</span><span> grad </span><span>(</span><span> sc2 </span><span>[</span><span>3</span><span>,</span><span>5</span><span>,</span><span>7</span><span>])</span><span> </span><span>[</span><span>mbW1</span><span>,</span><span> mbW2</span><span>]</span><span> </span><span>--- Error: see text</span></code></pre><p>--------------------------------</p><p>The last line (grad on sc2) gives the following error:</p><p>---------------------------------</p><pre><code><span>Couldn't match </span><span>type</span><span> </span><span>‘</span><span>Reverse s </span><span>(</span><span>Matrix Double</span><span>)’</span><span>
with </span><span>‘</span><span>Matrix </span><span>(</span><span>Reverse s </span><span>(</span><span>Matrix Double</span><span>))’</span><span>
Expected </span><span>type</span><span>:</span><span> </span><span>[</span><span>Reverse s </span><span>(</span><span>Matrix Double</span><span>)]</span><span>
</span><span>-></span><span> Reverse s </span><span>(</span><span>Matrix Double</span><span>)</span><span>
Actual </span><span>type</span><span>:</span><span> </span><span>[</span><span>Matrix </span><span>(</span><span>Reverse s </span><span>(</span><span>Matrix Double</span><span>))]</span><span>
</span><span>-></span><span> Reverse s </span><span>(</span><span>Matrix Double</span><span>)</span><span>
In the first argument </span><span>of</span><span> </span><span>‘</span><span>grad</span><span>’,</span><span> namely </span><span>‘(</span><span>sc2 </span><span>[</span><span>3</span><span>,</span><span> </span><span>5</span><span>,</span><span> </span><span>7</span><span>])’</span><span>
In the second argument </span><span>of</span><span> </span><span>‘($)’,</span><span> namely
</span><span>‘</span><span>grad </span><span>(</span><span>sc2 </span><span>[</span><span>3</span><span>,</span><span> </span><span>5</span><span>,</span><span> </span><span>7</span><span>])</span><span> </span><span>[</span><span>mbW1</span><span>,</span><span> mbW2</span><span>]’</span></code></pre><p>---------------------------------<br></p><p>I don't understand where the "Matrix of Matrix" in the actual type seen comes from. I'm feeding the <code>grad</code> with a curried version of sc2, making it a function on a list of Matrix.</p>
<p>Commenting out the two offending lines runs without problem, i.e.,
the first gradient works and is correctly calculated (I'm feeding
[1,2,3] as input to the program):</p>
<pre><code><span><br>-------------------<br></span></code><br><code><span><code><span>mbW1</span><span>:</span><span>M </span><span>[[</span><span>1.0</span><span>,</span><span>0.0</span><span>,</span><span>0.0</span><span>],[-</span><span>1.0</span><span>,</span><span>5.0</span><span>,</span><span>1.0</span><span>],[</span><span>1.0</span><span>,</span><span>2.0</span><span>,-</span><span>3.0</span><span>]]</span><span>
mbW2</span><span>:</span><span>M </span><span>[[</span><span>0.0</span><span>,</span><span>0.0</span><span>,</span><span>0.0</span><span>],[</span><span>1.0</span><span>,</span><span>3.0</span><span>,-</span><span>1.0</span><span>],[-</span><span>2.0</span><span>,</span><span>4.0</span><span>,</span><span>6.0</span><span>]]</span><span>
sc1 xin mbW1
</span><span>1</span><span>
</span><span>2</span><span>
</span><span>3</span><span>
</span><span>2.0179800657874893</span><span>
grad </span><span>(</span><span>sc1 </span><span>[</span><span>3</span><span>,</span><span>5</span><span>,</span><span>7</span><span>])</span><span> mbW1
M </span><span>[[</span><span>3.0</span><span>,</span><span>5.0</span><span>,</span><span>7.0</span><span>],[</span><span>7.630996942126885e-13</span><span>,</span><span>1.2718328236878141e-12</span><span>,</span><span>1.7805659531629398e-12</span><span>],[</span><span>1.0057130122694228e-3</span><span>,</span><span>1.6761883537823711e-3</span><span>,</span><span>2.3466636952953197e-3</span><span>]]</span><span>
sc2 xin </span><span>[</span><span>mbW1</span><span>,</span><span>mbW2</span><span>]</span><span>
</span><span>1.8733609463863194<br></span></code>-------------------<br><br></span></code>Both errors are an issue. I want to take the gradient of any such <code>sc2</code>
scoring function, depending on an array of matrices, evaluated at any
given "point" xin. Clearly, I'm not yet understanding the AD library
well enough. Any help would be appreciated. <br><br></pre><pre><br></pre><p><br></p><div><div><div dir="ltr"><div><div dir="ltr">--<br>Public key ID: E8FE60D7 <br>Public key server: see, e.g., hkp://<a href="http://keys.gnupg.net" target="_blank">keys.gnupg.net</a> <br></div></div></div></div></div>
</div>
</div><br></div>