Pytorchのデータセットに入っているMNISTとその亜種について調べてみた。これらのデータセットの呼び出し方と使い方についてまとめてみる。
取得できるMNIST系データセット
torchvision.datasets
に入っているMNIST系のデータセットは下記のコマンドで確認できる。
In [1]: from torchvision import datasets In [2]: [d for d in dir(datasets) if 'MNIST' in d] Out[2]: ['EMNIST', 'FashionMNIST', 'KMNIST', 'MNIST', 'QMNIST']
torchvisionのバージョン'0.5.0'では、EMNIST
, FashionMNIST
, KMNIST
, MNIST
, QMNIST
の5種類が使える模様。これらの使い方と、どのような画像なのかを書いていく。
MNIST
まずはMNISTのデータの読み込み方法。データはホームディレクトリ以下の./data
というディレクトリに保存するようにする。
import os from torch.utils.data import DataLoader from torchvision import transforms from torchvision import datasets if __name__ == '__main__': homedir = os.path.expanduser('~') datadir = '%s/.data' % homedir transform = transforms.Compose([transforms.ToTensor()]) trainset = datasets.MNIST(datadir, download=True, transform=transform) trainloader = DataLoader(trainset, batch_size=256, shuffle=True)
基本的には上記のtrainset = dataset.xxxx(datadir,..
のxxxx
の部分を入れ替えればよい。modelを学習する際はさらに次のように書く。
for i, (x, y) in enumerate(trainloader): #学習処理 y_ = model.forward(x) : :
単純にバッチを得たい場合はtrainloader
を読みだした後に以下のコマンド。trainloader
の定義の際にshuffle=True
にしていないと、同じデータが読み出されるので注意。
x, y = trainloader.__iter__().next()
shuffle=False
で同じデータを読みだしたくないときはインスタンスをはさめばよい。例えばtrainloader.__iter__()
をtmp
に入れる。
tmp = trainloader.__iter__() x, y = tmp.next() x2, y2 = tmp.next()
取り出したx
をつながった画像にするには一度numpyに変換してから整形する。np
は使わないので、numpy
をimportする必要はない。
out = x.detach().numpy() out = out.reshape(8,32,28,28) out = out.transpose(0,2,1,3) out = out.reshape(8*28,32*28) out = (out*255).astype(np.uint8)
下記が出来上がった画像
EMNIST
, FashionMNIST
, KMNIST
, MNIST
, QMNIST
EMNIST
EMNISTは数字だけでなく、アルファベットなどに拡張したMNIST。EMNISTの場合はsplit=''の部分を書く必要がある。byclass
, bymerge
, balanced
, letters
, digits
, mnist
の6種類が選べるらしい。今回は面白そうなのでletters
を選んでみた。
trainset = datasets.EMNIST(datadir, split='letters', download=True, transform=transform)
letters
はxとyが逆に入っているようなので、torch tensorの時点で入れ替えておく
x = x.permute(0,1,3,2)
EMNISTのletters
の画像
FashionMNIST
FashionMNISTは、文字ではなく服とか靴とかの画像で作られた10クラス分類用のグレー画像28x28のデータセット。
trainset = datasets.FashionMNIST(datadir, download=True)
FashionMNISTの画像。
KMNIST
日本語のくずし字のデータセット
trainset = datasets.KMNIST(datadir, download=True)
KMNISTの画像は以下。ランダムに並べると何かの呪文のように見える。。。
以下参考
QMNIST
MNISTの拡張。MNISTはテストセットに60kの画像があるようなんだけど、ラベルが信頼できないから10kしか使っていない。QMNISTではテストに60kの画像を使えるようにしたよと本家に書いてあった。
trainset = datasets.QMNIST(datadir, download=True, transform=transform)
下記がQMNISTの画像。画像自体は同じデータセットなので、MNISTとほとんど変わらない。
まとめ
Pytorchで使えるMNIST系のデータセットについてまとめてみた。他にも使えるデータセットはたくさんあるので、今度は別のデータセットについてまとめてみたいと思う。