聊聊pytorch中Optimizer与optimizer.step()的用法

作者:wang xiang 时间:2022-03-16 22:45:34 

当我们想指定每一层的学习率时:


optim.SGD([
                   {'params': model.base.parameters()},
                   {'params': model.classifier.parameters(), 'lr': 1e-3}
               ], lr=1e-2, momentum=0.9)

这意味着model.base的参数将会使用1e-2的学习率,model.classifier的参数将会使用1e-3的学习率,并且0.9的momentum将会被用于所有的参数。

进行单次优化

所有的optimizer都实现了step()方法,这个方法会更新所有的参数。它能按两种方式来使用:


optimizer.step()

这是大多数optimizer所支持的简化版本。一旦梯度被如backward()之类的函数计算好后,我们就可以调用这个函数。

例子


for input, target in dataset:
       optimizer.zero_grad()
       output = model(input)
       loss = loss_fn(output, target)
       loss.backward()
       optimizer.step()        
optimizer.step(closure)

一些优化算法例如Conjugate Gradient和LBFGS需要重复多次计算函数,因此你需要传入一个闭包去允许它们重新计算你的模型。这个闭包应当清空梯度,计算损失,然后返回。

例子:


for input, target in dataset:
   def closure():
       optimizer.zero_grad()
       output = model(input)
       loss = loss_fn(output, target)
       loss.backward()
       return loss
   optimizer.step(closure)

补充:Pytorch optimizer.step() 和loss.backward()和scheduler.step()的关系与区别

首先需要明确optimzier优化器的作用, 形象地来说,优化器就是需要根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值的作用,这也是机器学习里面最一般的方 * 。

从优化器的作用出发,要使得优化器能够起作用,需要主要两个东西:

1. 优化器需要知道当前的网络或者别的什么模型的参数空间

这也就是为什么在训练文件中,正式开始训练之前需要将网络的参数放到优化器里面,比如使用pytorch的话总会出现类似如下的代码:


optimizer_G = Adam(model_G.parameters(), lr=train_c.lr_G)   # lr 使用的是初始lr
optimizer_D = Adam(model_D.parameters(), lr=train_c.lr_D)

2. 需要知道反向传播的梯度信息

我们还是从代码入手,如下所示是Pytorch 中SGD优化算法的step()函数具体写法,具体SGD的写法放在参考部分。


def step(self, closure=None):
           """Performs a single optimization step.
           Arguments:
               closure (callable, optional): A closure that reevaluates the model
                   and returns the loss.
           """
           loss = None
           if closure is not None:
               loss = closure()

for group in self.param_groups:
               weight_decay = group['weight_decay']
               momentum = group['momentum']
               dampening = group['dampening']
               nesterov = group['nesterov']

for p in group['params']:
                   if p.grad is None:
                       continue
                   d_p = p.grad.data
                   if weight_decay != 0:
                       d_p.add_(weight_decay, p.data)
                   if momentum != 0:
                       param_state = self.state[p]
                       if 'momentum_buffer' not in param_state:
                           buf = param_state['momentum_buffer'] = d_p.clone()
                       else:
                           buf = param_state['momentum_buffer']
                           buf.mul_(momentum).add_(1 - dampening, d_p)
                       if nesterov:
                           d_p = d_p.add(momentum, buf)
                       else:
                           d_p = buf    
                   p.data.add_(-group['lr'], d_p)    
           return loss

从上面的代码可以看到step这个函数使用的是参数空间(param_groups)中的grad,也就是当前参数空间对应的梯度,这也就解释了为什么optimzier使用之前需要zero清零一下,因为如果不清零,那么使用的这个grad就得同上一个mini-batch有关,这不是我们需要的结果。

再回过头来看,我们知道optimizer更新参数空间需要基于反向梯度,因此,当调用optimizer.step()的时候应当是loss.backward()的时候,这也就是经常会碰到,如下情况


total_loss.backward()
optimizer_G.step()

loss.backward()在前,然后跟一个step。

那么为什么optimizer.step()需要放在每一个batch训练中,而不是epoch训练中,这是因为现在的mini-batch训练模式是假定每一个训练集就只有mini-batch这样大,因此实际上可以将每一次mini-batch看做是一次训练,一次训练更新一次参数空间,因而optimizer.step()放在这里。

scheduler.step()按照Pytorch的定义是用来更新优化器的学习率的,一般是按照epoch为单位进行更换,即多少个epoch后更换一次学习率,因而scheduler.step()放在epoch这个大循环下。

来源:https://blog.csdn.net/qq_40178291/article/details/99963586

标签:pytorch,Optimizer,optimizer.step
0
投稿

猜你喜欢

  • suggest项目总结-用户体验篇

    2008-01-30 20:04:00
  • 浅谈Webpack多页应用HMR卡住问题

    2023-07-20 01:27:12
  • 用画为5.12地震受灾同胞们祈福 Ⅱ

    2008-05-31 07:37:00
  • Go语言集成开发环境之VS Code安装使用

    2023-08-29 13:06:38
  • Microsoft Enterprise Library 5.0 如何集成MyS

    2011-03-16 15:19:00
  • python 使用百度AI接口进行人脸对比的步骤

    2021-07-06 02:20:06
  • php使用ZipArchive函数实现文件的压缩与解压缩

    2023-07-12 20:58:19
  • 关于windos10环境下编译python3版pjsua库的问题

    2021-06-04 08:12:13
  • 原生JS实现左右箭头选择日期实例代码

    2023-08-06 04:55:27
  • php引用返回与取消引用的详解

    2023-11-20 02:50:07
  • js鼠标动画特效

    2007-09-26 18:31:00
  • Python statistics模块示例详解

    2023-01-27 11:49:47
  • python实现简单名片管理系统

    2023-06-13 08:03:12
  • 从if else到switch case再到抽象

    2010-11-05 18:30:00
  • python2.7和NLTK安装详细教程

    2021-03-30 22:41:19
  • 对Django 转发和重定向的实例详解

    2023-06-19 04:17:53
  • 模仿IE自动完成功能

    2010-03-18 15:51:00
  • Django--权限Permissions的例子

    2021-02-16 01:44:51
  • python检测空间储存剩余大小和指定文件夹内存占用的实例

    2022-10-30 06:52:51
  • 解决Pandas的DataFrame输出截断和省略的问题

    2021-10-28 10:22:19
  • asp之家 网络编程 m.aspxhome.com