Pytorch: 转onnx及精度验证

2020-06-22  本文已影响0人  wzNote

1. 环境配置

pytorch
onnxruntime==1.2.0 (1.3.0版本会报错ImportError: cannot import name 'get_all_providers')
onnxruntime-gpu==1.2.0
cuda10.1+cudnn7.6

2. 模型准备和转换

用torch.save()存储模型结构和权重

model = torch.load('pix2pix.pth', map_location=torch.device('cuda'))

单卡训练的模型

torch.onnx._export(model, dummy_input, "pix2pix.onnx", verbose=True, opset_version=11)

多卡训练的模型

torch.onnx._export(model, dummy_input, "pix2pix.onnx", verbose=True, opset_version=11)

3. 验证是否有精度损失

import onnxruntime
import numpy as np
from onnxruntime.datasets import get_example

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# 得到torch模型的输出
dummy_input = torch.randn(1, 3, 256, 256, device='cuda')
model.eval()
with torch.no_grad():
    torch_out = model(dummy_input)
print(torch_out)

# 得到onnx模型的输出
example_model = get_example('D:/workspace_python/model_utils/pix2pix.onnx') #一定要写绝对路径
sess = onnxruntime.InferenceSession(example_model)
onnx_out = sess.run(None, {input_name: to_numpy(dummy_input)})

# 判断输出结果是否一致,小数点后3位一致即可
np.testing.assert_almost_equal(to_numpy(torch_out), onnx_out[0], decimal=3)
上一篇下一篇

猜你喜欢

热点阅读