Pythonいぬ

pythonを使った画像処理に関する記事を書いていきます

PytorchでEfficientNetを実装してみる

話題のEfficientNetを実装してみる。基本的な構造はNASNetとほぼ変わらないんだけど、EfficientNet特有の広さ深さ解像度などのパラメータも含めてコードを書いてみる。

f:id:tzmi:20200206180841p:plain 画像はこちらのサイトから引用しました。

環境

python 3.7.4
torch 1.0.0

ヘッダ

import math
import torch
from torch import nn

Swish activation layer

EfficientNetではReLUの代わりにSwishが使わている。以下にReLU関数とSwish関数の違いを示す。SwishはReLUとほぼ同じ形をしているが、ReLUに比べて連続的となる。

f:id:tzmi:20200213231641p:plain

まずはactivationのSwishのクラスから書く。Swish関数は入力にsigmoidを掛けた形をしている。レイヤとしては最下層の関数なので、特に__init__()も定義しなくてよいかと。

class Swish(nn.Module):  # Swish activation                                      
    def forward(self, x):
        return x * torch.sigmoid(x)

Squeeze Excitation module

いわゆるSE-module。1x1のサイズに縮小し、チャネル方向にも縮小と拡大を行ってから、最後にsigmoidに入れて、入力xにかける。チャネル方向のattention。途中のactivationはSwishを使う。Conv2dの初期化もここで行っておくと後で深堀りしなくてもよくなる。

class SEblock(nn.Module): # Squeeze Excitation                                  
    def __init__(self, ch_in, ch_sq):
    super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(ch_in, ch_sq, 1),
            Swish(),
            nn.Conv2d(ch_sq, ch_in, 1),
        )
        self.se.apply(weights_init)

    def forward(self, x):
        return x * torch.sigmoid(self.se(x))

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight)

    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight)
        nn.init.zeros_(m.bias)

ConvBN module

次はConv2dとBatchNorm2dを合わせたレイヤ。ConvBNReLUという書き方があるみたいだけど、ReLUじゃなくてSwishを使うあたりがややこしいので、Activationなしで書いておく。ここでもConv2dの初期化をしておく。

class ConvBN(nn.Module):
    def __init__(self, ch_in, ch_out, kernel_size,
                 stride=1, padding=0, groups=1):
    super().__init__()
    self.layers = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size,
                      stride, padding, groups=groups, bias=False),
            nn.BatchNorm2d(ch_out),
    )
        self.layers.apply(weights_init)

    def forward(self, x):
        return self.layers(x)

Drop Connect layer

BM module の最終層でstochastic depth的なことやるためにdrop connectを使っているようなので、せっかくだし後で応用効きそうだからClassから書いてみることにした。

具体的には、入力された4階テンソルからデータ列方向にマスクを作ってバッチの一部を切り落とすというところ。残す率、切り落とす率をkeep_rateとかdrop_rateと書いておくとどちらなのかがわかりやすいかなと思った。trainingのときに出力をkeep_rateで割る(x.div(keep_rate)の部分)ようにしておけばevalのときの出力は恒等関数でよくなる。

class DropConnect(nn.Module):
    def __init__(self, drop_rate):
        super().__init__()
        self.drop_rate = drop_rate

    def forward(self, x):
        if self.training:
            keep_rate = 1.0 - self.drop_rate
            r = torch.rand([x.size(0),1,1,1], dtype=x.dtype).to(x.device)
            r += keep_rate
            mask = r.floor()
            return x.div(keep_rate) * mask
        else:
            return x

BM module

