从Pytorch模型pth文件中读取参数成numpy矩阵的操作
作者:木盏 时间:2021-12-27 11:05:53
目的:
把训练好的pth模型参数提取出来,然后用其他方式部署到边缘设备。
Pytorch给了很方便的读取参数接口:
nn.Module.parameters()
直接看demo:
from torchvision.models.alexnet import alexnet
model = alexnet(pretrained=True).eval().cuda()
parameters = model.parameters()
for p in parameters:
numpy_para = p.detach().cpu().numpy()
print(type(numpy_para))
print(numpy_para.shape)
上面得到的numpy_para就是numpy参数了~
Note:
model.parameters()是以一个生成器的形式迭代返回每一层的参数。所以用for循环读取到各层的参数,循环次数就表示层数。
而每一层的参数都是torch.nn.parameter.Parameter类型,是Tensor的子类,所以直接用tensor转numpy(即p.detach().cpu().numpy())的方法就可以直接转成numpy矩阵。
方便又好用,爆赞~
补充:pytorch训练好的.pth模型转换为.pt
将python训练好的.pth文件转为.pt
import torch
import torchvision
from unet import UNet
model = UNet(3, 2)#自己定义的网络模型
model.load_state_dict(torch.load("best_weights.pth"))#保存的训练模型
model.eval()#切换到eval()
example = torch.rand(1, 3, 320, 480)#生成一个随机输入维度的输入
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。
来源:https://muzhan.blog.csdn.net/article/details/113066030
标签:Pytorch,pth,numpy,矩阵
0
投稿
猜你喜欢
Pygame实现监听鼠标示例详解
2021-12-16 00:42:58
解读数据库的嵌套查询的性能问题
2024-01-20 17:00:06
10款最佳Python开发工具推荐,每一款都是神器
2022-04-13 06:54:13
Python利用PyPDF2快速拆分PDF文档
2021-11-06 09:39:23
MySQL实时监控工具orztop的使用介绍
2024-01-13 18:15:17
Python2与Python3的区别实例分析
2021-01-07 11:47:17
Python快速转换numpy数组中Nan和Inf的方法实例说明
2021-11-12 06:24:03
使用python BeautifulSoup库抓取58手机维修信息
2022-08-10 01:55:20
asp中设置session过期时间方法总结
2013-06-01 19:52:04
SQL Server错误代码大全及解释(留着备用)
2012-07-11 16:17:03
python调用windows api锁定计算机示例
2021-09-08 03:28:38
Python 简单计算要求形状面积的实例
2022-10-19 09:02:33
VSCode 最全实用插件小结
2022-12-11 17:03:47
ASP + XML + JavaScript 实现动态无限级联动菜单
2008-06-13 06:31:00
SQL Server查询条件IN中能否使用变量的示例详解
2024-01-15 17:55:55
使用Golang的Context管理上下文的方法
2023-06-29 06:37:23
实例讲解MySQL中乐观锁和悲观锁
2024-01-19 00:46:35
深入了解vue-router原理并实现一个小demo
2024-04-30 10:25:31
[翻译]标记语言和样式手册 chapter 6 短语元素
2008-01-25 16:37:00
关于MySQL编码问题的经验总结
2007-08-23 16:10:00