Pythonいぬ

pythonを使った画像処理に関する記事を書いていきます

PytorchでPGGANを実装する

高解像度の画像を生成できるProgressive GAN (PGGAN)を実装してみた。

f:id:tzmi:20200507215124j:plain

色々と苦労があって1週間以上時間を使った。ガチで研究するなら再現で1週間くらいかかるようなくらいはやらないといけないのかもしれない。論文ではさらに評価指標をどうするかなどの細かい考察があるので、すごいとしか言いようがない。

基礎モジュール

ではまず基礎モジュールから説明して、ネットワークを組み上げていくことにする。

Pixel normalization

PGGANではBatchnormを使わずにPixelnormという方法を用いる。チャネル方向に二乗平均をとってsqrtをした値で割り算をする。この値はバッチに依存して変化するので、Batchnormのように値の保存はしない。

import torch
from torch import nn

class PixelNorm(nn.Module):
    def forward(self, x):
        eps = 1e-7
        mean = torch.mean(x**2, dim=1, keepdims=True)
        return x / (torch.sqrt(mean)+eps)

Equalized learning rate

各レイヤの重みを入力チャネルサイズで正規化する。Heの初期化と似た効果を期待するもの。入力チャネルサイズに応じてConvolutionの出力値の分布が変化するのを防ぐ効果がある。 実装上では重みを毎回正規化すると値がどんどん小さくなってしまうので、Pixelnormの出力に対して、重みではなく特徴量自体を正規化する。 モジュール名はWeight scaleだけど、実際にscaleしているのは特徴量であることに注意。特徴量に重みをかけて入力サイズで正規化することは線形処理なので、特徴量を入力サイズで正規化してから重みをかけても同じ計算となる。

class WeightScale(nn.Module):
    def forward(self, x, gain=2):
        scale = (gain/x.shape[1])**0.5
        return x * scale

Minibatch std.

出力の多様性を保証するためのモジュール。正解データのバッチが持つ多様性(標準偏差平均)をDiscriminatorの判断材料にするというもの。入力されたテンソルに対して、各バッチの標準偏差を求め、チャネルと縦横方向に平均化したテンソルを作る。作ったテンソルを入力テンソルにconcatして出力する。

class MiniBatchStd(nn.Module):
    def forward(self, x):
        std = torch.std(x, dim=0, keepdim=True)
        mean = torch.mean(std, dim=(1,2,3), keepdim=True)
        n,c,h,w = x.shape
        mean = torch.ones(n,1,h,w, dtype=x.dtype, device=x.device)*mean
        return torch.cat((x,mean), dim=1)

Conv2d Module

Conv2d周辺がややこしいので、module化する。Conv2dの入力前に上で定義したWeightScaleを使う。また、Conv2dではzero paddingではなくReflectionPad2dを使う。これによりzero paddingと比較して画像の端付近にアーティファクトができなくなる。さらに、Batchnormの代わりに上で定義したPixelnormを使う。 レイヤ定義で重みの初期化も行っておく。

class Conv2d(nn.Module):
    def __init__(self, inch, outch, kernel_size, padding=0):
        super().__init__()
        self.layers = nn.Sequential(
            WeightScale(),
            nn.ReflectionPad2d(padding),
            nn.Conv2d(inch, outch, kernel_size, padding=0),
            PixelNorm(),
            )
        nn.init.kaiming_normal_(self.layers[2].weight)

    def forward(self, x):
        return self.layers(x)

ConvModule G

Generator用のConvolutionモジュール。最初の部分だけ特別扱い。

class ConvModuleG(nn.Module):
    '''
    Args:
        out_size: (int), Ex.: 16 (resolution)
        inch: (int),  Ex.: 256
        outch: (int), Ex.: 128
    '''
    def __init__(self, out_size, inch, outch, first=False):
        super().__init__()

        if first:
            layers = [
                Conv2d(inch, outch, 3, padding=1),
                nn.LeakyReLU(0.2, inplace=False),
                Conv2d(outch, outch, 3, padding=1),
                nn.LeakyReLU(0.2, inplace=False),
            ]

        else:
            layers = [
                nn.Upsample((out_size, out_size), mode='nearest'),
                Conv2d(inch, outch, 3, padding=1),
                nn.LeakyReLU(0.2, inplace=False),
                Conv2d(outch, outch, 3, padding=1),
                nn.LeakyReLU(0.2, inplace=False),
            ]

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

