浅谈Pytorch torch.optim优化器个性化的使用

作者:小河沟大河沟 时间:2023-12-19 08:47:12 

一、简化前馈网络LeNet


import torch as t

class LeNet(t.nn.Module):
def __init__(self):
 super(LeNet, self).__init__()
 self.features = t.nn.Sequential(
  t.nn.Conv2d(3, 6, 5),
  t.nn.ReLU(),
  t.nn.MaxPool2d(2, 2),
  t.nn.Conv2d(6, 16, 5),
  t.nn.ReLU(),
  t.nn.MaxPool2d(2, 2)
 )
 # 由于调整shape并不是一个class层,
 # 所以在涉及这种操作(非nn.Module操作)需要拆分为多个模型
 self.classifiter = t.nn.Sequential(
  t.nn.Linear(16*5*5, 120),
  t.nn.ReLU(),
  t.nn.Linear(120, 84),
  t.nn.ReLU(),
  t.nn.Linear(84, 10)
 )

def forward(self, x):
 x = self.features(x)
 x = x.view(-1, 16*5*5)
 x = self.classifiter(x)
 return x

net = LeNet()

二、优化器基本使用方法

建立优化器实例

循环:

清空梯度

向前传播

计算Loss

反向传播

更新参数


from torch import optim

# 通常的step优化过程
optimizer = optim.SGD(params=net.parameters(), lr=1)
optimizer.zero_grad() # net.zero_grad()

input_ = t.autograd.Variable(t.randn(1, 3, 32, 32))
output = net(input_)
output.backward(output)

optimizer.step()

三、网络模块参数定制

为不同的子网络参数不同的学习率,finetune常用,使分类器学习率参数更高,学习速度更快(理论上)。

1.经由构建网络时划分好的模组进行学习率设定,


# # 直接对不同的网络模块制定不同学习率
optimizer = optim.SGD([{'params': net.features.parameters()}, # 默认lr是1e-5
     {'params': net.classifiter.parameters(), 'lr': 1e-2}], lr=1e-5)

2.以网络层对象为单位进行分组,并设定学习率


# # 以层为单位,为不同层指定不同的学习率
# ## 提取指定层对象
special_layers = t.nn.ModuleList([net.classifiter[0], net.classifiter[3]])
# ## 获取指定层参数id
special_layers_params = list(map(id, special_layers.parameters()))
print(special_layers_params)
# ## 获取非指定层的参数id
base_params = filter(lambda p: id(p) not in special_layers_params, net.parameters())
optimizer = t.optim.SGD([{'params': base_params},
      {'params': special_layers.parameters(), 'lr': 0.01}], lr=0.001)

四、在训练中动态的调整学习率


'''调整学习率'''
# 新建optimizer或者修改optimizer.params_groups对应的学习率
# # 新建optimizer更简单也更推荐,optimizer十分轻量级,所以开销很小
# # 但是新的优化器会初始化动量等状态信息,这对于使用动量的优化器(momentum参数的sgd)可能会造成收敛中的震荡
# ## optimizer.param_groups:长度2的list,optimizer.param_groups[0]:长度6的字典
print(optimizer.param_groups[0]['lr'])
old_lr = 0.1
optimizer = optim.SGD([{'params': net.features.parameters()},
     {'params': net.classifiter.parameters(), 'lr': old_lr*0.1}], lr=1e-5)

可以看到optimizer.param_groups结构,[{'params','lr', 'momentum', 'dampening', 'weight_decay', 'nesterov'},{……}],集合了优化器的各项参数。

torch.optim的灵活使用

重写sgd优化器


import torch
from torch.optim.optimizer import Optimizer, required

class SGD(Optimizer):
def __init__(self, params, lr=required, momentum=0, dampening=0, weight_decay1=0, weight_decay2=0, nesterov=False):
 defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
     weight_decay1=weight_decay1, weight_decay2=weight_decay2, nesterov=nesterov)
 if nesterov and (momentum <= 0 or dampening != 0):
  raise ValueError("Nesterov momentum requires a momentum and zero dampening")
 super(SGD, self).__init__(params, defaults)

def __setstate__(self, state):
 super(SGD, self).__setstate__(state)
 for group in self.param_groups:
  group.setdefault('nesterov', False)

def step(self, closure=None):
 """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """
 loss = None
 if closure is not None:
  loss = closure()

for group in self.param_groups:
  weight_decay1 = group['weight_decay1']
  weight_decay2 = group['weight_decay2']
  momentum = group['momentum']
  dampening = group['dampening']
  nesterov = group['nesterov']

for p in group['params']:
   if p.grad is None:
    continue
   d_p = p.grad.data
   if weight_decay1 != 0:
    d_p.add_(weight_decay1, torch.sign(p.data))
   if weight_decay2 != 0:
    d_p.add_(weight_decay2, p.data)
   if momentum != 0:
    param_state = self.state[p]
    if 'momentum_buffer' not in param_state:
     buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
     buf.mul_(momentum).add_(d_p)
    else:
     buf = param_state['momentum_buffer']
     buf.mul_(momentum).add_(1 - dampening, d_p)
    if nesterov:
     d_p = d_p.add(momentum, buf)
    else:
     d_p = buf

p.data.add_(-group['lr'], d_p)

return loss

来源:https://www.cnblogs.com/ranjiewen/p/9240512.html

标签:Pytorch,torch.optim,优化器
0
投稿

猜你喜欢

  • js显示世界时间示例(包括世界各大城市)

    2024-04-10 13:54:37
  • Python cookbook(数据结构与算法)将名称映射到序列元素中的方法

    2021-06-06 01:26:54
  • Golang实现常见的限流算法的示例代码

    2024-04-25 13:22:35
  • Python CSV模块使用实例

    2022-02-04 18:56:36
  • 浅谈python迭代器

    2023-07-21 21:56:47
  • Python 基于FIR实现Hilbert滤波器求信号包络详解

    2023-07-13 01:31:47
  • Sql Server2016 正式版安装程序图解教程

    2024-01-21 18:48:53
  • 详解SQL中的DQL查询语言

    2024-01-24 01:51:54
  • python之PySide2安装使用及QT Designer UI设计案例教程

    2023-01-18 06:42:53
  • SQL Server配置管理器无法连接到WMI提供程序

    2024-01-23 23:49:30
  • Django对接elasticsearch实现全文检索的示例代码

    2023-07-02 01:31:29
  • python实现线性回归算法

    2021-04-11 12:36:48
  • Go语言基础go接口用法示例详解

    2024-04-30 10:06:53
  • Django如何实现RBAC权限管理

    2021-05-20 19:14:27
  • css学习笔记:DIV水平垂直居中

    2009-06-19 12:45:00
  • Python实现识别图片内容的方法分析

    2022-01-04 21:10:51
  • python绘制多个子图的实例

    2023-01-31 02:24:24
  • 基于Python实现射击小游戏的制作

    2021-05-02 17:13:48
  • python 绘制国旗的示例

    2023-01-05 19:29:32
  • Python可视化程序调用流程解析

    2022-07-18 15:53:07
  • asp之家 网络编程 m.aspxhome.com