もともとはpixel-wise, depth-wise, pixel-wise の順に並べるレイヤだった気がするけど、 pixel-wise, depth-wise, squeeze excitation, drop connect という順に並ぶ。activationはReLUの代わりにSwishを使う。(ch_dqの計算とpaddingの部分でpython3での整数割り算の//を使っているのでpython2では動かないです。) strideが1でかつ入力chと出力chが同じ場合(EfficientNetのパラメータのひとつである「深さ」が増えた場合)だけdrop connectを使うようにする。

class BMConvBlock(nn.Module):
    def __init__(self, ch_in, ch_out,
                 expand_ratio, stride, kernel_size,
                 reduction_ratio=4, drop_connect_rate=0.2):
        super().__init__()
        self.use_residual = (ch_in==ch_out) & (stride==1)
        ch_med = int(ch_in * expand_ratio)
        ch_sq = max(1, ch_in//reduction_ratio)

        # define network                                                        
        if expand_ratio != 1.0:
            layers = [ConvBN(ch_in, ch_med, 1),Swish()]
        else:
            layers = []

        layers.extend([
            ConvBN(ch_med, ch_med, kernel_size, stride=stride,
                   padding=(kernel_size-1)//2, groups=ch_med), # depth-wise    
            Swish(),
            SEblock(ch_med, ch_sq), # Squeeze Excitation                        
            ConvBN(ch_med, ch_out, 1), # pixel-wise                             
        ])

        if self.use_residual:
            self.drop_connect = DropConnect(drop_connect_rate)

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        if self.use_residual:
            return x + self.drop_connect(self.layers(x))
        else:
            return self.layers(x)

Flatten layer

よく、一番上層のネットワークのforward()関数の中でx=x.mean(2,3)というのを見るけど、個人的にはなんとなく最上位のネットワークのforward()の中にはレイヤ以外の処理が入っていてほしくない。ということでFlatten()を使いたい。けど、pytorch 1.0.0にはFlatten()がないので、自分で書く。(pytorch 1.4.0にはnn.Flatten()がある)

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.shape[0], -1)

EfficientNet

ここまで書いてきたmoduleをまとめてefficentNetを作る。ちょっとだけ工夫したところとしては、classifier()の中にAdaptiveAvgPool2d(1)Flatten()を入れてfeatures()から接続できるようにしたところ。最初にあるパラメータはベースラインのB0の構成パラメータであり、EfficientNetのパラメータ(広さ、深さ、解像度)とは別ものなので注意。

class EfficientNet(nn.Module):
    def __init__(self, width_mult=1.0, depth_mult=1.0,
                 resolution=False, dropout_rate=0.2, 
                 input_ch=3, num_classes=1000):
        super().__init__()

        # expand_ratio, channel, repeats, stride, kernel_size                   
        settings = [
            [1,  16, 1, 1, 3],  # MBConv1_3x3, SE, 112 -> 112                   
            [6,  24, 2, 2, 3],  # MBConv6_3x3, SE, 112 ->  56                   
            [6,  40, 2, 2, 5],  # MBConv6_5x5, SE,  56 ->  28                   
            [6,  80, 3, 2, 3],  # MBConv6_3x3, SE,  28 ->  14                   
            [6, 112, 3, 1, 5],  # MBConv6_5x5, SE,  14 ->  14                   
            [6, 192, 4, 2, 5],  # MBConv6_5x5, SE,  14 ->   7                   
            [6, 320, 1, 1, 3]   # MBConv6_3x3, SE,   7 ->   7]                  
        ]

        ch_out = int(math.ceil(32*width_mult))
        features = [nn.AdaptiveAvgPool2d(resolution)] if resolution else []
        features.extend([ConvBN(input_ch, ch_out, 3, stride=2), Swish()])

        ch_in = ch_out
        for t, c, n, s, k in settings:
            ch_out  = int(math.ceil(c*width_mult))
            repeats = int(math.ceil(n*depth_mult))
            for i in range(repeats):
                stride = s if i==0 else 1
                features.extend([BMConvBlock(ch_in, ch_out, t, stride, k)])
                ch_in = ch_out

        ch_last = int(math.ceil(1280*width_mult))
        features.extend([ConvBN(ch_in, ch_last, 1), Swish()])

        self.features = nn.Sequential(*features)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            Flatten(),
            nn.Dropout(dropout_rate),
            nn.Linear(ch_last, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

EfficientNetB0からB7まで

EfiicientNetのパラメータ(広さ、深さ、解像度)を使ってB0からB7までのネットワークを定義する。パラメータが4つあって最後のが何だろうと思っていたが、途中のDrop connectと最後のdropoutのdrop rateだったらしい。EfficientNetは最適化の際に解像度(resolution)を使っているので、引数にresolutionも入れるべきかなと思ったけど、object detectionなどの他の用途に使いたいこともあるから、resolutionはデフォルトではNoneに設定した。

※pre-trainedは手元に計算資源がないので、作れなかったです。すいません。

def _efficientnet(w_mult, d_mult, resolution, drop_rate,
                  input_ch, num_classes=1000):
    model = EfficientNet(w_mult, d_mult,
                         resolution, drop_rate,
                         input_ch, num_classes)
    return model


def efficientnet_b0(input_ch=3, num_classes=1000):
    #(w_mult, d_mult, resolution, droprate) = (1.0, 1.0, 224, 0.2)              
    return _efficientnet(1.0, 1.0, None, 0.2, input_ch, num_classes)

def efficientnet_b1(input_ch=3, num_classes=1000):
    #(w_mult, d_mult, resolution, droprate) = (1.0, 1.1, 240, 0.2)              
    return _efficientnet(1.0, 1.1, None, 0.2, input_ch, num_classes)

def efficientnet_b2(input_ch=3, num_classes=1000):
    #(w_mult, d_mult, resolution, droprate) = (1.1, 1.2, 260, 0.3)              
    return _efficientnet(1.1, 1.2, None, 0.3, input_ch, num_classes)

def efficientnet_b3(input_ch=3, num_classes=1000):
    #(w_mult, d_mult, resolution, droprate) = (1.2, 1.4, 300, 0.3)              
    return _efficientnet(1.2, 1.4, None, 0.3, input_ch, num_classes)

def efficientnet_b4(input_ch=3, num_classes=1000):
    #(w_mult, d_mult, resolution, droprate) = (1.4, 1.8, 380, 0.4)              
    return _efficientnet(1.4, 1.8, None, 0.4, input_ch, num_classes)

def efficientnet_b5(input_ch=3, num_classes=1000):
    #(w_mult, d_mult, resolution, droprate) = (1.6, 2.2, 456, 0.4)              
    return _efficientnet(1.6, 2.2, None, 0.4, input_ch, num_classes)

def efficientnet_b6(input_ch=3, num_classes=1000):
    #(w_mult, d_mult, resolution, droprate) = (1.8, 2.6, 528, 0.5)              
    return _efficientnet(1.8, 2.6, None, 0.5, input_ch, num_classes)

def efficientnet_b7(input_ch=3, num_classes=1000):
    #(w_mult, d_mult, resolution, droprate) = (2.0, 3.1, 600, 0.5)              
    return _efficientnet(2.0, 3.1, None, 0.5, input_ch, num_classes)

動作確認

一応、B0だけは動く(cifar10のlossが下がっていく)ことを確認した。手持ちのGPUがしょぼいやつなので、GTX1080が手に入ったらもう少し大きいネットワークで試してみようかと思う。

まとめ

EfficientNetの実装について紹介した。内容は非常にシンプルだけど、実際に自分で書いてみると結構大変だった。普段から色々ネットワークを書く練習をしていないといけないと感じた。

参考にしたサイト

github.com

github.com

qiita.com