pytorch加载预训练模型与自己模型不匹配的解决方案

作者:找不到服务器1703 时间:2023-06-17 14:22:24 

pytorch中如果自己搭建网络并且加载别人的与训练模型的话,如果模型和参数不严格匹配,就可能会出问题,接下来记录一下我的解决方法。

两个有序字典找不同

模型的参数和pth文件的参数都是有序字典(OrderedDict),把字典中的键转为列表就可以在for循环里迭代找不同了。


model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
   if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
       err = 1

自己搭建模型的注意事项

搭网络时要对照pth文件的字典顺序搭,字典顺序、权重尺寸(shape)和变量命名必须与pth文件完全一致。如果仅仅是变量命名不同,可采用类似的方法对模型的权重重新赋值。


model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
   if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
       continue
   model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model.load_state_dict(model_dict2)

完整的代码见自己搭建resnet18网络并加载torchvision自带权重

新增的改进代码


model_dict1 = torch.load('yolov5.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
m, n = 0, 0
while True:
   if m >= len1 or n >= len2:
       break
   layername1, layername2 = model_list1[m], model_list2[n]
   w1, w2 = model_dict1[layername1], model_dict2[layername2]
   if w1.shape != w2.shape:
       continue
   model_dict2[layername2] = model_dict1[layername1]
   m += 1
   n += 1
model.load_state_dict(model_dict2)

如果因为模型不匹配,运行第14行语句后,可看自己情况手动对m或n加上1。

补充:pytorch的一些坑:用预训练的vgg模型的部分层的特征报错,如张量不匹配

看代码吧~


#打算取VGG19的第二个全连接层的输出,那么就需要构建一个类,这个类要包含VGG的全部卷积层,
#以及到第二个全连接层的全部网络还有他们对应的参数
class Classification_att(nn.Module):
   def __init__(self, rgb_range):
       super(Classification_att, self).__init__()
       self.vgg19 =models.vgg19(pretrained=True)
       vgg = models.vgg19(pretrained=True).features
       conv_modules = [m for m in vgg]
       self.vgg_conv = nn.Sequential(*conv_modules[:37])
       classfi = models.vgg19(pretrained=True).classifier
       classif_modules = [n for n in classfi]
       self.vgg_class = nn.Sequential(*classif_modules[:4])
       vgg_mean = (0.485, 0.456, 0.406)
       vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
       self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
       for p in self.vgg_conv.parameters():
           p.requires_grad = False
       for p in self.vgg_class.parameters():
           p.requires_grad = False
       self.classifi = nn.Sequential(
           nn.Linear(4096, 1024),
           nn.ReLU(True),
           nn.Linear(1024, 256),
           nn.ReLU(True),
           nn.Linear(256, 64),
       )

def forward(self, x):
       x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear',
       align_corners=False)
       x = self.sub_mean(x)
       x = self.vgg_conv(x)  
       x = self.vgg_class(x)  #执行这部报错,说张量不匹配

原因是因为卷积层的输出不能直接连接全连接层,即使输出的张量的总的大小是一致的

查看vgg的pytorch源码发现是


x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
#自己的代码没有torch.flatten(x, 1)这步

所以自己的少了一步


x = torch.flatten(x, 1)

补上就好了!

来源:https://blog.csdn.net/qq_34288751/article/details/114160725

标签:pytorch,加载,预训练,模型
0
投稿

猜你喜欢

  • Python生成指定数量的优惠码实操内容

    2021-06-12 18:29:25
  • 一文详解如何使用Python批量拼接图片

    2023-08-12 04:15:37
  • element跨分页操作选择详解

    2023-07-02 16:38:47
  • 详解vue-Resource(与后端数据交互)

    2024-06-05 09:15:06
  • 详解django2中关于时间处理策略

    2021-09-09 23:13:09
  • Python实现自动签到脚本的示例代码

    2021-07-07 14:08:37
  • MySQL 事务概念与用法深入详解

    2024-01-14 02:56:06
  • python3.8.3安装教程及环境配置的详细教程(64-bit)

    2023-03-01 05:21:10
  • C#创建数据库及导入sql脚本的方法

    2024-01-23 04:08:04
  • Python中Selenium模拟JQuery滑动解锁实例

    2021-10-21 09:49:52
  • mysql中循环截取用户信息并插入到目标表对应的字段中

    2024-01-23 21:42:51
  • tensorflow的ckpt及pb模型持久化方式及转化详解

    2022-12-10 17:32:08
  • Python+pandas编写命令行脚本操作excel的tips详情

    2023-09-23 22:35:53
  • Python 代码实现各种酷炫功能

    2022-12-24 10:40:16
  • SQL Server内存机制详解

    2024-01-20 09:57:48
  • Vue实现自定义字段导出EXCEL的示例代码

    2024-04-27 16:12:09
  • ASP.NET中的几种弹出框提示基本实现方法

    2023-07-13 00:23:50
  • Python OS模块常用函数说明

    2022-08-28 06:34:39
  • Python爬取数据保存为Json格式的代码示例

    2022-10-13 17:11:36
  • python使用selenium实现批量文件下载

    2022-09-23 15:27:08
  • asp之家 网络编程 m.aspxhome.com