Pythonいぬ

pythonを使った画像処理に関する記事を書いていきます

Pytorchでネットワーク重みの初期値を固定

Pytorchでネットワークの挙動に再現性を持たせたい。しかしPytorchでは、デフォルトではネットワークの重みは乱数で初期化され、乱数のシード値も固定されていない。Pytorchで乱数シードを固定してネットワークの初期値に再現性をもたせる方法を調べてみた。

シードの固定方法

torch.manual_seed()を使えばよい。下記のように書くだけでネットワーク重みの初期値は毎回固定できる。

import torch
torch.manual_seed(0)

出力の例

ipythonでtorchvision.models.mnasnet0_5()を使って重みを調査してみる。

from torchvision import models
models.mnasnet0_5().classifier[1].weight

出力は以下

Parameter containing:
tensor([[-0.0167, -0.0166, -0.0266,  ..., -0.0541, -0.0324, -0.0536],
        [ 0.0462,  0.0140, -0.0212,  ...,  0.0460, -0.0462, -0.0339],
        [-0.0483, -0.0484, -0.0401,  ...,  0.0393, -0.0112, -0.0014],
        ...,
        [ 0.0185,  0.0113, -0.0510,  ..., -0.0458, -0.0275,  0.0108],
        [ 0.0466, -0.0395,  0.0148,  ..., -0.0248,  0.0242,  0.0371],
        [ 0.0362, -0.0265,  0.0336,  ..., -0.0454, -0.0443,  0.0408]],
       requires_grad=True)

もう一度同じコマンドを打つと中身が変わっていることを確認できる。

Parameter containing:
tensor([[-0.0400, -0.0021, -0.0153,  ...,  0.0244,  0.0512,  0.0133],
        [-0.0416, -0.0242,  0.0141,  ..., -0.0252, -0.0169, -0.0430],
        [ 0.0299,  0.0126,  0.0344,  ..., -0.0131,  0.0430, -0.0427],
        ...,
        [-0.0250,  0.0368, -0.0071,  ...,  0.0407, -0.0460,  0.0309],
        [-0.0364,  0.0379,  0.0544,  ...,  0.0134,  0.0205, -0.0087],
        [-0.0055, -0.0214, -0.0508,  ...,  0.0316,  0.0391, -0.0261]],
       requires_grad=True)

seed固定した場合

ipythonで下記のように打ってみる。

import torch
from torchvision import models

torch.manual_seed(0)
models.mnasnet0_5().classifier[1].weight

torch.manual_seed(0)
models.mnasnet0_5().classifier[1].weight

以下のように出力されるパラメータが同じ値になっていることを確認できる。

Parameter containing:
tensor([[-0.0167, -0.0166, -0.0266,  ..., -0.0541, -0.0324, -0.0536],
        [ 0.0462,  0.0140, -0.0212,  ...,  0.0460, -0.0462, -0.0339],
        [-0.0483, -0.0484, -0.0401,  ...,  0.0393, -0.0112, -0.0014],
        ...,
        [ 0.0185,  0.0113, -0.0510,  ..., -0.0458, -0.0275,  0.0108],
        [ 0.0466, -0.0395,  0.0148,  ..., -0.0248,  0.0242,  0.0371],
        [ 0.0362, -0.0265,  0.0336,  ..., -0.0454, -0.0443,  0.0408]],
       requires_grad=True)

Parameter containing:
tensor([[-0.0167, -0.0166, -0.0266,  ..., -0.0541, -0.0324, -0.0536],
        [ 0.0462,  0.0140, -0.0212,  ...,  0.0460, -0.0462, -0.0339],
        [-0.0483, -0.0484, -0.0401,  ...,  0.0393, -0.0112, -0.0014],
        ...,
        [ 0.0185,  0.0113, -0.0510,  ..., -0.0458, -0.0275,  0.0108],
        [ 0.0466, -0.0395,  0.0148,  ..., -0.0248,  0.0242,  0.0371],
        [ 0.0362, -0.0265,  0.0336,  ..., -0.0454, -0.0443,  0.0408]],
       requires_grad=True)

ということで、ネットワーク定義の直前にtorch.manual_seed(0)を書いておけば重みの初期値は固定できる。