Pytorch之保存读取模型实例
作者:啧啧啧biubiu 时间:2023-04-03 02:15:11
pytorch保存数据
pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式。而pth文件是python中存储文件的常用格式。而在keras中则是使用.h5文件。
# 保存模型示例代码
print('===> Saving models...')
state = {
'state': model.state_dict(),
'epoch': epoch # 将epoch一并保存
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/autoencoder.t7')
保存用到torch.save函数,注意该函数第一个参数可以是单个值也可以是字典,字典可以存更多你要保存的参数(不仅仅是权重数据)。
pytorch读取数据
pytorch读取数据使用的方法和我们平时使用预训练参数所用的方法是一样的,都是使用load_state_dict这个函数。
下方的代码和上方的保存代码可以搭配使用。
print('===> Try resume from checkpoint')
if os.path.isdir('checkpoint'):
try:
checkpoint = torch.load('./checkpoint/autoencoder.t7')
model.load_state_dict(checkpoint['state']) # 从字典中依次读取
start_epoch = checkpoint['epoch']
print('===> Load last checkpoint data')
except FileNotFoundError:
print('Can\'t found autoencoder.t7')
else:
start_epoch = 0
print('===> Start from scratch')
以上是pytorch读取的方法汇总,但是要注意,在使用官方的预处理模型进行读取时,一般使用的格式是pth,使用官方的模型读取命令会检查你模型的格式是否正确,如果不是使用官方提供模型通过下面的函数强行读取模型(将其他模型例如caffe模型转过来的模型放到指定目录下)会发生错误。
def vgg19(pretrained=False, **kwargs):
"""VGG 19-layer model (configuration "E")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = VGG(make_layers(cfg['E']), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
return model
假如我们有从caffe模型转过来的pytorch模型([0-255,BGR]),我们可以使用:
model_dir = '自己的模型地址'
model = VGG()
model.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))
也就是pytorch的读取函数进行读取即可。
来源:https://blog.csdn.net/qq_37385726/article/details/81943980
标签:Pytorch,保存,读取,模型
0
投稿
猜你喜欢
Python基础之循环语句相关知识总结
2021-03-19 18:57:19
go程序员日常开发效率神器汇总
2024-02-16 23:04:40
详解Mysql中的JSON系列操作函数
2024-01-20 02:08:08
asp长文章分页显示思路
2007-08-23 13:54:00
Python中判断input()输入的数据的类型
2023-03-14 17:02:15
redis数据库及与python交互用法简单示例
2024-01-18 03:05:06
深入理解Python中的内置常量
2023-01-21 02:57:47
python QT界面关闭线程池的线程跟随退出完美解决方案
2023-01-01 11:56:21
Python求解平方根的方法
2023-02-13 13:25:47
vue实现二维码扫码功能(带样式)
2024-04-10 10:31:39
Python环境搭建过程从安装到Hello World
2023-03-03 07:41:36
Python实现的服务器示例小结【单进程、多进程、多线程、非阻塞式】
2023-02-24 00:19:25
Python文本统计功能之西游记用字统计操作示例
2023-08-10 00:12:59
mysql 服务意外停止1067错误解决办法小结
2024-01-26 05:56:38
Linux安装Pytorch1.8GPU(CUDA11.1)的实现
2021-12-20 10:02:00
Python实现绘制圣诞树和烟花的示例代码
2022-03-22 16:22:20
pytorch中retain_graph==True的作用说明
2021-08-03 09:15:26
python实现邮件发送功能
2023-10-11 02:27:09
python Celery定时任务的示例
2023-12-28 14:08:39
window.onload使用指南
2024-04-18 10:58:51