Pythonいぬ

pythonを使った画像処理に関する記事を書いていきます

Pytorchでクラス継承のsuper()の引数は省略できる

ディープラーニングでPytorchを使っていると、モデルの定義やモデルに組み込むモジュールの定義でクラスを作ることが多い。このとき、nn.Moduleを継承して定義するためsuperは必ず使う。ここでsuperの引数にクラス名を書くため、クラステンプレートのコピーの際にsuperの引数を書き換えるのを忘れて怒られることが多い。

実はpython3系では、クラスの継承部分のsuper()の中身は以下のように省略できる。

class Child(Mother):
    def __init__(self):
        super().__init__()

こうやって書けばsuperの引数を書き換え忘れて怒られることはなくなる。

Pytorchのモデル定義でよく起こること

pytorchではニューラルネットワークのモデルについて、必ずクラスを使ってnn.Moduleを継承して定義する。このモデルの定義について、githubのpytorchのコード見てると、よく以下のようなコードを見つける。

from torch import nn

# networkの定義                                                                 
class NewNet(nn.Module):
    def __init__(self):
        super(NewNet, self).__init__() # superの中に古いネットワークのクラス名を起き間違えるミスが多い
        self.conv2d = nn.Conv2d(1, 2, (3,3), stride=1, padding=1)

    def forward(self, x):
        x = self.conv2d(x)
        return x

if __name__ == '__main__':

    model = NewNet()

superの部分にクラス名が入っているため、新しいネットワークを作るためにクラス名だけを書き換えると、だいたいsuperの部分を書き換え忘れたせいでエラーが起こって不便。 例えばsuperの部分を super(Net, self).__init__() などと書くと以下のようなエラーが出る。

     14 class Net_new(nn.Module):
     15     def __init__(self):
---> 16         super(Net, self).__init__()
     17         self.conv2d = nn.Conv2d(1, 2, (3,3), stride=1, padding=1)
     18 

NameError: name 'Net' is not defined

継承(super)の部分で「Netという名前は定義されていません」と出てしまう。

superの引数省略

python3系ではクラス継承の際のsuperのかっこの中を省略できるため、以下のように書いてもよい。

from torch import nn

# networkの定義                                                                 
class NewNet(nn.Module):
    def __init__(self):
        super().__init__() # python3系ではsuperの中身は省略してよい
        self.conv2d = nn.Conv2d(1, 2, (3,3), stride=1, padding=1)

    def forward(self, x):
        x = self.conv2d(x)
        return x

if __name__ == '__main__':

    model = NewNet()

上記のように書くことで上記のミスはなくなるし、長いクラス名をsuperの中に書くこともなくなるので便利。特に複雑なネットワークを書くときに一部をモジュール化して書くことが多いので、こういった場合に混乱しなくて済む。

注意点として、python2系ではsuperの引数は省略できない。おそらく多くのpytorchのモデル定義でsuperの引数省略をやっていないのはこのためかと。

個人的にはpytenvとかcondaとかdockerとかを使えば2系か3系かの環境を変更できるのだから、python3系にしてsuperの引数省略で書いてしまった方が楽じゃないかなと思う。