Pytorch之Variable的用法

作者:啧啧啧biubiu 时间:2022-01-19 04:16:39 

1.简介

torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现

Variable和tensor的区别和联系

Variable是篮子,而tensor是鸡蛋,鸡蛋应该放在篮子里才能方便拿走(定义variable时一个参数就是tensor)

Variable这个篮子里除了装了tensor外还有requires_grad参数,表示是否需要对其求导,默认为False

Variable这个篮子呢,自身有一些属性

比如grad,梯度variable.grad是d(y)/d(variable)保存的是变量y对variable变量的梯度值,如果requires_grad参数为False,所以variable.grad返回值为None,如果为True,返回值就为对variable的梯度值

比如grad_fn,对于用户自己创建的变量(Variable())grad_fn是为none的,也就是不能调用backward函数,但对于由计算生成的变量,如果存在一个生成中间变量的requires_grad为true,那其的grad_fn不为none,反则为none

比如data,这个就很简单,这个属性就是装的鸡蛋(tensor)

Varibale包含三个属性:

data:存储了Tensor,是本体的数据 grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致 grad_fn:指向Function对象,用于反向传播的梯度计算之用

代码1


import numpy as np
import torch
from torch.autograd import Variable

x = Variable(torch.ones(2,2),requires_grad = False)
temp = Variable(torch.zeros(2,2),requires_grad = True)

y = x + temp + 2
y = y.mean() #求平均数

y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(x.grad) # d(y)/d(x)

输出1

none

(因为requires_grad=False)

代码2


import numpy as np
import torch
from torch.autograd import Variable

x = Variable(torch.ones(2,2),requires_grad = False)
temp = Variable(torch.zeros(2,2),requires_grad = True)

y = x + temp + 2
y = y.mean() #求平均数

y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(temp.grad) # d(y)/d(temp)

输出2

tensor([[0.2500, 0.2500],
[0.2500, 0.2500]])

代码3


import numpy as np
import torch
from torch.autograd import Variable

x = Variable(torch.ones(2,2),requires_grad = False)
temp = Variable(torch.zeros(2,2),requires_grad = True)

y = x + 2
y = y.mean() #求平均数

y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(x.grad) # d(y)/d(x)

输出3

Traceback (most recent call last):
File "path", line 12, in <module>
y.backward()

(报错了,因为生成变量y的中间变量只有x,而x的requires_grad是False,所以y的grad_fn是none)

代码4


import numpy as np
import torch
from torch.autograd import Variable

x = Variable(torch.ones(2,2),requires_grad = False)
temp = Variable(torch.zeros(2,2),requires_grad = True)

y = x + 2
y = y.mean() #求平均数

#y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(y.grad_fn) # d(y)/d(x)

输出4

none

2.grad属性

在每次backward后,grad值是会累加的,所以利用BP算法,每次迭代是需要将grad清零的。

x.grad.data.zero_()

(in-place操作需要加上_,即zero_)

来源:https://blog.csdn.net/qq_37385726/article/details/81706820

标签:Pytorch,Variable
0
投稿

猜你喜欢

  • Python闭包和装饰器用法实例详解

    2021-04-07 10:05:02
  • HTML的优化杂记

    2010-03-10 10:39:00
  • Oracle9i在Win2k环境下的完全卸载

    2010-07-28 13:03:00
  • 对抗MySQL数据库解密高手

    2008-12-25 13:14:00
  • [译]2009年海外Web设计风潮(下)

    2009-01-23 09:34:00
  • 使用Python获取CPU、内存和硬盘等windowns系统信息的2个例子

    2023-08-26 23:12:32
  • python 三边测量定位的实现代码

    2023-02-03 08:37:31
  • 如何用Python绘制棒棒糖图表

    2021-05-02 06:26:33
  • Python之dict(或对象)与json之间的互相转化实例

    2023-05-14 04:26:00
  • 聚族索引、非聚族索引、组合索引的含义和用途

    2010-07-02 21:51:00
  • php中让上传的文件大小在上传前就受限制的两种解决方法

    2023-10-25 17:53:12
  • 简单掌握Python的Collections模块中counter结构的用法

    2023-05-17 00:20:13
  • 减少SQL Server死锁的方法

    2009-01-05 13:49:00
  • Python统计python文件中代码,注释及空白对应的行数示例【测试可用】

    2023-04-30 00:11:19
  • 你知道吗实现炫酷可视化只要1行python代码

    2022-06-10 13:36:16
  • 网站图片与文本谁更重要?(中英文对照)

    2008-10-17 10:25:00
  • 比较一下看看自己掌握了多少SQL快捷键

    2009-01-04 14:04:00
  • python模块之subprocess模块级方法的使用

    2022-05-10 03:28:32
  • 使用Python3内置文档高效学习以及官方中文文档

    2022-06-13 08:14:45
  • Python pywifi ERROR Open handle failed问题及解决

    2021-01-16 03:54:28
  • asp之家 网络编程 m.aspxhome.com