Pytorch导出ONNX踩坑指南

2020-03-25  本文已影响0人  听松客未眠

相对与ONNX模型,Pytorch模型经常较为松散,API的限制也往往较为宽松。因此,在导出的过程中,不可避免地会遇到导出失败的问题。可以预见到,这块API可能在不久的将来会发生变化。

ONNX导出

ONNX导出的基本操作比较简单。官网上的例子是:

import torch
import torchvision

dummy_input = torch.randn(10, 3, 224, 224, device='cuda')
model = torchvision.models.alexnet(pretrained=True).cuda()

# Providing input and output names sets the display names for values
# within the model's graph. Setting these does not change the semantics
# of the graph; it is only for readability.
#
# The inputs to the network consist of the flat list of inputs (i.e.
# the values you would pass to the forward() method) followed by the
# flat list of parameters. You can partially specify names, i.e. provide
# a list here shorter than the number of inputs to the model, and we will
# only set that subset of names, starting from the beginning.
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]

torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)

可惜真要这么容易就好了

ONNX导出验证脚本

import onnxruntime
import numpy as np

sess = onnxruntime.InferenceSession('./model.onnx', None)

# 以图像分类为例,batchsize设为2测试导出模型支持batching。
sess.run(None, {'input_1': np.random.rand(2, 3, img_height, img_width).astype('float32')})

让导出模型支持同时处理多个数据(Batching)

支持Batching需要制定Dynamic Axes,即可变的维度。

案例:

torch.export(...,
  input_names=['input_1'],
  output_names=['output_1'],
  dynamic_axes={
    'input_1': [0],  # 第0维是batch dimension
    'output_1': [0],
  })

解决Caffe2运行报错

keep_initializers_as_inputs 这个参数是False的情况下,在Caffe2中报错:IndexError: _Map_base::at. 参考https://github.com/onnx/onnx/issues/2458

opset 11在onnxruntime中运行时没使用GPU

问题比较复杂。貌似tensorflow也有类似问题。导出时添加参数do_constant_folding=True或许可以解决。
参考https://github.com/NVIDIA/triton-inference-server/issues/1080

List of tensor的导出

定长list

定长list会导出为一个tuple

变长list

Pytorch 1.4,ONNX 9不支持变长List的导出。之后的Pytorch版本有支持,需要更高版本的ONNX

不支持的操作

不一致的Operator

Expand

Pytorch中,Expand未改动的dim可以指定为-1,导出到ONNX中时,需要手动指定每个dim的值。如:

Pytorch:
a = a.expand(10, -1, -1)
ONNX:
a = a.expand(10, a.size(1), a.size(2))

Squeeze

Pytorch中,Squeeze一个不为1维的dim不会有任何效果。ONNX会报错

上一篇 下一篇

猜你喜欢

热点阅读