最近は特に何もない中の人です。前回からの続きで、今回はM5StickV向けのモデルを作っていきます。
コードについてはこちらから:https://github.com/shtsno24/DAE_for_M5StickV
1-1.ターゲットモデル
M5Stack+KPUの作例で多く出ているのがYOLOなどを使用した物体検出です。まあ同じものをやっても面白みに欠けるので、今回はCNNを用いたDeep Autoencoderを作成していきます。Autoencoderを端的に言えば、入力データと同じ出力データを得られるようにトレーニングしたモデルです。一見無駄に見えますが、モデル構造を工夫することで、画像圧縮・復元ができるようになります(もう少し話すと、tf.kerasで容易にデータセットが構築できるCIFAR-10の入力画像をそのまま復元することにしました。モデルの入出力は32x32x3(トレーニング時は1x32x32x3)の画像になります(入出力のデータは0-1に正規化する)。
1-2.ライブラリ
M5StickV向けのモデル作成にはTensorflowまたはCaffeが対応しているようです。今回は対応しているレイヤの多いTensorflowを使います。Tensorflowによる開発の場合、Tensorflow ->Tensorflow Lite -> Kmodelとなるので、Tensorflow Liteへの対応がカギになります。
NNCaseのアーキテクチャ。 資料は見当たらないが、ONNX・PaddlePaddle(百度のDLフレームワーク)も対応しているらしい。 (NNCase/README.mdより) |
1-3.モデル作成前の準備
Tensorflowをpipでインストールすればモデルの作成・推論はできます。トレーニングに関してはGPUが欲しいので、Google Colabを使うか、自前でGPU付きのマシンを準備したほうがよいでしょう。中の人は、ラボのUbuntuマシンにDocker環境を作りました。
1-4.モデル
今回作成するCNNベースのDeep Autoencoderは、全層CNNで構成します。これはKPU
がCNN向けに最適化されているためであり(もっと言えば1x1 or 3x3カーネル、ストライド=1 or 2)、また全結合を使うよりもパラメータ数を削減し、小メモリ化を図るためです。(参考:NNCase/FAQ_EN.md)今回作成したモデル全体図(tflite)。レイヤが確認できないほど長くてもメモリに載るのはありがたい。 |
モデルを確認する際にはnetronを使いました。tf.kerasの.h5や.tfliteを読み込めるので、レイヤの種類の確認にはうってつけでした。
モデルの作成で注意したいのが、使用できるレイヤの種類です。今回はtf.kerasを使用したので、kerasのレイヤからTensorflow Liteのレイヤに変換できるかが第一関門となります。このページを参考にしながら、対応するkerasのレイヤと紐づけていくのが楽でした。YOLO、UNetあたりなら、ほぼそのまま変換ができそうです。
第二関門はTensorflow Liteのレイヤからkmodelのレイヤに変換できるかです。こちらはNNCaseのGitHubに一覧があるので、そこから対応するレイヤを確認しましょう。
コード:https://github.com/shtsno24/DAE_for_M5StickV/blob/master/Model.py
1-5.トレーニング
入力画像。実際は32x32の画像なのでものすごく小さい |
トレーニングは通常のtf.kerasを使った場合と同じです。今回は、訓練データ50000枚、検証データ10000枚で、100epoch回しました。バッチサイズは100枚とし、lossはMSE、OptimizerはAdamを使用しました。
tf.kerasモデルのトレーニング結果。輪郭がぼやけているが、雰囲気はとらえられている。 |
コード:https://github.com/shtsno24/DAE_for_M5StickV/blob/master/Train.py
1-6.TFLiteへの変換
tf.kerasに変換したモデルを今度は.tfliteに変換します。このやり方は、ここを参考にするとわかりやすいです。参考例ではkeras付属のMobilenetを読み出して変換していますが、今回は自作のモデルを読み込んだうえで変換するので、書き方が少し変わってきます。
コード:https://github.com/shtsno24/DAE_for_M5StickV/blob/master/Convert_to_tflite.py
.tfliteの推論結果。tf.kerasの結果と大差ない。 |
コード:https://github.com/shtsno24/DAE_for_M5StickV/blob/master/Convert_to_tflite.py
1-7.次回予告
.tfliteまで変換できたので、次回はkmodelへの変換となります(前回と今回でまだ本題に入っていない...)。
©2020 shts All Right Reserved.
0 件のコメント:
コメントを投稿