今更ながら誤差逆伝播の確認と、ちょっと混乱した逆伝播の処理についての自分なりの理解
ニューラルネットワークを使った論文を読んだりちょこっと実装したりしているうちに、「これの逆伝播ってどう考えればいいんだっけ?」みたいなことが起こりました。
ついでに最近おまじないみたいにloss.backward()って書いてしまっていたので、逆伝播ってなんだったかを含めて自分なりに復習しましたよ、というお話です。
ニューラルネットワークのどんな処理がピンとこなかったか
ピンとこなかった処理はこんな感じです。
複数の入力をそれぞれまずCNNなどのニューラルネットワークに入力して特徴量を得ます。そうしたら、それぞれの特徴量を結合(concat)して、それをまた全結合層とかに入力して最終的に出力を得ます。
このconcatの部分が最初に見たときピンときませんでした。というかよく考えたらそれ以外の層でもどんな感じで逆伝播が行われるのかの理解があやふやになっていました。
ということで、そもそも誤差逆伝播ってどんな感じの処理だったかということを思い出すところから始めようと思いたち、昔買った「ゼロから作るDeep Learning」を読み直しました。(昔読んだときはピンとこない部分も多かったんですが、読み直したらめっちゃわかりやすかったです。)
誤差逆伝播
そもそも誤差逆伝播って何をするものだったかに最初に触れておきたいと思います。誤差逆伝播は、重みパラメータの勾配を効率よく計算するための手法です。
ここで、なぜ重みパラメータの勾配を効率よく計算しなければならないのか、が問題になってくると思います。
まず、勾配が必要な理由ですが、これはパラメータの更新に必要だからです。SGDやAdamといった手法は、ざっくりいうと重みの勾配に基づいてLossが小さくなる方向に各重みの値を更新する手法です。そのため、パラメータを更新するためには勾配を計算する必要があります。
勾配を計算するだけなら数値微分で押し切ることも可能ですが、層が深くなると処理が膨大になってくるため、押し切るのは得策ではありません。そこで、勾配を効率よく計算する手法として提案されているのが誤差逆伝播法である、ということみたいです。
計算グラフによる誤差逆伝播の理解
誤差逆伝播法の解釈の1つに計算グラフというものがあります。(本の受け売り)これは、グラフを用いて誤差逆伝播法がどのように行われているのかを解釈するというものです。読んでいてわかりやすかったので、ここでもそれを使いたいと思います。まずは一般的な考え方を示して、それから簡単な例を示したいと思います。
一般的にはこんな処理があちこちで行われているイメージです。構築したネットワークがどのように出力を演算するかをグラフで表し、それに対して逆方向の演算を定義することで各変数の勾配がその都度得られるようになっています。
計算グラフのいいところは、複雑な処理を個々の局所的な計算に落とし込めるところ(らしい)です。例えば、この例で言うと、入力はxとyだけではないかもしれなくて別でもっと複雑な計算をしているかもしれないし、出力のzに対してこれからめっちゃ複雑な計算をする必要があるかもしれません。それでも、xとyの勾配を考えるときにはzの勾配と何らかの処理だけを考えればいいです。
もう1つのいいところとしては、計算の過程で各勾配が計算でき、それを共有することができるところらしいです。確かに。
簡単な例として、加算ノードと乗算ノードの場合の計算グラフを図示したいと思います。
加算ノード
加算ノードはこんな感じ。zに対するxとyの偏微分はそれぞれ1なので、結果的に伝播してきたものを素通しにしている感じです。
念の為、実際のフレームワークでも同じことが行われていることを確認してみます。使っているフレームワークはpytorchです。
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) y = torch.tensor([4.0, 5.0, 6.0], requires_grad=True) z = x + y print(z) z.backward(torch.tensor([1.0, 1.0, 1.0])) print(x.grad) print(y.grad)
出力はこんな感じ
tensor([5., 7., 9.], grad_fn=<AddBackward0>) tensor([1., 1., 1.]) tensor([1., 1., 1.])
zはちゃんとxとyの和になっているのが確認できます。また、xとyの勾配はzの勾配(ここでは[1,1,1]ということにしています)を素通ししていることも確認できました。
乗算ノード
やることは加算ノードのときと同じです。
zに対するxとyの偏微分はyとxになるので、もう片方の入力をそのままzの勾配にかけるようになっています。こちらもpytorchでどうなっているか確認してみます。
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) y = torch.tensor([4.0, 5.0, 6.0], requires_grad=True) z = x * y print(z) z.backward(torch.tensor([1.0, 1.0, 1.0])) print(x.grad) print(y.grad)
出力はこんな感じ
tensor([ 4., 10., 18.], grad_fn=<MulBackward0>) tensor([4., 5., 6.]) tensor([1., 2., 3.])
zはちゃんとxとyの積になっていることが確認できます。また、もう一方の入力をかけた結果が勾配になっていることがわかります。
ピンとこなかった処理を改めて考えてみる
以上のことを踏まえるとconcatという操作は以下のように考えることができるのではないかと思います。
concatという操作が行列の形状そのものを変形してしまう処理なので、ちょっとスッキリしない説明になっていますが、基本的にはzの勾配の部分部分をそれぞれ各ノードに渡していくということになるのだと思います。
これもpytorchで確認してみます。
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) y = torch.tensor([4.0, 5.0, 6.0], requires_grad=True) z = torch.cat([x, y]) print(z) z.backward(torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) print(x.grad) print(y.grad)
出力はこんな感じ
tensor([1., 2., 3., 4., 5., 6.], grad_fn=<CatBackward>) tensor([1., 2., 3.]) tensor([4., 5., 6.])
やっぱり上のメモの理解であっているんじゃないでしょうか。あれこれ考えずとも、torch.catを使えばいい感じにフレームワークが処理してくれるようになってたみたいですね。
最後に
ピンとこなかった処理について考えるだけなら計算グラフいらなかった気がする…
でも、ネットワークを組むときに自分が何をしているのかクリアになった気がするので良しとします。