Pythonいぬ

pythonを使った画像処理に関する記事を書いていきます

PytorchのDataLoaderから任意のバッチを出力

Pytorchで自作Datasetを作って、さらにDataLoaderからのバッチ出力を(x, y)だけではなく(x1, x2, y)とか(x, y, c)などと自由に制御したいときがある。

例えばtriplet lossを使いたいときは、ターゲットクラスyに対して複数の入力(x1, x2)を取ってくる必要があるし、ネットワークにconditionを入力したいときは(x, y, c)の3つをバッチで取ってくる必要がある。

まずは答え。collate_fnを定義してDataloaderの引数に入れる。

def collate_fn(batch):
    bx, by, bc = list(zip(*batch))
    bx = torch.stack(bx)
    by = torch.stack(by)
    bc = torch.stack(bc)
    return bx, by, bc

trainset = MyDataset()
trainloader = Dataloader(trainset, collate_fn=collate_fn)

for i, (x, y, c) in enumerate(trainloader):
    y_ = model.forward(x, c)
    :
    :

関数の名前もcollate_fnにしたけど、実はなんでもよくて、Dataloaderの引数のひとつであるcollate_fnに定義した関数を代入するだけでよい。最初Dataloderを自作するのかなとか考えたんだけど、そんな必要は全くなかった。

MyDataset classの定義

自作のデータセットの定義をする。データセットとしては./dataディレクトリにpngの画像データが3枚、./csvディレクトリに1行3列のcsvデータが3つあるとする。

この条件で下記の自作データセットクラスを作成する。このクラスは__getitem__関数でデータを1つだけ取得する。バッチの取得はこのクラスではなくDataLoaderクラスが行う。

import numpy as np
import pandas as pd
from glob import glob
from skimage import io as skio

import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms

class MyDataset(Dataset):
    def __init__(self, imgpath='./data', csvpath='./csv', transform=None):
        self.transform = transform
        self.imgfiles = sorted(glob('%s/*.png' % imgpath))
        self.csvfiles = sorted(glob('%s/*.csv' % csvpath))

    def __len__(self):
        return len(self.csvfiles)

    def __getitem__(self, idx):
        x = skio.imread(self.imgfiles[idx])
        yy = pd.read_csv(self.csvfiles[idx], header=None),
        yy = np.array(yy, dtype=np.float32)[0]
        x = self.transform(x) if self.transform else x
        c = yy[1:]
        y = yy[:1]
        y = torch.from_numpy(y)
        c = torch.from_numpy(c)
        return x, y, c

collate_fnの定義

collate_fnはDataLoaderクラスに読み込ませるためのバッチ回収用の関数。DataLoaderはデフォルトでは基本的な(x, y)程度の入力にしか対応できないが、この関数を編集することで様々な出力が可能となる。書き方は非常に簡単で以下のように書く。

def collate_fn(batch):
    bx, by , bc = list(zip(*batch))
    bx = torch.stack(bx)
    by = torch.stack(by)
    bc = troch.stack(bc)
    return bx, by, bc

物体検知のようにyとして画像枚に個数が異なる矩形の座標を返したいときは、torch.stackをとって直接リストで返せばよい。

DataLoaderへの代入

DataLoaderは編集しなくてよい。そのまま使う。

    trainset = MyDataset()
    trainloader = DataLoader(trainset, batch_size=2, collate_fn=collate_fn)

    for i, (x, y, c) in enumerate(trainloader):
        y_ = model.forward(x, c)

これで、自作Datasetから自分の好きな出力を得られるようになる。