pytorch如何冻结某层参数的实现
作者:Pr4da 时间:2021-02-03 11:49:36
在迁移学习finetune时我们通常需要冻结前几层的参数不参与训练,在Pytorch中的实现如下:
class Model(nn.Module):
def __init__(self):
super(Transfer_model, self).__init__()
self.linear1 = nn.Linear(20, 50)
self.linear2 = nn.Linear(50, 20)
self.linear3 = nn.Linear(20, 2)
def forward(self, x):
pass
假如我们想要冻结linear1层,需要做如下操作:
model = Model()
# 这里是一般情况,共享层往往不止一层,所以做一个for循环
for para in model.linear1.parameters():
para.requires_grad = False
# 假如真的只有一层也可以这样操作:
# model.linear1.weight.requires_grad = False
最后我们需要将需要优化的参数传入优化器,不需要传入的参数过滤掉,所以要用到filter()函数。
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)
其它的博客中都没有讲解filter()函数的作用,在这里我简单讲一下有助于更好的理解。
filter(function, iterable)
function: 判断函数
iterable: 可迭代对象
filter() 函数用于过滤序列,过滤掉不符合条件的元素,返回一个迭代器对象,如果要转换为列表,可以使用 list() 来转换。
该接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判,然后返回 True 或 False,最后将返回 True 的元素放到新列表中。
filter()函数将requires_grad = True的参数传入优化器进行反向传播,requires_grad = False的则被过滤掉。
来源:https://blog.csdn.net/qq_40210586/article/details/103878155
标签:pytorch,冻结,参数
0
投稿
猜你喜欢
python实现石头剪刀布小游戏
2022-03-22 15:47:36
用python的哈希函数对密码加密
2021-06-10 21:02:58
Django的URLconf中使用缺省视图参数的方法
2021-05-03 17:46:29
快速了解Python开发中的cookie及简单代码示例
2023-05-29 11:04:05
解析SQLServer任意列之间的聚合
2024-01-17 12:48:29
详解vue过度效果与动画transition使用示例
2023-07-02 16:45:03
Tensorflow实现卷积神经网络的详细代码
2022-02-20 22:14:06
python GUI库图形界面开发之PyQt5切换按钮控件QPushButton详细使用方法与实例
2021-06-13 09:13:14
linux下源码安装mysql5.6.20教程
2024-01-16 20:13:42
PyTorch搭建CNN实现风速预测
2022-09-11 17:40:19
人工智能学习pyTorch自建数据集及可视化结果实现过程
2022-08-04 14:54:33
python中wordcloud安装的方法小结
2022-07-11 04:29:44
python使用Image处理图片常用技巧分析
2023-01-17 14:51:38
Python如何用str.format()批量生成网址(豆瓣读书为例)
2022-10-02 15:38:41
细化解析:SQL Server 2000 的各种版本
2009-02-05 15:41:00
Python趣味挑战之用pygame实现飞机塔防游戏
2022-07-18 04:00:02
在IE下用getAttribute时要小心
2008-08-21 12:54:00
浅谈SQL Server 对于内存的管理[图文]
2024-01-14 07:41:44
python偏函数partial用法
2023-09-24 22:25:06
PHP动态生成javascript文件的2个例子
2024-05-11 09:25:44