pytorch 实现模型不同层设置不同的学习率方式
作者:-wxrui- 时间:2023-11-20 00:29:25
在目标检测的模型训练中, 我们通常都会有一个特征提取网络backbone, 例如YOLO使用的darknet SSD使用的VGG-16。
为了达到比较好的训练效果, 往往会加载预训练的backbone模型参数, 然后在此基础上训练检测网络, 并对backbone进行微调, 这时候就需要为backbone设置一个较小的lr。
class net(torch.nn.Module):
def __init__(self):
super(net, self).__init__()
# backbone
self.backbone = ...
# detect
self....
在设置optimizer时, 只需要参数分为两个部分, 并分别给定不同的学习率lr。
base_params = list(map(id, net.backbone.parameters()))
logits_params = filter(lambda p: id(p) not in base_params, net.parameters())
params = [
{"params": logits_params, "lr": config.lr},
{"params": net.backbone.parameters(), "lr": config.backbone_lr},
]
optimizer = torch.optim.SGD(params, momentum=config.momentum, weight_decay=config.weight_decay)
来源:https://blog.csdn.net/qq_42110481/article/details/81025575
标签:pytorch,模型,学习率
0
投稿
猜你喜欢
Django上传xlsx文件直接转化为DataFrame或直接保存的方法
2023-12-02 14:42:16
pytorch之torchvision.transforms图像变换实例
2021-05-19 05:44:05
python读取word文档的方法
2023-11-24 08:56:28
Go 热加载之fresh详解
2024-03-23 14:27:26
vue-cli使用stimulsoft.reports.js的详细教程
2024-04-09 10:58:59
asp如何在第10000名来访者访问时显示中奖页面?
2010-06-18 19:45:00
可以自动轮换的页签 tabs
2008-02-21 12:25:00
快速解决pandas.read_csv()乱码的问题
2023-07-10 21:14:47
对Python _取log的几种方式小结
2021-12-19 02:18:48
Python使用smtp和pop简单收发邮件完整实例
2022-01-07 05:48:40
PHP使用GIFEncoder类生成gif动态滚动字幕
2024-05-08 09:34:46
python保存字典数据到csv文件的完整代码
2023-04-09 17:02:02
jupyter notebook 恢复误删单元格或者历史代码的实现
2022-03-03 16:13:45
推荐值得学习的12款python-web开发框架
2021-10-20 21:46:10
基于python读取.mat文件并取出信息
2021-10-24 12:06:26
MySQL查看和修改时区的方法
2024-01-15 05:42:33
Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作
2022-05-06 14:26:39
asp获取文件md5值
2008-10-13 09:10:00
pandas ix &iloc &loc的区别
2023-03-12 16:31:54
springboot+idea+maven 多模块项目搭建的详细过程(连接数据库进行测试)
2024-01-19 08:04:34