高解像度の画像を生成できるProgressive GAN (PGGAN)を実装してみた。
色々と苦労があって1週間以上時間を使った。ガチで研究するなら再現で1週間くらいかかるようなくらいはやらないといけないのかもしれない。論文ではさらに評価指標をどうするかなどの細かい考察があるので、すごいとしか言いようがない。
基礎モジュール
ではまず基礎モジュールから説明して、ネットワークを組み上げていくことにする。
Pixel normalization
PGGANではBatchnormを使わずにPixelnormという方法を用いる。チャネル方向に二乗平均をとってsqrtをした値で割り算をする。この値はバッチに依存して変化するので、Batchnormのように値の保存はしない。
import torch from torch import nn class PixelNorm(nn.Module): def forward(self, x): eps = 1e-7 mean = torch.mean(x**2, dim=1, keepdims=True) return x / (torch.sqrt(mean)+eps)
Equalized learning rate
各レイヤの重みを入力チャネルサイズで正規化する。Heの初期化と似た効果を期待するもの。入力チャネルサイズに応じてConvolutionの出力値の分布が変化するのを防ぐ効果がある。
実装上では重みを毎回正規化すると値がどんどん小さくなってしまうので、Pixelnormの出力に対して、重みではなく特徴量自体を正規化する。
モジュール名はWeight scale
だけど、実際にscaleしているのは特徴量であることに注意。特徴量に重みをかけて入力サイズで正規化することは線形処理なので、特徴量を入力サイズで正規化してから重みをかけても同じ計算となる。
class WeightScale(nn.Module): def forward(self, x, gain=2): scale = (gain/x.shape[1])**0.5 return x * scale
Minibatch std.
出力の多様性を保証するためのモジュール。正解データのバッチが持つ多様性(標準偏差平均)をDiscriminatorの判断材料にするというもの。入力されたテンソルに対して、各バッチの標準偏差を求め、チャネルと縦横方向に平均化したテンソルを作る。作ったテンソルを入力テンソルにconcatして出力する。
class MiniBatchStd(nn.Module): def forward(self, x): std = torch.std(x, dim=0, keepdim=True) mean = torch.mean(std, dim=(1,2,3), keepdim=True) n,c,h,w = x.shape mean = torch.ones(n,1,h,w, dtype=x.dtype, device=x.device)*mean return torch.cat((x,mean), dim=1)
Conv2d Module
Conv2d周辺がややこしいので、module化する。Conv2dの入力前に上で定義したWeightScale
を使う。また、Conv2dではzero paddingではなくReflectionPad2d
を使う。これによりzero paddingと比較して画像の端付近にアーティファクトができなくなる。さらに、Batchnormの代わりに上で定義したPixelnormを使う。 レイヤ定義で重みの初期化も行っておく。
class Conv2d(nn.Module): def __init__(self, inch, outch, kernel_size, padding=0): super().__init__() self.layers = nn.Sequential( WeightScale(), nn.ReflectionPad2d(padding), nn.Conv2d(inch, outch, kernel_size, padding=0), PixelNorm(), ) nn.init.kaiming_normal_(self.layers[2].weight) def forward(self, x): return self.layers(x)
ConvModule G
Generator用のConvolutionモジュール。最初の部分だけ特別扱い。
class ConvModuleG(nn.Module): ''' Args: out_size: (int), Ex.: 16 (resolution) inch: (int), Ex.: 256 outch: (int), Ex.: 128 ''' def __init__(self, out_size, inch, outch, first=False): super().__init__() if first: layers = [ Conv2d(inch, outch, 3, padding=1), nn.LeakyReLU(0.2, inplace=False), Conv2d(outch, outch, 3, padding=1), nn.LeakyReLU(0.2, inplace=False), ] else: layers = [ nn.Upsample((out_size, out_size), mode='nearest'), Conv2d(inch, outch, 3, padding=1), nn.LeakyReLU(0.2, inplace=False), Conv2d(outch, outch, 3, padding=1), nn.LeakyReLU(0.2, inplace=False), ] self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x)
ConvModule D
次はDiscriminator用のConvolution module。こちらは最後の部分だけ特別扱いとする。
class ConvModuleD(nn.Module): ''' Args: out_size: (int), Ex.: 16 (resolution) inch: (int), Ex.: 256 outch: (int), Ex.: 128 ''' def __init__(self, out_size, inch, outch, final=False): super().__init__() if final: layers = [ MiniBatchStd(), # final block only Conv2d(inch+1, outch, 3, padding=1), nn.LeakyReLU(0.2, inplace=False), Conv2d(outch, outch, 4, padding=0), nn.LeakyReLU(0.2, inplace=False), nn.Conv2d(outch, 1, 1, padding=0), ] else: layers = [ Conv2d(inch, outch, 3, padding=1), nn.LeakyReLU(0.2, inplace=False), Conv2d(outch, outch, 3, padding=1), nn.LeakyReLU(0.2, inplace=False), nn.AdaptiveAvgPool2d((out_size, out_size)), ] self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x)
Generator
__init__
でとりあえず、全ネットワークを定義しておいて、forward
部分でどこまで使うかを決める。forward
はちょっと長くなるけど、進捗パラメータres
を用いて条件分岐する構造とする。層を増やしたいときはnp.array
の部分を増やせばよい。
class Generator(nn.Module): def __init__(self): super().__init__() # conv modules & toRGBs scale = 1 inchs = np.array([512,256,128,64,32,16], dtype=np.uint32)*scale outchs = np.array([256,128, 64,32,16, 8], dtype=np.uint32)*scale sizes = np.array([4,8,16,32,64,128], dtype=np.uint32) firsts = np.array([True, False, False, False, False, False], dtype=np.bool) blocks, toRGBs = [], [] for s, inch, outch, first in zip(sizes, inchs, outchs, firsts): blocks.append(ConvModuleG(s, inch, outch, first)) toRGBs.append(nn.Conv2d(outch, 3, 1, padding=0)) self.blocks = nn.ModuleList(blocks) self.toRGBs = nn.ModuleList(toRGBs) def forward(self, x, res, eps=1e-7): # to image n,c = x.shape x = x.reshape(n,c//16,4,4) # for the highest resolution res = min(res, len(self.blocks)) # get integer by floor nlayer = max(int(res-eps), 0) for i in range(nlayer): x = self.blocks[i](x) # high resolution x_big = self.blocks[nlayer](x) dst_big = self.toRGBs[nlayer](x_big) if nlayer==0: x = dst_big else: # low resolution x_sml = F.interpolate(x, x_big.shape[2:4], mode='nearest') dst_sml = self.toRGBs[nlayer-1](x_sml) alpha = res - int(res-eps) x = (1-alpha)*dst_sml + alpha*dst_big #return x, n, res return torch.sigmoid(x)
出力は[0,1]の範囲とするためにsigmoidでactivationする。
Discriminator
こちらも進捗パラメータres
を用いた分岐構造となる。定義やforward
処理はGeneratorと同じ。こちらも層を増やしたいときはnp.array
の部分を増やせばよい。
class Discriminator(nn.Module): def __init__(self): super().__init__() self.minbatch_std = MiniBatchStd() # conv modules & toRGBs scale = 1 inchs = np.array([256,128, 64,32,16, 8], dtype=np.uint32)*scale outchs = np.array([512,256,128,64,32,16], dtype=np.uint32)*scale sizes = np.array([1,4,8,16,32,64], dtype=np.uint32) finals = np.array([True, False, False, False, False, False], dtype=np.bool) blocks, fromRGBs = [], [] for s, inch, outch, final in zip(sizes, inchs, outchs, finals): fromRGBs.append(nn.Conv2d(3, inch, 1, padding=0)) blocks.append(ConvModuleD(s, inch, outch, final=final)) self.fromRGBs = nn.ModuleList(fromRGBs) self.blocks = nn.ModuleList(blocks) def forward(self, x, res): # for the highest resolution res = min(res, len(self.blocks)) # get integer by floor eps = 1e-7 n = max(int(res-eps), 0) # high resolution x_big = self.fromRGBs[n](x) x_big = self.blocks[n](x_big) if n==0: x = x_big else: # low resolution x_sml = F.adaptive_avg_pool2d(x, x_big.shape[2:4]) x_sml = self.fromRGBs[n-1](x_sml) alpha = res - int(res-eps) x = (1-alpha)*x_sml + alpha*x_big for i in range(n): x = self.blocks[n-1-i](x) return x
loss関数をWGAN-GPにするので、sigmoidは使わないようにする。
Training
さて、やっと学習部分まできた。学習ではlossにWGAN-GPを用いるのと、mean teacher のように学習させたネットワークのパラメータを移動平均していく。WGAN-GPについては、別の記事で説明しようかと思うので関数だけ。オリジナルと違う部分は進捗パラメータres
が入っているあたり。
def gradient_penalty(netD, real, fake, res, batch_size, gamma=1): device = real.device alpha = torch.rand(batch_size, 1, 1, 1, requires_grad=True).to(device) x = alpha*real + (1-alpha)*fake d_ = netD.forward(x, res) g = torch.autograd.grad(outputs=d_, inputs=x, grad_outputs=torch.ones(d_.shape).to(device), create_graph=True, retain_graph=True,only_inputs=True)[0] g = g.reshape(batch_size, -1) return ((g.norm(2,dim=1)/gamma-1.0)**2).mean()
ここから学習部分。netG
のパラメータ値はnetG_mavg
に代入していき、学習がある程度進んだら、netG_mavg
の値を使って推定を行う。オリジナルよりバッチサイズを半分にしてあるので、lr
も半分にしてある。バッチサイズが小さいので、多様性が小さくなるかもしれない。
f __name__ == '__main__': device = 'cuda' if torch.cuda.is_available() else 'cpu' netG = models.Generator().to(device) netD = models.Discriminator().to(device) netG_mavg = models.Generator().to(device) # moving average optG = torch.optim.Adam(netG.parameters(), lr=0.0005, betas=(0.0, 0.99)) optD = torch.optim.Adam(netD.parameters(), lr=0.0005, betas=(0.0, 0.99)) criterion = torch.nn.BCELoss() # dataset transform = transforms.Compose([transforms.CenterCrop(160), transforms.Resize((128,128)), transforms.ToTensor(), ]) trainset = datasets.CelebA('~/data', download=True, split='train', transform=transform) bs = 8 train_loader = DataLoader(trainset, batch_size=bs, shuffle=True) # training nepoch = 10 losses = [] res_step = 15000 j = 0 # constant random inputs z0 = torch.randn(16, 512*16).to(device) z0 = torch.clamp(z0, -1.,1.) for iepoch in range(nepoch): if j==res_step*6.5: optG.param_groups[0]['lr'] = 0.0001 optD.param_groups[0]['lr'] = 0.0001 for i, data in enumerate(train_loader): x, y = data x = x.to(device) res = j/res_step ### train generator ### z = torch.randn(bs, 512*16).to(x.device) x_ = netG.forward(z, res) d_ = netD.forward(x_, res) # fake lossG = -d_.mean() # WGAN_GP optG.zero_grad() lossG.backward() optG.step() # update netG_mavg by moving average momentum = 0.995 # remain momentum alpha = min(1.0-(1/(j+1)), momentum) for p_mavg, p in zip(netG_mavg.parameters(), netG.parameters()): p_mavg.data = alpha*p_mavg.data + (1.0-alpha)*p.data ### train discriminator ### z = torch.randn(x.shape[0], 512*16).to(x.device) x_ = netG.forward(z, res) x = F.adaptive_avg_pool2d(x, x_.shape[2:4]) d = netD.forward(x, res) # real d_ = netD.forward(x_, res) # fake loss_real = -d.mean() loss_fake = d_.mean() loss_gp = gradient_penalty(netD, x.data, x_.data, res, x.shape[0]) loss_drift = (d**2).mean() beta_gp = 10.0 beta_drift = 0.001 lossD = loss_real + loss_fake + beta_gp*loss_gp + beta_drift*loss_drift optD.zero_grad() lossD.backward() optD.step() print('ep: %02d %04d %04d lossG=%.10f lossD=%.10f' % (iepoch, i, j, lossG.item(), lossD.item())) losses.append([lossG.item(), lossD.item()]) j += 1 if j%500==0: netG_mavg.eval() z = torch.randn(16, 512*16).to(x.device) x_0 = netG_mavg.forward(z0, res) x_ = netG_mavg.forward(z, res) dst = torch.cat((x_0, x_), dim=0) dst = F.interpolate(dst, (128, 128), mode='nearest') dst = dst.to('cpu').detach().numpy() n, c, h, w = dst.shape dst = dst.reshape(4,8,c,h,w) dst = dst.transpose(0,3,1,4,2) dst = dst.reshape(4*h,8*w,3) dst = np.clip(dst*255., 0, 255).astype(np.uint8) skio.imsave('out/img_%03d_%05d.png' % (iepoch, j), dst) losses_ = np.array(losses) niter = losses_.shape[0]//100*100 x_iter = np.arange(100)*(niter//100) + niter//200 plt.plot(x_iter, losses_[:niter,0].reshape(100,-1).mean(1)) plt.plot(x_iter, losses_[:niter,1].reshape(100,-1).mean(1)) plt.tight_layout() plt.savefig('out/loss_%03d_%05d.png' % (iepoch, j)) plt.clf() netG_mavg.train() if j >= res_step*7: break if j%100==0: coolGPU() if j >= res_step*7: break
Loss
今回は128x128まで。まずはloss関数。青がGでオレンジがD。横軸はイテレーション回数で縦軸がlossの値。
出力画像
では、低解像度から結果を乗せていく。今回は解像度は15000イテレーションで切り替わるようにした。この値をもっと大きくするとさらにきれいな画像が得られるんだと思う。データセットはCelebAを使った。オリジナルの画素が256以下なのと、計算リソースがなかったので、128x128まで学習することにした。
4x4
8x8
16x16
32x32
64x64
128x128
まとめ
2020年のGWでまとまった休みがあったのでPGGANを頑張って実装してみたけど、想像の3倍くらい大変だったので、振り返ってみればやらない方がよかったかもしれない。提案者の苦労やソースの中身を何も知らなくてもgithubで簡単に利用できるというのは非常に大切なことなんだなと感じた1週間だった。