Pytorch中实现只导入部分模型参数的方式

作者:咆哮的阿杰 时间:2023-01-24 05:53:25 

我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected)。我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed)。如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误。那么在这种情况下,该如何导入模型呢?

好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数值。我们使用model.state_dict()获得这个字典,之后就能利用参数名称来实现导入。

请看下面的一个例子。

我们先搭建一个小小的网络。


import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
class Net(Module):
 def __init__(self):
   super(Net,self).__init__()
   self.conv1 = nn.Conv2d(3,32,3,1)
   self.conv2 = nn.Conv2d(32,3,3,1)
   self.w = nn.Parameter(t.randn(3,10))
   for p in self.children():
     nn.init.xavier_normal_(p.weight.data)
     nn.init.constant_(p.bias.data, 0)
 def forward(self, x):
   out = self.conv1(x)
   out = self.conv2(x)

out = F.avg_pool2d(out,(out.shape[2],out.shape[3]))
   out = F.linear(out,weight=self.w)
   return out

然后我们保存这个网络的初始值。


model = Net()
t.save(model.state_dict(),'xxx.pth')

现在我们将Net修改一下,多加几个卷积层,但并不加入到forward中,仅仅出于少些几行的目的。


import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F

class Net(Module):
 def __init__(self):
   super(Net, self).__init__()
   self.conv1 = nn.Conv2d(3, 32, 3, 1)
   self.conv2 = nn.Conv2d(32, 3, 3, 1)
   self.conv3 = nn.Conv2d(3,64,3,1)
   self.conv4 = nn.Conv2d(64,32,3,1)
   for p in self.children():
     nn.init.xavier_normal_(p.weight.data)
     nn.init.constant_(p.bias.data, 0)

self.w = nn.Parameter(t.randn(3, 10))
 def forward(self, x):
   out = self.conv1(x)
   out = self.conv2(x)

out = F.avg_pool2d(out, (out.shape[2], out.shape[3]))
   out = F.linear(out, weight=self.w)
   return out

我们现在试着导入之前保存的模型参数。


path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path))

'''
RuntimeError: Error(s) in loading state_dict for Net:
Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias".
'''

出现了没有在模型文件中找到error中的关键字的错误。

现在我们这样导入模型


path = 'xxx.pth'
model = Net()
save_model = t.load(path)
model_dict = model.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)

看看上面的代码,很容易弄明白。其中model_dict.update的作用是更新代码中搭建的模型参数字典。为啥更新我其实并不清楚,但这一步骤是必须的,否则还会报错。

为了弄清楚为什么要更新model_dict,我们不妨分别输出state_dict和model_dict的关键值看一看。


for k in state_dict.keys():
 print(k)

'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
'''
for k in model_dict.keys():
 print(k)

'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
conv4.weight
conv4.bias
'''

这个结果也是预料之中的,所以我猜测,update之后,model_dict和state_dict中具有相同键的值已经同步了。updata的目的就是使model_dict带有state_dict中都具有的那一部分参数的值,对于model_dict中有的,但是save_dict中没有的参数,值不改变,参数仍然使用初始值。

来源:https://blog.csdn.net/qq_34914551/article/details/87871134

标签:Pytorch,导入,模型参数
0
投稿

猜你喜欢

  • PHP错误Parse error: syntax error, unexpected end of file in test.php on line 12解决方法

    2023-11-14 16:13:53
  • Python 八个数据清洗实例代码详解

    2022-12-08 23:50:36
  • Golang解析yaml文件操作指南

    2024-05-09 14:51:59
  • python绘图pyecharts+pandas的使用详解

    2022-02-03 18:00:44
  • 超酷的js图片轮播渐变效果

    2007-10-10 20:45:00
  • python使用Pycharm创建一个Django项目

    2023-11-01 22:33:13
  • [翻译]标记语言和样式手册 Chapter 16 下一步

    2008-02-22 17:47:00
  • python添加命令行参数的详细过程

    2022-03-18 15:34:27
  • SQL Server2005打开数据表中的XML内容时报错的解决办法

    2024-01-18 01:32:12
  • 详解CSS3中的属性选择符

    2008-04-24 14:30:00
  • python中的txt文件转换为XML

    2021-12-05 10:45:48
  • 文件上传服务器-jupyter 中python解压及压缩方式

    2021-06-03 22:32:21
  • python 实现任务管理清单案例

    2023-09-01 04:59:17
  • MYSQL随机抽取查询 MySQL Order By Rand()效率问题

    2024-01-28 03:01:30
  • Python二进制转化为十进制数学算法详解

    2021-11-09 19:45:20
  • php开启openssl的方法

    2023-11-14 06:52:51
  • pytest使用parametrize将参数化变量传递到fixture

    2022-03-28 23:30:18
  • Python使用Matplotlib实现Logos设计代码

    2021-02-04 19:18:34
  • 用python制作游戏外 挂

    2023-08-03 15:55:43
  • 简单的Python人脸识别系统

    2023-01-26 23:31:57
  • asp之家 网络编程 m.aspxhome.com