Pytorch模型转onnx模型实例
作者:joey.lei 时间:2022-09-06 03:39:24
如下所示:
import io
import torch
import torch.onnx
from models.C3AEModel import PlainC3AENetCBAM
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def test():
model = PlainC3AENetCBAM()
pthfile = r'/home/joy/Projects/models/emotion/PlainC3AENet.pth'
loaded_model = torch.load(pthfile, map_location='cpu')
# try:
# loaded_model.eval()
# except AttributeError as error:
# print(error)
model.load_state_dict(loaded_model['state_dict'])
# model = model.to(device)
#data type nchw
dummy_input1 = torch.randn(1, 3, 64, 64)
# dummy_input2 = torch.randn(1, 3, 64, 64)
# dummy_input3 = torch.randn(1, 3, 64, 64)
input_names = [ "actual_input_1"]
output_names = [ "output1" ]
# torch.onnx.export(model, (dummy_input1, dummy_input2, dummy_input3), "C3AE.onnx", verbose=True, input_names=input_names, output_names=output_names)
torch.onnx.export(model, dummy_input1, "C3AE_emotion.onnx", verbose=True, input_names=input_names, output_names=output_names)
if __name__ == "__main__":
test()
直接将PlainC3AENetCBAM替换成需要转换的模型,然后修改pthfile,输入和onnx模型名字然后执行即可。
注意:上面代码中注释的dummy_input2,dummy_input3,torch.onnx.export对应的是多个输入的例子。
在转换过程中遇到的问题汇总
RuntimeError: Failed to export an ONNX attribute, since it's not constant, please try to make things (e.g., kernel size) static if possible
在转换过程中遇到RuntimeError: Failed to export an ONNX attribute, since it's not constant, please try to make things (e.g., kernel size) static if possible的错误。
根据报的错误日志信息打开/home/joy/.tensorflow/venv/lib/python3.6/site-packages/torch/onnx/symbolic_helper.py,在相应位置添加print之后,可以定位到具体哪个op出问题。
例如:
在相应位置添加
print(v.node())
输出信息如下:
%124 : Long() = onnx::Gather[axis=0](%122, %121), scope: PlainC3AENetCBAM/Bottleneck[cbam]/CBAM[cbam]/ChannelGate[ChannelGate] # /home/joy/Projects/models/emotion/WhatsTheemotion/models/cbam.py:46:0
原因是pytorch中的tensor.size(1)方式onnx识别不了,需要修改成常量。
来源:https://blog.csdn.net/lei19880402/article/details/103721362
标签:Pytorch,模型,onnx
![](/images/zang.png)
![](/images/jiucuo.png)
猜你喜欢
Swoole webSocket消息服务系统代码设计详解
2023-06-09 01:05:28
Python人工智能构建简单聊天机器人示例详解
2022-03-10 04:42:29
监控 url fragment变化的js代码
2023-08-25 10:20:58
python matplotlib.pyplot.plot()参数用法
2023-07-13 17:39:48
XML编程实例:用ASP+XML打造留言本
2008-05-04 13:37:00
DOM_window对象属性之--clipboardData对象操作代码
2011-02-05 10:49:00
pycharm无法导入lxml的解决办法
2023-08-24 04:34:39
![](https://img.aspxhome.com/file/2023/2/87922_0s.png)
python获取全国城市pm2.5、臭氧等空气质量过程解析
2023-06-04 21:46:07
Python中Generators教程的实现
2023-07-28 03:23:20
PHP实现sha-256哈希算法实例代码
2023-05-25 01:05:23
精细讲述SQL Server数据库备份多种方法
2009-01-13 13:33:00
python实现随机梯度下降法
2023-11-02 16:55:37
![](https://img.aspxhome.com/file/2023/8/105758_0s.png)
用python实现读取xlsx表格操作
2022-11-26 17:08:19
![](https://img.aspxhome.com/file/2023/7/103277_0s.png)
详解python 支持向量机(SVM)算法
2022-03-06 02:11:24
![](https://img.aspxhome.com/file/2023/4/75944_0s.png)
简单谈谈Python面向对象的相关知识
2022-08-25 19:11:23
python3中sorted函数里cmp参数改变详解
2022-11-11 17:21:49
操作Dom节点实现间歇滚动新闻
2009-10-16 20:51:00
善用用户反馈——浅谈用户反馈数据的处理
2010-07-09 16:58:00
![](https://img.aspxhome.com/file/UploadPic/20107/9/customer-support-banner-35s.jpg)
通过python检测字符串的字母
2023-01-11 22:49:47
Python实时监控网站浏览记录实现过程详解
2021-06-24 23:55:02
![](https://img.aspxhome.com/file/2023/8/97218_0s.jpg)