Pythonいぬ

pythonを使った画像処理に関する記事を書いていきます

Pytorchで複素数CNNをやってみる

Python の numpy では複素数のdtypeがあるけど、pytorch tensorには複素数型がない。Githubtensor自体を複素数にしているものがあるけど、なんとなくまだcudaサポートしていないとか、発展途上な感じ。

github.com

そもそもtensor複素数にするとconvolutionとかも全部できるようにしないといけないからpytorch自体の改良が難しいんじゃないかなと思う。(Tensorflowはできるっぽい)

ということで、Pytorchのtensor複素数にしない方法で複素数CNNと同じモジュールを実装できるかどうかを確かめてみた。(注意:これがどんなアプリに使えるかはまだわかってません。音声では使えそうですが。。。)

複素数モデル

まずはブロック図

f:id:tzmi:20200509144043p:plain

入力したxから実部と虚部を取り出して、それぞれ実部と虚部のConv2dに入力する。出てきた値を複素数の演算で足し合わせて新しく実部と虚部を作る。最後に実部のみをReLUして出力する。

こんなモデルにしてみたいというモジュールを先に書いてみると以下のようになる。普通にConv2dReLUが交互に並ぶようなネットワークで、これを複素数でできるようにする。

class ComplexMLD(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            ComplexConv2d(1,4,3, padding=1),
            ComplexReLU(),
            ComplexConv2d(4,8,3, padding=1),
            ComplexReLU(),
            ComplexConv2d(8,10,3, padding=1),
            ComplexAbs(),
        )

    def forward(self, x):
        return self.layers(x)

複素数レイヤの定義

最初にnn.Sequentialでまとめたいネットワークを書いてみた。なので、入力は1変数のxになって繋がっている必要がある。計算は複素数にするんだけど、プログラム上は入出力にrealimagが出てこないように工夫する。

ComplexConv2d layer

まずはComplexConv2dモジュールを書いてみる。

class ComplexConv2d(nn.Module):
    def __init__(self, n_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1, bias=True):
        super().__init__()
        self.conv_real = nn.Conv2d(n_channels, out_channels, kernel_size,
                                   stride, padding, dilation, groups, bias)
        self.conv_imag = nn.Conv2d(n_channels, out_channels, kernel_size,
                                   stride, padding, dilation, groups, bias)

    def forward(self, x):
        if len(x.shape)==4: # n,c,w,h
            x_real = x
            x_imag = torch.zeros_like(x)
        else: # n,p,c,w,h
            x_real = x[:,0]
            x_imag = x[:,1]

        x_real_out = self.conv_real(x_real) - self.conv_imag(x_imag)
        x_imag_out = self.conv_real(x_imag) + self.conv_imag(x_real)
        x_real_out = x_real_out.unsqueeze(1)
        x_imag_out = x_imag_out.unsqueeze(1)
        return torch.cat((x_real_out, x_imag_out), dim=1)

Conv2dを使って畳み込みのreal部分とimag部分のモジュールを作った。forwardに入力されたxは、shapeを見て実部のみか複素数かを判断するようにする。複素数の場合は、ちょっとややこしいけどshapeがn,p,c,h,wとなるようなテンソルを想定する。pの部分がキモで、これがrealかimagを示す2つの値となる。

そのあとは単純に複素数の計算をして、もう一度concatしてから出力する。入出力を1変数にまとめないと、nn.Sequentialを通せないので注意。

ComplexReLU layer

次にComplexReLUを書いてみる。

class ComplexReLU(nn.Module):
    def forward(self, x):
        x_real = x[:,0]
        x_imag = x[:,1]

        x_real = F.relu(x_real, inplace=True)
        x_real_out = x_real.unsqueeze(1)
        x_imag_out = x_imag.unsqueeze(1)
        return torch.cat((x_real_out, x_imag_out), dim=1)

こちらは単純に実部のみにReLUをかける。なんでこんなことするかというと、実数のみの場合のReLUでは閾値処理をしているけど、複素数は2次元なので、複素平面上の線で0かそうでないかを分けたいから。また、任意の線ではなく単純に虚軸に沿った線を使う理由は、ComplexConv2dが複素数平面上での回転も表せるはずだからで、多分これでいいと思う。心配だったら回転用のレイヤも定義すればよいだけなわけだし。重要なのは活性化の前後で非線形になるようにするということ。

ComplexAbs layer

最後は出力を実数のみにする方法としてComplexAbsを実装してみる。

class ComplexAbs(nn.Module):
    def forward(self, x):
        x_real = x[:,0]
        x_imag = x[:,1]
        return torch.sqrt(x_real**2 * x_imag**2)

最適化できるかチェック

最適化できるかどうかのチェックはMNISTとかで学習させるまでもない。xを乱数で作ってそれをモデルに代入し、出力y_の平均値をとったlossが下がるどうかを調査すればいい。例えば以下のようにする。

if __name__ == '__main__':

    x = torch.randn(2,1,10,10)
    model = ComplexMLD()
    opt = torch.optim.Adam(model.parameters())

    for i in range(10):
        y_ = model.forward(x)
        loss = torch.mean(y_)

        opt.zero_grad()
        loss.backward()
        opt.step()

        print(i, loss.item())

出力結果。lossが落ちているのできちんと最適化できてそう。

0 0.053851932287216187
1 0.04652361944317818
2 0.040614720433950424
3 0.035857390612363815
4 0.03206970915198326
5 0.028942720964550972
6 0.026352455839514732
7 0.024185555055737495
8 0.022345853969454765
9 0.020740525797009468

まとめ

複素数のCNNを実装できることを確かめてみた。音声認識とかに使えそうに見えるけど、実はどんなアプリに使えるかはまだわかっていない。これが使えるアプリケーションを考えてみるのも楽しいかもしれない。