Pytorch中torch.unsqueeze()与torch.squeeze()函数详细解析

作者:cv_lhp 时间:2022-05-22 00:41:57 

一. torch.squeeze()函数解析

1. 官网链接

torch.squeeze(),如下图所示:

Pytorch中torch.unsqueeze()与torch.squeeze()函数详细解析

2. torch.squeeze()函数解析

torch.squeeze(input, dim=None, out=None)

squeeze()函数的功能是维度压缩。返回一个tensor(张量),其中 input 中维度大小为1的所有维都已删除。

举个例子:如果 input 的形状为 (A×1×B×C×1×D),那么返回的tensor的形状则为 (A×B×C×D)

当给定 dim 时,那么只在给定的维度(dimension)上进行压缩操作,注意给定的维度大小必须是1,否则不能进行压缩。

举个例子:如果 input 的形状为 (A×1×B),squeeze(input, dim=0)后,返回的tensor不变,因为第0维的大小为A,不是1;squeeze(input, 1)后,返回的tensor将被压缩为 (A×B)。

3. 代码举例

3.1 输入size=(2, 1, 2, 1, 2)的张量

x = torch.randn(size=(2, 1, 2, 1, 2))
x.shape

输出结果如下:
torch.Size([2, 1, 2, 1, 2])

3.2 把x中维度大小为1的所有维都已删除

y = torch.squeeze(x)#表示把x中维度大小为1的所有维都已删除
y.shape

输出结果如下:
torch.Size([2, 2, 2])

3.3 把x中第一维删除,但是第一维大小为2,不为1,因此结果删除不掉

y = torch.squeeze(x,0)#表示把x中第一维删除,但是第一维大小为2,不为1,因此结果删除不掉
y.shape

输出结果如下:
torch.Size([2, 1, 2, 1, 2])

3.4 把x中第二维删除,因为第二维大小是1,因此可以删掉

y = torch.squeeze(x,1)#表示把x中第二维删除,因为第二维大小是1,因此可以删掉
y.shape

输出结果如下:
torch.Size([2, 2, 1, 2])

3.5 把x中最后一维删除,但是最后一维大小为2,不为1,因此结果删除不掉

y = torch.squeeze(x,dim=-1)#表示把x中最后一维删除,但是最后一维大小为2,不为1,因此结果删除不掉
y.shape

输出结果如下:
torch.Size([2, 1, 2, 1, 2])

二.torch.unsqueeze()函数解析

1. 官网链接

torch.unsqueeze(),如下图所示:

Pytorch中torch.unsqueeze()与torch.squeeze()函数详细解析

2. torch.unsqueeze()函数解析

torch.unsqueeze(input, dim) → Tensor

unsqueeze()函数起升维的作用,参数dim表示在哪个地方加一个维度,注意dim范围在:[-input.dim() - 1, input.dim() + 1]之间,比如输入input是一维,则dim=0时数据为行方向扩,dim=1时为列方向扩,再大错误。

3. 代码举例

3.1 输入一维张量,在第0维(行)扩展,第0维大小为1

x = torch.tensor([1, 2, 3, 4])
y = torch.unsqueeze(x, 0)#在第0维扩展,第0维大小为1
y,y.shape

输出结果如下:
(tensor([[1, 2, 3, 4]]), torch.Size([1, 4]))

3.2 在第1维(列)扩展,第1维大小为1

y = torch.unsqueeze(x, 1)#在第1维扩展,第1维大小为1
y,y.shape

输出结果如下:
(tensor([[1],
         [2],
         [3],
         [4]]),
 torch.Size([4, 1]))

3.3 在第最后一维(也就是倒数第一维进行)扩展,最后一维大小为1

y = torch.unsqueeze(x, -1)#在第最后一维扩展,最后一维大小为1
y,y.shape

输出结果如下:
(tensor([[1],
         [2],
         [3],
         [4]]),
 torch.Size([4, 1]))

来源:https://blog.csdn.net/flyingluohaipeng/article/details/125092937

标签:torch.unsqueeze(),torch.squeeze(),函数
0
投稿

猜你喜欢

  • MySQL decimal unsigned更新负数转化为0

    2024-01-14 20:59:36
  • Go语言resty http包调用jenkins api实例

    2024-05-21 10:27:27
  • python去掉字符串中重复字符的方法

    2022-11-23 09:17:35
  • python 写入csv乱码问题解决方法

    2021-11-13 11:32:22
  • Python 私有函数的实例详解

    2023-03-07 08:30:40
  • pandas dataframe drop函数介绍

    2023-07-11 17:19:17
  • ORACLE8的分区管理

    2010-07-30 13:18:00
  • SqlServer 表连接教程(问题解析)

    2024-01-27 00:35:55
  • 微信小程序分包操作实战指南

    2024-04-16 08:47:57
  • python 利用openpyxl读取Excel表格中指定的行或列教程

    2022-08-06 21:22:54
  • Python操作word文档插入图片和表格的实例演示

    2023-09-20 08:21:09
  • python批量替换多文件字符串问题详解

    2023-05-08 23:48:06
  • JavaScript实现动态时钟效果

    2024-04-16 10:27:04
  • Go语言开发redis封装及简单使用详解

    2024-05-08 10:53:30
  • mysql定时任务(event事件)实现详解

    2024-01-25 13:22:18
  • 基于Django的乐观锁与悲观锁解决订单并发问题详解

    2021-07-14 19:42:08
  • Python 如何将integer转化为罗马数(3999以内)

    2023-01-19 12:46:51
  • CSS网页布局开发时的常见问题小结

    2009-08-13 12:17:00
  • python类中super()和__init__()的区别

    2021-04-17 16:03:02
  • vue中jsonp插件的使用方法示例

    2024-05-05 09:11:52
  • asp之家 网络编程 m.aspxhome.com