Pythonいぬ

pythonを使った画像処理に関する記事を書いていきます

Pytorchでモデルの保存と読み込み

Pytorchでモデルを学習した後にモデルの学習済み重みパラメータを保存して、別のプログラムで読み込みたい。特にGPUで学習したモデルを、GPUを持たないCPUのみのPCで使いたい場合によくハマるのでメモを残しておく。

GPUで学習してGPUで読み込む

GPUで学習したモデルを保存して、GPUで読み込む場合は以下のコマンド。

保存方法(GPUで保存)

model_path = 'model.pth'
torch.save(model.state_dict(), model_path)

読み出し方法

model_path = 'model.pth'
model.load_state_dict(torch.load(model_path))

GPUで学習してCPUで読み込む

GPUで学習したモデルを保存して、CPUで読み込む場合は以下

保存方法(CPUで保存)

model_path = 'model.pth'
torch.save(model.to('cpu').state_dict(), model_path)

CPUマシンでのGPUモデルの読み出し

model_path = 'model.pth'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

保存時はstate_dict()を使う

モデルはtorch.saveで直接保存することもできるが、state_dictで保存した方が無駄な情報を削れてファイルサイズを小さくできる。公式ページにもstate_dictで保存することが推奨されている。 pytorch.org

model.state_dict()は必要な情報のみを保存するので、ネットワーク構造や各レイヤの引数(入出力チャネル数やカーネルサイズなど)などの無駄な情報を保存せずにすむ。

ネットワークの構造などの情報はモデルの定義に書いてあるはずなので、モデルを定義したクラスのスクリプトがあるなら保存する必要はない。例えば、mansnet0_5を保存するなら下記のように保存する。

from torchvision import models
model = models.mnasnet0_5()
torch.save(model.to('cpu').state_dict(), 'model.pth')

読み出し方法は以下。重みパラメータのみを保存しているので、モデルを定義し直しておく必要がある。

from torchvision import models
model = models.mnasnet0_5()
model.load_state_dict(torch.load('model.pth'))

なお、注意すべき点としてmodel.pthを保存した後にクラスの中身を書き換えると保存したパラメータは読み出せなくなる。ネットワーク構造だけでなくメンバ変数名を変えても読み出せなくなるので注意。

state_dict()と直接保存のサイズ比較

上でstate_dict()を使うとファイルサイズを削減できると行ったが、ファイルサイズをどの程度削減できるのかを、state_dict()を使った場合と、直接保存する場合とで比較してみる。

例えば、model = nn.Conv2d(1, 2, 3)を保存した場合、ファイルサイズは下記のようになる。

8.1K direct.pth
498  sate_dict.pth

学習済みモデルは、ネットワークを通して様々なアプリケーションに配布することが多いので、可能な限り保存ファイルのサイズは小さくしたい。なので、state_dict()を使ってパラメータだけ保存した方がよい。

学習済みモデルをCPUのPCで使う方法

state_dict()で保存する際は、modelのdeviceも保存されるため、GPUで学習したモデルをGPUが使えないPCで使いたい場合にdevice周辺の問題で読み出せないことがある。

こういったdevice問題を避けるには以下の2つの方法がある。

  1. モデルのdeviceをGPUからCPUに変更してから保存する
  2. モデルの読み出し時にdeviceをCPUに指定する

モデルのdeviceをGPUからCPUに変更して保存

配布先でのデバッグはやりたくないので、できればGPUマシンで保存するときは必ず以下のようto('cpu')をやってから保存することをお勧めする。

model = model.to('cpu')
torch.save(model.state_dict(), 'model_cpu.pth')

こうすれば、CPUのみのPCでもGPUマシンで学習した学習済みモデルを使うことができる。

すでに保存されているモデルについては、GPUマシンでいったん読み出してからCPUで保存するということもできる。

model.load_state_dict(torch.load('model_gpu.pth')) 
torch.save(model.state_dict().to('cpu'), 'model_cpu.pth')

モデルの読み出し時にdeviceをCPUに指定

どうしてもCPUマシンでGPUで保存したデータを読みたいときは以下のようにmap_locationをCPUに設定する。

model.load_state_dict(torch.load('model_gpu.pth', map_location=torch.device('cpu')))

読み出しエラーの再現

GPUで学習したモデルをdeviceを変えずに保存し、CPUのみが使えるPCで直接読みだしてみる。こうするとtorch.load()の部分で一旦GPUメモリを経由するためにエラーが出る。

例えば、GPUが使えるPCで、

from torch import nn
model = nn.Conv2d(1,2,3, bias=False).to('cuda')

とすると、model.state_dict()の中身は以下のようになる。

OrderedDict([('weight', tensor([[[[-0.0855, -0.1221,  0.1030],
                        [-0.3128, -0.2851, -0.3323],
                        [ 0.0387, -0.0047, -0.1450]]],
              
              
                      [[[ 0.2419,  0.2015, -0.1305],
                        [ 0.0084, -0.2863,  0.0524],
                        [ 0.1791, -0.1799, -0.0403]]]], device='cuda:0'))])

device='cuda:0'がついている。ここでいったん保存してみる

torch.save(model.state_dict(), 'model_gpu.pth')

ここで保存したモデルファイルを、GPUが使えるPCでdeviceをCPUとして読みだしてみる。

from torch import nn
model = nn.Conv2d(1,2,3, bias=False).to('cpu')
model.load_state_dict(torch.load('model_gpu.pth')) 

このようにしてからもう一度state_dictの中身を見ると

OrderedDict([('weight', tensor([[[[-0.0855, -0.1221,  0.1030],
                        [-0.3128, -0.2851, -0.3323],
                        [ 0.0387, -0.0047, -0.1450]]],
              
              
                      [[[ 0.2419,  0.2015, -0.1305],
                        [ 0.0084, -0.2863,  0.0524],
                        [ 0.1791, -0.1799, -0.0403]]]]))])

一見、cpuモードに戻っているように見える。しかし、loadしたsate_dictを以下のコマンドで見てみると

torch.load('model_gpu.pth')

出力

OrderedDict([('weight', tensor([[[[-0.0855, -0.1221,  0.1030],
                        [-0.3128, -0.2851, -0.3323],
                        [ 0.0387, -0.0047, -0.1450]]],
              
              
                      [[[ 0.2419,  0.2015, -0.1305],
                        [ 0.0084, -0.2863,  0.0524],
                        [ 0.1791, -0.1799, -0.0403]]]], device='cuda:0'))])

となり、device='cuda:0'がいる。この部分でGPUメモリを使う。

したがって、CPUのみのPCでこのモデルを読みだそうとすると下記のエラーが出る。

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

モデルのdeviceを'cpu'に変換してから保存すれば、このエラーは出ずに問題なく読み取れるようになる。