Pytorch通过保存为ONNX模型转TensorRT5的实现
作者:小关学长 时间:2023-10-22 13:45:27
1 Pytorch以ONNX方式保存模型
def saveONNX(model, filepath):
'''
保存ONNX模型
:param model: 神经网络模型
:param filepath: 文件保存路径
'''
# 神经网络输入数据类型
dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda')
torch.onnx.export(model, dummy_input, filepath, verbose=True)
2 利用TensorRT5中ONNX解析器构建Engine
def ONNX_build_engine(onnx_file_path):
'''
通过加载onnx文件,构建engine
:param onnx_file_path: onnx文件路径
:return: engine
'''
# 打印日志
G_LOGGER = trt.Logger(trt.Logger.WARNING)
with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
builder.max_batch_size = 100
builder.max_workspace_size = 1 << 20
print('Loading ONNX file from path {}...'.format(onnx_file_path))
with open(onnx_file_path, 'rb') as model:
print('Beginning ONNX file parsing')
parser.parse(model.read())
print('Completed parsing of ONNX file')
print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
engine = builder.build_cuda_engine(network)
print("Completed creating Engine")
# 保存计划文件
# with open(engine_file_path, "wb") as f:
# f.write(engine.serialize())
return engine
3 构建TensorRT运行引擎进行预测
def loadONNX2TensorRT(filepath):
'''
通过onnx文件,构建TensorRT运行引擎
:param filepath: onnx文件路径
'''
# 计算开始时间
Start = time()
engine = self.ONNX_build_engine(filepath)
# 读取测试集
datas = DataLoaders()
test_loader = datas.testDataLoader()
img, target = next(iter(test_loader))
img = img.numpy()
target = target.numpy()
img = img.ravel()
context = engine.create_execution_context()
output = np.empty((100, 10), dtype=np.float32)
# 分配内存
d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
bindings = [int(d_input), int(d_output)]
# pycuda操作缓冲区
stream = cuda.Stream()
# 将输入数据放入device
cuda.memcpy_htod_async(d_input, img, stream)
# 执行模型
context.execute_async(100, bindings, stream.handle, None)
# 将预测结果从从缓冲区取出
cuda.memcpy_dtoh_async(output, d_output, stream)
# 线程同步
stream.synchronize()
print("Test Case: " + str(target))
print("Prediction: " + str(np.argmax(output, axis=1)))
print("tensorrt time:", time() - Start)
del context
del engine
补充知识:Pytorch/Caffe可以先转换为ONNX,再转换为TensorRT
近来工作,试图把Pytorch用TensorRT运行。折腾了半天,没有完成。github中的转换代码,只能处理pytorch 0.2.0的功能(也明确表示不维护了)。和同事一起处理了很多例外,还是没有通过。吾以为,实际上即使勉强过了,能不能跑也是问题。
后来有高手建议,先转换为ONNX,再转换为TensorRT。这个思路基本可行。
是不是这样就万事大吉?当然不是,还是有严重问题要解决的。这只是个思路。
来源:https://blog.csdn.net/qq_38003892/article/details/89314108
标签:Pytorch,ONNX,TensorRT5
0
投稿
猜你喜欢
python 正则表达式参数替换实例详解
2022-08-11 18:21:44
Oracle不同数据库间对比分析脚本
2024-01-17 07:05:07
Django创建项目+连通mysql的操作方法
2024-01-12 17:16:42
MySQL binlog_ignore_db 参数的具体使用
2024-01-20 12:32:59
Python实现单项链表的最全教程
2021-12-21 09:45:26
PHP单例模式简单用法示例
2023-11-18 19:45:41
JavaScript 自动分号插入(JavaScript synat:auto semicolon insertion)
2013-08-09 10:14:56
Python Pandas中根据列的值选取多行数据
2023-02-16 04:17:59
PHP连接MySQL数据库的三种方式实例分析【mysql、mysqli、pdo】
2023-11-15 01:31:13
纯js封装的ajax功能函数与用法示例
2024-05-11 09:09:20
python实现复制文件到指定目录
2022-09-25 20:53:12
最基础的Python的socket编程入门教程
2022-10-13 03:38:46
交互设计师应该具备哪些素质
2009-03-12 12:21:00
python 解析html之BeautifulSoup
2021-02-09 23:09:12
vue-cli4.5.x快速搭建项目
2024-04-27 15:52:18
Firefox 3.6新功能预览
2009-12-01 14:23:00
SuperSocket 信息: (SpnRegister) : Error 1355。解决方法
2024-01-17 22:54:02
Python将主机名转换为IP地址的方法
2023-09-06 21:30:42
Python实现照片卡通化
2021-03-29 18:45:40
python基于 Web 实现 m3u8 视频播放的实例
2022-06-15 22:16:40