PyTorch加载模型model.load_state_dict()问题及解决

作者:是否龙磊磊真的一无所有 时间:2022-11-08 07:03:53 

PyTorch加载模型model.load_state_dict()问题

希望将训练好的模型加载到新的网络上。

如上面题目所描述的,PyTorch在加载之前保存的模型参数的时候,遇到了问题。

Unexpected key(s) in state_dict: "module.features. ...".,Expected ".features....". 直接原因是key值名字不对应。

表明了加载过程中,期望获得的key值为feature...,而不是module.features....。

这是由模型保存过程中导致的,模型应该是在DataParallel模式下面,也就是采用了多GPU训练模型,然后直接保存的。

You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it without . You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can load the weights file, create a new ordered dict without the module prefix, and load it back.

解决上面的问题有三个办法: 

1. 对load的模型创建新的字典

去掉不需要的key值"module".

# original saved file with DataParallel
state_dict = torch.load('checkpoint.pt')  # 模型可以保存为pth文件,也可以为pt文件。
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
   name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.
   new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。
# load params
model.load_state_dict(new_state_dict) # 从新加载这个模型。

2. 直接用空白''代替'module.'

model.load_state_dict({k.replace('module.',''):v for k,v in torch.load('checkpoint.pt').items()})

# 相当于用''代替'module.'。
#直接使得需要的键名等于期望的键名。

3. 最简单的方法

加载模型之后,接着将模型DataParallel,此时就可以load_state_dict。

如果有多个GPU,将模型并行化,用DataParallel来操作。

这个过程会将key值加一个"module. ***"。

model = VGGNet()
params=model.state_dict() #获得模型的原始状态以及参数。
for k,v in params.items():
   print(k) #只打印key值,不打印具体参数。

4. 总结

从出错显示的问题就可以看出,key值不匹配,因此可以选择多种方法,将模型参数加载进去。

这个方法通常会在load_state_dict过程中遇到。将训练好的一个网络参数,移植到另外一个网络上面,继续训练。

或者将训练好的网络checkpoint加载进模型,再次进行训练。可以打印出model state_dict来看出两者的差别。

model = VGGNet()
params=model.state_dict() #获得模型的原始状态以及参数。
for k,v in params.items():
   print(k) #只打印key值,不打印具体参数。

features.0.0.weight   
features.0.1.weight
features.1.conv.3.weight
features.1.conv.4.num_batches_tracked

PyTorch加载模型model.load_state_dict()问题及解决

model = VGGNet()
checkpoint = torch.load('checkpoint.pt', map_location='cpu')
# Load weights to resume from checkpoint。
# print('**************************************')
# 这个方法能够直接打印出你保存的checkpoint的键和值。
for k,v in checkpoint.items():
   print(k)
print("*****************************************")

输出结果为:

module.features.0.0.weight",

"module.features.0.1.weight",

"module.features.0.1.bias

可以看出不匹配,模型的参数中,key值不同,多了module。

PS: 追加

在移植参数的过程中,对于出现 .total_ops和.total_params结尾的参数,可参考以下代码:

from collections import OrderedDict
checkpoint = torch.load(
   pretrained_model_file_path,
   map_location=(None if use_cuda and not remap_to_cpu else "cpu"))
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
   if not k.endswith('total_ops') and not k.endswith('total_params'):
       name = k[7:]
       new_state_dict[name] = v

最后

来源:https://blog.csdn.net/qq_32998593/article/details/89343507

标签:PyTorch,加载模型,model.load,state,dict
0
投稿

猜你喜欢

  • python scatter散点图用循环分类法加图例

    2021-07-26 01:44:01
  • python皮尔逊相关性数据分析分析及实例代码

    2021-03-12 13:23:34
  • SQL Server格式转换函数Cast、Convert介绍

    2024-01-22 18:42:16
  • Python实现随机取一个矩阵数组的某几行

    2021-10-04 16:45:52
  • python调用ffmpeg命令行工具便捷操作视频示例实现过程

    2023-12-19 07:48:22
  • 理解Proxy及使用Proxy实现vue数据双向绑定操作

    2024-04-26 17:41:43
  • Oracle 异构服务实践

    2007-08-17 10:00:00
  • Python程序退出方式小结

    2021-12-21 19:29:55
  • 四个Python操作Excel的常用脚本分享

    2023-12-04 07:04:27
  • 浅谈MySQL之浅入深出页原理

    2024-01-18 20:38:29
  • MySQL语句执行顺序和编写顺序实例解析

    2024-01-26 12:39:18
  • python通过http下载文件的方法详解

    2021-11-11 04:26:23
  • T-SQL 查询语句的执行顺序解析

    2024-01-14 08:00:00
  • 用 Python 元类的特性实现 ORM 框架

    2022-02-12 12:45:24
  • 通过遮罩层实现浮层DIV登录的js代码

    2024-06-24 00:08:58
  • Vue Cli与BootStrap结合实现表格分页功能

    2024-05-09 15:21:44
  • Git配置.gitignore文件忽略被指定的文件上传

    2022-09-23 03:00:08
  • python Pandas如何对数据集随机抽样

    2023-10-02 08:28:13
  • OpenCV-Python实现腐蚀与膨胀的实例

    2023-06-05 18:07:07
  • 在import scipy.misc 后找不到 imsave的解决方案

    2023-08-09 05:21:45
  • asp之家 网络编程 m.aspxhome.com