Pythonいぬ

pythonを使った画像処理に関する記事を書いていきます

Pytorchで特徴点のデータ拡張

Deep learningで学習するとき、ほとんどどんなタスクでも使うデータ拡張(data augmentation)。画像の輝度値をランダムで変化させたり、zoom in/outや画像のshift, flip, 回転などを使って、ネットワークを入力の分布にできる正則化の効果がある。

だけどOpenPoseなどで使うkeypoint抽出ではターゲットのkeypointの座標も動かさないといけない。ので、shift, zoomなどを自分で書こうとするとわりと大変。affineなんてもってのほか。ということで、data augmentationのライブラリであるalbumentationsを使ってkeypointのaugmentationをやってみる。

特徴点 (keypoint)

画像1枚に対して、その画像上のいくつかの点を特徴点とする。学習で使うことを考えているので、この特徴点は画像の特徴を表すものでなくてもOK。画像上のどこに打ってもよい。ただし、自分で学習用の正解データを作る場合は各画像で矛盾がないように打つ必要がある。

例えば、以下のように特徴点を打つとする。

f:id:tzmi:20200517212307p:plain

例として、それぞれ左目、右目、鼻の3箇所に特徴点を打ってみた。この特徴点は画像とペアになって保存されるとする。

この特徴点を含んだdata augmentationをしてみたいと思う。画像の方は拡大縮小やシフトなど色々やればいいと思うが、これに合わせて座標のほうも動かすのは以外と大変。シフトや拡大をやりすぎて特に画像の外に出たときにどう扱うかも考える必要がある。

ということで、自作するのではなくdata augmentationのライブラリを使うことにする。

特徴点のデータ拡張

albumentationsを使う。

github.com

本家では「pip installして使ってね」と書いてあるけど、自分の環境で直接pip install したら環境がバグって入れ直しになった。なので、condaなどで新しい環境を作ってからインストールした方がよいかもしれない。(自己責任でお願いします)

Pytorchにtransformsという関数があるけど、albumentationsはこれを拡張したdata augmentation用のライブラリ。

ちょうど「シフトスケール回転」という名前の関数があるので使ってみることにする。使いやすいように関数でラップして画像と特徴点を入れるとdata augmentationされた画像と特徴点が返ってくるようにする。

import albumentations as albu

def data_augmentation(img, keypoints):
    ssr = albu.ShiftScaleRotate(shift_limit=0.5,
                                scale_limit=0.5,
                                rotate_limit=20,
                                border_mode=1,
                                p=1.0)
    kpfunc = albu.KeypointParams(format='xy', remove_invisible=False)
    trans_func = albu.Compose([ssr,], p=1.0, keypoint_params=kpfunc)
    transform = trans_func(image=img, keypoints=keypoints)

    img = transform['image']
    keypoints = transform['keypoints']

    return img, keypoints

KyepointParamsremove_invisibleという引数があるが、これは特徴点がdata augmentationで画像の外に飛んでも、外の座標が返ってくるようにする引数。これをTrueにすると画像の内側の座標のみが返るようになるため、入出力でkeypointの数が変化することがある。

必ずdata augmentationするようにしたいので、pの値は1.0にしておく。pを0.5にすると確率0.5でdata augmentationをすることになる。

使い方としては例えば以下のようにすればよいかと。

import numpy as np

if __name__ == '__main__':
    img_org = xxxxxxx
    keypoints_org = np.array( ((98,134), (197,106), (175,166)), dtype=np.float64)

    imgs, kps = []
    for i in range(18):
        img, keypoints = data_augmentation(img_org, keypoints_org)

        imgs.append(img)
        kps.append(keypoints)

他にもPytorchなら自作のDataSetに組み込んでもいいかもしれない。

結果

上記の方法でdata augmentationをやってみた結果。

f:id:tzmi:20200517212216p:plain

いい感じ。きちんとkeypointもdata augmentationできてる。