ConvModule D

次はDiscriminator用のConvolution module。こちらは最後の部分だけ特別扱いとする。

class ConvModuleD(nn.Module):
    '''
    Args:
        out_size: (int), Ex.: 16 (resolution)
        inch: (int),  Ex.: 256
        outch: (int), Ex.: 128
    '''
    def __init__(self, out_size, inch, outch, final=False):
        super().__init__()

        if final:
            layers = [
                MiniBatchStd(), # final block only
                Conv2d(inch+1, outch, 3, padding=1),
                nn.LeakyReLU(0.2, inplace=False),
                Conv2d(outch, outch, 4, padding=0), 
                nn.LeakyReLU(0.2, inplace=False),
                nn.Conv2d(outch, 1, 1, padding=0), 
            ]
        else:
            layers = [
                Conv2d(inch, outch, 3, padding=1),
                nn.LeakyReLU(0.2, inplace=False),
                Conv2d(outch, outch, 3, padding=1),
                nn.LeakyReLU(0.2, inplace=False),
                nn.AdaptiveAvgPool2d((out_size, out_size)),
            ]

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

Generator

__init__でとりあえず、全ネットワークを定義しておいて、forward部分でどこまで使うかを決める。forwardはちょっと長くなるけど、進捗パラメータresを用いて条件分岐する構造とする。層を増やしたいときはnp.arrayの部分を増やせばよい。

