pytorchでtensorの画像サイズを縮小するにはadaptive_avg_pool2d
を使えばよかった。しかし拡大する際にはこの関数だとnearest neighbor になる。ということでtorch tensorでbicubic補間をやってみる。
まずは結果から。opencvでbucibucした場合とほとんど変わらない結果になる。
pytorchでの画像サイズの縮小はこちら tzmi.hatenablog.com
pytorchの関数を使う方法
下のスクラップから書く記事を書いた後にpytorchでbicubicするための関数を見つけたので追記する。使い方は簡単。
import torch from torch.nn import functional as F x = torch.randn(1,1,4,4) x_out = F.interpolate(x, (8,8), mode='bicubic', align_corners=False)
入力がn,c,h,wとなっていることだけ注意すればよい。
ニューラルネットのレイヤとして使いたい場合はnn.Upsample
import torch from torch import nn x = torch.randn(1,1,4,4) m = nn.Upsample((8,8), mode='bicubic', align_corners=False) x_out = m(x)
Bicubic 補間
スクラップからbicubicを書くのをやったことがなかった。ググってもいい情報が出てこないから「Pytorchでbicubicをpytorchで実装した記事」を書いてみた。c++じゃなくてpythonで書く。そしてforループは使わずに書く(重要)。
bicubicなんてただの周辺4x4画素のsinc関数重み付き平均である。つまり、①「入力画像の周辺4x4画素の画素値」と②「各画素に対する重み」が求まればよい。forループなしで。具体的に説明すると以下のようになる。
左はリサイズ後の画像における画素中心をリサイズ前の画素中心のグリッドの中に置いた図。出力画素値を計算するために入力画像の周辺4x4画素を使う。右は入力画像の画素間の距離を単位とした近似sinc関数の図。このsinc関数の中心をリサイズ後の画像の画素中心として、各入力画素中心における重みを計算する。つまり、ほしい情報は以下
- リサイズ後の各画素に対して、リサイズ前の周辺4x4画素の画素値(画素値用)
- リサイズ後の各画素をリサイズ前の画素座標に置いたとき、リサイズ前の周辺4x4画素までの距離(重み用)
それぞれのtensorのshapeは、 画素値の方は(n, c, h_out, w_out, 4, 4)
、距離の方は(h_out, w_out, 4, 4)
となる。重みつき平均だから最終的な関数は以下のようになるんだろう。
def bicubic(img, h_out=256, w_out=256, a=-0.5): weight, value = get_bicubic_weight_value(img, h_out, w_out, a=-0.5) dst = torch.mean(weight*value, dim=(4,5))/torch.mean(weight, dim=(2,3)) return dst
この重み(weight)と画素値(value)を求めたい。
Pytorchでのbicubicの実装
では関数の中身に行ってみよう。まずはリサイズ前の座標とリサイズ後の座標を確定させてみる。リサイズ後の左上の画素の左上を原点とみると、一番左上の画像の中心は(0.5,0.5)
となる。同様に右下の端は画素数が(h_in, w_in)
であれば、(h_in+0.5, w_in+0.5)
となる。リサイズ後も同じなので、まずは画素中心が[0.5, 1.5, 2.5...w_in+0.5]となる配列を作る。
xaxis_out = torch.arange(w_out).reshape(1,-1).repeat(h_out,1)+0.5 yaxis_out = torch.arange(h_out).reshape(-1,1).repeat(1,w_out)+0.5
次に作った配列にw/w_outを掛け算する。さらにリサイズ前の画像で、画像の左上の端を原点としていたところから、左上の画素中心を原点とするように戻すために0.5を引く。
xaxis_out = xaxis_out*(w/w_out)-0.5 yaxis_out = yaxis_out*(h/h_out)-0.5
周辺画素の座標も作っておく。
px = torch.arange(4).reshape(1,-1).repeat(4,1) - 1 py = torch.arange(4).reshape(-1,1).repeat(1,4) - 1
出力画像の画素の座標から入力画像における周辺4x4画素の画素値を得る。リサイズ後の画素が周辺4x4画素を持っていると考えると、(h_out, w_out, 4, 4)
のサイズを持つテンソルを作れば良い。リサイズ後の画素の整数部分を持ってくれば、入力画像における周辺画素を得ることができる。
# value x = torch.floor(xaxis_out) # integers y = torch.floor(yaxis_out) # integers ix = x.reshape(h_out,w_out,1,1) + px.reshape(1,1,4,4) + 2 iy = y.reshape(h_out,w_out,1,1) + py.reshape(1,1,4,4) + 2 ix = ix.type(torch.long).reshape(-1) iy = iy.type(torch.long).reshape(-1) value= F.pad(img, (2,2, 2,2), mode='reflect')[:,:,iy,ix] value = value.reshape(n,c,h_out,w_out,4,4)
重みの方は周辺画素までの距離をsinc関数に入れるだけだから、ほしいのは「リサイズ前画像で-1から2までのグリッドの中にいるリサイズ後画像の位置」(x,yともに0と1の間)。つまり周辺4x4画素がわかっている前提では、小数点だけとればよい。座標の小数点から入力画像の画素中心までの距離を求め、近似sinc関数で重みを求める。
# weight x = xaxis_out - torch.floor(xaxis_out) # After decimal points y = yaxis_out - torch.floor(yaxis_out) # After decimal points dx = x.reshape(h_out,w_out,1,1) - px.reshape(1,1,4,4) dy = y.reshape(h_out,w_out,1,1) - py.reshape(1,1,4,4) d = torch.sqrt(dx**2+dy**2) mask1 = (d<1).type(torch.float) mask2 = ((d>=1)&(d<2)).type(torch.float) weight = mask1*(1.0-(a+3)*d**2+(a+2)*d**3) + mask2*(a*(-4+8*d-5*d**2+d**3))
Pytorchのレイヤ用のモジュール化
最後にpytorchなのでnn.Module継承のクラスにする。こうすればニューラルネットのレイヤとして使える。
class Bicubic(nn.Module): def __init__(self, h_out, w_out, a=-0.5): super().__init__() self.h_out = h_out self.w_out = w_out self.a = a def forward(self, x): weight, value = self._get_bicubic_weight_value(x) x = torch.mean(weight*value, dim=(4,5))/torch.mean(weight, dim=(2,3)) return x def _get_bicubic_weight_value(self, img): # common h_out, w_out, a = self.h_out, self.w_out, self.a n,c,h,w = img.shape xaxis_out = torch.arange(w_out).reshape(1,-1).repeat(h_out,1)+0.5 yaxis_out = torch.arange(h_out).reshape(-1,1).repeat(1,w_out)+0.5 xaxis_out = xaxis_out*(w/w_out)-0.5 yaxis_out = yaxis_out*(h/h_out)-0.5 px = torch.arange(4).reshape(1,-1).repeat(4,1) - 1 py = torch.arange(4).reshape(-1,1).repeat(1,4) - 1 # weight x = xaxis_out - torch.floor(xaxis_out) # After decimal points y = yaxis_out - torch.floor(yaxis_out) # After decimal points dx = x.reshape(h_out,w_out,1,1) - px.reshape(1,1,4,4) dy = y.reshape(h_out,w_out,1,1) - py.reshape(1,1,4,4) d = torch.sqrt(dx**2+dy**2) mask1 = (d<1).type(torch.float) mask2 = ((d>=1)&(d<2)).type(torch.float) weight = mask1*(1.0-(a+3)*d**2+(a+2)*d**3) + mask2*(a*(-4+8*d-5*d**2+d**3)) # value x = torch.floor(xaxis_out) # integers y = torch.floor(yaxis_out) # integers ix = x.reshape(h_out,w_out,1,1) + px.reshape(1,1,4,4) + 2 iy = y.reshape(h_out,w_out,1,1) + py.reshape(1,1,4,4) + 2 ix = ix.type(torch.long).reshape(-1) iy = iy.type(torch.long).reshape(-1) value= F.pad(img, (2,2, 2,2), mode='reflect')[:,:,iy,ix] value = value.reshape(n,c,h_out,w_out,4,4) return weight, value
weight計算しているときに気になったんだけど、sinc関数じゃなくてパラメータにして最適化したらどんな形になるんだろうか。今度別の記事を書くときにやってみることにする。