pytorch中的广播语义

作者:机器学习入坑者 时间:2023-04-22 15:16:36 

pytorch的广播语义(broadcasting semantics),和numpy的很像,所以可以先看看numpy的文档:

1、什么是广播语义?

官方文档有这样一个解释:

In short, if a PyTorch operation supports broadcast, then its Tensor arguments can be automatically expanded to be of equal sizes (without making copies of the data).

这句话的意思大概是:简单的说,如果一个pytorch操作支持广播,那么它的Tensor参数可以自动的扩展为相同的尺寸(不需要复制数据)。

按照我的理解,应该是指算法计算过程中,不同的Tensor如果size不同,但是符合一定的规则,那么可以自动的进行维度扩展,来实现Tensor的计算。在维度扩展的过程中,并不是真的把维度小的Tensor复制为和维度大的Tensor相同,因为这样太浪费内存了。

2、广播语义的规则

首先来看标准的情况,两个Tensor的size相同,则可以直接计算:

x = torch.empty((4, 2, 3))
y = torch.empty((4, 2, 3)) 
print((x+y).size()) 

输出:

torch.Size([4, 2, 3]) 

但是,如果两个Tensor的维度并不相同,pytorch也是可以根据下面的两个法则进行计算:

  • (1)Each tensor has at least one dimension.

  • (2)When iterating over the dimension sizes, starting at the trailing dimension, the dimension sizes must either be equal, one of them is 1, or one of them does not exist.

  • 每个Tensor至少有一个维度。

  • 迭代标注尺寸时,从后面的标注开始

第一个规则要求每个参与计算的Tensor至少有一个维度,第二个规则是指在维度迭代时,从最后一个维度开始,可以有三种情况:

  • 维度相等

  • 其中一个维度是1

  • 其中一个维度不存在

3、不符合广播语义的例子

x = torch.empty((0, ))
y = torch.empty((2, 3)) 
print((x + y).size())

输出:

RuntimeError: The size of tensor a (0) must match  the size of tensor b (3) at non-singleton dimension 1 

这里,不满足第一个规则“每个参与计算的Tensor至少有一个维度”。

x = torch.empty(5, 2, 4, 1) 
y = torch.empty(3, 1, 1) 
print((x + y).size())

输出:

RuntimeError: The size of tensor a (2) must match 
the size of tensor b (3) at non-singleton dimension 1 

这里,不满足第二个规则,因为从最后的维度开始迭代的过程中,倒数第三个维度:x是2,y是3。这并不符合第二条规则的三种情况,所以不能使用广播语义。

4、符合广播语义的例子

x = torch.empty(5, 3, 4, 1) 
y = torch.empty(3, 1, 1) 
print((x + y).size()) 

输出:

torch.Size([5, 3, 4, 1]) 

x是四维的,y是三维的,从最后一个维度开始迭代:

  • 最后一维:x是1,y是1,满足规则二 

  • 倒数第二维:x是4,y是1,满足规则二 

  • 倒数第三维:x是3,y是3,满足规则一

  • 倒数第四维:x是5,y是0,满足规则一 

来源:https://zhuanlan.zhihu.com/p/338298069

标签:pytorch,广播,语义
0
投稿

猜你喜欢

  • Django中间件工作流程及写法实例代码

    2021-09-01 17:47:38
  • python使用selenium模拟浏览器进入好友QQ空间留言功能

    2021-06-24 16:24:16
  • Python实现选择排序

    2021-06-17 03:23:40
  • text-indent 隐藏文字时出现的 outline问题

    2007-12-02 17:31:00
  • MySQL数据库性能优化之表结构优化

    2012-05-08 07:10:34
  • 实现php删除链表中重复的结点

    2023-09-05 09:36:15
  • 利用Matlab绘制各类特殊图形的实例代码

    2021-01-05 21:10:31
  • Python socket模块方法实现详解

    2021-02-12 20:52:39
  • php 常用算法和时间复杂度

    2023-11-05 10:30:49
  • 详解python 利用echarts画地图(热力图)(世界地图,省市地图,区县地图)

    2021-10-24 06:59:30
  • PHP登录(ajax提交数据和后台校验)实例分享

    2024-04-28 09:43:41
  • JavaScript中神奇的call()方法

    2024-04-30 09:52:41
  • 使用python库xlsxwriter库来输出各种xlsx文件的示例

    2022-04-27 14:50:30
  • MySQL 管理

    2024-01-13 14:42:43
  • PHP中SimpleXML函数用法分析

    2023-06-23 11:52:09
  • python Django模板的使用方法(图文)

    2022-03-30 04:23:52
  • 详解如何在微信小程序开发中正确的使用vant ui组件

    2024-05-25 15:18:33
  • Python中requests、aiohttp、httpx性能比拼

    2023-10-17 05:27:26
  • Python在游戏中的热更新实现

    2022-04-05 14:10:15
  • win10系统下Anaconda3安装配置方法图文教程

    2022-08-06 23:01:49
  • asp之家 网络编程 m.aspxhome.com