Pytorch模型转onnx模型实例

作者:joey.lei 时间:2022-09-06 03:39:24 

如下所示:


import io
import torch
import torch.onnx
from models.C3AEModel import PlainC3AENetCBAM

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def test():
 model = PlainC3AENetCBAM()

pthfile = r'/home/joy/Projects/models/emotion/PlainC3AENet.pth'
 loaded_model = torch.load(pthfile, map_location='cpu')
 # try:
 #   loaded_model.eval()
 # except AttributeError as error:
 #   print(error)

model.load_state_dict(loaded_model['state_dict'])
 # model = model.to(device)

#data type nchw
 dummy_input1 = torch.randn(1, 3, 64, 64)
 # dummy_input2 = torch.randn(1, 3, 64, 64)
 # dummy_input3 = torch.randn(1, 3, 64, 64)
 input_names = [ "actual_input_1"]
 output_names = [ "output1" ]
 # torch.onnx.export(model, (dummy_input1, dummy_input2, dummy_input3), "C3AE.onnx", verbose=True, input_names=input_names, output_names=output_names)
 torch.onnx.export(model, dummy_input1, "C3AE_emotion.onnx", verbose=True, input_names=input_names, output_names=output_names)

if __name__ == "__main__":
test()

直接将PlainC3AENetCBAM替换成需要转换的模型,然后修改pthfile,输入和onnx模型名字然后执行即可。

注意:上面代码中注释的dummy_input2,dummy_input3,torch.onnx.export对应的是多个输入的例子。

在转换过程中遇到的问题汇总

RuntimeError: Failed to export an ONNX attribute, since it's not constant, please try to make things (e.g., kernel size) static if possible

在转换过程中遇到RuntimeError: Failed to export an ONNX attribute, since it's not constant, please try to make things (e.g., kernel size) static if possible的错误。

根据报的错误日志信息打开/home/joy/.tensorflow/venv/lib/python3.6/site-packages/torch/onnx/symbolic_helper.py,在相应位置添加print之后,可以定位到具体哪个op出问题。

例如:

在相应位置添加


print(v.node())

输出信息如下:


%124 : Long() = onnx::Gather[axis=0](%122, %121), scope: PlainC3AENetCBAM/Bottleneck[cbam]/CBAM[cbam]/ChannelGate[ChannelGate] # /home/joy/Projects/models/emotion/WhatsTheemotion/models/cbam.py:46:0

原因是pytorch中的tensor.size(1)方式onnx识别不了,需要修改成常量。

来源:https://blog.csdn.net/lei19880402/article/details/103721362

标签:Pytorch,模型,onnx
0
投稿

猜你喜欢

  • Swoole webSocket消息服务系统代码设计详解

    2023-06-09 01:05:28
  • Python人工智能构建简单聊天机器人示例详解

    2022-03-10 04:42:29
  • 监控 url fragment变化的js代码

    2023-08-25 10:20:58
  • python matplotlib.pyplot.plot()参数用法

    2023-07-13 17:39:48
  • XML编程实例:用ASP+XML打造留言本

    2008-05-04 13:37:00
  • DOM_window对象属性之--clipboardData对象操作代码

    2011-02-05 10:49:00
  • pycharm无法导入lxml的解决办法

    2023-08-24 04:34:39
  • python获取全国城市pm2.5、臭氧等空气质量过程解析

    2023-06-04 21:46:07
  • Python中Generators教程的实现

    2023-07-28 03:23:20
  • PHP实现sha-256哈希算法实例代码

    2023-05-25 01:05:23
  • 精细讲述SQL Server数据库备份多种方法

    2009-01-13 13:33:00
  • python实现随机梯度下降法

    2023-11-02 16:55:37
  • 用python实现读取xlsx表格操作

    2022-11-26 17:08:19
  • 详解python 支持向量机(SVM)算法

    2022-03-06 02:11:24
  • 简单谈谈Python面向对象的相关知识

    2022-08-25 19:11:23
  • python3中sorted函数里cmp参数改变详解

    2022-11-11 17:21:49
  • 操作Dom节点实现间歇滚动新闻

    2009-10-16 20:51:00
  • 善用用户反馈——浅谈用户反馈数据的处理

    2010-07-09 16:58:00
  • 通过python检测字符串的字母

    2023-01-11 22:49:47
  • Python实时监控网站浏览记录实现过程详解

    2021-06-24 23:55:02
  • asp之家 网络编程 m.aspxhome.com