class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        # conv modules & toRGBs
        scale = 1
        inchs  = np.array([512,256,128,64,32,16], dtype=np.uint32)*scale
        outchs = np.array([256,128, 64,32,16, 8], dtype=np.uint32)*scale
        sizes = np.array([4,8,16,32,64,128], dtype=np.uint32)
        firsts = np.array([True, False, False, False, False, False], dtype=np.bool)
        blocks, toRGBs = [], []
        for s, inch, outch, first in zip(sizes, inchs, outchs, firsts):
            blocks.append(ConvModuleG(s, inch, outch, first))
            toRGBs.append(nn.Conv2d(outch, 3, 1, padding=0))

        self.blocks = nn.ModuleList(blocks)
        self.toRGBs = nn.ModuleList(toRGBs)

    def forward(self, x, res, eps=1e-7):
        # to image
        n,c = x.shape
        x = x.reshape(n,c//16,4,4)

        # for the highest resolution
        res = min(res, len(self.blocks))

        # get integer by floor
        nlayer = max(int(res-eps), 0)
        for i in range(nlayer):
            x = self.blocks[i](x)

        # high resolution
        x_big = self.blocks[nlayer](x)
        dst_big = self.toRGBs[nlayer](x_big)

        if nlayer==0:
            x = dst_big
        else:
            # low resolution
            x_sml = F.interpolate(x, x_big.shape[2:4], mode='nearest')
            dst_sml = self.toRGBs[nlayer-1](x_sml)
            alpha = res - int(res-eps)
            x = (1-alpha)*dst_sml + alpha*dst_big

        #return x, n, res
        return torch.sigmoid(x)

出力は[0,1]の範囲とするためにsigmoidでactivationする。

Discriminator

こちらも進捗パラメータresを用いた分岐構造となる。定義やforward処理はGeneratorと同じ。こちらも層を増やしたいときはnp.arrayの部分を増やせばよい。

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.minbatch_std = MiniBatchStd()

        # conv modules & toRGBs
        scale = 1
        inchs = np.array([256,128, 64,32,16, 8], dtype=np.uint32)*scale
        outchs  = np.array([512,256,128,64,32,16], dtype=np.uint32)*scale
        sizes = np.array([1,4,8,16,32,64], dtype=np.uint32)
        finals = np.array([True, False, False, False, False, False], dtype=np.bool)
        blocks, fromRGBs = [], []
        for s, inch, outch, final in zip(sizes, inchs, outchs, finals):
            fromRGBs.append(nn.Conv2d(3, inch, 1, padding=0))
            blocks.append(ConvModuleD(s, inch, outch, final=final))

        self.fromRGBs = nn.ModuleList(fromRGBs)
        self.blocks = nn.ModuleList(blocks)

    def forward(self, x, res):
        # for the highest resolution
        res = min(res, len(self.blocks))

        # get integer by floor
        eps = 1e-7
        n = max(int(res-eps), 0)

        # high resolution
        x_big = self.fromRGBs[n](x)
        x_big = self.blocks[n](x_big)

        if n==0:
            x = x_big
        else:
            # low resolution
            x_sml = F.adaptive_avg_pool2d(x, x_big.shape[2:4])
            x_sml = self.fromRGBs[n-1](x_sml)
            alpha = res - int(res-eps)
            x = (1-alpha)*x_sml + alpha*x_big

        for i in range(n):
            x = self.blocks[n-1-i](x)

        return x

loss関数をWGAN-GPにするので、sigmoidは使わないようにする。

Training

さて、やっと学習部分まできた。学習ではlossにWGAN-GPを用いるのと、mean teacher のように学習させたネットワークのパラメータを移動平均していく。WGAN-GPについては、別の記事で説明しようかと思うので関数だけ。オリジナルと違う部分は進捗パラメータresが入っているあたり。

def gradient_penalty(netD, real, fake, res, batch_size, gamma=1):
    device = real.device
    alpha = torch.rand(batch_size, 1, 1, 1, requires_grad=True).to(device)
    x = alpha*real + (1-alpha)*fake
    d_ = netD.forward(x, res)
    g = torch.autograd.grad(outputs=d_, inputs=x,
                            grad_outputs=torch.ones(d_.shape).to(device),
                            create_graph=True, retain_graph=True,only_inputs=True)[0]
    g = g.reshape(batch_size, -1)
    return ((g.norm(2,dim=1)/gamma-1.0)**2).mean()

ここから学習部分。netGのパラメータ値はnetG_mavgに代入していき、学習がある程度進んだら、netG_mavgの値を使って推定を行う。オリジナルよりバッチサイズを半分にしてあるので、lrも半分にしてある。バッチサイズが小さいので、多様性が小さくなるかもしれない。

f __name__ == '__main__':

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    netG = models.Generator().to(device)
    netD = models.Discriminator().to(device)
    netG_mavg = models.Generator().to(device) # moving average
    optG = torch.optim.Adam(netG.parameters(), lr=0.0005, betas=(0.0, 0.99))
    optD = torch.optim.Adam(netD.parameters(), lr=0.0005, betas=(0.0, 0.99))
    criterion = torch.nn.BCELoss()

    # dataset
    transform = transforms.Compose([transforms.CenterCrop(160),
                                    transforms.Resize((128,128)),
                                    transforms.ToTensor(), ])

    trainset = datasets.CelebA('~/data', download=True, split='train',
                               transform=transform)

    bs = 8
    train_loader = DataLoader(trainset, batch_size=bs, shuffle=True)

    # training
    nepoch = 10
    losses = []
    res_step = 15000
    j = 0
    # constant random inputs
    z0 = torch.randn(16, 512*16).to(device)
    z0 = torch.clamp(z0, -1.,1.)
    for iepoch in range(nepoch):
        if j==res_step*6.5:
            optG.param_groups[0]['lr'] = 0.0001
            optD.param_groups[0]['lr'] = 0.0001

        for i, data in enumerate(train_loader):
            x, y = data
            x = x.to(device)
            res = j/res_step

            ### train generator ###
            z = torch.randn(bs, 512*16).to(x.device)
            x_ = netG.forward(z, res)
            d_ = netD.forward(x_, res) # fake
            lossG = -d_.mean() # WGAN_GP

            optG.zero_grad()
            lossG.backward()
            optG.step()

            # update netG_mavg by moving average
            momentum = 0.995 # remain momentum
            alpha = min(1.0-(1/(j+1)), momentum)
            for p_mavg, p in zip(netG_mavg.parameters(), netG.parameters()):
                p_mavg.data = alpha*p_mavg.data + (1.0-alpha)*p.data

            ### train discriminator ###
            z = torch.randn(x.shape[0], 512*16).to(x.device)
            x_ = netG.forward(z, res)
            x = F.adaptive_avg_pool2d(x, x_.shape[2:4])
            d = netD.forward(x, res)   # real
            d_ = netD.forward(x_, res) # fake
            loss_real = -d.mean()
            loss_fake = d_.mean()
            loss_gp = gradient_penalty(netD, x.data, x_.data, res, x.shape[0])
            loss_drift = (d**2).mean()

            beta_gp = 10.0
            beta_drift = 0.001
            lossD = loss_real + loss_fake + beta_gp*loss_gp + beta_drift*loss_drift

            optD.zero_grad()
            lossD.backward()
            optD.step()

            print('ep: %02d %04d %04d lossG=%.10f lossD=%.10f' %
                  (iepoch, i, j, lossG.item(), lossD.item()))

            losses.append([lossG.item(), lossD.item()])
            j += 1

            if j%500==0:
                netG_mavg.eval()
                z = torch.randn(16, 512*16).to(x.device)
                x_0 = netG_mavg.forward(z0, res)
                x_ = netG_mavg.forward(z, res)

                dst = torch.cat((x_0, x_), dim=0)
                dst = F.interpolate(dst, (128, 128), mode='nearest')
                dst = dst.to('cpu').detach().numpy()
                n, c, h, w = dst.shape
                dst = dst.reshape(4,8,c,h,w)
                dst = dst.transpose(0,3,1,4,2)
                dst = dst.reshape(4*h,8*w,3)
                dst = np.clip(dst*255., 0, 255).astype(np.uint8)
                skio.imsave('out/img_%03d_%05d.png' % (iepoch, j), dst)

                losses_ = np.array(losses)
                niter = losses_.shape[0]//100*100
                x_iter = np.arange(100)*(niter//100) + niter//200
                plt.plot(x_iter, losses_[:niter,0].reshape(100,-1).mean(1))
                plt.plot(x_iter, losses_[:niter,1].reshape(100,-1).mean(1))
                plt.tight_layout()
                plt.savefig('out/loss_%03d_%05d.png' % (iepoch, j))
                plt.clf()

                netG_mavg.train()

            if j >= res_step*7:
                break

            if j%100==0:
                coolGPU()

        if j >= res_step*7:
            break

Loss

今回は128x128まで。まずはloss関数。青がGでオレンジがD。横軸はイテレーション回数で縦軸がlossの値。

f:id:tzmi:20200507223727p:plain

出力画像

では、低解像度から結果を乗せていく。今回は解像度は15000イテレーションで切り替わるようにした。この値をもっと大きくするとさらにきれいな画像が得られるんだと思う。データセットはCelebAを使った。オリジナルの画素が256以下なのと、計算リソースがなかったので、128x128まで学習することにした。

4x4

f:id:tzmi:20200507223448j:plain

8x8

f:id:tzmi:20200507223512j:plain

16x16

f:id:tzmi:20200507223529j:plain

32x32

f:id:tzmi:20200507223545j:plain

64x64

f:id:tzmi:20200507223600j:plain

128x128

f:id:tzmi:20200507215124j:plain

まとめ

2020年のGWでまとまった休みがあったのでPGGANを頑張って実装してみたけど、想像の3倍くらい大変だったので、振り返ってみればやらない方がよかったかもしれない。提案者の苦労やソースの中身を何も知らなくてもgithubで簡単に利用できるというのは非常に大切なことなんだなと感じた1週間だった。