pytorch 实现模型不同层设置不同的学习率方式

作者:-wxrui- 时间:2023-11-20 00:29:25 

在目标检测的模型训练中, 我们通常都会有一个特征提取网络backbone, 例如YOLO使用的darknet SSD使用的VGG-16。

为了达到比较好的训练效果, 往往会加载预训练的backbone模型参数, 然后在此基础上训练检测网络, 并对backbone进行微调, 这时候就需要为backbone设置一个较小的lr。


class net(torch.nn.Module):
 def __init__(self):
   super(net, self).__init__()
   # backbone
   self.backbone = ...
   # detect
   self....

在设置optimizer时, 只需要参数分为两个部分, 并分别给定不同的学习率lr。


base_params = list(map(id, net.backbone.parameters()))
logits_params = filter(lambda p: id(p) not in base_params, net.parameters())
params = [
 {"params": logits_params, "lr": config.lr},
 {"params": net.backbone.parameters(), "lr": config.backbone_lr},
]
optimizer = torch.optim.SGD(params, momentum=config.momentum, weight_decay=config.weight_decay)

来源:https://blog.csdn.net/qq_42110481/article/details/81025575

标签:pytorch,模型,学习率
0
投稿

猜你喜欢

  • Django上传xlsx文件直接转化为DataFrame或直接保存的方法

    2023-12-02 14:42:16
  • pytorch之torchvision.transforms图像变换实例

    2021-05-19 05:44:05
  • python读取word文档的方法

    2023-11-24 08:56:28
  • Go 热加载之fresh详解

    2024-03-23 14:27:26
  • vue-cli使用stimulsoft.reports.js的详细教程

    2024-04-09 10:58:59
  • asp如何在第10000名来访者访问时显示中奖页面?

    2010-06-18 19:45:00
  • 可以自动轮换的页签 tabs

    2008-02-21 12:25:00
  • 快速解决pandas.read_csv()乱码的问题

    2023-07-10 21:14:47
  • 对Python _取log的几种方式小结

    2021-12-19 02:18:48
  • Python使用smtp和pop简单收发邮件完整实例

    2022-01-07 05:48:40
  • PHP使用GIFEncoder类生成gif动态滚动字幕

    2024-05-08 09:34:46
  • python保存字典数据到csv文件的完整代码

    2023-04-09 17:02:02
  • jupyter notebook 恢复误删单元格或者历史代码的实现

    2022-03-03 16:13:45
  • 推荐值得学习的12款python-web开发框架

    2021-10-20 21:46:10
  • 基于python读取.mat文件并取出信息

    2021-10-24 12:06:26
  • MySQL查看和修改时区的方法

    2024-01-15 05:42:33
  • Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作

    2022-05-06 14:26:39
  • asp获取文件md5值

    2008-10-13 09:10:00
  • pandas ix &iloc &loc的区别

    2023-03-12 16:31:54
  • springboot+idea+maven 多模块项目搭建的详细过程(连接数据库进行测试)

    2024-01-19 08:04:34
  • asp之家 网络编程 m.aspxhome.com