Pythonいぬ

pythonを使った画像処理に関する記事を書いていきます

Pytorch で LeNet-5 を論文どおりに実装してみる

LeNet5ライクのネットワークを実装したという記事は多いけど、LeNetを論文通りに再現してみた記事がないかなと思ったので、アーキテクチャの部分だけ再現してみた。

LeNet-5

f:id:tzmi:20200118222655p:plain

この図の中にもSub-samplingとかGaussian connectionとか聞いたことがない言葉がちらほら。ということで、論文を参考に再現してみる。

Sub-sampling layer

C1とS2の間にある sub-sampling layer というのがよく知られているレイヤとは少し違う。論文を読むと以下のように書いてある。

Layer S2 は14x14の6つの特徴マップのsub-sampling layerである。それぞれの特徴マップのユニットはC1における特徴マップと一致している2x2の隣接画素でつながっている。S2への4つの入力は加算され、学習可能な係数で積算され、学習可能なバイアスが加算される。2x2の受容フィールドはオーバラップしないので、S2の特徴マップはC1の半分になる。Layer S2は5880個の接続と12個の学習パラメータを持つ。

接続の数の計算は以下

5880 = 282 x 6 + 142 x 6

それぞれのチャネルが独立に重みとバイアスを持っているということか。よく考えると カーネルサイズが(1x1)のdepth-wise convolution ってことかな(これで学習パラメータが12個になるはず)。4つの入力が加算されるというところも含めて再現してみよう。

import torch
from torch import nn

