pytorch fine-tune 预训练的模型操作

作者:This is bill 时间:2023-05-02 01:05:25 

之一:

torchvision 中包含了很多预训练好的模型,这样就使得 fine-tune 非常容易。本文主要介绍如何 fine-tune torchvision 中预训练好的模型。

安装


pip install torchvision

如何 fine-tune

以 resnet18 为例:


from torchvision import models
from torch import nn
from torch import optim

resnet_model = models.resnet18(pretrained=True)
# pretrained 设置为 True,会自动下载模型 所对应权重,并加载到模型中
# 也可以自己下载 权重,然后 load 到 模型中,源码中有 权重的地址。

# 假设 我们的 分类任务只需要 分 100 类,那么我们应该做的是
# 1. 查看 resnet 的源码
# 2. 看最后一层的 名字是啥 (在 resnet 里是 self.fc = nn.Linear(512 * block.expansion, num_classes))
# 3. 在外面替换掉这个层
resnet_model.fc= nn.Linear(in_features=..., out_features=100)

# 这样就 哦了,修改后的模型除了输出层的参数是 随机初始化的,其他层都是用预训练的参数初始化的。

# 如果只想训练 最后一层的话,应该做的是:
# 1. 将其它层的参数 requires_grad 设置为 False
# 2. 构建一个 optimizer, optimizer 管理的参数只有最后一层的参数
# 3. 然后 backward, step 就可以了

# 这一步可以节省大量的时间,因为多数的参数不需要计算梯度
for para in list(resnet_model.parameters())[:-2]:
   para.requires_grad=False

optimizer = optim.SGD(params=[resnet_model.fc.weight, resnet_model.fc.bias], lr=1e-3)

...

为什么

这里介绍下 运行resnet_model.fc= nn.Linear(in_features=..., out_features=100)时 框架内发生了什么

这时应该看 nn.Module 源码的 __setattr__ 部分,因为 setattr 时都会调用这个方法:


def __setattr__(self, name, value):
   def remove_from(*dicts):
       for d in dicts:
           if name in d:
               del d[name]

首先映入眼帘就是 remove_from 这个函数,这个函数的目的就是,如果出现了 同名的属性,就将旧的属性移除。 用刚才举的例子就是:

预训练的模型中 有个 名字叫fc 的 Module。

在类定义外,我们 将另一个 Module 重新 赋值给了 fc。

类定义内的 fc 对应的 Module 就会从 模型中 删除。

之二:

前言

这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。

参数初始化

参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。

pytorch fine-tune 预训练的模型操作

所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:


def weight_init(m):
# 使用isinstance来判断m属于什么类型
   if isinstance(m, nn.Conv2d):
       n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
       m.weight.data.normal_(0, math.sqrt(2. / n))
   elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
       m.weight.data.fill_(1)
       m.bias.data.zero_()

Finetune

往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。

局部微调

有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。


model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
   param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)

# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微调

有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:


ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
                    model.parameters())

optimizer = torch.optim.SGD([
           {'params': base_params},
           {'params': model.fc.parameters(), 'lr': 1e-3}
           ], lr=1e-2, momentum=0.9)

其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。

之三:

pytorch finetune模型

文章主要讲述如何在pytorch上读取以往训练的模型参数,在模型的名字已经变更的情况下又如何读取模型的部分参数等。

pytorch 模型的存储与读取

其中在模型的保存过程有存储模型和参数一起的也有单独存储模型参数的

单独存储模型参数

存储时使用:


torch.save(the_model.state_dict(), PATH)

读取时:


the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

存储模型与参数

存储:


torch.save(the_model, PATH)

读取:


the_model = torch.load(PATH)

模型的参数

fine-tune的过程是读取原有模型的参数,但是由于模型的所要处理的数据集不同,最后的一层class的总数不同,所以需要修改模型的最后一层,这样模型读取的参数,和在大数据集上训练好下载的模型参数在形式上不一样。需要我们自己去写函数读取参数。

pytorch模型参数的形式

模型的参数是以字典的形式存储的。


model_dict = the_model.state_dict(),
for k,v in model_dict.items():
   print(k)

即可看到所有的键值

如果想修改模型的参数,给相应的键值赋值即可


model_dict[k] = new_value

最后更新模型的参数


the_model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是一样的

我们可以通过下列算法进行读取模型


model_dict = model.state_dict()
pretrained_dict = torch.load(model_path)
# 1. filter out unnecessary keys
diff = {k: v for k, v in model_dict.items() if \
           k in pretrained_dict and pretrained_dict[k].size() == v.size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
pretrained_dict.update(diff)
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的


model_dict = model.state_dict()
pretrained_dict = torch.load(model_path)
keys = []
for k,v in pretrained_dict.items():
   keys.append(k)
i = 0
for k,v in model_dict.items():
   if v.size() == pretrained_dict[keys[i]].size():
       print(k, ',', keys[i])
        model_dict[k]=pretrained_dict[keys[i]]
   i = i + 1
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的

自己找对应关系,一个key对应一个key的赋值

来源:https://blog.csdn.net/Scythe666/article/details/82809615

标签:pytorch,fine-tune,预训练,模型
0
投稿

猜你喜欢

  • Python读取配置文件的实战操作

    2021-08-12 19:48:09
  • IE 下 href 的 BUG

    2008-11-10 12:32:00
  • 为Python的web框架编写MVC配置来使其运行的教程

    2022-05-30 01:54:32
  • Python+OpenCV人脸检测原理及示例详解

    2021-07-31 19:31:51
  • django静态文件加载的方法

    2022-12-26 13:57:56
  • Python装饰器decorator用法实例

    2023-02-06 23:26:43
  • ASP编写完整的一个IP所在地搜索类

    2007-10-18 10:43:00
  • Python实现连点器的示例代码

    2023-04-17 00:11:29
  • Asp性能优化之Response.IsClientConnected属性及其应用示例

    2008-09-18 12:13:00
  • 一个asp版XMLDOM操作类

    2011-04-19 10:50:00
  • Laravel使用PHPQRCODE实现生成带有LOGO的二维码图片功能示例

    2024-05-03 15:28:12
  • Python xlwt工具使用详解,生成excel栏位宽度可自适应内容长度

    2024-01-03 20:20:20
  • 解决vue组件中click事件失效的问题

    2023-07-02 16:34:10
  • python tkinter库的Text记录点击路经和删除记录详情

    2021-04-15 03:41:13
  • vue实现瀑布流组件滑动加载更多

    2024-05-02 17:09:45
  • Django如何实现内容缓存示例详解

    2022-02-23 15:33:01
  • CentOS 7下安装Python 3.5并与Python2.7兼容并存详解

    2021-09-18 03:03:32
  • 在python3.64中安装pyinstaller库的方法步骤

    2022-08-12 10:27:05
  • golang指数运算操作

    2024-01-30 22:35:53
  • Python实现快速保存微信公众号文章中的图片

    2021-02-18 23:03:25
  • asp之家 网络编程 m.aspxhome.com