tensorflow生成tflite格式模型的注意事项
近期有网友就之前的一篇文章问我问题(MobileNet SSD V2模型的压缩与tflite格式的转换(补充版) - 简书),于是我又重新走了一遍转换过程,发现了一些新的问题,在此注明一下。
Tensorflow不同版本
10月tensorflow2.0版本正式发布,API的变化很大,转换tflite的方式也发生了变化。
在tfv1中,tflite_convert.py文件支持graph_def_file,saved_model_dir,keras_model_file三种模型文件的输入,在tf2.0中则只剩下saved_model_dir,keras_model_file两种模型文件,但是之前通过tensorflow object detection api训练出来的模型正是graph_def_file格式,所以按照tfv1流程在tf2.0的环境下是肯定走不通的,会报如下错误:
tflite_convert.py: error: one of the arguments --saved_model_dir --keras_model_file is required
目前网上使用tensorflow object detection api的教程应该都是在tfv1环境下操作的,为了避免少走弯路,还是先使用tfv1环境操作为妙。
由于我还没有用tf2.0训练转换过模型,所以如何实现暂且不谈。
不同系统
上一篇文章是在ubuntu16.04,tf1.12的环境下实现的,但是一个月前的我的ubuntu系统崩溃掉了,而我又懒得重装一次(不得不吐槽在外星人上安装双系统真的是很麻烦啊),于是我现在本机的开发环境是win10+tf1.15。
在tf1.15的tflite_convert.py文件中,已经加入了对tf版本的判断(这在tf1.12中是没有的,具体从哪个版本加入的没有考证)。
由于上次转换模型已经是在8月份了,所以其实我也是看着自己的文章才能完整再来一遍,但是出现了下面这个错误:
Check failed: GetOpWithOutput(model, output_array) Specified output array "'TFLite_Detection_PostProcess'" is not produced by any op in this graph.
天秀啊,TFLite_Detection_PostProcess不是默认名字嘛,居然说木有。。。
在github上搜到了一位小哥的回复,说是在windows上需要把引号去掉,也就是输入的参数应该是
output_arrays=TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3
Toco/TFLite_Convert for TFLite Problem · Issue #22106 · tensorflow/tensorflow · GitHub
综上,在tfv1环境下,使用ubuntu和windows转换模型应该可以顺利完成了。