pytorch模型转onnx模型的方法详解

作者:挣扎的笨鸟 时间:2021-07-20 06:36:37 

学习目标

1.掌握pytorch模型转换到onnx模型

2.顺利运行onnx模型

3.比对onnx模型和pytorch模型的输出结果

学习大纲

  • pytorch模型转换onnx模型

  • 运行onnx模型

  • onnx模型输出与pytorch模型比对

学习内容

前提条件:需要安装onnx 和 onnxruntime,可以通过 pip install onnx 和 pip install onnxruntime 进行安装

1 . pytorch 转 onnx

pytorch 转 onnx 只需要一个函数 torch.onnx.export

torch.onnx.export(model, args, path, export_params, verbose, input_names, output_names, do_constant_folding, dynamic_axes, opset_version)

参数说明:

  • model——需要导出的pytorch模型

  • args——模型的输入参数,满足输入层的shape正确即可。

  • path——输出的onnx模型的位置。例如‘yolov5.onnx’。

  • export_params——输出模型是否可训练。default=True,表示导出trained model,否则untrained。

  • verbose——是否打印模型转换信息。default=False。

  • input_names——输入节点名称。default=None。

  • output_names——输出节点名称。default=None。

  • do_constant_folding——是否使用常量折叠(不了解),默认即可。default=True。

  • dynamic_axes——模型的输入输出有时是可变的,如Rnn,或者输出图像的batch可变,可通过该参数设置。如输入层的shape为(b,3,h,w),batch,height,width是可变的,但是chancel是固定三通道。
    格式如下 :
    1)仅list(int) dynamic_axes={‘input’:[0,2,3],‘output’:[0,1]}
    2)仅dict<int, string> dynamic_axes={&lsquo;input&rsquo;:{0:&lsquo;batch&rsquo;,2:&lsquo;height&rsquo;,3:&lsquo;width&rsquo;},&lsquo;output&rsquo;:{0:&lsquo;batch&rsquo;,1:&lsquo;c&rsquo;}}
    3)mixed dynamic_axes={&lsquo;input&rsquo;:{0:&lsquo;batch&rsquo;,2:&lsquo;height&rsquo;,3:&lsquo;width&rsquo;},&lsquo;output&rsquo;:[0,1]}

  • opset_version&mdash;&mdash;opset的版本,低版本不支持upsample等操作。

import torch
import torch.nn
import onnx

model = torch.load('best.pt')
model.eval()

input_names = ['input']
output_names = ['output']

x = torch.randn(1,3,32,32,requires_grad=True)

torch.onnx.export(model, x, 'best.onnx', input_names=input_names, output_names=output_names, verbose='True')

2 . 运行onnx模型

检查onnx模型,并使用onnxruntime运行。

import onnx
import onnxruntime as ort

model = onnx.load('best.onnx')
onnx.checker.check_model(model)

session = ort.InferenceSession('best.onnx')
x=np.random.randn(1,3,32,32).astype(np.float32)  # 注意输入type一定要np.float32!!!!!
# x= torch.randn(batch_size,chancel,h,w)

outputs = session.run(None,input = { 'input' : x })

参数说明:

  • output_names: default=None
    用来指定输出哪些,以及顺序
    若为None,则按序输出所有的output,即返回[output_0,output_1]
    若为[&lsquo;output_1&rsquo;,&lsquo;output_0&rsquo;],则返回[output_1,output_0]
    若为[&lsquo;output_0&rsquo;],则仅返回[output_0:tensor]

  • input:dict
    可以通过session.get_inputs().name获得名称
    其中key值要求与torch.onnx.export中设定的一致

3.onnx模型输出与pytorch模型比对

import numpy as np
np.testing.assert_allclose(torch_result[0].detach().numpu(),onnx_result,rtol=0.0001)

如前所述,经验表明,ONNX 模型的运行效率明显优于原 PyTorch 模型,这似乎是源于 ONNX 模型生成过程中的优化,这也导致了模型的生成过程比较耗时,但整体效率依旧可观。

此外,根据对 ONNX 模型和 PyTorch 模型运行结果的统计分析(误差的均值和标准差),可以看出 ONNX 模型的运行结果误差很小、基本可靠。

内容参考:https://zhuanlan.zhihu.com/p/422290231

来源:https://blog.csdn.net/weixin_38989668/article/details/123840882

标签:pytorch,模型,onnx
0
投稿

猜你喜欢

  • Python爬虫之获取心知天气API实时天气数据并弹窗提醒

    2023-04-17 14:40:58
  • 揭开HTML 5工作草稿的神秘面纱

    2008-02-13 08:25:00
  • python中有关时间日期格式转换问题

    2023-03-17 07:43:12
  • Python单链表的简单实现方法

    2021-08-14 01:58:33
  • Asp无组件上传进度条解决方案

    2010-04-24 16:01:00
  • 如何运行Python程序的方法

    2023-01-13 07:56:03
  • 关于爬虫中scrapy.Request的更多参数用法

    2023-10-14 02:20:26
  • python实现多线程的方式及多条命令并发执行

    2023-08-09 11:37:20
  • Python Web版语音合成实例详解

    2021-11-28 04:37:20
  • python日志logging模块使用方法分析

    2023-01-06 17:22:51
  • Go语言中slice作为参数传递时遇到的一些“坑”

    2023-08-05 02:05:12
  • 小试Python中的pack()使用方法

    2021-02-03 06:00:43
  • 如何更优雅地写python代码

    2022-03-03 04:53:24
  • 教程:MySQL中多表操作和批处理方法

    2009-07-30 08:20:00
  • Python上下文管理器类和上下文管理器装饰器contextmanager用法实例分析

    2022-05-01 15:04:21
  • VS 2008的性能改进

    2007-10-07 21:42:00
  • 利用Python读取Excel表内容的详细过程

    2022-10-24 05:43:33
  • tensorflow 保存模型和取出中间权重例子

    2021-05-11 07:30:11
  • asp如何用JMail POP3接收电子邮件?

    2010-06-13 13:09:00
  • 微信小程序与php 实现微信支付的简单实例

    2023-11-14 15:22:07
  • asp之家 网络编程 m.aspxhome.com