pytorch模型部署 pth转onnx的方法

作者:aoyou19 时间:2022-07-05 03:49:04 

Pytorch转ONNX的意义

一般来说转ONNX只是一个手段,在之后得到ONNX模型后还需要再将它做转换,比如转换到TensorRT上完成部署,或者有的人多加一步,从ONNX先转换到caffe,再从caffe到tensorRT。Pytorch自带的torch.onnx.export转换得到的ONNX,ONNXRuntime需要的ONNX,TensorRT需要的ONNX都是不同的。

将pytorch训练保存的pth文件转为onnx文件,为后续模型部署做准备。

一、分类模型

import torch
import os
import timm
import argparse
from utils_net import Resnet
parser = argparse.ArgumentParser()
parser.add_argument("--pth_path", default='classify_model.pth')
parser.add_argument("--save_onnx_path", default='classify_model.onnx')
parser.add_argument("--input_width", default=416)
parser.add_argument("--input_height", default=416)
parser.add_argument("--input_channel", default=1)
parser.add_argument("--num_classes", default=6)
args = parser.parse_args()
def pth_to_onnx(pth_path, onnx_path, in_hig, in_wid, in_chal, num_cls):
   if not onnx_path.endswith('.onnx'):
       print('Warning! The onnx model name is not correct,\
             please give a name that ends with \'.onnx\'!')
       return 0
   model = Resnet(num_classes=num_cls)
   model.load_state_dict(torch.load(pth_path))
   model.eval()
   print(f'{pth_path} model loaded')
   input_names = ['input']
   output_names = ['output']
   im = torch.rand(1, in_chal, in_hig, in_wid)
   torch.onnx.export(model, im, onnx_path,
                     verbose=False,
                     input_names=input_names,
                     output_names=output_names)
   print("Exporting .pth model to onnx model has been successful!")
   print(f"Onnx model save as {onnx_path}")
if __name__ == '__main__':
   pth_to_onnx(pth_path=args.pth_path,
               onnx_path=args.save_onnx_path,
               in_hig=args.input_height,
               in_wid=args.input_width,
               in_chal=args.input_channel,
               num_cls=args.num_classes)

运行结果:

classify_model.pth model loaded
Exporting .pth model to onnx model has been successful!
Onnx model save as classify_model.onnx

Process finished with exit code 0

二、分割模型

import torch
import os
import argparse
from utils_net import seg_net
parser = argparse.ArgumentParser()
parser.add_argument("--pth_path", default='segment_model.pth')
parser.add_argument("--save_onnx_path", default='segment_model.onnx')
parser.add_argument("--input_width", default=416)
parser.add_argument("--input_height", default=416)
parser.add_argument("--input_channel", default=1)
parser.add_argument("--num_classes", default=4)
args = parser.parse_args()
def pth_to_onnx(pth_path, onnx_path, in_hig, in_wid, in_channel, num_cls):
   if not onnx_path.endswith('.onnx'):
       print('Warning! The onnx model name is not correct,\
             please give a name that ends with \'.onnx\'!')
       return 0
   model = seg_net(in_channel=in_channel, num_cls=num_cls)
   model.load_state_dict(torch.load(pth_path))
   model.eval()
   print(f'{pth_path} model loaded')
   input_names = ['input']
   output_names = ['output']
   im = torch.rand(1, in_channel, in_hig, in_wid)
   torch.onnx.export(model, im, onnx_path,
                     verbose=False,
                     input_names=input_names,
                     output_names=output_names,
                     opset_version=11)
   print("Exporting .pth model to onnx model has been successful!")
   print(f"Onnx model save as {onnx_path}")
if __name__ == '__main__':
   pth_to_onnx(pth_path=args.pth_path,
               onnx_path=args.save_onnx_path,
               in_hig=args.input_height,
               in_wid=args.input_width,
               in_channel=args.input_channel,
               num_cls=args.num_classes)

运行结果:

segment_model.pth model loaded
Exporting .pth model to onnx model has been successful!
Onnx model save as segment_model.onnx

Process finished with exit code 0

三、目标检测模型

