torchvision官方Mask RCNN 转ONNX
2020-04-20 本文已影响0人
qizhen816
Torch==1.4 Torchvision==0.5.0版本的官方Mask RCNN已经可以转ONNX了,
https://github.com/pytorch/vision/blob/7b60f4db9707d7afdbb87fd4e8ef6906ca014720/test/test_onnx.py#L31
在onnxruntime上有些操作还不支持,速度不是特别快。
import onnx
import torch
import torch.onnx
import torchvision
import cv2
import numpy as np
from torchvision import transforms
def pytorch2onnx():
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False,pretrained_backbone=False,min_size=200, max_size=500)
# pretrained和pretrained_backbone为True将会自动下载预训练权重
pthfile = 'maskrcnn_resnet50_fpn.pth'
load = torch.load(pthfile, map_location='cpu')
model.load_state_dict(load, strict=True) #["model"]
# data type nchw
img = cv2.imread('imgs/2.jpg')
img1 = cv2.resize(img, (400,500))
tt = transforms.ToTensor()
img1 = [tt(img1)]
model.eval()
input_names = ["images_tensors"]
output_names = ["outputs"]
torch.onnx.export(model, (img1,),
"maskrcnn.onnx",
verbose=True,
opset_version=11,
# dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]},
# do_constant_folding=True,
input_names=input_names,
output_names=output_names,)
print('pass')
if __name__ == "__main__":
pytorch2onnx()
一定要确保图片可以产生预测结果,使用空白tensor会报错
RuntimeError: ONNX export failed: Couldn't export Python operator _NewEmptyTensorOp