いろいろなことを同時に走らせて、全部中途半端になっている中の人です。マウスの設計とFPGA入門、バイトの教材開発を同時に進めるのはやめておきましょう。
前回に引き続き今回もNeural Network Libraries(NNabla)でFCNを動かす話です。
今回は実装編ですが学習編と結構似たコードになるので、学習用のコードから流用できます。やったぜ。
とりあえず必要なライブラリ群を読み込みます。
#-*-coding:utf-8-*-
import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF
import os
import sys
import cv2
import numpy as np
学習ではsolversやinitializer、データセットを載せるためのdata_iteratorをimportしていましたが、今回は不要なので削除。データセット(npz)の読み込みはこの前と同じものを使用します。学習時になかったものとしては各ピクセルに割り当てられたラベルを画像に変換する関数を実装しました。(もっと頭良く組めるはずなんですが、僕には難しいです。)
def label2img(label):
"""
* 1 | road | blue
* 2 | out of road | green
* 3 | line | red
* 4 | backgrownd | black
* 5 | object | yellow
* 6 | other rane | white
*
* out : 3-dimensional numpy array([ch][y][x])
* buff : 3-dimensional numpy array([1][y][x])
* label : 4-dimensional numpy array([1][6][y][x])
* color : [[B,G,R]]
"""
color = np.array([[255,0,0],[0,255,0],[0,0,255],[0,0,0],[0,255,255],[255,255,255]])
buff = np.argmax(label, axis = 1)
out = np.zeros((3,buff.shape[1],buff.shape[2]))
for i in range(len(color)):
out[0][buff[0] == i] = color[i][0]
out[1][buff[0] == i] = color[i][1]
out[2][buff[0] == i] = color[i][2]
return out.astype(np.uint8)
ネットワークは学習時と同じものを使用してください。ここまでで必要な関数がそろいました。では実際に推論を走らせるコードに移ります。
try:
#Loading Dataset
train, teach = load_data(NPZ,"img","img_test")
#Set Params to network
nn.clear_parameters()
x = nn.Variable(train[0:1].shape)
y = network(x, test=True)
t = nn.Variable(teach[0:1].shape)
print("x:",x.shape, "y:", y.shape, "t:", t.shape)
#Search if Params file or not
if os.path.exists(param_file) == True:
nn.load_parameters(param_file)
else:
print("Parameter file was not found!!!")
sys.exit()
#Test
for i in range(train.shape[0]):
x.d, t.d = train[i:i+1], teach[i:i+1]
y.forward()
input_img = x.d[0].transpose(1,2,0).astype(np.uint8)
predict_img = y.d
predict_img = label2img(predict_img).transpose(1,2,0)
ground_truth_img = t.d
ground_truth_img = label2img(ground_truth_img).transpose(1,2,0)
show_img = np.concatenate((input_img,
predict_img,
ground_truth_img,
),axis=1)
cv2.imshow("imshow", show_img[::-1,:,:])
cv2.waitKey(1)
except:
import traceback
traceback.print_exc()
finally:
input(">>")
cv2.destroyAllWindows()
load_data関数で推論させるデータを取り込みます。今回は入力画像、推論結果に加えて教師画像も載せたいので、ついでに取り込みます。
その後パラメータの読み込みまでは学習時とほぼ同じです。学習済みパラメータの読み込み部のみ、パラメータがなかったらプログラムを終了できるように変えてあります。
for分の中では実際に推論をしています。t.dに教師画像を入れてありますが、推論のみを行うのであれば不要になります。
また推論自体はy.forward()で行われ、y.dに推論結果が入っています。あとの部分はcv2での画像表示に関する部分です。そこまで難しくはないはずですが、cv2で画像を表示させる場合、データ構造がNNabla側の(1, color, row, col)から(row, col, color)に代わるのでそれをimg2label関数とnp.transposeで変換しています。
ここまでで推論部分が完成しました。学習部分さえできればその応用で推論ができるのは結構楽ですね(Chainerも似てるけど)。あとはおまけでGPUあたりのお話ができればいいかなぁ。
©2018 shts All Right Reserved.