pytorch中的广播语义
作者:机器学习入坑者 时间:2023-04-22 15:16:36
pytorch的广播语义(broadcasting semantics),和numpy的很像,所以可以先看看numpy的文档:
1、什么是广播语义?
官方文档有这样一个解释:
In short, if a PyTorch operation supports broadcast, then its Tensor arguments can be automatically expanded to be of equal sizes (without making copies of the data).
这句话的意思大概是:简单的说,如果一个pytorch操作支持广播,那么它的Tensor参数可以自动的扩展为相同的尺寸(不需要复制数据)。
按照我的理解,应该是指算法计算过程中,不同的Tensor如果size
不同,但是符合一定的规则,那么可以自动的进行维度扩展,来实现Tensor
的计算。在维度扩展的过程中,并不是真的把维度小的Tensor复制为和维度大的Tensor相同,因为这样太浪费内存了。
2、广播语义的规则
首先来看标准的情况,两个Tensor的size相同,则可以直接计算:
x = torch.empty((4, 2, 3))
y = torch.empty((4, 2, 3))
print((x+y).size())
输出:
torch.Size([4, 2, 3])
但是,如果两个Tensor
的维度并不相同,pytorch也是可以根据下面的两个法则进行计算:
(1)Each tensor has at least one dimension.
(2)When iterating over the dimension sizes, starting at the trailing dimension, the dimension sizes must either be equal, one of them is 1, or one of them does not exist.
每个
Tensor
至少有一个维度。迭代标注尺寸时,从后面的标注开始
第一个规则要求每个参与计算的Tensor
至少有一个维度,第二个规则是指在维度迭代时,从最后一个维度开始,可以有三种情况:
维度相等
其中一个维度是1
其中一个维度不存在
3、不符合广播语义的例子
x = torch.empty((0, ))
y = torch.empty((2, 3))
print((x + y).size())
输出:
RuntimeError: The size of tensor a (0) must match the size of tensor b (3) at non-singleton dimension 1
这里,不满足第一个规则“每个参与计算的Tensor
至少有一个维度”。
x = torch.empty(5, 2, 4, 1)
y = torch.empty(3, 1, 1)
print((x + y).size())
输出:
RuntimeError: The size of tensor a (2) must match
the size of tensor b (3) at non-singleton dimension 1
这里,不满足第二个规则,因为从最后的维度开始迭代的过程中,倒数第三个维度:x是2,y是3。这并不符合第二条规则的三种情况,所以不能使用广播语义。
4、符合广播语义的例子
x = torch.empty(5, 3, 4, 1)
y = torch.empty(3, 1, 1)
print((x + y).size())
输出:
torch.Size([5, 3, 4, 1])
x是四维的,y是三维的,从最后一个维度开始迭代:
最后一维:x是1,y是1,满足规则二
倒数第二维:x是4,y是1,满足规则二
倒数第三维:x是3,y是3,满足规则一
倒数第四维:x是5,y是0,满足规则一
来源:https://zhuanlan.zhihu.com/p/338298069
![](/images/zang.png)
![](/images/jiucuo.png)
猜你喜欢
Django中间件工作流程及写法实例代码
![](https://img.aspxhome.com/file/2023/3/85153_0s.png)
python使用selenium模拟浏览器进入好友QQ空间留言功能
![](https://img.aspxhome.com/file/2023/5/127875_0s.jpg)
Python实现选择排序
text-indent 隐藏文字时出现的 outline问题
MySQL数据库性能优化之表结构优化
实现php删除链表中重复的结点
利用Matlab绘制各类特殊图形的实例代码
![](https://img.aspxhome.com/file/2023/5/102815_0s.png)
Python socket模块方法实现详解
php 常用算法和时间复杂度
详解python 利用echarts画地图(热力图)(世界地图,省市地图,区县地图)
![](https://img.aspxhome.com/file/2023/9/135249_0s.png)
PHP登录(ajax提交数据和后台校验)实例分享
![](https://img.aspxhome.com/file/2023/1/132811_0s.png)
JavaScript中神奇的call()方法
使用python库xlsxwriter库来输出各种xlsx文件的示例
MySQL 管理
PHP中SimpleXML函数用法分析
python Django模板的使用方法(图文)
![](https://img.aspxhome.com/file/2023/9/63889_0s.png)
详解如何在微信小程序开发中正确的使用vant ui组件
![](https://img.aspxhome.com/file/2023/5/123715_0s.png)
Python中requests、aiohttp、httpx性能比拼
Python在游戏中的热更新实现
win10系统下Anaconda3安装配置方法图文教程
![](https://img.aspxhome.com/file/2023/6/82926_0s.png)