Pytorch中的Broadcasting问题

作者:luputo 时间:2022-10-03 06:26:58 

Numpy、Pytorch中的broadcasting

写在前面

自己一直都不清楚numpy、pytorch里面不同维数的向量之间的element wise的计算究竟是按照什么规则来确认维数匹配和不匹配的情况的,比如

>>> b = np.ones((4,5))
>>> a = np.arange(5)
>>> c = a + b
>>> c.shape
(4, 5)
>>> c
array([[1., 2., 3., 4., 5.],
       [1., 2., 3., 4., 5.],
       [1., 2., 3., 4., 5.],
       [1., 2., 3., 4., 5.]])

上面这种情况就会自动让a和b的维数匹配,a加到了b的每一行上

>>> b = np.ones((5,4))
>>> a = np.arange(5)
>>> c = a + b
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: operands could not be broadcast together with shapes (5,) (5,4)

这种情况就无法匹配,此时我们希望的是a能自动加到b的每一列上,但结果看来好像不行

虽然一直存在这种疑惑,但因为平时遇到的各种运算都比较简单,遇到这种不是直接匹配的array的加法第一直觉就是去console里面试一试,报错就换个姿势再试一试,总归问题可以快速地解决,但是最近在写模型的时候,遇到了绕不过去的问题,所以去查了文档,本文就以解决那个问题为目标,来解释清楚pytorch(numpy也是一样)中的broadcasting semantics的问题

问题描述

我有一个数据Tensor,维数是64 &times; 2048 64\times204864&times;2048,现在我想通过对这64 6464个2048 20482048维的向量做attention(也就是做一个加权和)来得到一个2048 20482048维的向量,因为模型的需要,我需要用五组不同的权值向量来计算出五个不同的加权结果,也就是我的计算结果应该是一个5 &times; 2048 5\times 20485&times;2048维的向量,因为在64 6464个向量上加权,所以一组权值向量是64 6464维,五组就是5 &times; 64 5\times 645&times;64维

尝试解决

现在我手头上有两个Tensor,一个是数据Tensor(64 &times; 2048 64\times 204864&times;2048)另一个是权值Tensor(5 &times; 64 5\times 645&times;64),我GAN!直到我写到了这里,我才发现这不是一个矩阵乘法就能解决的问题嘛+_+,当然,我想给自己正名,这里我简化了一下问题所以才发现原来这么容易就解决了,而原来我在写代码的时候因为还要考虑batch_size等问题才云里雾里不知道咋办,还好当时没想出来,所以去查了文档发现了新的东西,然后写文章的时候想到也算是完满了(不然也不会发现自己好涝)

以上都是题外话,现在,我们还是考虑用愚蠢的element wise的方法来解决,好在现在有两种方法可以解决问题,所以我们可以用来相互检验一下,element wise的解决方法就是,我希望这5个64维的权值向量分别和这64个2048维的向量进行element wise的乘法,也就是第一个64维权值向量先对64个2048维向量加权得到一个2048维的向量,然后第二个64维权值向量先对64个2048维向量加权得到一个2048维的向量&hellip;,以此类推总共五个,最终得到五个64 &times; 2048 64&times;204864&times;2048维的向量,然后求和得到最后的5 &times; 2048 5&times;20485&times;2048维的向量

那么按照平常的习惯,我就去先试试pytorch能不能直接地理解我的想法

>>> import torch
>>> bs = 10 # batch_size
>>> x = torch.randn(bs,64,2048)
>>> att = torch.randn(5,64)
>>> out = att * x
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: The size of tensor a (64) must match the size of tensor b (2048) at non-singleton dimension 2

直接乘不行,因为维数是不匹配的,那怎样的维数才算匹配呢?

BROADCASTING SEMANTICS

以下内容主要来源于自官方文档

很多pytorch的运算是支持broadcasting semantics的,而简单来说,如果运算支持broadcast,则参与运算的Tensor会自动进行扩展来使得运算符左右的Tensor维数匹配,而无需人手动地去拷贝其中的某个Tensor,这就类似于我们开头的那个例子

>>> b = np.ones((4,5))
>>> a = np.arange(5)
>>> c = a + b
>>> c.shape
(4, 5)
>>> c
array([[1., 2., 3., 4., 5.],
       [1., 2., 3., 4., 5.],
       [1., 2., 3., 4., 5.],
       [1., 2., 3., 4., 5.]])

我们无需让a的维数和b一样,因为numpy自动帮我们做了

这里的另一个重要的概念是broadcastable,如果两个Tensor是broadcastable的,那么就可以对他俩使用支持broadcast的运算,比如直接加减乘除

