Pythonいぬ

pythonを使った画像処理に関する記事を書いていきます

Pytorchで学習のチェックポイントの保存

Pytorchでモデルを保存する場合、モデルのパラメータのみを保存することが多い。しかし、モデルパラメータだけではlossがどれくらいか、optimizerは何を使ったか、何イテレーション学習してあるかなどの情報がわからない。これらがわからないと特に途中から学習を開始するfine tuningや転移学習のときに困ってしまうので、これらの学習に関する情報を一緒に保存する方法をメモしておく。

モデルのみの保存はこちら。

tzmi.hatenablog.com

チェックポイントの保存

pytorchのsaveでは、modelmodel.state_dict()のようなクラスインスタンスの保存に加えて、pythonのdictなどを保存することもできる。これを利用してモデルの学習情報を以下のコードのように保存できる。モデルのみを保存した場合は拡張子に.pt.pthなどをつけることが多いが、これと区別するために拡張子は.cpt(チェックポイント)としておく。

import torch

# 適当なモデルとoptimizerの定義
model = torch.nn.Conv2d(3,3,3)
opt = torch.optim.Adam(model.parameters())

# モデルの入出力(学習データの代わり)
x = torch.randn(1,3,10,10)
y = torch.randn(1,3,8,8)

# 学習のループ
losses = []
for iiter in range(10):
    y_ = model.forward(x)
    loss = torch.sum((y-y_)**2)
    opt.zero_grad()
    loss.backward()
    opt.step()

    losses.append(loss.item())
    print(loss.item())

# 学習情報の保存
outfile = 'out.cpt'
torch.save({'iter': iiter,
            'model_state_dict': model.state_dict(),
            'opt_state_dict': opt.state_dict(),
            'loss': losses,
            }, outfile)

ちなみにAdamは移動平均とるためにイテレーション回数を保存していて、opt.state_dict()の中に'step'という変数がいる。これがiiterと同じ意味を持つので、Adamを使う場合はiterは保存しなくてもよいかもしれない。

逆に途中から学習する場合は、Adamの情報が保存されていないと、内部のイテレーション回数や移動平均などの情報がとれない。つまり、一度Adamをリセットした状態で学習が再開されることになる。特に論文を書く場合は再現性が必要になるので、学習を途中で止めると条件が変わってしまって何かしら問題が起こる。これを防ぐ意味でもoptimzerの状態は一緒に保存しておいた方がいい。

保存したチェックポイントの読み出し

保存したモデルの学習情報は単純なdictを保存した場合と同様に、以下のように読み出せる。ちょうど、numpyのnpz保存を読み出す場合と同じ感じ。

import torch

# 保存したときと同じモデルとoptimizerの定義
model = torch.nn.Conv2d(3,3,3)
opt = torch.optim.Adam(model.parameters())

# 読み出し
cptfile = 'out.cpt'
cpt = torch.load(cptfile)
stdict_m = cpt['model_state_dict']
stdict_o = cpt['opt_state_dict']
model.load_state_dict(stdict_m)
oplt.load_state_dict(stdict_o)

これまでモデルのみを保存していた場合と一緒に使いたい場合は、拡張子や、cpt.__class__ == dictなどで型を確認してif文などで分ければよいかと思う。これで途中まで学習した情報を使って学習を再開できるようになる。

import torch

# 保存したときと同じモデルとoptimizerの定義
model = torch.nn.Conv2d(3,3,3)
opt = torch.optim.Adam(model.parameters())

# 読み出し
cptfile = 'out.cpt'
cpt = torch.load(cptfile)
if cpt.__class__ == dict:
    stdict_m = cpt['model_state_dict']
else:
    stdict_m = cpt

model.load_state_dict(stdict_m)