pytorch如何冻结某层参数的实现

作者:Pr4da 时间:2021-02-03 11:49:36 

在迁移学习finetune时我们通常需要冻结前几层的参数不参与训练,在Pytorch中的实现如下:


class Model(nn.Module):
def __init__(self):
 super(Transfer_model, self).__init__()
 self.linear1 = nn.Linear(20, 50)
 self.linear2 = nn.Linear(50, 20)
 self.linear3 = nn.Linear(20, 2)

def forward(self, x):
pass

假如我们想要冻结linear1层,需要做如下操作:


model = Model()
# 这里是一般情况,共享层往往不止一层,所以做一个for循环
for para in model.linear1.parameters():
para.requires_grad = False
# 假如真的只有一层也可以这样操作:
# model.linear1.weight.requires_grad = False

 最后我们需要将需要优化的参数传入优化器,不需要传入的参数过滤掉,所以要用到filter()函数。


optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)

其它的博客中都没有讲解filter()函数的作用,在这里我简单讲一下有助于更好的理解。

filter(function, iterable)

  • function: 判断函数

  • iterable: 可迭代对象

filter() 函数用于过滤序列,过滤掉不符合条件的元素,返回一个迭代器对象,如果要转换为列表,可以使用 list() 来转换。

该接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判,然后返回 True 或 False,最后将返回 True 的元素放到新列表中。

filter()函数将requires_grad = True的参数传入优化器进行反向传播,requires_grad = False的则被过滤掉。

来源:https://blog.csdn.net/qq_40210586/article/details/103878155

标签:pytorch,冻结,参数
0
投稿

猜你喜欢

  • python实现石头剪刀布小游戏

    2022-03-22 15:47:36
  • 用python的哈希函数对密码加密

    2021-06-10 21:02:58
  • Django的URLconf中使用缺省视图参数的方法

    2021-05-03 17:46:29
  • 快速了解Python开发中的cookie及简单代码示例

    2023-05-29 11:04:05
  • 解析SQLServer任意列之间的聚合

    2024-01-17 12:48:29
  • 详解vue过度效果与动画transition使用示例

    2023-07-02 16:45:03
  • Tensorflow实现卷积神经网络的详细代码

    2022-02-20 22:14:06
  • python GUI库图形界面开发之PyQt5切换按钮控件QPushButton详细使用方法与实例

    2021-06-13 09:13:14
  • linux下源码安装mysql5.6.20教程

    2024-01-16 20:13:42
  • PyTorch搭建CNN实现风速预测

    2022-09-11 17:40:19
  • 人工智能学习pyTorch自建数据集及可视化结果实现过程

    2022-08-04 14:54:33
  • python中wordcloud安装的方法小结

    2022-07-11 04:29:44
  • python使用Image处理图片常用技巧分析

    2023-01-17 14:51:38
  • Python如何用str.format()批量生成网址(豆瓣读书为例)

    2022-10-02 15:38:41
  • 细化解析:SQL Server 2000 的各种版本

    2009-02-05 15:41:00
  • Python趣味挑战之用pygame实现飞机塔防游戏

    2022-07-18 04:00:02
  • 在IE下用getAttribute时要小心

    2008-08-21 12:54:00
  • 浅谈SQL Server 对于内存的管理[图文]

    2024-01-14 07:41:44
  • python偏函数partial用法

    2023-09-24 22:25:06
  • PHP动态生成javascript文件的2个例子

    2024-05-11 09:25:44
  • asp之家 网络编程 m.aspxhome.com