元バイオ系

元バイオウェット系がデータサイエンスやらを勉強していくブログ。 基本自分用のまとめ。

テンソルを使いこなしたい3(行列の行列微分)

hotoke-x.hatenablog.com

hotoke-x.hatenablog.com

の続き。
本記事ではテンソル微分演算についてまとめる。行列の行列微分がややこしかったので、縮約記法で計算すればスッキリできるのではと思ったのが事の発端。ようやくここまで来た。

準備

スカラーのベクトル微分、ベクトルのスカラー微分、ベクトルのベクトル微分スカラーの行列微分を以下のように定義する。また、それぞれの最右辺は縮約記法での表現を表す。 {} $$ \begin{align} \frac{\mathrm{d}f}{\mathrm{d}\boldsymbol{x}} &= \left( \begin{array}{c} \frac{\mathrm{d}f}{\mathrm{d}x_{1}} \\ \frac{\mathrm{d}f}{\mathrm{d}x_{2}} \\ \vdots \\ \frac{\mathrm{d}f}{\mathrm{d}x_{n}} \end{array} \right) = \partial{}_{i} f \\ \frac{\mathrm{d}\boldsymbol{f}}{\mathrm{d}x} &=\left(\begin{array}{c} \frac{\mathrm{d}f_{1}}{\mathrm{d}x} \\ \frac{\mathrm{d}f_{2}}{\mathrm{d}x} \\ \vdots \\ \frac{\mathrm{d}f_{n}}{\mathrm{d}x} \end{array} \right) = \partial{} f_{i} \\ \frac{\mathrm{d}\boldsymbol{f}}{\mathrm{d}\boldsymbol{x}} &= \left(\begin{array}{cccc} \frac{\mathrm{d}f_{1}}{\mathrm{d}x_{1}} & \frac{\mathrm{d}f_{1}}{\mathrm{d}x_{2}} & \ldots & \frac{\mathrm{d}f_{1}}{\mathrm{d}x_{n}} \\ \frac{\mathrm{d}f_{2}}{\mathrm{d}x_{1}} & \frac{\mathrm{d}f_{2}}{\mathrm{d}x_{2}} & \ldots & \frac{\mathrm{d}f_{2}}{\mathrm{d}x_{n}}\\ \vdots & \vdots & \ddots & \vdots \\ \frac{\mathrm{d}f_{m}}{\mathrm{d}x_{1}} & \frac{\mathrm{d}f_{m}}{\mathrm{d}x_{2}} & \ldots & \frac{\mathrm{d}f_{m}}{\mathrm{d}x_{n}} \end{array} \right) = \partial{}_{j} f_{i} \label{vec_by_vec} \\ \frac{\mathrm{d}f}{\mathrm{d}\boldsymbol{X}} &= \left(\begin{array}{cccc} \frac{\mathrm{d}f}{\mathrm{d}x_{11}} & \frac{\mathrm{d}f}{\mathrm{d}x_{12}} & \ldots & \frac{\mathrm{d}f}{\mathrm{d}x_{1n}} \\ \frac{\mathrm{d}f}{\mathrm{d}x_{21}} & \frac{\mathrm{d}f}{\mathrm{d}x_{22}} & \ldots & \frac{\mathrm{d}f}{\mathrm{d}x_{2n}}\\ \vdots & \vdots & \ddots & \vdots \\ \frac{\mathrm{d}f}{\mathrm{d}x_{m1}} & \frac{\mathrm{d}f}{\mathrm{d}x_{m2}} & \ldots & \frac{\mathrm{d}f}{\mathrm{d}x_{mn}} \end{array} \right) = \partial{}_{ij} f \label{scalar_by_matrix} \end{align} $$

また、\eqref{vec_by_vec}より  {} $$ \begin{align} \partial{}_{j} x_{i}= \delta_{ij} \end{align} $$ である(この後使う)。

なお、\eqref{vec_by_vec}は転置した形として定義されていることもある。行列表現のまま計算を進めようとすると、行列積が可換でないことに注意する必要がある。縮約記法で計算すれば気にしなくて良くなる。

さて、ベクトルの行列微分、行列のベクトル微分、行列の行列微分はどうなるのだろう。行列の行列微分がわかれば他もノリでわかると思うので、上記と同様に縮約記法を使って行列の行列微分の定義から、ニューラルネットワークの逆誤差伝播の式を導出するところまでやってみる。

