Python の numpy では複素数のdtypeがあるけど、pytorch tensorには複素数型がない。Githubでtensor自体を複素数にしているものがあるけど、なんとなくまだcudaサポートしていないとか、発展途上な感じ。
そもそもtensorを複素数にするとconvolutionとかも全部できるようにしないといけないからpytorch自体の改良が難しいんじゃないかなと思う。(Tensorflowはできるっぽい)
ということで、Pytorchのtensorを複素数にしない方法で複素数CNNと同じモジュールを実装できるかどうかを確かめてみた。(注意:これがどんなアプリに使えるかはまだわかってません。音声では使えそうですが。。。)
複素数モデル
まずはブロック図
入力したx
から実部と虚部を取り出して、それぞれ実部と虚部のConv2d
に入力する。出てきた値を複素数の演算で足し合わせて新しく実部と虚部を作る。最後に実部のみをReLU
して出力する。
こんなモデルにしてみたいというモジュールを先に書いてみると以下のようになる。普通にConv2d
とReLU
が交互に並ぶようなネットワークで、これを複素数でできるようにする。
class ComplexMLD(nn.Module): def __init__(self): super().__init__() self.layers = nn.Sequential( ComplexConv2d(1,4,3, padding=1), ComplexReLU(), ComplexConv2d(4,8,3, padding=1), ComplexReLU(), ComplexConv2d(8,10,3, padding=1), ComplexAbs(), ) def forward(self, x): return self.layers(x)
複素数レイヤの定義
最初にnn.Sequential
でまとめたいネットワークを書いてみた。なので、入力は1変数のx
になって繋がっている必要がある。計算は複素数にするんだけど、プログラム上は入出力にreal
とimag
が出てこないように工夫する。
ComplexConv2d layer
まずはComplexConv2d
モジュールを書いてみる。
class ComplexConv2d(nn.Module): def __init__(self, n_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super().__init__() self.conv_real = nn.Conv2d(n_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.conv_imag = nn.Conv2d(n_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) def forward(self, x): if len(x.shape)==4: # n,c,w,h x_real = x x_imag = torch.zeros_like(x) else: # n,p,c,w,h x_real = x[:,0] x_imag = x[:,1] x_real_out = self.conv_real(x_real) - self.conv_imag(x_imag) x_imag_out = self.conv_real(x_imag) + self.conv_imag(x_real) x_real_out = x_real_out.unsqueeze(1) x_imag_out = x_imag_out.unsqueeze(1) return torch.cat((x_real_out, x_imag_out), dim=1)
Conv2dを使って畳み込みのreal部分とimag部分のモジュールを作った。forward
に入力されたx
は、shapeを見て実部のみか複素数かを判断するようにする。複素数の場合は、ちょっとややこしいけどshapeがn,p,c,h,w
となるようなテンソルを想定する。p
の部分がキモで、これがrealかimagを示す2つの値となる。
そのあとは単純に複素数の計算をして、もう一度concatしてから出力する。入出力を1変数にまとめないと、nn.Sequential
を通せないので注意。
ComplexReLU layer
次にComplexReLU
を書いてみる。
class ComplexReLU(nn.Module): def forward(self, x): x_real = x[:,0] x_imag = x[:,1] x_real = F.relu(x_real, inplace=True) x_real_out = x_real.unsqueeze(1) x_imag_out = x_imag.unsqueeze(1) return torch.cat((x_real_out, x_imag_out), dim=1)
こちらは単純に実部のみにReLUをかける。なんでこんなことするかというと、実数のみの場合のReLUでは閾値処理をしているけど、複素数は2次元なので、複素平面上の線で0かそうでないかを分けたいから。また、任意の線ではなく単純に虚軸に沿った線を使う理由は、ComplexConv2dが複素数平面上での回転も表せるはずだからで、多分これでいいと思う。心配だったら回転用のレイヤも定義すればよいだけなわけだし。重要なのは活性化の前後で非線形になるようにするということ。
ComplexAbs layer
最後は出力を実数のみにする方法としてComplexAbs
を実装してみる。
class ComplexAbs(nn.Module): def forward(self, x): x_real = x[:,0] x_imag = x[:,1] return torch.sqrt(x_real**2 * x_imag**2)
最適化できるかチェック
最適化できるかどうかのチェックはMNISTとかで学習させるまでもない。xを乱数で作ってそれをモデルに代入し、出力y_の平均値をとったlossが下がるどうかを調査すればいい。例えば以下のようにする。
if __name__ == '__main__': x = torch.randn(2,1,10,10) model = ComplexMLD() opt = torch.optim.Adam(model.parameters()) for i in range(10): y_ = model.forward(x) loss = torch.mean(y_) opt.zero_grad() loss.backward() opt.step() print(i, loss.item())
出力結果。lossが落ちているのできちんと最適化できてそう。
0 0.053851932287216187 1 0.04652361944317818 2 0.040614720433950424 3 0.035857390612363815 4 0.03206970915198326 5 0.028942720964550972 6 0.026352455839514732 7 0.024185555055737495 8 0.022345853969454765 9 0.020740525797009468
まとめ
複素数のCNNを実装できることを確かめてみた。音声認識とかに使えそうに見えるけど、実はどんなアプリに使えるかはまだわかっていない。これが使えるアプリケーションを考えてみるのも楽しいかもしれない。