PytorchでVAEをやってみる。理論的なところは他に説明がたくさんあるので省略する。VAEは最初から自分で書こうとするといろんなところでハマるので、とりあえず動くモデルを載せてみることにする。
VAEモデルの定義
まずはモデルの実装から。学習はネットワーク全体で行い、推定時はデコーダだけ使いたいので、デコーダだけ関数で書いておく。lossもmuとlogvarを外に出すとややこしくなるので、メンバ関数にしておく。
import torch from torch import nn from torch.nn import functional as F class VAE(nn.Module): def __init__(self): super().__init__() self.efc1 = nn.Linear(784, 400) self.fc21 = nn.Linear(400, 10) self.fc22 = nn.Linear(400, 10) self.dfc3 = nn.Linear(10, 400) self.dfc4 = nn.Linear(400, 784) def forward(self, x): # encoder x = x.reshape(x.shape[0], -1) # flatten x = F.relu(self.efc1(x)) # sampler ここではreluしないこと!! self.mu = self.fc21(x) self.logvar = self.fc22(x) std = torch.exp(0.5*self.logvar) eps = torch.randn_like(std) if self.training: x = self.mu + std*eps else: x = self.mu # decoder x = self.decoder(x) return x def decoder(self, x): x = F.relu(self.dfc3(x)) x = torch.sigmoid(self.dfc4(x)) return x def loss(self, x_, x): mu, logvar = self.mu, self.logvar x = x.reshape(x.shape[0], -1) # flatten bce = F.binary_cross_entropy(x_, x, reduction='sum') kld = 0.5 * torch.sum(mu**2 + logvar.exp() - logvar - 1) return bce + kld
途中のself.training
のif文があるところでは、学習と推定で処理を分けている。
KLDのlogva.exp()-logvar
の部分は中間層の広がりやすさなので、中間層のチャネルが増減したときにこれも変化すべき。ということでVAEのKLDには平均ではなく和が使われる。というように直感的に理解している。
この部分は昔からずっと気になっているところ。理論的に正しいかどうかは置いておいて、BCEとKLDの間に収束しやすいハイパーパラメータが多分あるんじゃないか。
なぜかというと、このlossだとBCEの大きさが出力画像サイズに依存してしまっているから。lossが出力画像サイズ依存とかおかしいと思うんだけどこれも自然なんだろうか。出力サイズが大きくなると中間層がガウス分布になりにくくなるというのは、いったいどういうことなんだろう。
まあとりあえずMNISTではこれでいいということで、次は学習。
学習
datasetとtrainloaderを使うとこんなに簡単に書ける。ここでの注意点としては出力がsigmoidで活性化されているため、入力は[0,1]に正規化されていることが必要。間違えてNormarize入れて[-1,1]とかにすると学習できなくなるので注意。早く収束させるためにlrは5 epoch目までは0.01にして、それ以降は0.001とするようにしておく。
import torch from torchvision import transforms from torchvision import datasets if __name__ == '__main__': #main() device = 'cuda' if torch.cuda.is_available() else 'cpu' model = VAE().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # dataset transform = transforms.ToTensor() trainset = datasets.MNIST(root='~/.data', train=True, download=True, transform=transform) bs = 128 # batch_size trainloader = DataLoader(trainset, batch_size=bs, shuffle=True) # training nepoch = 10 model = model.train() for iepoch in range(nepoch): lr = 0.001 if iepoch>=5 else 0.01 optimizer.param_groups[0]['lr'] = lr for iiter, (x, _) in enumerate(trainloader): x = x.to(device) x_ = model.forward(x) loss = model.loss(x_, x) optimizer.zero_grad() loss.backward() optimizer.step() print('%03d epoch, loss=%.8f' % (iepoch, loss.item()))
今回使ったモデルの規模ならGPUあれば学習は数分以内に終わるはず。多分CPUでも10分くらいかな。
画像生成
まずはオートエンコーダの出力
まずまず。KLDの項が強すぎると中間層がただのガウス分布になるため、オートエンコーダの出力は全て同じになる。逆にKLDが弱すぎると乱数から生成した画像がめちゃくちゃになるはず。
では、続いて乱数から画像生成
変なのもいるけどなんとか数字の形状をとどめているので、まあまあいいんじゃないかな。今回は中間層10チャネルでやったけど、中間層のチャネル数が大きすぎると自由度が大きくなってデータが混ざりにくくなり、KLDが弱すぎる場合と同じように生成画像がめちゃくちゃになる。
今回の結果は、オートエンコーダの出力が結構きちんとできていて、乱数からの画像生成がまあまあという感じだったので、KLDの項が弱すぎたのかもしれない。Lossの重みと中間層のチャネル数をうまく調整しないといけないので、うまく中間層をばらつかせつつきれいな画像を出す(つまり乱数からきれいな画像を生成する)のは難しい。(MNISTならなんとかできるけど、もう少し大きめの画像だともっと難しくなる)
もう少しいろいろやってみようかなと思ったけど、大変なのでこれくらいにしておく。