在pytorch中实现只让指定变量向后传播梯度

作者:美利坚节度使 时间:2022-03-27 15:35:30 

pytorch中如何只让指定变量向后传播梯度?

(或者说如何让指定变量不参与后向传播?)

有以下公式,假如要让L对xvar求导:

在pytorch中实现只让指定变量向后传播梯度

(1)中,L对xvar的求导将同时计算out1部分和out2部分;

(2)中,L对xvar的求导只计算out2部分,因为out1的requires_grad=False;

(3)中,L对xvar的求导只计算out1部分,因为out2的requires_grad=False;

验证如下:


#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed May 23 10:02:04 2018
@author: hy
"""

import torch
from torch.autograd import Variable
print("Pytorch version: {}".format(torch.__version__))
x=torch.Tensor([1])
xvar=Variable(x,requires_grad=True)
y1=torch.Tensor([2])
y2=torch.Tensor([7])
y1var=Variable(y1)
y2var=Variable(y2)
#(1)
print("For (1)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()
#(2)
print("For (2)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
out1 = out1.detach()
print("after out1.detach(), out1 requres_grad: {}".format(out1.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()
#(3)
print("For (3)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
#out1 = out1.detach()
out2 = out2.detach()
print("after out2.detach(), out2 requres_grad: {}".format(out1.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()

pytorch中,将变量的requires_grad设为False,即可让变量不参与梯度的后向传播;

但是不能直接将out1.requires_grad=False;

其实,Variable类型提供了detach()方法,所返回变量的requires_grad为False。

注意:如果out1和out2的requires_grad都为False的话,那么xvar.grad就出错了,因为梯度没有传到xvar

补充:

volatile=True表示这个变量不计算梯度, 参考:Volatile is recommended for purely inference mode, when you're sure you won't be even calling .backward(). It's more efficient than any other autograd setting - it will use the absolute minimal amount of memory to evaluate the model. volatile also determines that requires_grad is False.

来源:https://blog.csdn.net/ying86615791/article/details/80417465

标签:pytorch,指定变量,梯度
0
投稿

猜你喜欢

  • 最新屏蔽百度快照的方法

    2009-07-06 14:37:00
  • 将python代码和注释分离的方法

    2022-04-06 12:04:50
  • Python之Matplotlib文字与注释的使用方法

    2022-09-09 10:36:51
  • python实现在IDLE中输入多行的方法

    2023-05-06 19:24:01
  • 如何在社区建立一个寻呼台?

    2009-11-08 18:59:00
  • python 使用递归回溯完美解决八皇后的问题

    2023-03-07 21:48:22
  • python命令行参数解析OptionParser类用法实例

    2022-06-21 17:57:24
  • python操作字典类型的常用方法(推荐)

    2023-01-17 17:05:33
  • Python3导入CSV文件的实例(跟Python2有些许的不同)

    2023-09-21 05:19:29
  • Python线性回归实战分析

    2023-05-19 04:35:42
  • 在Python中通过getattr获取对象引用的方法

    2023-08-24 23:32:33
  • 解析:轻松掌握在 Mac OS X中安装MySQL

    2009-01-14 11:51:00
  • Python pandas找出、删除重复的数据实例

    2023-07-05 11:24:46
  • 使用Python下载Bing图片(代码)

    2023-11-05 00:42:59
  • 让SQL Server数据库自动执行管理任务(一)

    2009-03-20 10:35:00
  • Go语言中slice作为参数传递时遇到的一些“坑”

    2023-08-05 02:05:12
  • 使用access数据库时可能用到的数据转换

    2008-09-10 12:49:00
  • 解决Golang中goroutine执行速度的问题

    2023-08-25 20:12:12
  • python实现一个摇骰子小游戏

    2021-11-06 05:49:26
  • python 读取.csv文件数据到数组(矩阵)的实例讲解

    2023-08-10 12:12:36
  • asp之家 网络编程 m.aspxhome.com