Pythonいぬ

pythonを使った画像処理に関する記事を書いていきます

Pytorchでkerasのsummaryをやる

kerasを使っていたときは、model.summary()という関数があって、ネットワークの各レイヤにおける出力サイズがどうなっていくかを簡単に可視化できていた。

Pytorchはdefine by runなのでネットワーク内の各層のサイズはforward処理のときに決まる。なのでなんとなくsummaryができないのもわかるんだけど、例えば「このサイズのテンソルが入力されたときは中身はこんなサイズになるよ」みたいなことがわかっていると、バッチサイズなども決めやすくなる。

ということでptorchでのsummaryを自作してみることにした。中身がどうなってるか知ると、中間層や出力をいじれたりするので、勉強がてら自作する。

中身を理解しなくてもとりあえず使いたいという人はtorchsummaryを使えばいいかと思う。

pip install torchsummary

この記事で書くことはtorchsummaryを自分で書いてみるという話。

作りたいもの

以下のコマンドを打つとネットワークの情報を出力してくれるものを作りたい。いわゆるkerasのmodel.summary()

import torch
from torchvision import models

x = torch.randn(2,3,224,224)
model = models.vgg16()
summary(model, x)

出力はこちら

-------------------------------------------------------------------
          Layer              (type)         Output Shape    Param #
===================================================================
         conv_0            (Conv2d)       (64, 224, 224)       1792
   activation_1              (ReLU)       (64, 224, 224)          0
         conv_2            (Conv2d)       (64, 224, 224)      36928
   activation_3              (ReLU)       (64, 224, 224)          0
      pooling_4         (MaxPool2d)       (64, 112, 112)          0
         conv_5            (Conv2d)      (128, 112, 112)      73856
   activation_6              (ReLU)      (128, 112, 112)          0
         conv_7            (Conv2d)      (128, 112, 112)     147584
   activation_8              (ReLU)      (128, 112, 112)          0
      pooling_9         (MaxPool2d)        (128, 56, 56)          0
        conv_10            (Conv2d)        (256, 56, 56)     295168
  activation_11              (ReLU)        (256, 56, 56)          0
        conv_12            (Conv2d)        (256, 56, 56)     590080
  activation_13              (ReLU)        (256, 56, 56)          0
        conv_14            (Conv2d)        (256, 56, 56)     590080
  activation_15              (ReLU)        (256, 56, 56)          0
     pooling_16         (MaxPool2d)        (256, 28, 28)          0
        conv_17            (Conv2d)        (512, 28, 28)    1180160
  activation_18              (ReLU)        (512, 28, 28)          0
        conv_19            (Conv2d)        (512, 28, 28)    2359808
  activation_20              (ReLU)        (512, 28, 28)          0
        conv_21            (Conv2d)        (512, 28, 28)    2359808
  activation_22              (ReLU)        (512, 28, 28)          0
     pooling_23         (MaxPool2d)        (512, 14, 14)          0
        conv_24            (Conv2d)        (512, 14, 14)    2359808
  activation_25              (ReLU)        (512, 14, 14)          0
        conv_26            (Conv2d)        (512, 14, 14)    2359808
  activation_27              (ReLU)        (512, 14, 14)          0
        conv_28            (Conv2d)        (512, 14, 14)    2359808
  activation_29              (ReLU)        (512, 14, 14)          0
     pooling_30         (MaxPool2d)          (512, 7, 7)          0
     pooling_31 (AdaptiveAvgPool2d)          (512, 7, 7)          0
      linear_32            (Linear)              (4096,)  102764544
  activation_33              (ReLU)              (4096,)          0
     dropout_34           (Dropout)              (4096,)          0
      linear_35            (Linear)              (4096,)   16781312
  activation_36              (ReLU)              (4096,)          0
     dropout_37           (Dropout)              (4096,)          0
      linear_38            (Linear)              (1000,)    4097000
===================================================================
Total params: 138,357,544
Trainable params: 138,357,544
None-trainable params: 0
-------------------------------------------------------------------
Input size (MB): 1.15
Foward/backward pass size (MB): 218.78
Params size (MB): 527.79
Estimated total size (MB): 747.72
-------------------------------------------------------------------

hook

レイヤ関数の名前やパラメータなどはpytorchに既にあるmodel.modules関数やmodel.parameters関数などで得ることができる。問題は中間層のサイズで、これは既存関数で得ることができないので、hookを使ってとってこれるようにする。

