PytorchのDataLoaderって便利だなと思いつつも、forループ以外の方法でデータを取り出すことができなかったので、中身を少し調べてみた。以下のようにすればデータがとれることがわかった。
tmp = testloader.__iter__() x1, y1 = tmp.next() x2, y2 = tmp.next()
forループ以外のデータの取り出し方法
例えば、MNISTのデータを取り出すには、DataLoaderクラスのインスタンスに対して、__iter__()
とnext()
を使って以下のように書けばよい。
from torch.utils.data import DataLoader from torchvision import transforms from torchvision.datasets import MNIST if __name__ == '__main__': # datasetの読み出し bs = 128 # batch size transform = transforms.ToTensor() testset = MNIST(root='./data', train=False, download=True, transform=transform) testloader = DataLoader(testset, batch_size=bs, shuffle=False) # dataの取り出し tmp = testloader.__iter__() x1, y1 = tmp.next() x2, y2 = tmp.next()
testloader.__iter__().next()
というようにインスタンスを作らずに直接アクセスすると、イテレーションが回らずにずっと最初のデータしか出てこなくなるので注意。必ずtmp
みたいな何かしらの変数に入れてからやる必要がある。
forループを使う場合
他のサイトにもよく書いてある方法ではあるが、forループを使う場合についても書いておく。
from torch.utils.data import DataLoader from torchvision import transforms from torchvision.datasets import MNIST if __name__ == '__main__': # datasetの読み出し bs = 128 # batch size transform = transforms.ToTensor() testset = MNIST(root='./data', train=False, download=True, transform=transform) testloader = DataLoader(testset, batch_size=bs, shuffle=False) # dataの取り出し for i, data in enumerate(testloader): x, y = data