torchvision官方Mask RCNN 转ONNX

2020-04-20  本文已影响0人  qizhen816

Torch==1.4 Torchvision==0.5.0版本的官方Mask RCNN已经可以转ONNX了,
https://github.com/pytorch/vision/blob/7b60f4db9707d7afdbb87fd4e8ef6906ca014720/test/test_onnx.py#L31

在onnxruntime上有些操作还不支持,速度不是特别快。

import onnx
import torch
import torch.onnx
import torchvision
import cv2
import numpy as np
from torchvision import transforms

def pytorch2onnx():
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False,pretrained_backbone=False,min_size=200, max_size=500)
    # pretrained和pretrained_backbone为True将会自动下载预训练权重
    pthfile = 'maskrcnn_resnet50_fpn.pth'
    load = torch.load(pthfile, map_location='cpu')
    model.load_state_dict(load, strict=True) #["model"]

    # data type nchw
    img = cv2.imread('imgs/2.jpg')
    img1 = cv2.resize(img, (400,500))
    tt = transforms.ToTensor()
    img1 = [tt(img1)]
    model.eval()
    input_names = ["images_tensors"]
    output_names = ["outputs"]
    torch.onnx.export(model, (img1,),
                      "maskrcnn.onnx",
                      verbose=True,
                      opset_version=11,
                      # dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]},
                      # do_constant_folding=True,
                      input_names=input_names,
                      output_names=output_names,)
    print('pass')


if __name__ == "__main__":
    pytorch2onnx()

一定要确保图片可以产生预测结果,使用空白tensor会报错
RuntimeError: ONNX export failed: Couldn't export Python operator _NewEmptyTensorOp

上一篇 下一篇

猜你喜欢

热点阅读