在这里插入代码片
import torch
import onnx
import argparse
from utils_net import YoloBody
parser = argparse.ArgumentParser()
parser.add_argument("--pth_path", default='yolo.pth')
parser.add_argument("--save_onnx_path", default='yolo.onnx')
parser.add_argument("--input_width", default=416)
parser.add_argument("--input_height", default=416)
parser.add_argument("--num_classes", default=2)
parser.add_argument("--anchors_mask", default=[[6, 7, 8], [3, 4, 5], [0, 1, 2]])
args = parser.parse_args()
def pth_to_onnx(pth_path: str, save_onnx_path: str, num_cls: int,
               in_hig: int, in_wid: int, anchor_mask: list,
               opset_version: int = 12, simplify: bool = False):
   """
   :param pth_path: pth文件文件
   :param save_onnx_path: 准备保存的onnx路径
   :param num_cls: 检测目标类别数
   :param in_hig: 网络输入高度
   :param in_wid: 网络输入宽度
   :param anchor_mask: anchor宽高索引
   :param opset_version: onnx算子集版本
   :param simplify: 是否对模型进行简化
   :return:保存onnx到指定路径
   """
   # Build model, load weights
   net = YoloBody(anchors_mask=anchor_mask,
                  num_classes=num_cls)
   # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
   # net.load_state_dict(torch.load(pth_path, map_location=device))
   net.load_state_dict(torch.load(pth_path))
   # print(next(net.parameters()).device)
   net = net.eval()
   print(f'{pth_path} model loaded')
   im = torch.zeros(1, 3, in_hig, in_wid).to('cpu')
   input_layer_names = ['images']
   output_layer_names = ['output']
   # Export the model
   print(f'Starting export with onnx {onnx.__version__}.')
   torch.onnx.export(net,
                     im,
                     f=save_onnx_path,
                     verbose=False,
                     opset_version=opset_version,
                     training=torch.onnx.TrainingMode.EVAL,
                     do_constant_folding=True,
                     input_names=input_layer_names,
                     output_names=output_layer_names,
                     dynamic_axes=None)
   # Checks
   model_onnx = onnx.load(save_onnx_path)  # load onnx model
   onnx.checker.check_model(model_onnx)  # check onnx model
   # Simplify onnx
   if simplify:
       import onnxsim
       print(f'Simplifying with onnx-simplifier {onnxsim.__version__}.')
       model_onnx, check = onnxsim.simplify(
           model_onnx,
           dynamic_input_shape=False,
           input_shapes=None)
       assert check, 'assert check failed'
       onnx.save(model_onnx, save_onnx_path)
   print('Onnx model save as {}'.format(save_onnx_path))
if __name__ == '__main__':
   pth_to_onnx(pth_path=args.pth_path,
               save_onnx_path=args.save_onnx_path,
               num_cls=args.num_classes,
               in_hig=args.input_height,
               in_wid=args.input_width,
               anchor_mask=args.anchors_mask)

运行结果:

yolo.pth model loaded
Starting export with onnx 1.11.0.
Onnx model save as yolo.onnx

Process finished with exit code 0

参考链接:

1.yolo
2.模型部署翻车记:pytorch转onnx踩坑实录

来源:https://blog.csdn.net/aoyou19/article/details/129407797

标签:pytorch,部署,pth,onnx
0
投稿

猜你喜欢

  • Python元组定义及集合的使用

    2023-11-22 12:32:03
  • access MDB 转换为 Execl(ASP类)

    2008-07-19 12:10:00
  • pytorch模型预测结果与ndarray互转方式

    2023-12-06 02:35:11
  • js换图片效果可进行定时操作

    2023-08-23 07:45:34
  • Python模拟登录和登录跳转的参考示例

    2023-07-29 07:09:47
  • pytest解读fixtures之Teardown处理yield和addfinalizer方案

    2023-06-18 22:13:01
  • win10从零安装配置pytorch全过程图文详解

    2022-07-01 20:54:55
  • python math模块的基本使用教程

    2022-01-30 23:07:53
  • pandas每次多Sheet写入文件的方法

    2022-02-07 03:50:39
  • python在windows调用svn-pysvn的实现

    2022-03-15 05:13:48
  • php小技巧之过滤ascii控制字符

    2023-10-03 05:13:15
  • oracle 安装与SQLPLUS简单用法

    2009-06-10 17:49:00
  • [多图] Google Chrome 试用 Tips

    2009-12-09 15:49:00
  • [翻译]标记语言和样式手册 Chapter 11 打印样式

    2008-02-11 18:44:00
  • python一行输入n个数据问题

    2023-09-11 21:50:48
  • opencv-python 开发环境的安装、配置教程详解

    2022-04-25 22:14:58
  • Python实现文件操作帮助类的示例代码

    2023-02-14 16:46:57
  • 详解python如何调用C/C++底层库与互相传值

    2022-02-25 07:18:00
  • asp添加数据实现代码

    2011-02-05 10:42:00
  • k-means 聚类算法与Python实现代码

    2022-02-01 02:55:22
  • asp之家 网络编程 m.aspxhome.com