解决Pytorch修改预训练模型时遇到key不匹配的情况

作者:月亮不秃头 时间:2022-11-29 15:43:43 

一、Pytorch修改预训练模型时遇到key不匹配

最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保存之后。

在我使用新赋值的网络模型时出现了key不匹配的问题


#加载后保存(未修改网络)
base_weights = torch.load(args.save_folder + args.basenet)
ssd_net.vgg.load_state_dict(base_weights)
torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')

# 将新保存的网络代替之前的预训练模型
   ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
   net = ssd_net
   ...
   if args.resume:
       ...
   else:
       base_weights = torch.load(args.save_folder + args.basenet)
       #args.basenet为ssd_base.pth
       print('Loading base network...')
       ssd_net.vgg.load_state_dict(base_weights)

此时会如下出错误:

Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)

RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.

说明之前的预训练模型 key参数为"0.weight", “0.bias”,但是经过加载保存之后变为了"vgg.0.weight", “vgg.0.bias”

我认为是因为本身的模型定义文件里self.vgg = nn.ModuleList(base)这一句。

现在的问题是因为自己定义保存的模型key参数多了一个前缀。

可以通过如下语句进行修改,并加载


from collections import OrderedDict   #导入此模块
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**  
for k, v in base_weights.items():
   name = k[4:]   # remove `vgg.`,即只取vgg.0.weights的后面几位
   new_state_dict[name] = v
   ssd_net.vgg.load_state_dict(new_state_dict)

此时就不会再出错了。

参考了这个篇。修改一下就可以应用到自己的模型啦。

//www.jb51.net/article/214214.htm

二、pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘

最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。

KeyError: 'layer1.0.bn1.num_batches_tracked'

其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,

这个参数的作用如下:

训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1

如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).

其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.

所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked'.代码例子,如下.

有问题的代码:


  def load_specific_param(self, state_dict, param_name, model_path):
       param_dict = torch.load(model_path)
       for i in state_dict:
           key = param_name + '.' + i
           state_dict[i].copy_(param_dict[key])
       del param_dict

对'num_batches_tracked进行过滤:


  def load_specific_param(self, state_dict, param_name, model_path):
       param_dict = torch.load(model_path)
       param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
       for i in state_dict:
           key = param_name + '.' + i
           if 'num_batches_tracked' in key:
               continue
           state_dict[i].copy_(param_dict[key])
       del param_dict

来源:https://blog.csdn.net/weixin_44039925/article/details/99447653

标签:Pytorch,预训练,key,不匹配
0
投稿

猜你喜欢

  • 简单介绍Python中用于求最小值的min()方法

    2021-05-27 23:00:11
  • SQL Server数据库备份出错及应对措施

    2009-04-20 17:02:00
  • 教你利用Selenium+python自动化来解决pip使用异常

    2022-11-17 18:49:08
  • js自动闭合html标签(自动补全html标记)

    2023-08-25 07:06:35
  • django admin 后台实现三级联动的示例代码

    2023-11-04 04:01:43
  • 还在手动盖楼抽奖?教你用Python实现自动评论盖楼抽奖(一)

    2023-12-26 21:32:41
  • Seaborn数据分析NBA球员信息数据集

    2021-06-27 03:36:04
  • 使用phpMyAdmin进行mysql数据库备份和还原的方法

    2008-10-13 20:56:00
  • Go语言字符串基础示例详解

    2023-07-17 03:14:56
  • python之数字图像处理方式

    2023-02-02 18:27:09
  • 在python中使用requests 模拟浏览器发送请求数据的方法

    2022-05-05 03:17:35
  • python区块链创建多个交易教程

    2021-05-28 13:40:42
  • Python基于pygame模块播放MP3的方法示例

    2023-09-22 12:33:19
  • 教你快速掌握一些方便易用的SQL语句

    2008-11-28 15:21:00
  • Python 利用argparse模块实现脚本命令行参数解析

    2022-12-01 16:11:55
  • 如何学习Python time模块

    2023-07-30 17:14:59
  • 使用Python的Dataframe取两列时间值相差一年的所有行方法

    2023-11-11 06:50:25
  • OpenCV3.0+Python3.6实现特定颜色的物体追踪

    2021-05-13 09:01:03
  • PHP的mysqli_thread_id()函数讲解

    2023-06-13 10:09:43
  • Thinkphp5.0 框架的请求方式与响应方式分析

    2023-11-15 00:07:09
  • asp之家 网络编程 m.aspxhome.com