Pythonいぬ

pythonを使った画像処理に関する記事を書いていきます

Pytorchで取得できるMNIST系のdataset一覧

Pytorchのデータセットに入っているMNISTとその亜種について調べてみた。これらのデータセットの呼び出し方と使い方についてまとめてみる。

取得できるMNIST系データセット

torchvision.datasetsに入っているMNIST系のデータセットは下記のコマンドで確認できる。

In [1]: from torchvision import datasets                                                                          
In [2]: [d for d in dir(datasets) if 'MNIST' in d]                                                                

Out[2]: ['EMNIST', 'FashionMNIST', 'KMNIST', 'MNIST', 'QMNIST']

torchvisionのバージョン'0.5.0'では、EMNIST, FashionMNIST, KMNIST, MNIST, QMNISTの5種類が使える模様。これらの使い方と、どのような画像なのかを書いていく。

MNIST

まずはMNISTのデータの読み込み方法。データはホームディレクトリ以下の./dataというディレクトリに保存するようにする。

import os
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets

if __name__ == '__main__':
    homedir = os.path.expanduser('~')
    datadir = '%s/.data' % homedir

    transform = transforms.Compose([transforms.ToTensor()])
    trainset = datasets.MNIST(datadir, download=True, transform=transform)
    trainloader = DataLoader(trainset, batch_size=256, shuffle=True)

基本的には上記のtrainset = dataset.xxxx(datadir,..xxxxの部分を入れ替えればよい。modelを学習する際はさらに次のように書く。

for i, (x, y) in enumerate(trainloader):
    #学習処理
    y_ = model.forward(x)
    :
    :

単純にバッチを得たい場合はtrainloaderを読みだした後に以下のコマンド。trainloaderの定義の際にshuffle=Trueにしていないと、同じデータが読み出されるので注意。

x, y = trainloader.__iter__().next()

shuffle=Falseで同じデータを読みだしたくないときはインスタンスをはさめばよい。例えばtrainloader.__iter__()tmpに入れる。

tmp = trainloader.__iter__()
x, y = tmp.next()
x2, y2 = tmp.next()

取り出したxをつながった画像にするには一度numpyに変換してから整形する。npは使わないので、numpyをimportする必要はない。

out = x.detach().numpy()
out = out.reshape(8,32,28,28)
out = out.transpose(0,2,1,3)
out = out.reshape(8*28,32*28)
out = (out*255).astype(np.uint8)

下記が出来上がった画像

f:id:tzmi:20200224191855p:plain

EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST

EMNIST

EMNISTは数字だけでなく、アルファベットなどに拡張したMNIST。EMNISTの場合はsplit=''の部分を書く必要がある。byclass, bymerge, balanced, letters, digits, mnistの6種類が選べるらしい。今回は面白そうなのでlettersを選んでみた。

trainset = datasets.EMNIST(datadir, split='letters', download=True, transform=transform)

lettersはxとyが逆に入っているようなので、torch tensorの時点で入れ替えておく

x = x.permute(0,1,3,2)

EMNISTのlettersの画像

f:id:tzmi:20200224192445p:plain

FashionMNIST

FashionMNISTは、文字ではなく服とか靴とかの画像で作られた10クラス分類用のグレー画像28x28のデータセット

trainset = datasets.FashionMNIST(datadir, download=True)

FashionMNISTの画像。

f:id:tzmi:20200224192134p:plain

KMNIST

日本語のくずし字のデータセット

trainset = datasets.KMNIST(datadir, download=True)

KMNISTの画像は以下。ランダムに並べると何かの呪文のように見える。。。

f:id:tzmi:20200224193331p:plain

以下参考

codh.rois.ac.jp

QMNIST

MNISTの拡張。MNISTはテストセットに60kの画像があるようなんだけど、ラベルが信頼できないから10kしか使っていない。QMNISTではテストに60kの画像を使えるようにしたよと本家に書いてあった。

trainset = datasets.QMNIST(datadir, download=True, transform=transform)

下記がQMNISTの画像。画像自体は同じデータセットなので、MNISTとほとんど変わらない。

f:id:tzmi:20200224193722p:plain

まとめ

Pytorchで使えるMNIST系のデータセットについてまとめてみた。他にも使えるデータセットはたくさんあるので、今度は別のデータセットについてまとめてみたいと思う。