Pytorchにはmodel.register_forward_hookという関数があって、modelのfoward関数を使ったときに、中身の情報を取ってくることができる。あらかじめ罠を貼っておいてforward処理のときにひっかかってもらうイメージ。

以下が情報を得るための関数。

from torch import nn

def getInfos(model, x):
    layers, nparams, shapes, trainables = [], [], [], []

    def hook_fn(m, input, output):
        # layer
        layer = str(m.__class__)[:-2].split('.')[-2:]

        # shape
        input_shape = tuple(input[0].shape)
        output_shape = tuple(output[0].shape)
        shape = [input_shape, output_shape]

        # nparam
        nparam = 0
        trainable = False
        if len(list(m.children()))==0:
            for i, p in enumerate(m.parameters()):
                nparam += torch.prod(torch.tensor(p.shape)).item()
                trainable = p.requires_grad

        layers.append(layer)
        shapes.append(shape)
        nparams.append(nparam)
        trainables.append(trainable)

    # set hook to get infos
    for m in model.modules():
        if isinstance(m, model.__class__):
            continue
        elif isinstance(m, nn.Sequential):
            continue
        elif isinstance(m, nn.ModuleList):
            continue
        else:
            m.register_forward_hook(hook_fn)

    y_ = model(x)

    return layers, shapes, nparams, trainables

関数内部でhook_fnという関数を定義している。実際の処理はforループの部分から始まっていて、SequentialModuleList以外のレイヤについて、register_forward_hookを指定している。こうするとforward 処理のときに各レイヤでhook_fn関数が走って、最初に定義したリストに情報がどんどん入っていくというしくみ。

中間層の情報を取りながら学習できるので中間層や勾配を常時可視化しながら学習なんてこともできそう。

出力部分

こちらは上で説明したgetInfos関数を使ってとってきた情報をprintで出力するだけ。割と色々な情報を持ってこれるので、工夫してここに書いた情報以外も出力してみるなんてことをしてもいいかもしれない。

def summary(model, x):
    model = model.to('cpu')
    x = x.to('cpu')
    input_size = x.shape
    nstr = 67

    # titles
    summary_str = '-'*nstr + '\n'
    summary_str += '%15s%20s %20s %10s\n' % ('Layer', '(type)', 'Output Shape', 'Param #')
    summary_str += '='*nstr + '\n'

    # parameter details
    layers, shapes, nparams, trainables = getInfos(model, x)
    total_output = 0
    for i, (layer, shape, nparam) in enumerate(zip(layers, shapes, nparams)):
        str_layer = '%s_%d' % (layer[0], i)
        str_type = '(%s)' % layer[1]
        summary_str += '%15s%20s %20s %10s\n' % (str_layer, str_type, shape[1], nparam)
        total_output += torch.prod(torch.tensor(shape[1]))

    summary_str += '='*nstr + '\n'

    # parameter summaries
    total_params = torch.sum(torch.tensor(nparams)).item()
    trainable_params = torch.sum(torch.tensor(nparams)*torch.tensor(trainables)).item()
    total_input_size = torch.prod(torch.tensor(input_size)).item() *4. / (1024**2)
    total_output_size = 2. * total_output *4. / (1024**2) # x2: forward & backward
    total_param_size = total_params*4./(1024**2)
    total_size = total_param_size + total_output_size + total_input_size
    #summary_str += 'Total params: %d\n' % total_params
    summary_str += 'Total params: {:,}\n'.format(total_params)
    summary_str += 'Trainable params: {:,}\n'.format(trainable_params)
    summary_str += 'None-trainable params: {:,}\n'.format(total_params - trainable_params)
    summary_str += '-'*nstr + '\n'
    summary_str += 'Input size (MB): %.2f\n' % total_input_size
    summary_str += 'Foward/backward pass size (MB): %.2f\n' % total_output_size
    summary_str += 'Params size (MB): %.2f\n' % total_param_size
    summary_str += 'Estimated total size (MB): %.2f\n' % total_size
    summary_str += '-'*nstr + '\n'
    print(summary_str)

※注意点として、同じレイヤを複数回利用するネットワークには対応していません。同じレイヤを複数回利用する場合は、パラメータを複数回数えます。これはバグではなく、複数回利用に対応するには機能拡張が必要です。

参考サイト

元ネタはこちらのgithub

github.com