行列の行列微分

行列 \displaystyle Xの自身による微分は縮約記法を使って以下のように定義される。

 {} $$ \begin{align} \frac{\partial X_{kl}}{\partial X_{ij}} = \delta_{ik}\delta_{li} \end{align} $$

要は、 \displaystyle X_{kl} \displaystyle X_{kl}微分したときは \displaystyle 1それ以外は \displaystyle 0と言っているだけである。「じゃあ全成分 \displaystyle 1になるじゃん」とか一瞬思うかもしれない(自分は思った)がそれは大間違い。 \displaystyle k,lそれぞれについて \displaystyle i,jを走査して計算するので4階テンソルになっている(添字も4つあるし)。

逆誤差伝播

以下のような単純なニューラルネットワークを考える。 Lは損失関数でスカラーである。形式的に書いてるだけなので実態は何でも良い。

f:id:hotoke-X:20190210064732p:plain

これを式で書けば

$$ \begin{align} Y &= WX + B \\ X &= \left(\begin{array}{c} x_{1} \\ \vdots \\ x_{n} \end{array} \right), W = \left(\begin{array}{ccc} w_{11} & \ldots & w_{1n} \\ \vdots & \ddots & \vdots \\ w_{m1} & \ldots & w_{mn} \end{array} \right), B = \left(\begin{array}{c} b_{1} \\ \vdots \\ b_{m} \end{array} \right) \end{align} $$

である。計算グラフにして逆誤差伝播を赤で書き入れると以下のようになる。

f:id:hotoke-X:20190210082429p:plain

ナニコレ。特に \displaystyle \left( \frac{\partial Y}{\partial X}\right)^{\mathrm{T}} \frac{\partial L}{\partial Y}。そりゃ \displaystyle Xの次元に一致しなきゃいけないから逆算すればわかるけど気持ち悪い。一方で、 \displaystyle \frac{\partial L}{\partial Y} \frac{\partial Y}{\partial W}も連鎖律としてはスッキリしているが、ベクトルの行列微分がある。しかも \displaystyle X^{\mathrm{T}}についてもなんで転置されてるんだかよくわからない。

逆誤差伝播(縮約記法でスッキリ ver.)

では縮約記法で \displaystyle \frac{\partial L}{\partial X}, \frac{\partial L}{\partial W}を計算してみよう。

 {} $$ \begin{align} Y_{i} = W_{ij} X_{j} + B_{i} \end{align} $$ として、  {} $$ \begin{align} \frac{\partial L}{\partial X_{j}} &= \frac{\partial L}{\partial Y_{k}} \frac{\partial Y_{k}}{\partial X_{j}} = \frac{\partial L}{\partial Y_{k}} \frac{\partial \left(W_{kl} X_{l} + B_{k} \right)}{\partial X_{j}} = \frac{\partial L}{\partial Y_{k}} W_{kl} \delta_{lj} \\ &= \frac{\partial L}{\partial Y_{k}} W_{kj} = W_{kj} \frac{\partial L}{\partial Y_{k}} = W^{\mathrm{T}} \frac{\partial L}{\partial Y} \\ \frac{\partial L}{\partial W_{ij}} &= \frac{\partial L}{\partial Y_{k}} \frac{\partial Y_{k}}{\partial W_{ij}} = \frac{\partial L}{\partial Y_{k}} \frac{\partial \left(W_{kl} X_{l} + B_{k} \right)}{\partial W_{ij}} = \frac{\partial L}{\partial Y_{k}} X_{l} \delta_{ik}\delta_{lj} \\ &= \frac{\partial L}{\partial Y_{i}} X_{j} = \frac{\partial L}{\partial Y} X^{\mathrm{T}} \end{align} $$

最右辺は行列表現で書いた。今まで悩んでたのが馬鹿らしくなるくらい簡単に計算できた。

まとめ

もう全部これで良くない?

とはいえ、実装と実行速度を考えると行列表現が得られた方が実用的であることは明白。なので、縮約記法で計算してから行列表現に戻してやるのが実務的なやり方なんだと思う。

追記:
前回記事でまとめたエディントンのイプシロンは全く必要なかった(笑)。まぁでも知っていることは良い事だ。

参考

http://www2.imm.dtu.dk/pubdb/views/edoc_download.php/3274/pdf/imm3274.pdf