Python torch.onnx.export用法详细介绍

作者:Kmaeii 时间:2022-04-28 22:07:33 

函数原型

Python torch.onnx.export用法详细介绍

参数介绍

mode (torch.nn.Module, torch.jit.ScriptModule or torch.jit.ScriptFunction)

需要转换的模型,支持的模型类型有:torch.nn.Module, torch.jit.ScriptModule or torch.jit.ScriptFunction

args (tuple or torch.Tensor)

args可以被设置成三种形式

1.一个tuple

args = (x, y, z)

这个tuple应该与模型的输入相对应,任何非Tensor的输入都会被硬编码入onnx模型,所有Tensor类型的参数会被当做onnx模型的输入。

2.一个Tensor

args = torch.Tensor([1, 2, 3])

一般这种情况下模型只有一个输入

3.一个带有字典的tuple

args = (x,
       {'y': input_y,
        'z': input_z})

这种情况下,所有字典之前的参数会被当做“非关键字”参数传入网络,字典种的键值对会被当做关键字参数传入网络。如果网络中的关键字参数未出现在此字典中,将会使用默认值,如果没有设定默认值,则会被指定为None。

NOTE:

一个特殊情况,当网络本身最后一个参数为字典时,直接在tuple最后写一个字典则会被误认为关键字传参。所以,可以通过在tuple最后添加一个空字典来解决。

#错误写法:

torch.onnx.export(
   model,
   (x,
    # WRONG: will be interpreted as named arguments
    {y: z}),
   "test.onnx.pb")

# 纠正

torch.onnx.export(
   model,
   (x,
    {y: z},
    {}),
   "test.onnx.pb")

f

一个文件类对象或一个路径字符串,二进制的protocol buffer将被写入此文件

export_params (bool, default True)

如果为True则导出模型的参数。如果想导出一个未训练的模型,则设为False

verbose (bool, default False)

如果为True,则打印一些转换日志,并且onnx模型中会包含doc_string信息。

training (enum, default TrainingMode.EVAL)

枚举类型包括:

TrainingMode.EVAL - 以推理模式导出模型。

TrainingMode.PRESERVE - 如果model.training为False,则以推理模式导出;否则以训练模式导出。

TrainingMode.TRAINING - 以训练模式导出,此模式将禁止一些影响训练的优化操作。

input_names (list of str, default empty list)

按顺序分配给onnx图的输入节点的名称列表。

output_names (list of str, default empty list)

按顺序分配给onnx图的输出节点的名称列表。

operator_export_type (enum, default None)

默认为OperatorExportTypes.ONNX, 如果Pytorch built with DPYTORCH_ONNX_CAFFE2_BUNDLE,则默认为OperatorExportTypes.ONNX_ATEN_FALLBACK。

枚举类型包括:

OperatorExportTypes.ONNX - 将所有操作导出为ONNX操作。

OperatorExportTypes.ONNX_FALLTHROUGH - 试图将所有操作导出为ONNX操作,但碰到无法转换的操作(如onnx未实现的操作),则将操作导出为“自定义操作”,为了使导出的模型可用,运行时必须支持这些自定义操作。支持自定义操作方法见链接。

OperatorExportTypes.ONNX_ATEN - 所有ATen操作导出为ATen操作,ATen是Pytorch的内建tensor库,所以这将使得模型直接使用Pytorch实现。(此方法转换的模型只能被Caffe2直接使用)

OperatorExportTypes.ONNX_ATEN_FALLBACK - 试图将所有的ATen操作也转换为ONNX操作,如果无法转换则转换为ATen操作(此方法转换的模型只能被Caffe2直接使用)。例如:

# 转换前:
graph(%0 : Float):
 %3 : int = prim::Constant[value=0]()
 # conversion unsupported
 %4 : Float = aten::triu(%0, %3)
 # conversion supported
 %5 : Float = aten::mul(%4, %0)
 return (%5)

# 转换后:
graph(%0 : Float):
 %1 : Long() = onnx::Constant[value={0}]()
 # not converted
 %2 : Float = aten::ATen[operator="triu"](%0, %1)
 # converted
 %3 : Float = onnx::Mul(%2, %0)
 return (%3)

opset_version (int, default 9)

默认是9。值必须等于_onnx_main_opset或在_onnx_stable_opsets之内。具体可在torch/onnx/symbolic_helper.py中找到。例如:

_default_onnx_opset_version = 9

_onnx_main_opset = 13

_onnx_stable_opsets = [7, 8, 9, 10, 11, 12]

