Pytorch模型转onnx模型实例

作者:joey.lei 时间:2022-09-06 03:39:24 

如下所示:


import io
import torch
import torch.onnx
from models.C3AEModel import PlainC3AENetCBAM

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def test():
 model = PlainC3AENetCBAM()

pthfile = r'/home/joy/Projects/models/emotion/PlainC3AENet.pth'
 loaded_model = torch.load(pthfile, map_location='cpu')
 # try:
 #   loaded_model.eval()
 # except AttributeError as error:
 #   print(error)

model.load_state_dict(loaded_model['state_dict'])
 # model = model.to(device)

#data type nchw
 dummy_input1 = torch.randn(1, 3, 64, 64)
 # dummy_input2 = torch.randn(1, 3, 64, 64)
 # dummy_input3 = torch.randn(1, 3, 64, 64)
 input_names = [ "actual_input_1"]
 output_names = [ "output1" ]
 # torch.onnx.export(model, (dummy_input1, dummy_input2, dummy_input3), "C3AE.onnx", verbose=True, input_names=input_names, output_names=output_names)
 torch.onnx.export(model, dummy_input1, "C3AE_emotion.onnx", verbose=True, input_names=input_names, output_names=output_names)

if __name__ == "__main__":
test()

直接将PlainC3AENetCBAM替换成需要转换的模型,然后修改pthfile,输入和onnx模型名字然后执行即可。

注意:上面代码中注释的dummy_input2,dummy_input3,torch.onnx.export对应的是多个输入的例子。

在转换过程中遇到的问题汇总

RuntimeError: Failed to export an ONNX attribute, since it's not constant, please try to make things (e.g., kernel size) static if possible

在转换过程中遇到RuntimeError: Failed to export an ONNX attribute, since it's not constant, please try to make things (e.g., kernel size) static if possible的错误。

根据报的错误日志信息打开/home/joy/.tensorflow/venv/lib/python3.6/site-packages/torch/onnx/symbolic_helper.py,在相应位置添加print之后,可以定位到具体哪个op出问题。

例如:

在相应位置添加


print(v.node())

输出信息如下:


%124 : Long() = onnx::Gather[axis=0](%122, %121), scope: PlainC3AENetCBAM/Bottleneck[cbam]/CBAM[cbam]/ChannelGate[ChannelGate] # /home/joy/Projects/models/emotion/WhatsTheemotion/models/cbam.py:46:0

原因是pytorch中的tensor.size(1)方式onnx识别不了,需要修改成常量。

来源:https://blog.csdn.net/lei19880402/article/details/103721362

标签:Pytorch,模型,onnx
0
投稿

猜你喜欢

  • Python fileinput模块使用介绍

    2023-08-22 14:32:12
  • Python如何获取pid和进程名字

    2023-11-11 11:44:11
  • Python3使用PySynth制作音乐的方法

    2021-03-18 19:41:01
  • Burpsuite模块之Burpsuite Intruder模块详解

    2023-11-24 05:31:24
  • Mysql安装注意事项、安装失败的五个原因分析

    2024-01-22 14:16:48
  • Linux安装Python虚拟环境virtualenv的方法

    2022-07-07 00:33:36
  • Python中搜索和替换文件中的文本的实现(四种)

    2022-04-23 01:03:39
  • Tensor 和 NumPy 相互转换的实现

    2023-07-05 04:55:51
  • 使用Python脚本对Linux服务器进行监控的教程

    2022-06-19 18:27:26
  • PyTorch开源图像分类工具箱MMClassification详解

    2023-11-21 02:20:06
  • ThinkPHP php 框架学习笔记

    2023-09-10 08:20:32
  • python中property和setter装饰器用法

    2022-04-20 21:38:03
  • 详解利用上下文管理器扩展Python计时器

    2023-11-07 09:33:48
  • python语言元素知识点详解

    2023-07-30 03:33:08
  • HTML中的setCapture和releaseCapture使用介绍

    2024-04-18 09:51:18
  • php的对象传值与引用传值代码实例讲解

    2023-11-06 08:42:37
  • 合并ThinkPHP配置文件以消除代码冗余的实现方法

    2023-11-21 11:54:31
  • mysql中int(3)和int(10)的数值范围是否相同

    2024-01-17 16:37:48
  • python按列索引提取文件夹内所有excel指定列汇总(示例代码)

    2021-11-08 09:46:50
  • Python一行代码实现自动发邮件功能

    2021-04-06 06:04:38
  • asp之家 网络编程 m.aspxhome.com