Pythonいぬ

pythonを使った画像処理に関する記事を書いていきます

PytorchでDataLoaderからデータを取り出す

PytorchのDataLoaderって便利だなと思いつつも、forループ以外の方法でデータを取り出すことができなかったので、中身を少し調べてみた。以下のようにすればデータがとれることがわかった。

tmp = testloader.__iter__()
x1, y1 = tmp.next() 
x2, y2 = tmp.next()

forループ以外のデータの取り出し方法

例えば、MNISTのデータを取り出すには、DataLoaderクラスのインスタンスに対して、__iter__()next()を使って以下のように書けばよい。

from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

if __name__ == '__main__':

    # datasetの読み出し
    bs = 128 # batch size 
    transform = transforms.ToTensor()
    testset = MNIST(root='./data', train=False, download=True, transform=transform)
    testloader = DataLoader(testset, batch_size=bs, shuffle=False)

    # dataの取り出し
    tmp = testloader.__iter__()
    x1, y1 = tmp.next() 
    x2, y2 = tmp.next()

testloader.__iter__().next()というようにインスタンスを作らずに直接アクセスすると、イテレーションが回らずにずっと最初のデータしか出てこなくなるので注意。必ずtmpみたいな何かしらの変数に入れてからやる必要がある。

forループを使う場合

他のサイトにもよく書いてある方法ではあるが、forループを使う場合についても書いておく。

from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

if __name__ == '__main__':

    # datasetの読み出し
    bs = 128 # batch size 
    transform = transforms.ToTensor()
    testset = MNIST(root='./data', train=False, download=True, transform=transform)
    testloader = DataLoader(testset, batch_size=bs, shuffle=False)

    # dataの取り出し
    for i, data in enumerate(testloader):
        x, y = data