<div dir="ltr"><p>I am trying to understand how can I use <code>Numeric.AD</code> (automatic differentiation) in Haskell.</p>

<p>I defined a simple matrix type and a scalar function taking an array 
and two matrices as arguments. I want to use AD to get the gradient of 
the scoring function with respect to both matrices, but I'm running into
 compilation problems. Here is the code:</p><p>-------------------------------</p><pre style="" class=""><code><span class="">{-# LANGUAGE DeriveTraversable, DeriveFunctor, DeriveFoldable #-}</span><span class="">
</span><span class="">import</span><span class=""> Numeric.AD.Mode.Reverse as R
</span><span class="">import</span><span class=""> Data.Traversable as T
</span><span class="">import</span><span class=""> Data.Foldable as F

</span><span class="">--- Non-linear function on "vectors"</span><span class="">
logistic x </span><span class="">=</span><span class=""> </span><span class="">1.0</span><span class=""> </span><span class="">/</span><span class=""> </span><span class="">(</span><span class="">1.0</span><span class=""> </span><span class="">+</span><span class=""> exp</span><span class="">(-</span><span class="">x</span><span class="">)</span><span class=""> </span><span class="">)</span><span class="">
phi v </span><span class="">=</span><span class=""> map logistic v
phi' </span><span class="">(</span><span class="">x</span><span class="">:</span><span class="">xs</span><span class="">)</span><span class=""> </span><span class="">=</span><span class=""> x </span><span class="">:</span><span class=""> </span><span class="">(</span><span class="">phi xs</span><span class="">)</span><span class="">

</span><span class="">--- dot product</span><span class="">
dot u v </span><span class="">=</span><span class=""> foldr </span><span class="">(+)</span><span class=""> </span><span class="">0</span><span class=""> </span><span class="">$</span><span class=""> zipWith </span><span class="">(*)</span><span class=""> u v

</span><span class="">--- simple matrix type</span><span class="">
</span><span class="">data</span><span class=""> Matrix a </span><span class="">=</span><span class=""> M </span><span class="">[[</span><span class="">a</span><span class="">]]</span><span class=""> </span><span class="">deriving</span><span class=""> </span><span class="">(</span><span class="">Eq</span><span class="">,</span><span class="">Show</span><span class="">,</span><span class="">Functor</span><span class="">,</span><span class="">F.Foldable</span><span class="">,</span><span class="">T.Traversable</span><span class="">)</span><span class="">

</span><span class="">--- action of a matrix on a vector</span><span class="">
mv </span><span class="">_</span><span class=""> </span><span class="">[]</span><span class=""> </span><span class="">=</span><span class=""> </span><span class="">[]</span><span class="">
mv </span><span class="">(</span><span class="">M </span><span class="">[])</span><span class=""> </span><span class="">_</span><span class=""> </span><span class="">=</span><span class=""> </span><span class="">[]</span><span class="">
mv </span><span class="">(</span><span class=""> M m </span><span class="">)</span><span class=""> v </span><span class="">=</span><span class=""> </span><span class="">(</span><span class=""> dot </span><span class="">(</span><span class="">head m</span><span class="">)</span><span class="">  v </span><span class="">)</span><span class=""> </span><span class="">:</span><span class="">  </span><span class="">(</span><span class="">mv </span><span class="">(</span><span class="">M </span><span class="">(</span><span class="">tail m</span><span class="">))</span><span class=""> v </span><span class="">)</span><span class="">

</span><span class="">--- two matrices</span><span class="">
mbW1 </span><span class="">=</span><span class=""> M </span><span class="">$</span><span class=""> </span><span class="">[[</span><span class="">1</span><span class="">,</span><span class="">0</span><span class="">,</span><span class="">0</span><span class="">],[-</span><span class="">1</span><span class="">,</span><span class="">5</span><span class="">,</span><span class="">1</span><span class="">],[</span><span class="">1</span><span class="">,</span><span class="">2</span><span class="">,-</span><span class="">3</span><span class="">]]</span><span class="">
mbW2 </span><span class="">=</span><span class=""> M </span><span class="">$</span><span class=""> </span><span class="">[[</span><span class="">0</span><span class="">,</span><span class="">0</span><span class="">,</span><span class="">0</span><span class="">],[</span><span class="">1</span><span class="">,</span><span class="">3</span><span class="">,-</span><span class="">1</span><span class="">],[-</span><span class="">2</span><span class="">,</span><span class="">4</span><span class="">,</span><span class="">6</span><span class="">]]</span><span class="">

</span><span class="">--- two different scoring functions</span><span class="">
sc1 v m </span><span class="">=</span><span class=""> foldr </span><span class="">(+)</span><span class=""> </span><span class="">0</span><span class=""> </span><span class="">$</span><span class=""> </span><span class="">(</span><span class="">phi' </span><span class="">.</span><span class=""> </span><span class="">(</span><span class="">mv m</span><span class="">)</span><span class=""> </span><span class="">)</span><span class="">  v  

sc2 </span><span class="">::</span><span class=""> Floating a </span><span class="">=></span><span class=""> </span><span class="">[</span><span class="">a</span><span class="">]</span><span class=""> </span><span class="">-></span><span class=""> </span><span class="">[</span><span class="">Matrix a</span><span class="">]</span><span class=""> </span><span class="">-></span><span class=""> a
sc2 v </span><span class="">[</span><span class="">m1</span><span class="">,</span><span class=""> m2</span><span class="">]</span><span class=""> </span><span class="">=</span><span class=""> foldr </span><span class="">(+)</span><span class=""> </span><span class="">0</span><span class=""> </span><span class="">$</span><span class=""> </span><span class="">(</span><span class="">phi' </span><span class="">.</span><span class=""> </span><span class="">(</span><span class="">mv m2</span><span class="">)</span><span class=""> </span><span class="">.</span><span class=""> phi' </span><span class="">.</span><span class=""> </span><span class="">(</span><span class="">mv m1</span><span class="">)</span><span class=""> </span><span class="">)</span><span class=""> v

strToInt </span><span class="">=</span><span class=""> read </span><span class="">::</span><span class=""> String </span><span class="">-></span><span class=""> Double
strLToIntL </span><span class="">=</span><span class=""> map strToInt
</span><span class="">--- testing</span><span class="">
main </span><span class="">=</span><span class=""> </span><span class="">do</span><span class="">
        putStrLn </span><span class="">$</span><span class=""> </span><span class="">"mbW1:"</span><span class=""> </span><span class="">++</span><span class=""> </span><span class="">(</span><span class="">show mbW1</span><span class="">)</span><span class="">
        putStrLn </span><span class="">$</span><span class=""> </span><span class="">"mbW2:"</span><span class=""> </span><span class="">++</span><span class=""> </span><span class="">(</span><span class="">show mbW2</span><span class="">)</span><span class="">
        rawInput </span><span class=""><-</span><span class="">  readFile </span><span class="">"/dev/stdin"</span><span class="">
        </span><span class="">let</span><span class=""> xin</span><span class="">=</span><span class=""> strLToIntL </span><span class="">$</span><span class=""> lines rawInput
        putStrLn </span><span class="">"sc xin mbW1"</span><span class="">
        print </span><span class="">$</span><span class=""> sc1 xin mbW1  </span><span class="">--- ok. = </span><span class="">
        putStrLn </span><span class="">"grad (sc1 xin) mbW1"</span><span class="">
        print </span><span class="">$</span><span class=""> grad </span><span class="">(</span><span class=""> sc1 xin</span><span class="">)</span><span class=""> mbW1   </span><span class="">-- yields an error: expects xin [Reverse s Double] instead of [Double]</span><span class="">
        putStrLn </span><span class="">"grad (sc1 [3,5,7]) mbW1"</span><span class="">
        print </span><span class="">$</span><span class=""> grad </span><span class="">(</span><span class=""> sc1 </span><span class="">[</span><span class="">3</span><span class="">,</span><span class="">5</span><span class="">,</span><span class="">7</span><span class="">])</span><span class=""> mbW1   </span><span class="">--- ok. =</span><span class="">
        putStrLn </span><span class="">"sc2 xin [mbW1,mbW2]"</span><span class="">
        print </span><span class="">$</span><span class=""> sc2 xin </span><span class="">[</span><span class="">mbW1</span><span class="">,</span><span class=""> mbW2</span><span class="">]</span><span class="">
        putStrLn </span><span class="">"grad (sc2 [3,5,7) [mbW1,mbW2]"</span><span class="">
        print </span><span class="">$</span><span class=""> grad </span><span class="">(</span><span class=""> sc2 </span><span class="">[</span><span class="">3</span><span class="">,</span><span class="">5</span><span class="">,</span><span class="">7</span><span class="">])</span><span class=""> </span><span class="">[</span><span class="">mbW1</span><span class="">,</span><span class=""> mbW2</span><span class="">]</span><span class="">  </span><span class="">--- Error: see text</span></code></pre><p>--------------------------------</p><p>The last line (grad on sc2) gives the following error:</p><p>---------------------------------</p><pre style="" class=""><code><span class="">Couldn't match </span><span class="">type</span><span class=""> </span><span class="">‘</span><span class="">Reverse s </span><span class="">(</span><span class="">Matrix Double</span><span class="">)’</span><span class="">
               with </span><span class="">‘</span><span class="">Matrix </span><span class="">(</span><span class="">Reverse s </span><span class="">(</span><span class="">Matrix Double</span><span class="">))’</span><span class="">
Expected </span><span class="">type</span><span class="">:</span><span class=""> </span><span class="">[</span><span class="">Reverse s </span><span class="">(</span><span class="">Matrix Double</span><span class="">)]</span><span class="">
               </span><span class="">-></span><span class=""> Reverse s </span><span class="">(</span><span class="">Matrix Double</span><span class="">)</span><span class="">
  Actual </span><span class="">type</span><span class="">:</span><span class=""> </span><span class="">[</span><span class="">Matrix </span><span class="">(</span><span class="">Reverse s </span><span class="">(</span><span class="">Matrix Double</span><span class="">))]</span><span class="">
               </span><span class="">-></span><span class=""> Reverse s </span><span class="">(</span><span class="">Matrix Double</span><span class="">)</span><span class="">
In the first argument </span><span class="">of</span><span class=""> </span><span class="">‘</span><span class="">grad</span><span class="">’,</span><span class=""> namely </span><span class="">‘(</span><span class="">sc2 </span><span class="">[</span><span class="">3</span><span class="">,</span><span class=""> </span><span class="">5</span><span class="">,</span><span class=""> </span><span class="">7</span><span class="">])’</span><span class="">
In the second argument </span><span class="">of</span><span class=""> </span><span class="">‘($)’,</span><span class=""> namely
  </span><span class="">‘</span><span class="">grad </span><span class="">(</span><span class="">sc2 </span><span class="">[</span><span class="">3</span><span class="">,</span><span class=""> </span><span class="">5</span><span class="">,</span><span class=""> </span><span class="">7</span><span class="">])</span><span class=""> </span><span class="">[</span><span class="">mbW1</span><span class="">,</span><span class=""> mbW2</span><span class="">]’</span></code></pre><p>---------------------------------<br></p><p>I don't understand where the "Matrix of Matrix" in the actual type seen comes from. I'm feeding the <code>grad</code> with a curried version of sc2, making it a function on a list of Matrix.</p>

<p>Commenting out the two offending lines runs without problem, i.e., 
the first gradient works and is correctly calculated (I'm feeding 
[1,2,3] as input to the program):</p>

<pre style="" class=""><code><span class=""><br>-------------------<br></span></code><br><code><span class=""><code><span class="">mbW1</span><span class="">:</span><span class="">M </span><span class="">[[</span><span class="">1.0</span><span class="">,</span><span class="">0.0</span><span class="">,</span><span class="">0.0</span><span class="">],[-</span><span class="">1.0</span><span class="">,</span><span class="">5.0</span><span class="">,</span><span class="">1.0</span><span class="">],[</span><span class="">1.0</span><span class="">,</span><span class="">2.0</span><span class="">,-</span><span class="">3.0</span><span class="">]]</span><span class="">
mbW2</span><span class="">:</span><span class="">M </span><span class="">[[</span><span class="">0.0</span><span class="">,</span><span class="">0.0</span><span class="">,</span><span class="">0.0</span><span class="">],[</span><span class="">1.0</span><span class="">,</span><span class="">3.0</span><span class="">,-</span><span class="">1.0</span><span class="">],[-</span><span class="">2.0</span><span class="">,</span><span class="">4.0</span><span class="">,</span><span class="">6.0</span><span class="">]]</span><span class="">
sc1 xin mbW1
</span><span class="">1</span><span class="">
</span><span class="">2</span><span class="">
</span><span class="">3</span><span class="">
</span><span class="">2.0179800657874893</span><span class="">
grad </span><span class="">(</span><span class="">sc1 </span><span class="">[</span><span class="">3</span><span class="">,</span><span class="">5</span><span class="">,</span><span class="">7</span><span class="">])</span><span class=""> mbW1
M </span><span class="">[[</span><span class="">3.0</span><span class="">,</span><span class="">5.0</span><span class="">,</span><span class="">7.0</span><span class="">],[</span><span class="">7.630996942126885e-13</span><span class="">,</span><span class="">1.2718328236878141e-12</span><span class="">,</span><span class="">1.7805659531629398e-12</span><span class="">],[</span><span class="">1.0057130122694228e-3</span><span class="">,</span><span class="">1.6761883537823711e-3</span><span class="">,</span><span class="">2.3466636952953197e-3</span><span class="">]]</span><span class="">
sc2 xin </span><span class="">[</span><span class="">mbW1</span><span class="">,</span><span class="">mbW2</span><span class="">]</span><span class="">
</span><span class="">1.8733609463863194<br></span></code>-------------------<br><br></span></code>Both errors are an issue. I want to take the gradient of any such <code>sc2</code>
 scoring function, depending on an array of matrices, evaluated at any 
given "point" xin. Clearly, I'm not yet understanding the AD library 
well enough. Any help would be appreciated. <br><br></pre><pre style="" class=""><br></pre><p><br></p><div><div class="gmail_signature"><div dir="ltr"><div><div dir="ltr">--<br>Public key ID: E8FE60D7 <br>Public key server: see, e.g., hkp://<a href="http://keys.gnupg.net" target="_blank">keys.gnupg.net</a> <br></div></div></div></div></div>
</div>