class SubSampling2d(nn.Module):
    def __init__(self, nc=6):
        super().__init__()
        self.conv = nn.Conv2d(nc, nc, 1, grops=nc)

    def forward(self, x):
        n,c,w,h = x.shape
        x = x.reshape(n,c,w//2,2,h//2,2) # w,hは2の倍数
        x = torch.sum(x, dim=(3,5))
        x = self.conv(x)
        return x

Masked Convolution

S2とC3の間では一部のチャネルをmaskしたconvolutionをやっている模様。Dropoutに似ているけど、こういうネットワークもあまり見ない。これはどうやって実装するかは悩みどころ。いろいろ迷ったけどmaskは論文のtable. 1を参照して直接書いて、conv2dの重みをマスクすることにした。バイアスは、、、、まあいいか。(コンストラクタの引数変えるとmaskとサイズが合わなくなって動かなくなるので注意)

import torch
from torch import nn

class MaskedConv2d(nn.Module):
    def __init__(self, inch=6, outch=16, kernel_size=(5,5)):
        super().__init__()
        self.conv = nn.Conv2d(inch, outch, kernel_size=kernel_size)

    def forward(self, x):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        mask = torch.tensor([[1,0,0,0,1,1, 1,0,0,1,1,1, 1,0,1,1],
                             [1,1,0,0,0,1, 1,1,0,0,1,1, 1,1,0,1],
                             [1,1,1,0,0,0, 1,1,1,0,0,1, 0,1,1,1],
                             [0,1,1,1,0,0, 1,1,1,1,0,0, 1,0,1,1],
                             [0,0,1,1,1,0, 0,1,1,1,1,0, 1,1,0,1],
                             [0,0,0,1,1,1, 0,0,1,1,1,1, 0,1,1,1]],
                            dtype=torch.float32).to(device)
        mask = mask.transpose(1,0).reshape(16, 6, 1, 1)
        self.conv.weight.data = self.conv.weight.data * mask
        x = self.conv(x)
        return x

Convolution (Full connect)

LeNet-5ライクのネットワークを書くとき、最後の2層はLinearで書くけど、オリジナルの論文ではC5はConvolution層となる。論文中には画像サイズが大きくなってもいいように畳み込みにすると書いてある。(結局最後に12x7にしてるからflattenにしても同じだと思う。)

Activation

ActivationについてもReLUではないので、一応書いておく。基本的には中間層ではsigmoidを使う。最後のActivationは tanh となっているが、2つの係数(A, S)がついている。Batchnormがないので、パラメータをうまく調整しないと勾配消失がおきそう。

 f(a) = A tanh(Sa)

パラメータAとSは定数のようで、(A, S) = (1.7159, 2/3)と書いてある。レイヤとして書いてみると以下のようになる。

import torch
from torch import nn

class LeNetTanh(nn.Module):
    def __init__(self, a=1.7159, s=2./3.):
        super().__init__()
        self.a = a
        self.s = s

    def forward(self, x):
        return self.a * torch.tanh(self.s * x)

Gaussian connection

最終層F6では、各クラスに対して予め決めておいた特徴パラメータ(正解データ)とのL2距離を、RBF (Radial basis function)で計算する。(softmax + cross entropyではない)。特徴パラメータ(正解データ)は+1か-1の値で構成される(84 x 10)の行列であり、学習では更新しない。この特徴パラメータ(正解データ)はランダムで作ってもよいが、12x7のASCII文字のビットマップ文字で作る。間違えるのであれば類似した文字間で間違えるようになる。類似した文字が間違いやすいようにしておくことで、システムの後処理で修正できるし、複数の正解に解釈可能な文字列を後処理で抜き出すこともできるようになる。

なるほど、確かにone-hotベクトルだと全てが直行しているから、どれと間違えやすいというようなことは起こらない。これなら類似したものを間違えるから識別器のミスは人間の直感に合ったものになるような気がした。

さて、この12x7の文字ビットマップをどのように作るかを考えるのに時間がかかった。フォントを画像化するコードを書いて作ってみた。

import numpy as np
import PIL
from PIL import ImageDraw
from PIL import ImageFont
import cv2
from skimage import io as skio

if __name__ == '__main__':

    # Carlito-Bold, pathとfontsize
    ttfontname = '/usr/share/fonts/truetype/crosextra/Carlito-Bold.ttf'
    fontsize = 15

    # 背景色,フォントの色を設定
    bg_color, text_color = 0, 255

    for i in range(10):
        text = '%d' % i

        # 画像を用意して文字列を描く
        canvasSize = (14,24)
        img  = PIL.Image.new('L', canvasSize, bg_color)
        draw = ImageDraw.Draw(img)
        font = ImageFont.truetype(ttfontname, fontsize)
        text_w, text_h = draw.textsize(text,font=font)
        img_size = int(max(text_w, text_h))

        # 文字の大きさを使ってもう一度画像をつくる
        canvasSize = (img_size, img_size)
        img  = PIL.Image.new('L', canvasSize, bg_color)
        draw = ImageDraw.Draw(img)

        # 中央に配置
        left = img_size//2 - text_w//2
        top = img_size//2 - int(text_h//2*1.3)
        draw.text((left, top), text, fill=text_color, font=font)

        #img
        img = np.array(img)
        mask = img>127
        img[mask] = 255
        img[~mask] = 0
        img = img[:,2:-2]
        img = cv2.resize(img, (7, 12), interpolation=cv2.INTER_NEAREST) 
        skio.imsave('./target/%s.png' % text, img) # 画像を保存する
        img2 = cv2.resize(img, (140, 240), interpolation=cv2.INTER_NEAREST)

画像は各ラベル毎に作ってみたが、以下では可視化するためにつなげてみた。

f:id:tzmi:20200201135840p:plain

これらを正解として、推定結果をMSE lossで近づけるように最適化していく。

loss関数

loss関数は、正解を近づける項と不正解を遠ざける項を持つ。log尤度を使っているので、不正解のlossは無限に下がることができて、正解のlossが落ちるように学習できない状況になる。この状況を避けるために不正解のlossには最小値  e^{-j}の項を追加している。要は不正解lossが小さくなりすぎなければいいのだから、hingeとかで制限してもいいのかもしれない。というかhingeにしてnegativeサンプル入れるとtriplet lossになる?

 \displaystyle
E1 = \frac{1}{N} \sum_{p}^{P} (y_i)

 \displaystyle
E2 = log(e^{-j} + \sum_{i} e^{-y_i})

 \displaystyle
E = E1 + E2

logとexpが入り組んでるので32bit浮動少数点が計算に弱そうなのと、近似の式変形も面倒くさいのでlossの実装は止めて単純なMSE lossにすることにした。普通にMSEで正解を近づける項だけでも収束するので。

学習

やっと学習の部分まできた。ここにくるまで休日を1日分使ってしまった。とりあえず、MNISTを用意して学習してみる。コードは以下

import numpy as np
from skimage import io as skio

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torchvision import datasets

from lenet5 import LeNet5

if __name__ == '__main__':

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # modelとoptimizerの定義
    model = LeNet5().to(device)
    opt = torch.optim.Adam(model.parameters())

    # datasetの読み込み                                                                    
    batch_size = 32
    transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, ), (0.5, ))])

    trainset = datasets.MNIST(root='./data',
                              train=True,
                              download=True,
                              transform=transform)
    trainloader = DataLoader(trainset,
                             batch_size=batch_size,
                             shuffle=True)

    testset = datasets.MNIST(root='./data',
                             train=False,
                             download=True,
                             transform=transform)
    testloader = DataLoader(testset,
                            batch_size=batch_size,
                            shuffle=False)

    # 正解データ                                                                           
    target = []
    for i in range(10):
        img = skio.imread('target/%d.png' % i)
        img = img/255.
        target.append(img.flatten())

    target = np.array(target).astype(np.float32)
    target = torch.from_numpy(target)

    # training                                                                             
    losses = []
    for iepoch in range(2):
        for iiter, (x, label) in enumerate(trainloader, 0):
            # to GPU                                                                       
            x = x.to(device)
            y = target.to(device)

            # 32x32へ拡大                                                                  
            x = F.interpolate(x, (32, 32), mode='bicubic', align_corners=False)

            # 推定                                                                         
            y_ = model.forward(x) # y_.shape = (bs, 84)                                    

            # loss計算                                                                     
            yi_target = (y_-y[label])**2
            loss_pos = yi_target.mean()

            loss = loss_pos
            losses.append(loss.item())

            opt.zero_grad()
            loss.backward()
            opt.step()

            print('epoch: %03d, (iiter=%05d) loss=%.5f'
                  % (iepoch, iiter, loss.item()))

    # test                                                                                 
    total, tp = 0, 0
    print('test')
    for (x, label) in testloader:
        x = x.to(device)
    y = target.to(device)
        label = label.to(device)

    x = F.interpolate(x, (32, 32), mode='bicubic', align_corners=False)
        y_ = model.forward(x)

    y_ = y_.reshape(y_.shape[0], 1, 84)
        y = y.reshape(1, 10, 84)
        label_ = ((y_-y)**2).sum(dim=2).argmin(dim=1)

    total += label.shape[0]
        tp += (label_==label).sum().item()

    acc = tp/total
    print('accuracy= %.3f' % acc)

結果

accuracy= 0.963

精度もまあまあ。lossもちゃんと下がったらしい。あんまり深入りしすぎてもしょうがないのでこれくらいにしておく。(横軸はイテレーションの1/10)

f:id:tzmi:20200223151245p:plain

まとめ

LeNet-5のアーキテクチャ再現をやってみた。ちゃんと読んでみると最近発明されているような構造やlossなどが出てきて驚いた。あとは、one-hot vectorでlossを作るのがいかに簡単なのかが理解できた。次はもっと新しい論文を再現したい。