_export_onnx_opset_version = _default_onnx_opset_version

do_constant_folding (bool, default False)

是否使用“常量折叠”优化。常量折叠将使用一些算好的常量来优化一些输入全为常量的节点。

example_outputs (T or a tuple of T, where T is Tensor or convertible to Tensor, default None)

当需输入模型为ScriptModule 或 ScriptFunction时必须提供。此参数用于确定输出的类型和形状,而不跟踪(tracing )模型的执行。

dynamic_axes (dict<string, dict<python:int, string>> or dict<string, list(int)>, default empty dict)

通过以下规则设置动态的维度:

KEY(str) - 必须是input_names或output_names指定的名称,用来指定哪个变量需要使用到动态尺寸。

VALUE(dict or list) - 如果是一个dict,dict中的key是变量的某个维度,dict中的value是我们给这个维度取的名称。如果是一个list,则list中的元素都表示此变量的某个维度。

具体可参考如下示例:

class SumModule(torch.nn.Module):
   def forward(self, x):
       return torch.sum(x, dim=1)

# 以动态尺寸模式导出模型

torch.onnx.export(SumModule(), (torch.ones(2, 2),), "onnx.pb",
                 input_names=["x"], output_names=["sum"],
                 dynamic_axes={
                     # dict value: manually named axes
                     "x": {0: "my_custom_axis_name"},
                     # list value: automatic names
                     "sum": [0],
                 })

### 导出后的节点信息

##input

input {
 name: "x"
 ...
     shape {
       dim {
         dim_param: "my_custom_axis_name"  # axis 0
       }
       dim {
         dim_value: 2  # axis 1
...

##output
output {
 name: "sum"
 ...
     shape {
       dim {
         dim_param: "sum_dynamic_axes_1"  # axis 0
...

keep_initializers_as_inputs (bool, default None)

NONE

custom_opsets (dict<str, int>, default empty dict)

NONE

Torch.onnx.export执行流程:

1、如果输入到torch.onnx.export的模型是nn.Module类型,则默认会将模型使用torch.jit.trace转换为ScriptModule

2、使用args参数和torch.jit.trace将模型转换为ScriptModule,torch.jit.trace不能处理模型中的循环和if语句

3、如果模型中存在循环或者if语句,在执行torch.onnx.export之前先使用torch.jit.script将nn.Module转换为ScriptModule

4、模型转换成onnx之后,预测结果与之前会有稍微的差别,这些差别往往不会改变模型的预测结果,比如预测的概率在小数点之后五六位有差别。

来源:https://blog.csdn.net/Dteam_f/article/details/122487634

标签:python,torch.onnx.export,参数
0
投稿

猜你喜欢

  • python二叉树遍历的实现方法

    2021-09-19 03:53:14
  • ASP图片分页代码 (通用)

    2009-06-22 12:57:00
  • php基于协程实现异步的方法分析

    2023-06-11 10:08:39
  • python opencv 直方图反向投影的方法

    2022-10-07 18:37:37
  • 解决vue组件中click事件失效的问题

    2023-07-02 16:34:10
  • python删除指定类型(或非指定)的文件实例详解

    2022-04-10 06:46:52
  • python实现控制台打印的方法

    2021-12-18 12:21:04
  • Python issubclass和isinstance函数的具体使用

    2021-08-22 01:39:06
  • Mysql Error Code : 1436 Thread stack overrun

    2024-01-23 14:04:04
  • 详解javascript中var与ES6规范中let、const区别与用法

    2024-05-09 15:06:17
  • 使用typescript快速开发一个cli的实现示例

    2023-08-30 07:25:25
  • Oracle数据库的备份及恢复策略研究

    2010-07-16 12:54:00
  • OpenCV图像处理GUI功能详解

    2021-01-26 15:55:34
  • python实现多线程暴力破解登陆路由器功能代码分享

    2023-08-28 21:27:01
  • JAVA及PYTHON质数计算代码对比解析

    2023-08-29 23:41:31
  • Python定时执行之Timer用法示例

    2021-09-14 21:46:01
  • 详解CSS3中的属性选择符

    2008-04-24 14:30:00
  • vue3.0+vue-router+element-plus初实践

    2024-05-21 10:17:49
  • 用php来改写404错误页让你的页面更友好

    2023-10-26 20:16:21
  • Django的URLconf中使用缺省视图参数的方法

    2021-05-03 17:46:29
  • asp之家 网络编程 m.aspxhome.com