ディープラーニングでPytorchを使っていると、モデルの定義やモデルに組み込むモジュールの定義でクラスを作ることが多い。このとき、nn.Moduleを継承して定義するためsuperは必ず使う。ここでsuperの引数にクラス名を書くため、クラステンプレートのコピーの際にsuperの引数を書き換えるのを忘れて怒られることが多い。
実はpython3系では、クラスの継承部分のsuper()の中身は以下のように省略できる。
class Child(Mother): def __init__(self): super().__init__()
こうやって書けばsuperの引数を書き換え忘れて怒られることはなくなる。
Pytorchのモデル定義でよく起こること
pytorchではニューラルネットワークのモデルについて、必ずクラスを使ってnn.Moduleを継承して定義する。このモデルの定義について、githubのpytorchのコード見てると、よく以下のようなコードを見つける。
from torch import nn # networkの定義 class NewNet(nn.Module): def __init__(self): super(NewNet, self).__init__() # superの中に古いネットワークのクラス名を起き間違えるミスが多い self.conv2d = nn.Conv2d(1, 2, (3,3), stride=1, padding=1) def forward(self, x): x = self.conv2d(x) return x if __name__ == '__main__': model = NewNet()
superの部分にクラス名が入っているため、新しいネットワークを作るためにクラス名だけを書き換えると、だいたいsuperの部分を書き換え忘れたせいでエラーが起こって不便。 例えばsuperの部分を super(Net, self).__init__() などと書くと以下のようなエラーが出る。
14 class Net_new(nn.Module): 15 def __init__(self): ---> 16 super(Net, self).__init__() 17 self.conv2d = nn.Conv2d(1, 2, (3,3), stride=1, padding=1) 18 NameError: name 'Net' is not defined
継承(super)の部分で「Netという名前は定義されていません」と出てしまう。
superの引数省略
python3系ではクラス継承の際のsuperのかっこの中を省略できるため、以下のように書いてもよい。
from torch import nn # networkの定義 class NewNet(nn.Module): def __init__(self): super().__init__() # python3系ではsuperの中身は省略してよい self.conv2d = nn.Conv2d(1, 2, (3,3), stride=1, padding=1) def forward(self, x): x = self.conv2d(x) return x if __name__ == '__main__': model = NewNet()
上記のように書くことで上記のミスはなくなるし、長いクラス名をsuperの中に書くこともなくなるので便利。特に複雑なネットワークを書くときに一部をモジュール化して書くことが多いので、こういった場合に混乱しなくて済む。
注意点として、python2系ではsuperの引数は省略できない。おそらく多くのpytorchのモデル定義でsuperの引数省略をやっていないのはこのためかと。
個人的にはpytenvとかcondaとかdockerとかを使えば2系か3系かの環境を変更できるのだから、python3系にしてsuperの引数省略で書いてしまった方が楽じゃないかなと思う。