Pythonいぬ

pythonを使った画像処理に関する記事を書いていきます

Pytorch で VAE の実装と画像生成

PytorchでVAEをやってみる。理論的なところは他に説明がたくさんあるので省略する。VAEは最初から自分で書こうとするといろんなところでハマるので、とりあえず動くモデルを載せてみることにする。

VAEモデルの定義

まずはモデルの実装から。学習はネットワーク全体で行い、推定時はデコーダだけ使いたいので、デコーダだけ関数で書いておく。lossもmuとlogvarを外に出すとややこしくなるので、メンバ関数にしておく。

import torch
from torch import nn
from torch.nn import functional as F

class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.efc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 10)
        self.fc22 = nn.Linear(400, 10)
        self.dfc3 = nn.Linear(10, 400)
        self.dfc4 = nn.Linear(400, 784)

    def forward(self, x):
        # encoder                                                                    
        x = x.reshape(x.shape[0], -1) # flatten                                      
        x = F.relu(self.efc1(x))

        # sampler ここではreluしないこと!!                                              
        self.mu = self.fc21(x)
        self.logvar = self.fc22(x)
        std = torch.exp(0.5*self.logvar)
        eps = torch.randn_like(std)
        if self.training:
            x = self.mu + std*eps
        else:
            x = self.mu

        # decoder                                                                    
        x = self.decoder(x)
        return x

    def decoder(self, x):
        x = F.relu(self.dfc3(x))
        x = torch.sigmoid(self.dfc4(x))
        return x

    def loss(self, x_, x):
        mu, logvar = self.mu, self.logvar
        x = x.reshape(x.shape[0], -1) # flatten                                      
        bce = F.binary_cross_entropy(x_, x, reduction='sum')
        kld = 0.5 * torch.sum(mu**2 + logvar.exp() - logvar - 1)
        return bce + kld

途中のself.trainingのif文があるところでは、学習と推定で処理を分けている。

KLDのlogva.exp()-logvarの部分は中間層の広がりやすさなので、中間層のチャネルが増減したときにこれも変化すべき。ということでVAEのKLDには平均ではなく和が使われる。というように直感的に理解している。

この部分は昔からずっと気になっているところ。理論的に正しいかどうかは置いておいて、BCEとKLDの間に収束しやすいハイパーパラメータが多分あるんじゃないか。

なぜかというと、このlossだとBCEの大きさが出力画像サイズに依存してしまっているから。lossが出力画像サイズ依存とかおかしいと思うんだけどこれも自然なんだろうか。出力サイズが大きくなると中間層がガウス分布になりにくくなるというのは、いったいどういうことなんだろう。

まあとりあえずMNISTではこれでいいということで、次は学習。

学習

datasetとtrainloaderを使うとこんなに簡単に書ける。ここでの注意点としては出力がsigmoidで活性化されているため、入力は[0,1]に正規化されていることが必要。間違えてNormarize入れて[-1,1]とかにすると学習できなくなるので注意。早く収束させるためにlrは5 epoch目までは0.01にして、それ以降は0.001とするようにしておく。

import torch
from torchvision import transforms
from torchvision import datasets

if __name__ == '__main__':
    #main()                                                                          

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = VAE().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    # dataset                                                                        
    transform = transforms.ToTensor()
    trainset = datasets.MNIST(root='~/.data', train=True,
                              download=True, transform=transform)
    bs = 128 # batch_size                                                            
    trainloader = DataLoader(trainset, batch_size=bs, shuffle=True)

    # training                                                                       
    nepoch = 10
    model = model.train()
    for iepoch in range(nepoch):
        lr = 0.001 if iepoch>=5 else 0.01
        optimizer.param_groups[0]['lr'] = lr
        for iiter, (x, _) in enumerate(trainloader):
            x = x.to(device)
            x_ = model.forward(x)
            loss = model.loss(x_, x)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print('%03d epoch, loss=%.8f' % (iepoch, loss.item()))

今回使ったモデルの規模ならGPUあれば学習は数分以内に終わるはず。多分CPUでも10分くらいかな。

画像生成

まずはオートエンコーダの出力

f:id:tzmi:20200329220350j:plain

まずまず。KLDの項が強すぎると中間層がただのガウス分布になるため、オートエンコーダの出力は全て同じになる。逆にKLDが弱すぎると乱数から生成した画像がめちゃくちゃになるはず。

では、続いて乱数から画像生成

f:id:tzmi:20200329220409j:plain

変なのもいるけどなんとか数字の形状をとどめているので、まあまあいいんじゃないかな。今回は中間層10チャネルでやったけど、中間層のチャネル数が大きすぎると自由度が大きくなってデータが混ざりにくくなり、KLDが弱すぎる場合と同じように生成画像がめちゃくちゃになる。

今回の結果は、オートエンコーダの出力が結構きちんとできていて、乱数からの画像生成がまあまあという感じだったので、KLDの項が弱すぎたのかもしれない。Lossの重みと中間層のチャネル数をうまく調整しないといけないので、うまく中間層をばらつかせつつきれいな画像を出す(つまり乱数からきれいな画像を生成する)のは難しい。(MNISTならなんとかできるけど、もう少し大きめの画像だともっと難しくなる)

もう少しいろいろやってみようかなと思ったけど、大変なのでこれくらいにしておく。