Pytorchで自作Datasetを作って、さらにDataLoaderからのバッチ出力を(x, y)だけではなく(x1, x2, y)とか(x, y, c)などと自由に制御したいときがある。
例えばtriplet lossを使いたいときは、ターゲットクラスyに対して複数の入力(x1, x2)を取ってくる必要があるし、ネットワークにconditionを入力したいときは(x, y, c)の3つをバッチで取ってくる必要がある。
まずは答え。collate_fn
を定義してDataloaderの引数に入れる。
def collate_fn(batch): bx, by, bc = list(zip(*batch)) bx = torch.stack(bx) by = torch.stack(by) bc = torch.stack(bc) return bx, by, bc trainset = MyDataset() trainloader = Dataloader(trainset, collate_fn=collate_fn) for i, (x, y, c) in enumerate(trainloader): y_ = model.forward(x, c) : :
関数の名前もcollate_fnにしたけど、実はなんでもよくて、Dataloaderの引数のひとつであるcollate_fnに定義した関数を代入するだけでよい。最初Dataloderを自作するのかなとか考えたんだけど、そんな必要は全くなかった。
MyDataset classの定義
自作のデータセットの定義をする。データセットとしては./data
ディレクトリにpngの画像データが3枚、./csv
ディレクトリに1行3列のcsvデータが3つあるとする。
この条件で下記の自作データセットクラスを作成する。このクラスは__getitem__
関数でデータを1つだけ取得する。バッチの取得はこのクラスではなくDataLoaderクラスが行う。
import numpy as np import pandas as pd from glob import glob from skimage import io as skio import torch from torch.utils.data import DataLoader from torch.utils.data import Dataset from torchvision import transforms class MyDataset(Dataset): def __init__(self, imgpath='./data', csvpath='./csv', transform=None): self.transform = transform self.imgfiles = sorted(glob('%s/*.png' % imgpath)) self.csvfiles = sorted(glob('%s/*.csv' % csvpath)) def __len__(self): return len(self.csvfiles) def __getitem__(self, idx): x = skio.imread(self.imgfiles[idx]) yy = pd.read_csv(self.csvfiles[idx], header=None), yy = np.array(yy, dtype=np.float32)[0] x = self.transform(x) if self.transform else x c = yy[1:] y = yy[:1] y = torch.from_numpy(y) c = torch.from_numpy(c) return x, y, c
collate_fnの定義
collate_fn
はDataLoaderクラスに読み込ませるためのバッチ回収用の関数。DataLoaderはデフォルトでは基本的な(x, y)程度の入力にしか対応できないが、この関数を編集することで様々な出力が可能となる。書き方は非常に簡単で以下のように書く。
def collate_fn(batch): bx, by , bc = list(zip(*batch)) bx = torch.stack(bx) by = torch.stack(by) bc = troch.stack(bc) return bx, by, bc
物体検知のようにyとして画像枚に個数が異なる矩形の座標を返したいときは、torch.stackをとって直接リストで返せばよい。
DataLoaderへの代入
DataLoaderは編集しなくてよい。そのまま使う。
trainset = MyDataset() trainloader = DataLoader(trainset, batch_size=2, collate_fn=collate_fn) for i, (x, y, c) in enumerate(trainloader): y_ = model.forward(x, c)
これで、自作Datasetから自分の好きな出力を得られるようになる。