物体检测4:Android上应用物体监测(TF1)
2020-03-31 本文已影响0人
古风子
tf
在上一章节训练一个物体检测器,
我们将训练后的模型导出成了pb文件,用在PC侧tensorflow物体监测;本章节,我们尝试在Android手机上转化我们训练好的模型,供手机端tensorflow-lite使用
生成pb和pbtxt文件
#~/tensorflow/models/research/object_detection$
python export_tflite_ssd_graph.py \
--pipeline_config_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_config/ssd_mobilenet_v1_raccoon.config \
--trained_checkpoint_prefix=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/model.ckpt-62236 \
--output_directory=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train \
--add_postprocessing_op=true
输出结果为:
(base) jiadongfeng@jiadongfeng:~/tensorflow/dataset/raccoon_dataset/jdf_train$ ls | grep tflite_graph
tflite_graph.pb
tflite_graph.pbtxt
pb文件转化成tflite
运行以下命令:
#~/anaconda2/lib/python2.7/site-packages/tensorflow/lite
toco \
--graph_def_file=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/tflite_graph.pb \
--output_file=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/detect.tflite \
--input_shapes=1,300,300,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
--inference_type=QUANTIZED_UINT8 \
--mean_values=128 \
--std_dev_values=128 \
--change_concat_input_ranges=false \
--allow_custom_ops
会提示以下错误:
F tensorflow/lite/toco/tooling_util.cc:1709] Array FeatureExtractor/MobilenetV1/MobilenetV1/Conv2d_0/Relu6, which is an input to the DepthwiseConv operator producing the output array FeatureExtractor/MobilenetV1/MobilenetV1/Conv2d_1_depthwise/Relu6, is lacking min/max data, which is necessary for quantization. If accuracy matters, either target a non-quantized output format, or run quantized training with your model from a floating point checkpoint to change the input graph to contain min/max information. If you don't care about accuracy, you can pass --default_ranges_min= and --default_ranges_max= for easy experimentation.
Aborted (core dumped)
错误解决方案一:
使用非量化的转换,需要将inference_type=QUANTIZED_UINT8 改为—inference_type=FLOAT并添加:
--default_ranges_min
--default_ranges_max
Quantized模型里面的权重参数用1个字节的uint8类型表示,模型大小是Float版本的四分之一;后续我们再讲解怎么生成Quantized的模型文件
最后运行以下命令:
toco \
--graph_def_file=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/tflite_graph.pb \
--output_file=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/detect.tflite \
--input_shapes=1,300,300,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
--inference_type=FLOAT \
--mean_values=128 \
--std_dev_values=128 \
--change_concat_input_ranges=false \
--allow_custom_ops \
--default_ranges_min=0\
--default_ranges_max=6
生成detect.tflite文件:
base) jiadongfeng@jiadongfeng:~/tensorflow/dataset/raccoon_dataset/jdf_train$ ls | grep detect
detect.tflite
生成的文件达到22MB,而原生的支持80个物种监测的tflite文件(Quantized类型)相机集成物体监测,仅仅为3MB左右;
解决方案二:
使用量化转换,将inference_type和input_data_type设置为QUANTIZED_UINT8 ;
参数default_ranges_min和default_ranges_max也需要设置
toco \
--graph_def_file=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/tflite_graph.pb \
--output_file=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/detect.tflite \
--input_shapes=1,300,300,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
--inference_type=QUANTIZED_UINT8 \
--input_data_type=QUANTIZED_UINT8 \
--mean_values=128 \
--std_dev_values=128 \
--change_concat_input_ranges=false \
--allow_custom_ops \
--default_ranges_min=0\
--default_ranges_max=6
此种方案生成的tflite文件,比非量化模式减少了四倍,精度稍有下降
tflite集成到相机中
- 首先,相机集成流程参考相机集成物体监测,我们这里直接替换tflite文件,修改对应的label_map文件即可
- 然后,原生的监测模型是量化后的模型,而我们的是float类型的模型;所以需要将TF_OD_API_IS_QUANTIZED 改为false
private static final boolean TF_OD_API_IS_QUANTIZED = true;
detector =
TFLiteObjectDetectionAPIModel.create(
cameraActivity.getAssets(),
TF_OD_API_MODEL_FILE,
TF_OD_API_LABELS_FILE,
TF_OD_API_INPUT_SIZE,
TF_OD_API_IS_QUANTIZED);
- 最后,集成后的监测效果图为: