pytorch中的squeeze函数、cat函数使用

作者:zhuanse 时间:2022-03-27 14:32:24 

1 squeeze(): 去除size为1的维度,包括行和列。

至于维度大于等于2时,squeeze()不起作用。

行、例:


>>> torch.rand(4, 1, 3)

(0 ,.,.) =
 0.5391  0.8523  0.9260

(1 ,.,.) =
 0.2507  0.9512  0.6578

(2 ,.,.) =
 0.7302  0.3531  0.9442

(3 ,.,.) =
 0.2689  0.4367  0.6610
[torch.FloatTensor of size 4x1x3]

>>> torch.rand(4, 1, 3).squeeze()

0.0801  0.4600  0.1799
0.0236  0.7137  0.6128
0.0242  0.3847  0.4546
0.9004  0.5018  0.4021
[torch.FloatTensor of size 4x3]

列、例:


>>> torch.rand(4, 3, 1)

(0 ,.,.) =
 0.7013
 0.9818
 0.9723

(1 ,.,.) =
 0.9902
 0.8354
 0.3864

(2 ,.,.) =
 0.4620
 0.0844
 0.5707

(3 ,.,.) =
 0.5722
 0.2494
 0.5815
[torch.FloatTensor of size 4x3x1]

>>> torch.rand(4, 3, 1).squeeze()

0.8784  0.6203  0.8213
0.7238  0.5447  0.8253
0.1719  0.7830  0.1046
0.0233  0.9771  0.2278
[torch.FloatTensor of size 4x3]

不变、例:


>>> torch.rand(4, 3, 2)

(0 ,.,.) =
 0.6618  0.1678
 0.3476  0.0329
 0.1865  0.4349

(1 ,.,.) =
 0.7588  0.8972
 0.3339  0.8376
 0.6289  0.9456

(2 ,.,.) =
 0.1392  0.0320
 0.0033  0.0187
 0.8229  0.0005

(3 ,.,.) =
 0.2327  0.6264
 0.4810  0.6642
 0.8625  0.6334
[torch.FloatTensor of size 4x3x2]

>>> torch.rand(4, 3, 2).squeeze()

(0 ,.,.) =
 0.0593  0.8910
 0.9779  0.1530
 0.9210  0.2248

(1 ,.,.) =
 0.7938  0.9362
 0.1064  0.6630
 0.9321  0.0453

(2 ,.,.) =
 0.0189  0.9187
 0.4458  0.9925
 0.9928  0.7895

(3 ,.,.) =
 0.5116  0.7253
 0.0132  0.6673
 0.9410  0.8159
[torch.FloatTensor of size 4x3x2]

2 cat函数


>>> t1=torch.FloatTensor(torch.randn(2,3))
>>> t1

-1.9405  1.2009  0.0018
0.9463  0.4409 -1.9017
[torch.FloatTensor of size 2x3]

>>> t2=torch.FloatTensor(torch.randn(2,2))
>>> t2

0.0942  0.1581
1.1621  1.2617
[torch.FloatTensor of size 2x2]

>>> torch.cat((t1, t2), 1)

-1.9405  1.2009  0.0018  0.0942  0.1581
0.9463  0.4409 -1.9017  1.1621  1.2617
[torch.FloatTensor of size 2x5]

补充:pytorch中 max()、view()、 squeeze()、 unsqueeze()

查了好多博客都似懂非懂,后来写了几个小例子,瞬间一目了然。

一、torch.max()


import torch  
a=torch.randn(3)
print("a:\n",a)
print('max(a):',torch.max(a))

b=torch.randn(3,4)
print("b:\n",b)
print('max(b,0):',torch.max(b,0))
print('max(b,1):',torch.max(b,1))

输出:

a:
tensor([ 0.9558, 1.1242, 1.9503])
max(a): tensor(1.9503)
b:
tensor([[ 0.2765, 0.0726, -0.7753, 1.5334],
[ 0.0201, -0.0005, 0.2616, -1.1912],
[-0.6225, 0.6477, 0.8259, 0.3526]])
max(b,0): (tensor([ 0.2765, 0.6477, 0.8259, 1.5334]), tensor([ 0, 2, 2, 0]))
max(b,1): (tensor([ 1.5334, 0.2616, 0.8259]), tensor([ 3, 2, 2]))

max(a),用于一维数据,求出最大值。

max(a,0),计算出数据中一列的最大值,并输出最大值所在的行号。

max(a,1),计算出数据中一行的最大值,并输出最大值所在的列号。


print('max(b,1):',torch.max(b,1)[1])

输出:只输出行最大值所在的列号


max(b,1): tensor([ 3,  2,  2])

torch.max(b,1)[0], 只返回最大值的每个数

二、view()

a.view(i,j)表示将原矩阵转化为i行j列的形式

i为-1表示不限制行数,输出1列


a=torch.randn(3,4)
print(a)

输出:

tensor([[-0.8146, -0.6592, 1.5100, 0.7615],
[ 1.3021, 1.8362, -0.3590, 0.3028],
[ 0.0848, 0.7700, 1.0572, 0.6383]])

b=a.view(-1,1)
print(b)

输出:

tensor([[-0.8146],
[-0.6592],
[ 1.5100],
[ 0.7615],
[ 1.3021],
[ 1.8362],
[-0.3590],
[ 0.3028],
[ 0.0848],
[ 0.7700],
[ 1.0572],
[ 0.6383]])

i为1,j为-1表示不限制列数,输出1行


b=a.view(1,-1)
print(b)

输出:

tensor([[-0.8146, -0.6592, 1.5100, 0.7615, 1.3021, 1.8362, -0.3590,
0.3028, 0.0848, 0.7700, 1.0572, 0.6383]])

i为-1,j为2表示不限制行数,输出2列


b=a.view(-1,2)
print(b)

输出:

tensor([[-0.8146, -0.6592],
[ 1.5100, 0.7615],
[ 1.3021, 1.8362],
[-0.3590, 0.3028],
[ 0.0848, 0.7700],
[ 1.0572, 0.6383]])

i为-1,j为3表示不限制行数,输出3列

i为4,j为3表示输出4行3列


b=a.view(-1,3)
print(b)
b=a.view(4,3)
print(b)

输出:

tensor([[-0.8146, -0.6592, 1.5100],
[ 0.7615, 1.3021, 1.8362],
[-0.3590, 0.3028, 0.0848],
[ 0.7700, 1.0572, 0.6383]])
tensor([[-0.8146, -0.6592, 1.5100],
[ 0.7615, 1.3021, 1.8362],
[-0.3590, 0.3028, 0.0848],
[ 0.7700, 1.0572, 0.6383]])

三、

1.torch.squeeze()

压缩矩阵,我理解为降维

a.squeeze(i) 压缩第i维,如果这一维维数是1,则这一维可有可无,便可以压缩


import torch  
a=torch.randn(1,3,4)
print(a)
b=a.squeeze(0)
print(b)
c=a.squeeze(1)
print(c

输出:

tensor([[[ 0.4627, 1.6447, 0.1320, 2.0946],
[-0.0080, 0.1794, 1.1898, -1.2525],
[ 0.8281, -0.8166, 1.8846, 0.9008]]])

一页三行4列的矩阵

第0维为1,则可以通过squeeze(0)删掉,转化为三行4列的矩阵

tensor([[ 0.4627, 1.6447, 0.1320, 2.0946],
[-0.0080, 0.1794, 1.1898, -1.2525],
[ 0.8281, -0.8166, 1.8846, 0.9008]])

第1维不为1,则不可以压缩

tensor([[[ 0.4627, 1.6447, 0.1320, 2.0946],
[-0.0080, 0.1794, 1.1898, -1.2525],
[ 0.8281, -0.8166, 1.8846, 0.9008]]])

2.torch.unsqueeze()

unsqueeze(i) 表示将第i维设置为1

对压缩为3行4列后的矩阵b进行操作,将第0维设置为1


c=b.unsqueeze(0)
print(c)

输出一个一页三行四列的矩阵

tensor([[[ 0.0661, -0.2386, -0.6610, 1.5774],
[ 1.2210, -0.1084, -0.1166, -0.2379],
[-1.0012, -0.4363, 1.0057, -1.5180]]])

将第一维设置为1


c=b.unsqueeze(1)
print(c)

输出一个3页,一行,4列的矩阵

tensor([[[-1.0067, -1.1477, -0.3213, -1.0633]],
[[-2.3976, 0.9857, -0.3462, -0.3648]],
[[ 1.1012, -0.4659, -0.0858, 1.6631]]])

另外,squeeze、unsqueeze操作不改变原矩阵

来源:https://blog.csdn.net/abc781cba/article/details/79663190

标签:pytorch,squeeze,cat
0
投稿

猜你喜欢

  • 如何使用flask将模型部署为服务

    2021-11-11 06:02:48
  • Oracle对两个数据表交集的查询

    2010-07-26 12:51:00
  • 在ASP.NET 2.0中操作数据之三十九:在DataList的编辑界面里添加验证控件

    2023-07-06 02:02:48
  • python字符串替换的2种方法

    2022-12-27 20:59:24
  • Python深度学习TensorFlow神经网络基础概括

    2022-08-13 02:57:18
  • python实现输入三角形边长自动作图求面积案例

    2023-08-12 01:47:49
  • 彻底弄懂CSS盒子模式之二(导航栏实例)

    2007-05-11 16:52:00
  • 用python分割TXT文件成4K的TXT文件

    2022-06-27 02:12:44
  • Python中tell()方法的使用详解

    2021-06-29 16:21:59
  • 使用numpy对数组求平均时如何忽略nan值

    2023-09-19 20:37:24
  • python pprint模块中print()和pprint()两者的区别

    2023-10-18 07:34:18
  • js实现屏蔽默认快捷键调用自定义事件示例

    2023-09-05 09:28:31
  • Go实现一个配置包详解

    2024-05-22 10:29:57
  • vue不通过路由直接获取url中参数的方法示例

    2024-04-30 08:41:06
  • django实现前后台交互实例

    2022-04-12 20:53:33
  • Oracle 中文字段进行排序的sql语句

    2024-01-22 13:26:43
  • Django-Scrapy生成后端json接口的方法示例

    2021-07-16 18:46:46
  • pytorch 批次遍历数据集打印数据的例子

    2022-06-09 08:23:46
  • DW表格应用之细线框的制作

    2008-02-03 19:00:00
  • Python+OpenCV绘制灰度直方图详解

    2023-06-09 18:50:50
  • asp之家 网络编程 m.aspxhome.com