PyTorch中torch.matmul()函数常见用法总结

作者:wendy_ya 时间:2023-03-28 16:01:31 

一、函数介绍

pytorch中两个张量的乘法可以分为两种:

  • 两个张量对应元素相乘,在PyTorch中可以通过torch.mul函数(或*运算符)实现;

  • 两个张量矩阵相乘,在PyTorch中可以通过torch.matmul函数实现;

torch.matmul(input, other) → Tensor
计算两个张量input和other的矩阵乘积
【注意】:matmul函数没有强制规定维度和大小,可以用利用广播机制进行不同维度的相乘操作。

二、常见用法

torch.matmul()也是一种类似于矩阵相乘操作的tensor连乘操作。但是它可以利用python中的广播机制,处理一些维度不同的tensor结构进行相乘操作。这也是该函数与torch.bmm()区别所在。

2.1 两个一维向量的乘积运算

若两个tensor都是一维的,则返回两个向量的点积运算结果:

import torch
x = torch.tensor([1,2])
y = torch.tensor([3,4])
print(x,y)
print(torch.matmul(x,y),torch.matmul(x,y).size())

运行结果:

tensor([1, 2]) tensor([3, 4])
tensor(11) torch.Size([])

PyTorch中torch.matmul()函数常见用法总结

2.2 两个二维矩阵的乘积运算

若两个tensor都是二维的,则返回两个矩阵的矩阵相乘结果:

import torch
x = torch.tensor([[1,2],[3,4]])
y = torch.tensor([[5,6,7],[8,9,10]])
print(torch.matmul(x,y),torch.matmul(x,y).size())

运行结果:

tensor([[21, 24, 27],[47, 54, 61]]) torch.Size([2, 3])

PyTorch中torch.matmul()函数常见用法总结

2.3 一个一维向量和一个二维矩阵的乘积运算

若input为一维,other为二维,则先将input的一维向量扩充到二维(维数前面插入长度为1的新维度),然后进行矩阵乘积,得到结果后再将此维度去掉,得到的与input的维度相同。

import torch
x = torch.tensor([1,2])
y = torch.tensor([[5,6,7],[8,9,10]])
print(torch.matmul(x,y),torch.matmul(x,y).size())

运行结果:

tensor([21, 24, 27]) torch.Size([3])

【分析】:首先将x维度从(2)扩充为(,2),然后将x(,2) 与y(2,3)进行相乘,得到(,3),最后去掉一维部分,得到(3)

PyTorch中torch.matmul()函数常见用法总结

2.4 一个二维矩阵和一个一维向量的乘积运算

若input为二维,other为一维,则先将other的一维向量扩充到二维(维数后面插入长度为1的新维度),然后进行矩阵乘积,得到结果后再将此维度去掉,得到的与other的维度相同。

import torch
x = torch.tensor([[1,2,3],[4,5,6]])
y = torch.tensor([7,8,9])
print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())

运行结果:

tensor([ 50, 122])
torch.Size([2])

【分析】:首先y维度从(3)扩充为(3,),然后将x(2,3)与x(2,)进行相乘,得到(2,),最后去掉一维部分,得到(2)

【总结】:2.3和2.4基本类似,唯一不同的是2.3中一维向量和二维矩阵的乘积运算需要在一维向量前面插入长度为1的新维度(x为一维向量,y为二维矩阵);2.4中二维矩阵和一维向量的乘积运算需要在一维向量后面插入长度为1的新维度(x为二维矩阵,y为一维向量)。

2.5 其他

其他的暂时用不上,有需要的可以自行查阅相关资料~

参考:https://cloud.tencent.com/developer/article/1802317

来源:https://blog.csdn.net/didi_ya/article/details/121158666

标签:PyTorch,torch.matmul()
0
投稿

猜你喜欢

  • opencv-python图像增强解读

    2022-10-10 04:29:16
  • python列表插入append(), extend(), insert()用法详解

    2021-05-12 13:32:40
  • Python Socket实现远程木马弹窗详解

    2022-11-28 10:04:39
  • python通过索引遍历列表的方法

    2021-07-06 06:06:53
  • Python利用Turtle绘画简单图形

    2021-12-02 04:59:21
  • python实现从字典中删除元素的方法

    2023-11-10 17:26:33
  • 百度工程师讲PHP函数的实现原理及性能分析(三)

    2023-10-20 01:33:03
  • Python编写memcached启动脚本代码实例

    2023-02-13 19:59:51
  • python Dijkstra算法实现最短路径问题的方法

    2022-02-21 03:08:51
  • Python+Qt身体特征识别人数统计源码窗体程序(使用步骤)

    2021-06-03 10:40:54
  • 用于分页的两个Asp函数

    2007-09-07 10:09:00
  • OpenCV图像变换之傅里叶变换的一些应用

    2023-12-01 22:11:34
  • 09七夕节各大搜索引擎LOGO欣赏

    2009-08-27 15:34:00
  • 关于python字符串方法分类详解

    2023-12-30 22:51:44
  • Oracle判断表、列、主键是否存在的方法

    2023-07-22 19:13:06
  • 在漏洞利用Python代码真的很爽

    2023-11-24 15:57:29
  • 零基础使用Python读写处理Excel表格的方法

    2021-01-02 13:33:07
  • 用python实现前向分词最大匹配算法的示例代码

    2023-12-03 00:18:46
  • python+requests接口压力测试500次,查看响应时间的实例

    2021-09-29 08:27:56
  • em与px的区别以及em特点和应用

    2008-11-11 12:03:00
  • asp之家 网络编程 m.aspxhome.com