PyTorch中的Variable变量详解

作者:Wei Ji 时间:2023-02-19 18:48:47 

一、了解Variable

顾名思义,Variable就是 变量 的意思。实质上也就是可以变化的量,区别于int变量,它是一种可以变化的变量,这正好就符合了反向传播,参数更新的属性。

具体来说,在pytorch中的Variable就是一个存放会变化值的地理位置,里面的值会不停发生片花,就像一个装鸡蛋的篮子,鸡蛋数会不断发生变化。那谁是里面的鸡蛋呢,自然就是pytorch中的tensor了。(也就是说,pytorch都是有tensor计算的,而tensor里面的参数都是Variable的形式)。如果用Variable计算的话,那返回的也是一个同类型的Variable。

【tensor 是一个多维矩阵】

用一个例子说明,Variable的定义:


import torch
from torch.autograd import Variable # torch 中 Variable 模块
tensor = torch.FloatTensor([[1,2],[3,4]])
# 把鸡蛋放到篮子里, requires_grad是参不参与误差反向传播, 要不要计算梯度
variable = Variable(tensor, requires_grad=True)

print(tensor)
"""
1 2
3 4
[torch.FloatTensor of size 2x2]
"""

print(variable)
"""
Variable containing:
1 2
3 4
[torch.FloatTensor of size 2x2]
"""

注:tensor不能反向传播,variable可以反向传播。

二、Variable求梯度

Variable计算时,它会逐渐地生成计算图。这个图就是将所有的计算节点都连接起来,最后进行误差反向传递的时候,一次性将所有Variable里面的梯度都计算出来,而tensor就没有这个能力。


v_out.backward() # 模拟 v_out 的误差反向传递

print(variable.grad) # 初始 Variable 的梯度
'''
0.5000 1.0000
1.5000 2.0000
'''

三、获取Variable里面的数据

直接print(Variable) 只会输出Variable形式的数据,在很多时候是用不了的。所以需要转换一下,将其变成tensor形式。


print(variable)  # Variable 形式
"""
Variable containing:
1 2
3 4
[torch.FloatTensor of size 2x2]
"""

print(variable.data) # 将variable形式转为tensor 形式
"""
1 2
3 4
[torch.FloatTensor of size 2x2]
"""

print(variable.data.numpy()) # numpy 形式
"""
[[ 1. 2.]
[ 3. 4.]]
"""

扩展

在PyTorch中计算图的特点总结如下:

autograd根据用户对Variable的操作来构建其计算图。

1、requires_grad

variable默认是不需要被求导的,即requires_grad属性默认为False,如果某一个节点的requires_grad为True,那么所有依赖它的节点requires_grad都为True。

2、volatile

variable的volatile属性默认为False,如果某一个variable的volatile属性被设为True,那么所有依赖它的节点volatile属性都为True。volatile属性为True的节点不会求导,volatile的优先级比requires_grad高。

3、retain_graph

多次反向传播(多层监督)时,梯度是累加的。一般来说,单次反向传播后,计算图会free掉,也就是反向传播的中间缓存会被清空【这就是动态度的特点】。为进行多次反向传播需指定retain_graph=True来保存这些缓存。

4、backward()

反向传播,求解Variable的梯度。放在中间缓存中。

来源:https://blog.csdn.net/qq_19329785/article/details/85029116

标签:PyTorch,Variable,变量
0
投稿

猜你喜欢

  • Python 读写 Matlab Mat 格式数据的操作

    2023-08-23 01:21:12
  • 基于Python实现一键找出磁盘里所有猫照

    2022-06-20 03:52:32
  • python3代码输出嵌套式对象实例详解

    2021-09-16 07:35:55
  • 解决pytorch下只打印tensor的数值不打印出device等信息的问题

    2023-04-20 18:25:52
  • Python 聊聊socket中的listen()参数(数字)到底代表什么

    2022-10-17 00:49:25
  • MYSQL建立外键失败几种情况记录Can't create table不能创建表

    2024-01-22 19:57:22
  • SQL Server 2005通用分页存储过程及多表联接应用

    2024-01-13 22:39:31
  • Python爬虫番外篇之Cookie和Session详解

    2022-02-09 18:56:44
  • PHP魔术方法__ISSET、__UNSET使用实例

    2024-05-22 10:09:08
  • SQL Server 数据库实用SQL语句

    2024-01-14 00:10:40
  • python为什么会环境变量设置不成功

    2023-01-18 04:33:08
  • SQL Server2000的安全策略

    2007-08-06 17:14:00
  • js金额浮点格式化控件

    2008-08-01 16:52:00
  • vue中echarts的用法及与elementui-select的协同绑定操作

    2024-05-10 14:20:13
  • 如何利用Golang解析读取Mysql备份文件

    2024-01-28 20:51:19
  • 可插入图片的TEXT文本框

    2024-02-25 20:07:36
  • typecho统计博客所有文章的字数实例详解

    2023-06-13 07:52:36
  • 基于logstash实现日志文件同步elasticsearch

    2023-09-01 14:45:57
  • 从trim原型函数看js正则表达式的性能

    2008-12-11 13:55:00
  • python适合人工智能的理由和优势

    2021-08-10 11:01:12
  • asp之家 网络编程 m.aspxhome.com