影视处理

物体检测4:Android上应用物体监测(TF1)

2020-03-31  本文已影响0人  古风子
tf

在上一章节训练一个物体检测器,
我们将训练后的模型导出成了pb文件,用在PC侧tensorflow物体监测;本章节,我们尝试在Android手机上转化我们训练好的模型,供手机端tensorflow-lite使用

生成pb和pbtxt文件

#~/tensorflow/models/research/object_detection$ 
python export_tflite_ssd_graph.py \
    --pipeline_config_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_config/ssd_mobilenet_v1_raccoon.config  \
    --trained_checkpoint_prefix=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/model.ckpt-62236 \
    --output_directory=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train \
    --add_postprocessing_op=true

输出结果为:

(base) jiadongfeng@jiadongfeng:~/tensorflow/dataset/raccoon_dataset/jdf_train$ ls | grep tflite_graph
tflite_graph.pb
tflite_graph.pbtxt



pb文件转化成tflite

运行以下命令:

#~/anaconda2/lib/python2.7/site-packages/tensorflow/lite
toco  \
 --graph_def_file=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/tflite_graph.pb \
 --output_file=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/detect.tflite \
 --input_shapes=1,300,300,3 \
 --input_arrays=normalized_input_image_tensor \
 --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'  \
 --inference_type=QUANTIZED_UINT8 \
 --mean_values=128 \
 --std_dev_values=128 \
 --change_concat_input_ranges=false \
 --allow_custom_ops 

会提示以下错误:

 F tensorflow/lite/toco/tooling_util.cc:1709] Array FeatureExtractor/MobilenetV1/MobilenetV1/Conv2d_0/Relu6, which is an input to the DepthwiseConv operator producing the output array FeatureExtractor/MobilenetV1/MobilenetV1/Conv2d_1_depthwise/Relu6, is lacking min/max data, which is necessary for quantization. If accuracy matters, either target a non-quantized output format, or run quantized training with your model from a floating point checkpoint to change the input graph to contain min/max information. If you don't care about accuracy, you can pass --default_ranges_min= and --default_ranges_max= for easy experimentation.
Aborted (core dumped)

错误解决方案一:
使用非量化的转换,需要将inference_type=QUANTIZED_UINT8 改为—inference_type=FLOAT并添加:
--default_ranges_min
--default_ranges_max

Quantized模型里面的权重参数用1个字节的uint8类型表示,模型大小是Float版本的四分之一;后续我们再讲解怎么生成Quantized的模型文件

最后运行以下命令:

toco  \
 --graph_def_file=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/tflite_graph.pb \
 --output_file=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/detect.tflite \
 --input_shapes=1,300,300,3 \
 --input_arrays=normalized_input_image_tensor \
 --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'  \
 --inference_type=FLOAT \
 --mean_values=128 \
 --std_dev_values=128 \
 --change_concat_input_ranges=false \
 --allow_custom_ops \
 --default_ranges_min=0\
 --default_ranges_max=6

生成detect.tflite文件:

base) jiadongfeng@jiadongfeng:~/tensorflow/dataset/raccoon_dataset/jdf_train$ ls | grep detect
detect.tflite

生成的文件达到22MB,而原生的支持80个物种监测的tflite文件(Quantized类型)相机集成物体监测,仅仅为3MB左右;

解决方案二:
使用量化转换,将inference_type和input_data_type设置为QUANTIZED_UINT8 ;
参数default_ranges_min和default_ranges_max也需要设置

toco  \
 --graph_def_file=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/tflite_graph.pb \
 --output_file=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/detect.tflite \
 --input_shapes=1,300,300,3 \
 --input_arrays=normalized_input_image_tensor \
 --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'  \
 --inference_type=QUANTIZED_UINT8 \
 --input_data_type=QUANTIZED_UINT8 \
 --mean_values=128 \
 --std_dev_values=128 \
 --change_concat_input_ranges=false \
 --allow_custom_ops \
 --default_ranges_min=0\
 --default_ranges_max=6

此种方案生成的tflite文件,比非量化模式减少了四倍,精度稍有下降

tflite集成到相机中

集成浣熊监测.png
    private static final boolean TF_OD_API_IS_QUANTIZED = true;

    detector =
          TFLiteObjectDetectionAPIModel.create(
                            cameraActivity.getAssets(),
                            TF_OD_API_MODEL_FILE,
                            TF_OD_API_LABELS_FILE,
                            TF_OD_API_INPUT_SIZE,
                            TF_OD_API_IS_QUANTIZED);
浣熊识别.png
上一篇下一篇

猜你喜欢

热点阅读