Pytorchでモデルを学習した後にモデルの学習済み重みパラメータを保存して、別のプログラムで読み込みたい。特にGPUで学習したモデルを、GPUを持たないCPUのみのPCで使いたい場合によくハマるのでメモを残しておく。
GPUで学習してGPUで読み込む
GPUで学習したモデルを保存して、GPUで読み込む場合は以下のコマンド。
保存方法(GPUで保存)
model_path = 'model.pth'
torch.save(model.state_dict(), model_path)
読み出し方法
model_path = 'model.pth'
model.load_state_dict(torch.load(model_path))
GPUで学習してCPUで読み込む
GPUで学習したモデルを保存して、CPUで読み込む場合は以下
保存方法(CPUで保存)
model_path = 'model.pth' torch.save(model.to('cpu').state_dict(), model_path)
CPUマシンでのGPUモデルの読み出し
model_path = 'model.pth' model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
保存時はstate_dict()を使う
モデルはtorch.save
で直接保存することもできるが、state_dict
で保存した方が無駄な情報を削れてファイルサイズを小さくできる。公式ページにもstate_dict
で保存することが推奨されている。
pytorch.org
model.state_dict()
は必要な情報のみを保存するので、ネットワーク構造や各レイヤの引数(入出力チャネル数やカーネルサイズなど)などの無駄な情報を保存せずにすむ。
ネットワークの構造などの情報はモデルの定義に書いてあるはずなので、モデルを定義したクラスのスクリプトがあるなら保存する必要はない。例えば、mansnet0_5
を保存するなら下記のように保存する。
from torchvision import models model = models.mnasnet0_5() torch.save(model.to('cpu').state_dict(), 'model.pth')
読み出し方法は以下。重みパラメータのみを保存しているので、モデルを定義し直しておく必要がある。
from torchvision import models model = models.mnasnet0_5() model.load_state_dict(torch.load('model.pth'))
なお、注意すべき点としてmodel.pthを保存した後にクラスの中身を書き換えると保存したパラメータは読み出せなくなる。ネットワーク構造だけでなくメンバ変数名を変えても読み出せなくなるので注意。
state_dict()と直接保存のサイズ比較
上でstate_dict()
を使うとファイルサイズを削減できると行ったが、ファイルサイズをどの程度削減できるのかを、state_dict()
を使った場合と、直接保存する場合とで比較してみる。
例えば、model = nn.Conv2d(1, 2, 3)
を保存した場合、ファイルサイズは下記のようになる。
8.1K direct.pth 498 sate_dict.pth
学習済みモデルは、ネットワークを通して様々なアプリケーションに配布することが多いので、可能な限り保存ファイルのサイズは小さくしたい。なので、state_dict()
を使ってパラメータだけ保存した方がよい。
学習済みモデルをCPUのPCで使う方法
state_dict()
で保存する際は、modelのdevice
も保存されるため、GPUで学習したモデルをGPUが使えないPCで使いたい場合にdevice
周辺の問題で読み出せないことがある。
こういったdevice
問題を避けるには以下の2つの方法がある。
- モデルのdeviceをGPUからCPUに変更してから保存する
- モデルの読み出し時にdeviceをCPUに指定する
モデルのdeviceをGPUからCPUに変更して保存
配布先でのデバッグはやりたくないので、できればGPUマシンで保存するときは必ず以下のようto('cpu')
をやってから保存することをお勧めする。
model = model.to('cpu') torch.save(model.state_dict(), 'model_cpu.pth')
こうすれば、CPUのみのPCでもGPUマシンで学習した学習済みモデルを使うことができる。
すでに保存されているモデルについては、GPUマシンでいったん読み出してからCPUで保存するということもできる。
model.load_state_dict(torch.load('model_gpu.pth')) torch.save(model.state_dict().to('cpu'), 'model_cpu.pth')
モデルの読み出し時にdeviceをCPUに指定
どうしてもCPUマシンでGPUで保存したデータを読みたいときは以下のようにmap_location
をCPUに設定する。
model.load_state_dict(torch.load('model_gpu.pth', map_location=torch.device('cpu')))
読み出しエラーの再現
GPUで学習したモデルをdeviceを変えずに保存し、CPUのみが使えるPCで直接読みだしてみる。こうするとtorch.load()
の部分で一旦GPUメモリを経由するためにエラーが出る。
例えば、GPUが使えるPCで、
from torch import nn model = nn.Conv2d(1,2,3, bias=False).to('cuda')
とすると、model.state_dict()
の中身は以下のようになる。
OrderedDict([('weight', tensor([[[[-0.0855, -0.1221, 0.1030], [-0.3128, -0.2851, -0.3323], [ 0.0387, -0.0047, -0.1450]]], [[[ 0.2419, 0.2015, -0.1305], [ 0.0084, -0.2863, 0.0524], [ 0.1791, -0.1799, -0.0403]]]], device='cuda:0'))])
device='cuda:0'
がついている。ここでいったん保存してみる
torch.save(model.state_dict(), 'model_gpu.pth')
ここで保存したモデルファイルを、GPUが使えるPCでdeviceをCPUとして読みだしてみる。
from torch import nn model = nn.Conv2d(1,2,3, bias=False).to('cpu') model.load_state_dict(torch.load('model_gpu.pth'))
このようにしてからもう一度state_dict
の中身を見ると
OrderedDict([('weight', tensor([[[[-0.0855, -0.1221, 0.1030], [-0.3128, -0.2851, -0.3323], [ 0.0387, -0.0047, -0.1450]]], [[[ 0.2419, 0.2015, -0.1305], [ 0.0084, -0.2863, 0.0524], [ 0.1791, -0.1799, -0.0403]]]]))])
一見、cpu
モードに戻っているように見える。しかし、loadしたsate_dict
を以下のコマンドで見てみると
torch.load('model_gpu.pth')
出力
OrderedDict([('weight', tensor([[[[-0.0855, -0.1221, 0.1030], [-0.3128, -0.2851, -0.3323], [ 0.0387, -0.0047, -0.1450]]], [[[ 0.2419, 0.2015, -0.1305], [ 0.0084, -0.2863, 0.0524], [ 0.1791, -0.1799, -0.0403]]]], device='cuda:0'))])
となり、device='cuda:0'
がいる。この部分でGPUメモリを使う。
したがって、CPUのみのPCでこのモデルを読みだそうとすると下記のエラーが出る。
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
モデルのdeviceを'cpu'に変換してから保存すれば、このエラーは出ずに問題なく読み取れるようになる。