PyTorch入門
目次
- PyTorchとは
- モデルの構築
- データ整形
- 学習
- モデルのセーブ、ロード
- まとめ
1. PyTorchとは
Facebook AI Research (FAIR)により開発された、Pythonの深層学習ライブラリの1つです。
define by runといって、データを流しながら計算グラフを構築するため、データに合わせて動的にモデルの構造を変えることができるという特徴があります。また、深層学習の順伝播→誤差逆伝播のプロセスを直感的に記述することができます。
ライブラリとして非常に使いやすいため人気が高く、TensorFlowと並んで、Pythonの代表的な深層学習ライブラリとなっています。
2. モデルの構築
PyTorchでのモデルを構築する際には、torch.nn.Moduleクラスを継承したクラスを作成することになります。モデルのクラス内のforwardメソッド内に、順伝播時に行う処理を記述する決まりになっています。
コードとしては大体こんな感じになります。
import torch.nn as nn
import torch.nn.functional as F
class MyModel(nn.Module):
def __init__(
self,
in_features,
dim_hidden,
out_features
):
super().__init__()
self.fc1 = nn.Linear(in_features, dim_hidden)
self.fc2 = nn.Linear(dim_hidden, out_features)
def forward(self, x):
out = self.fc1(x)
out = F.relu(x)
out = self.fc2(x)
out = F.relu(x)
return out
torch.nn.Linear()クラスは、与えられた変数に重みをかけて出力するクラスです。
__init__の中でfc1とかfc2と書いているのが、モデルの1層目と2層目に対応している部分です。モデルの構造的な部分は、基本的に__init__に記述します。
出力するときは、
model = MyModel(10, 10, 10)
out = model(x)
みたいな感じです。
3. データ整形
入力形式に合わせてDatasetクラスを定義することが多いです。
通常、__getitem__()
内に入力形式の整形などの処理を書きます。
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __getitem__(self, index):
# 何かしらの前処理
return {
'input': torch.Tensor,
'target': torch.Tensor
}
このように定義したDatasetは、DataLoaderに渡してあげるとforループの中で取り出すことができて便利です。
from torch.utils.data import DataLoader
dataset = MyDataset()
# batch_sizeを指定して、一定の塊ごとにデータを取り出すことができる
dataloader = DataLoader(dataset, batch_size=32)
for batch in dataloader:
# 何かしらの処理
4. 学習
深層学習における学習というのは、損失関数の値に対する偏微分を計算し、偏微分の値に応じてモデルの各パラメータを更新していく一連の処理を指します。更新前後のパラメータを、、損失関数をとすると、以下のようになります。
は学習率と呼ばれる正の値で、一度に更新する重み量を調整するためのハイパーパラメータです。また、は目標値に対して計算された損失の値で、学習を行う際の指標となります。
この更新を行うためには、損失関数の偏微分の値を求める必要があるわけですが、PyTorchにはそのための仕組みが用意されています。これは、PyTorchが公式に自動微分と呼んでいる機能で、これを利用することで偏微分の値を簡単に求めることができます。
自動微分で求めた偏微分の値は、各パラメータテンソルの、grad
という変数に格納されています。学習段階で、この値を使っていくことになります。PyTorchでは、パラメータ更新を行うための各種アルゴリズムを利用することができます。optimizer = torch.optim.Adam()()
のような形で使用するアルゴリズムを指定し、optimizer.step()
とすればモデルパラメータを更新できます。
具体的には、以下の流れで学習を行っていくことになります。
- モデル、損失関数、オプティマイザーの定義
model = MyModel()
criterion = torch.nn.MSELoss()
# オプティマイザーにはモデルのパラメータを渡すこと
optimizer = torch.optim.Adam(model.parameters())
- 出力を計算する
out = model(x)
- 損失を計算する
loss = criterion(out, target)
- 損失をモデルパラメータで偏微分する
loss.backward()
- モデルパラメータを更新する
optimizer.step()
上記の1.-5.をforループで回していくのが、PyTorchで学習を行う際の基本的な流れです。
5. モデルのセーブ、ロード
学習をしたらモデルを保存しましょう。
モデルのパラメータはstate_dict
という形式で保存することがでます。
ロードする際には、torch.load()
でstate_dict
の形でパラメータを取り出したあと、model.load(state_dict)
としてパラメータを反映します。
# セーブ時
torch.save(model.state_dict(), 'extremely_awesome_model.pth')
# ロード時
state_dict = torch.load('extremely_awesome_model.pth')
model.load(state_dict)
6. まとめ
PyTorchを使ってモデルの構築から学習、保存までのおおまかな流れを解説しました。
直感的に使えるところがこのライブラリのいいところだと思うので、ぜひ触ってみてください。