NumPyなどの基礎的なモジュールのみを用いてニューラルネットワークを自作してみた

2

2023年01月08日 13:03

こんにちにゃんです。
水色桜(みずいろさくら)です。
今回はニューラルネットワークに関する記事を書いていこうと思います。

本記事の目的

本記事ではnumpyなどの基礎的な道具のみを用いてニューラルネットワークを自作していきます。私と同じく、ゼロからニューラルネットワークを自作してみたいという方の参考になれば幸いです。数式に関する解説もしていますので、この記事だけで完結できるようにしていくつもりです。もし記事が良いと思ってもらえたらいいねボタンを押してもらえると今後の励みになります。
今回は特に画像分類問題としてポピュラーなMNISTについて解析を行っていきます。他の記事を見るとライブラリを用いてMNISTの分類を行っている例はたくさんありますが、numpyなどの基礎的なモジュールのみを用いている例は少ないように思えました。

まず今回作成したニューラルネットワークの性能について示します。

img
10回の学習で約89%の精度を達成できています。CNNなどを使った場合はもう少し高い精度が達成できますが、自作したニューラルネットワークとしては十分な精度だと思いました。

ニューラルネットワークとは

人間の脳のしくみ(ニューロン間のあらゆる相互接続)から着想を得たもので、脳機能の特性のいくつかをコンピュータ上で表現するために作られた数学モデルです。(Udemyメディアから引用
image.png
Udemyメディアから引用

ニューラルネットワークは入力層、中間層、出力層の3つから構成されます。本記事で作成するニューラルネットワークのモデルは次のような感じです。
image.png

ニューラルネットワークにおいては、前の層からの出力に重みとバイアスを加え、それをアクティベーション関数と呼ばれる非線形関数にかけます。これにより、複雑な分布も表現することが可能になります。ニューラルネットワークは学習を行うことで出力層で人間が望む結果(正しい答え、正解)が出るようにします。学習によって中間層の重みとバイアスを最適化していきます。

数学的な背景

ニューラルネットワークを構成するうえで、中間層と出力層のアクティベーション関数を選ぶ必要があります。本記事では中間層のアクティベーション関数はLeaky ReLU、出力層のアクティベーション関数はsoftmax関数を用います。
image.png
Leaky ReLU関数は先行研究において優れた成果を出している関数であり、勾配消失(勾配が小さくなり、ニューラルネットワークの学習が進まなくなってしまう問題)が起きにくいという性質を有しています。
softmax関数は出力をすべて足し合わせると1になるように設計された関数であり、分類問題に用いられている関数です。入力に対して0~1の任意の値を返すため、擬似的に確率として扱うことができます。
交差エントロピー誤差はsoftmax関数と一緒に用いると逆伝播が簡単な形に書けるように設計された関数です。一見複雑そうに見えますが、微分すると簡単な形になります。
image.png

誤差逆伝播法はニューラルネットワークにおいて勾配を効率的に求めるための手法です。微分の連鎖率を用いて図のように逆側から勾配を求めていきます。
誤差逆伝播法の導出方法は下の画像の通りです。交差エントロピー誤差を出力で微分した値と、出力を入力で微分した値を掛け合わせると、結果としてとっても簡単な形に書き表すことができます。出力から教師データを引いたものになるので、実装もしやすくなっています(そうなるようにsoftmax関数と交差エントロピー誤差は設計されています。)
image.png

adam

現在最も多く用いらえている最適化手法。2015年にKingma氏らによって発表されました。Adaptive Moment Estimationの略です。

image.png

Fumio-eisanさんの記事より引用

移動平均で振動を抑制するmomentumと 学習率を調整して振動を抑制するRMSPropを組み合わせた関数となっています。
式を見ると、第1式は移動平均をとっています。また、第5式をみると、見かけの学習率α/(√v+ε) (εは十分小さい数)はvが大きいほど小さくなることがわかります。これは勾配▽E(w)が大きいときに、振動が生じにくくなるようにするためです。

adamの実装

では先ほど示した数式を元にadamを実装していきます。xが1次の勾配、yが2次の勾配を表しています。sとtはそれぞれ前の時刻の1次の勾配と2次の勾配を表しています。計算したxとyは次の時刻の計算で用いるために出力させます。dは▽E(w)を、rateは学習率を表しています。またnは時刻を、x1は時刻n-1におけるwを、zは時刻nにおけるwを表しています。
img

neural networkの実装

最後に作成したニューラルネットワーク全体のコードを示します。もしお忙しい方はコピペして使ってみてください。

img

終わりに

今回はNumPyなどの基礎的なモジュールのみを用いてニューラルネットワークの自作を行う方法について書いてきました。
私と同じく一からニューラルネットワークを自作してみたいという方の助けになれば幸いです。
では、ばいにゃん~。

# Python
2

診断を受けるとあなたの現在の業務委託単価を算出します。今後副業やフリーランスで単価を交渉する際の参考になります。また次の単価レンジに到達するためのヒントも確認できます。

目次を見る