Pytorch模型转onnx模型实例
作者:joey.lei 时间:2022-09-06 03:39:24
如下所示:
import io
import torch
import torch.onnx
from models.C3AEModel import PlainC3AENetCBAM
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def test():
model = PlainC3AENetCBAM()
pthfile = r'/home/joy/Projects/models/emotion/PlainC3AENet.pth'
loaded_model = torch.load(pthfile, map_location='cpu')
# try:
# loaded_model.eval()
# except AttributeError as error:
# print(error)
model.load_state_dict(loaded_model['state_dict'])
# model = model.to(device)
#data type nchw
dummy_input1 = torch.randn(1, 3, 64, 64)
# dummy_input2 = torch.randn(1, 3, 64, 64)
# dummy_input3 = torch.randn(1, 3, 64, 64)
input_names = [ "actual_input_1"]
output_names = [ "output1" ]
# torch.onnx.export(model, (dummy_input1, dummy_input2, dummy_input3), "C3AE.onnx", verbose=True, input_names=input_names, output_names=output_names)
torch.onnx.export(model, dummy_input1, "C3AE_emotion.onnx", verbose=True, input_names=input_names, output_names=output_names)
if __name__ == "__main__":
test()
直接将PlainC3AENetCBAM替换成需要转换的模型,然后修改pthfile,输入和onnx模型名字然后执行即可。
注意:上面代码中注释的dummy_input2,dummy_input3,torch.onnx.export对应的是多个输入的例子。
在转换过程中遇到的问题汇总
RuntimeError: Failed to export an ONNX attribute, since it's not constant, please try to make things (e.g., kernel size) static if possible
在转换过程中遇到RuntimeError: Failed to export an ONNX attribute, since it's not constant, please try to make things (e.g., kernel size) static if possible的错误。
根据报的错误日志信息打开/home/joy/.tensorflow/venv/lib/python3.6/site-packages/torch/onnx/symbolic_helper.py,在相应位置添加print之后,可以定位到具体哪个op出问题。
例如:
在相应位置添加
print(v.node())
输出信息如下:
%124 : Long() = onnx::Gather[axis=0](%122, %121), scope: PlainC3AENetCBAM/Bottleneck[cbam]/CBAM[cbam]/ChannelGate[ChannelGate] # /home/joy/Projects/models/emotion/WhatsTheemotion/models/cbam.py:46:0
原因是pytorch中的tensor.size(1)方式onnx识别不了,需要修改成常量。
来源:https://blog.csdn.net/lei19880402/article/details/103721362
标签:Pytorch,模型,onnx
0
投稿
猜你喜欢
Python fileinput模块使用介绍
2023-08-22 14:32:12
Python如何获取pid和进程名字
2023-11-11 11:44:11
Python3使用PySynth制作音乐的方法
2021-03-18 19:41:01
Burpsuite模块之Burpsuite Intruder模块详解
2023-11-24 05:31:24
Mysql安装注意事项、安装失败的五个原因分析
2024-01-22 14:16:48
Linux安装Python虚拟环境virtualenv的方法
2022-07-07 00:33:36
Python中搜索和替换文件中的文本的实现(四种)
2022-04-23 01:03:39
Tensor 和 NumPy 相互转换的实现
2023-07-05 04:55:51
使用Python脚本对Linux服务器进行监控的教程
2022-06-19 18:27:26
PyTorch开源图像分类工具箱MMClassification详解
2023-11-21 02:20:06
ThinkPHP php 框架学习笔记
2023-09-10 08:20:32
python中property和setter装饰器用法
2022-04-20 21:38:03
详解利用上下文管理器扩展Python计时器
2023-11-07 09:33:48
python语言元素知识点详解
2023-07-30 03:33:08
HTML中的setCapture和releaseCapture使用介绍
2024-04-18 09:51:18
php的对象传值与引用传值代码实例讲解
2023-11-06 08:42:37
合并ThinkPHP配置文件以消除代码冗余的实现方法
2023-11-21 11:54:31
mysql中int(3)和int(10)的数值范围是否相同
2024-01-17 16:37:48
python按列索引提取文件夹内所有excel指定列汇总(示例代码)
2021-11-08 09:46:50
Python一行代码实现自动发邮件功能
2021-04-06 06:04:38