Pytorchでネットワークの挙動に再現性を持たせたい。しかしPytorchでは、デフォルトではネットワークの重みは乱数で初期化され、乱数のシード値も固定されていない。Pytorchで乱数シードを固定してネットワークの初期値に再現性をもたせる方法を調べてみた。
シードの固定方法
torch.manual_seed()
を使えばよい。下記のように書くだけでネットワーク重みの初期値は毎回固定できる。
import torch torch.manual_seed(0)
出力の例
ipythonでtorchvision.models.mnasnet0_5()
を使って重みを調査してみる。
from torchvision import models models.mnasnet0_5().classifier[1].weight
出力は以下
Parameter containing: tensor([[-0.0167, -0.0166, -0.0266, ..., -0.0541, -0.0324, -0.0536], [ 0.0462, 0.0140, -0.0212, ..., 0.0460, -0.0462, -0.0339], [-0.0483, -0.0484, -0.0401, ..., 0.0393, -0.0112, -0.0014], ..., [ 0.0185, 0.0113, -0.0510, ..., -0.0458, -0.0275, 0.0108], [ 0.0466, -0.0395, 0.0148, ..., -0.0248, 0.0242, 0.0371], [ 0.0362, -0.0265, 0.0336, ..., -0.0454, -0.0443, 0.0408]], requires_grad=True)
もう一度同じコマンドを打つと中身が変わっていることを確認できる。
Parameter containing: tensor([[-0.0400, -0.0021, -0.0153, ..., 0.0244, 0.0512, 0.0133], [-0.0416, -0.0242, 0.0141, ..., -0.0252, -0.0169, -0.0430], [ 0.0299, 0.0126, 0.0344, ..., -0.0131, 0.0430, -0.0427], ..., [-0.0250, 0.0368, -0.0071, ..., 0.0407, -0.0460, 0.0309], [-0.0364, 0.0379, 0.0544, ..., 0.0134, 0.0205, -0.0087], [-0.0055, -0.0214, -0.0508, ..., 0.0316, 0.0391, -0.0261]], requires_grad=True)
seed固定した場合
ipythonで下記のように打ってみる。
import torch from torchvision import models torch.manual_seed(0) models.mnasnet0_5().classifier[1].weight torch.manual_seed(0) models.mnasnet0_5().classifier[1].weight
以下のように出力されるパラメータが同じ値になっていることを確認できる。
Parameter containing: tensor([[-0.0167, -0.0166, -0.0266, ..., -0.0541, -0.0324, -0.0536], [ 0.0462, 0.0140, -0.0212, ..., 0.0460, -0.0462, -0.0339], [-0.0483, -0.0484, -0.0401, ..., 0.0393, -0.0112, -0.0014], ..., [ 0.0185, 0.0113, -0.0510, ..., -0.0458, -0.0275, 0.0108], [ 0.0466, -0.0395, 0.0148, ..., -0.0248, 0.0242, 0.0371], [ 0.0362, -0.0265, 0.0336, ..., -0.0454, -0.0443, 0.0408]], requires_grad=True) Parameter containing: tensor([[-0.0167, -0.0166, -0.0266, ..., -0.0541, -0.0324, -0.0536], [ 0.0462, 0.0140, -0.0212, ..., 0.0460, -0.0462, -0.0339], [-0.0483, -0.0484, -0.0401, ..., 0.0393, -0.0112, -0.0014], ..., [ 0.0185, 0.0113, -0.0510, ..., -0.0458, -0.0275, 0.0108], [ 0.0466, -0.0395, 0.0148, ..., -0.0248, 0.0242, 0.0371], [ 0.0362, -0.0265, 0.0336, ..., -0.0454, -0.0443, 0.0408]], requires_grad=True)
ということで、ネットワーク定義の直前にtorch.manual_seed(0)
を書いておけば重みの初期値は固定できる。