Pytorchでニューラルネットワークモデルのパラメータが更新されているかどうかを確認したいときがある。モデルのパラメータを確認する方法はいくつかあるけど、Pytorchはモジュールごとにモデルを作っていくことが多いので、とりあえず簡単に確認する方法をいくつか書いてみることにする。
list(model.parameters())
まずはひとつめの方法。model.parameters()
をlist化するだけでOK。書くパラメータにアクセスするためにforループを書く必要もない。どのレイヤのどのmoduleのパラメータかがわからないので大きいネットワークのパラメータチェックには不向きだけど、module単位の小さいネットワークがちゃんと動いているかどうかのチェックはスムーズにできる。
from torch import nn model = nn.Linear(3, 3) print(list(model.parameters()))
出力は以下
[Parameter containing: tensor([[-0.0419, -0.0525, -0.0323], [ 0.1902, 0.3475, 0.1459], [ 0.5106, -0.1896, 0.1551]], requires_grad=True), Parameter containing: tensor([-0.1259, -0.0232, -0.1674], requires_grad=True)]
nn.Moduleを継承しているクラスであれば、nn.Linearなどでもparameters()
という関数は持っているので、自分でニューラルネットのクラスを作っても同じようにできる。例えば以下のようなNetクラスを作って同じことをしてみる。
from torch import nn class Net(nn.Module): def __init__(self): super().__init__() self.layers = nn.Sequential( nn.Linear(3,3), nn.ReLU(), nn.Linear(3,3), nn.Sigmoid(), ) def forward(self, x): return self.layers(x) model = Net() print(list(model.parameters()))
結果は以下
[Parameter containing: tensor([[-0.2279, -0.5007, 0.4614], [ 0.5661, 0.5210, 0.2435], [ 0.2621, -0.4700, 0.4480]], requires_grad=True), Parameter containing: tensor([ 0.2879, -0.1006, 0.1832], requires_grad=True), Parameter containing: tensor([[-0.0751, -0.4191, 0.4684], [-0.3355, 0.2323, -0.1926], [ 0.4090, -0.5508, -0.0795]], requires_grad=True), Parameter containing: tensor([-0.0533, 0.0500, -0.4279], requires_grad=True)]
model.state_dict()
model.state_dict()
はモデルをsaveするときに使うだけではなく、パラメータの可視化にも使える。list(model.parameters()))
との大きな違いは、OrderedDict形式になっているためにどのレイヤかを識別できること。
model = Net()
print(model.state_dict())
下記のような出力が得られる。
OrderedDict([('layers.0.weight', tensor([[ 0.4853, 0.3065, 0.4571], [-0.3992, -0.1177, 0.3469], [-0.4186, 0.2202, -0.2316]])), ('layers.0.bias', tensor([-0.1896, 0.4691, 0.5213])), ('layers.2.weight', tensor([[-0.5379, 0.2572, -0.4723], [-0.3963, -0.0715, -0.0839], [ 0.5445, -0.4799, -0.2490]])), ('layers.2.bias', tensor([ 0.3794, -0.4309, -0.2685]))])
各パラメータには名前が付いているので、keyで取り出すこともできる。
stdict = model.state_dict() print(stdict['layers.0.weight'])
出力は以下
tensor([[ 0.4853, 0.3065, 0.4571], [-0.3992, -0.1177, 0.3469], [-0.4186, 0.2202, -0.2316]])
ネットワークが大きくなるとmodel.state_dict()
の出力が膨大になって、keyの確認が難しくなる。keyだけを取り出したいときは下記のようにすればよい
print(stdict.keys()) k = list(stdict.keys())[0] print(k) print(stdict[k])
OrderedDictはindexでは要素にアクセスできないので、listにしてからindexにアクセスする。
odict_keys(['layers.0.weight', 'layers.0.bias', 'layers.2.weight', 'layers.2.bias']) layers.0.weight tensor([[ 0.4853, 0.3065, 0.4571], [-0.3992, -0.1177, 0.3469], [-0.4186, 0.2202, -0.2316]])
ここで紹介した2つの方法は、直接メンバをたたいたり、model.modules()
などを使ってひとつひとつネットワークを掘ってアクセスする方法よりも便利なんじゃないかな。もっと簡単な方法があったらまた追記します。