而两个向量要是broadcast的话,必须满足以下两个条件

  • 每个tensor至少是一维的

  • 两个tensor的维数从后往前,对应的位置要么是相等的,要么其中一个是1,或者不存在

这是官方的例子解释

>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# 相同维数的tensor一定是broadcastable的

>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# 不是broadcastable的,因为每个tensor维数至少要是1

>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(  3,1,1)
# 是broadcastable的,因为从后往前看,一定要注意是从后往前看!
# 第一个维度都是1,相等,满足第二个条件
# 第二个维度其中有一个是1,满足第二个条件
# 第三个维度都是3,相等,满足第二个条件
# 第四个维度其中有一个不存在,满足第二个条件

# 但是
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(  3,1,1)
# 不是broadcastable的,因为从后往前看第三个维度是不match的 2!=3,且都不是1

如果x和y是broadcastable的,那么结果的tensor的size按照如下的规则计算

  • 如果两者的维度不一样,那么就自动增加1维(也就是unsqueeze)

  • 对于结果的每个维度,它取x和y在那一维上的最大值

官方的例子

>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty(  3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])

>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])

>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

此外,关于broadcast导致的就地(in-place)操作和梯度运算的兼容性等问题,可以自行参考官方文档

解决问题

上面我们看到,要想两个Tensor支持element wise的运算,需要它们是broadcastable的,而要想它们是broadcastable的,就需要它们的维度自后向前逐一匹配,回到我们原来的问题中,我们有两个Tensor x(64 &times; 2048) att(5 &times; 64),为了让它们broadcastable,我们只需要

>>> import torch
>>> bs = 10 # batch_size
>>> x = torch.randn(bs,64,2048)
>>> att = torch.randn(5,64)
>>> x = x.unsqueeze(1)
>>> att = att.view(1,*att.shape,1)
>>> x.shape
torch.Size([10, 1, 64, 2048])
>>> att.shape
torch.Size([1, 5, 64, 1])
>>> out = x * att
>>> out.shape
torch.Size([10, 5, 64, 2048])

最后我们来验证两种方法是否结果相同

>>> import torch
>>> bs = 10 
>>> x = torch.randn(bs,64,2048)
>>> att = torch.randn(5,64)
>>> out1 = torch.matmul(att,x)  # 直接矩阵相乘
>>> out.shape
torch.Size([10, 5, 2048])

>>> x = x.unsqueeze(1)
>>> att = att.view(1,*att.shape,1)
>>> out2 = x * att  # element wise的方法
>>> out2 = out2.sum(dim=2)

>>> test = torch.sum((out1-out2)<0.00001)  # 浮点数有微小的误差
>>> test
tensor(102400)
>>> out1.numel()  # 最后表明两个out向量是相等的
102400

Reference

[1] https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#module-numpy.doc.broadcasting

[2] https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics

来源:https://blog.csdn.net/luo3300612/article/details/100100291

标签:Pytorch,Broadcasting
0
投稿

猜你喜欢

  • 小白学Python之实现OCR识别

    2022-02-12 20:35:48
  • Golang限流库与漏桶和令牌桶的使用介绍

    2024-05-10 13:57:50
  • Python使用scrapy采集时伪装成HTTP/1.1的方法

    2023-07-07 01:28:40
  • 使用标准的表单字段名

    2008-06-30 14:14:00
  • MySQL性能参数详解之Max_connect_errors 使用介绍

    2024-01-21 13:31:36
  • videocapture库制作python视频高速传输程序

    2023-08-22 14:47:48
  • 10个顶级Python实用库推荐

    2023-08-27 17:41:46
  • keras的ImageDataGenerator和flow()的用法说明

    2021-12-12 08:54:57
  • 人工智能学习Pytorch数据集分割及动量示例详解

    2021-04-29 11:28:55
  • Python中集合创建与使用详解

    2022-04-30 05:29:42
  • php cookie中点号(句号)自动转为下划线问题

    2023-09-07 11:05:04
  • SQL Server中字符串函数的用法详解

    2024-01-14 05:42:56
  • 教你如何在Pygame 中移动你的游戏角色

    2022-03-29 10:04:29
  • python实现将html表格转换成CSV文件的方法

    2023-08-25 00:48:41
  • 利用python numpy+matplotlib绘制股票k线图的方法

    2022-12-16 07:21:52
  • 用FrontPage制作缩略图和图片重叠效果

    2007-11-18 14:45:00
  • JavaScript实现切换多张图片

    2024-04-17 09:54:18
  • 对“关于购物车的想法”的一些回复

    2009-03-10 18:15:00
  • MySQL表设计优化与索引 (八)

    2010-10-25 19:46:00
  • Python如何利用%操作符格式化字符串详解

    2022-07-17 14:08:39
  • asp之家 网络编程 m.aspxhome.com