Pythonいぬ

pythonを使った画像処理に関する記事を書いていきます

Pytorchでのモデルパラメータの確認

Pytorchでニューラルネットワークモデルのパラメータが更新されているかどうかを確認したいときがある。モデルのパラメータを確認する方法はいくつかあるけど、Pytorchはモジュールごとにモデルを作っていくことが多いので、とりあえず簡単に確認する方法をいくつか書いてみることにする。

list(model.parameters())

まずはひとつめの方法。model.parameters()をlist化するだけでOK。書くパラメータにアクセスするためにforループを書く必要もない。どのレイヤのどのmoduleのパラメータかがわからないので大きいネットワークのパラメータチェックには不向きだけど、module単位の小さいネットワークがちゃんと動いているかどうかのチェックはスムーズにできる。

from torch import nn

model = nn.Linear(3, 3)
print(list(model.parameters()))

出力は以下

[Parameter containing:
tensor([[-0.0419, -0.0525, -0.0323],
        [ 0.1902,  0.3475,  0.1459],
        [ 0.5106, -0.1896,  0.1551]], requires_grad=True), 
Parameter containing:
tensor([-0.1259, -0.0232, -0.1674], requires_grad=True)]

nn.Moduleを継承しているクラスであれば、nn.Linearなどでもparameters()という関数は持っているので、自分でニューラルネットのクラスを作っても同じようにできる。例えば以下のようなNetクラスを作って同じことをしてみる。

from torch import nn
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(3,3),
            nn.ReLU(),
            nn.Linear(3,3),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.layers(x)

model = Net()
print(list(model.parameters()))

結果は以下

[Parameter containing:
tensor([[-0.2279, -0.5007,  0.4614],
        [ 0.5661,  0.5210,  0.2435],
        [ 0.2621, -0.4700,  0.4480]], requires_grad=True), 
Parameter containing:
tensor([ 0.2879, -0.1006,  0.1832], requires_grad=True), 
Parameter containing:
tensor([[-0.0751, -0.4191,  0.4684],
        [-0.3355,  0.2323, -0.1926],
        [ 0.4090, -0.5508, -0.0795]], requires_grad=True), 
Parameter containing:
tensor([-0.0533,  0.0500, -0.4279], requires_grad=True)]

model.state_dict()

model.state_dict()はモデルをsaveするときに使うだけではなく、パラメータの可視化にも使える。list(model.parameters()))との大きな違いは、OrderedDict形式になっているためにどのレイヤかを識別できること。

model = Net()
print(model.state_dict())

下記のような出力が得られる。

OrderedDict([('layers.0.weight', tensor([[ 0.4853,  0.3065,  0.4571],
        [-0.3992, -0.1177,  0.3469],
        [-0.4186,  0.2202, -0.2316]])), 
('layers.0.bias', tensor([-0.1896,  0.4691,  0.5213])), 
('layers.2.weight', tensor([[-0.5379,  0.2572, -0.4723],
        [-0.3963, -0.0715, -0.0839],
        [ 0.5445, -0.4799, -0.2490]])), 
('layers.2.bias', tensor([ 0.3794, -0.4309, -0.2685]))])

各パラメータには名前が付いているので、keyで取り出すこともできる。

stdict = model.state_dict()
print(stdict['layers.0.weight'])

出力は以下

tensor([[ 0.4853,  0.3065,  0.4571],
        [-0.3992, -0.1177,  0.3469],
        [-0.4186,  0.2202, -0.2316]])

ネットワークが大きくなるとmodel.state_dict()の出力が膨大になって、keyの確認が難しくなる。keyだけを取り出したいときは下記のようにすればよい

print(stdict.keys())

k = list(stdict.keys())[0]
print(k)
print(stdict[k])

OrderedDictはindexでは要素にアクセスできないので、listにしてからindexにアクセスする。

odict_keys(['layers.0.weight', 'layers.0.bias', 'layers.2.weight', 'layers.2.bias'])

layers.0.weight
tensor([[ 0.4853,  0.3065,  0.4571],
        [-0.3992, -0.1177,  0.3469],
        [-0.4186,  0.2202, -0.2316]])

ここで紹介した2つの方法は、直接メンバをたたいたり、model.modules()などを使ってひとつひとつネットワークを掘ってアクセスする方法よりも便利なんじゃないかな。もっと簡単な方法があったらまた追記します。