pytorch中index_select()的用法详解

作者:g_blink 时间:2022-01-20 19:44:05 

pytorch中index_select()的用法


index_select(input, dim, index)

功能:在指定的维度dim上选取数据,不如选取某些行,列

参数介绍

  • 第一个参数input是要索引查找的对象

  • 第二个参数dim是要查找的维度,因为通常情况下我们使用的都是二维张量,所以可以简单的记忆: 0代表行,1代表列

  • 第三个参数index是你要索引的序列,它是一个tensor对象

刚开始学习pytorch,遇到了index_select(),一开始不太明白几个参数的意思,后来查了一下资料,算是明白了一点。


a = torch.linspace(1, 12, steps=12).view(3, 4)
print(a)
b = torch.index_select(a, 0, torch.tensor([0, 2]))
print(b)
print(a.index_select(0, torch.tensor([0, 2])))
c = torch.index_select(a, 1, torch.tensor([1, 3]))
print(c)

先定义了一个tensor,这里用到了linspace和view方法。

第一个参数是索引的对象,第二个参数0表示按行索引,1表示按列进行索引,第三个参数是一个tensor,就是索引的序号,比如b里面tensor[0, 2]表示第0行和第2行,c里面tensor[1, 3]表示第1列和第3列。

输出结果如下:

tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.]])
tensor([[ 1.,  2.,  3.,  4.],
        [ 9., 10., 11., 12.]])
tensor([[ 1.,  2.,  3.,  4.],
        [ 9., 10., 11., 12.]])
tensor([[ 2.,  4.],
        [ 6.,  8.],
        [10., 12.]])

示例2 


import torch

x = torch.Tensor([[[1, 2, 3],
         [4, 5, 6]],

[[9, 8, 7],
         [6, 5, 4]]])
print(x)
print(x.size())
index = torch.LongTensor([0, 0, 1])
print(torch.index_select(x, 0, index))
print(torch.index_select(x, 0, index).size())
print(torch.index_select(x, 1, index))
print(torch.index_select(x, 1, index).size())
print(torch.index_select(x, 2, index))
print(torch.index_select(x, 2, index).size())

input的张量形状为2×2×3,index为[0, 0, 1]的向量

分别从0、1、2三个维度来使用index_select()函数,并输出结果和形状,维度大于2就会报错因为input最大只有三个维度

输出:

tensor([[[1., 2., 3.],
         [4., 5., 6.]],
 
        [[9., 8., 7.],
         [6., 5., 4.]]])
torch.Size([2, 2, 3])
tensor([[[1., 2., 3.],
         [4., 5., 6.]],
 
        [[1., 2., 3.],
         [4., 5., 6.]],
 
        [[9., 8., 7.],
         [6., 5., 4.]]])
torch.Size([3, 2, 3])
tensor([[[1., 2., 3.],
         [1., 2., 3.],
         [4., 5., 6.]],
 
        [[9., 8., 7.],
         [9., 8., 7.],
         [6., 5., 4.]]])
torch.Size([2, 3, 3])
tensor([[[1., 1., 2.],
         [4., 4., 5.]],
 
        [[9., 9., 8.],
         [6., 6., 5.]]])
torch.Size([2, 2, 3])

对结果进行分析:

index是大小为3的向量,输入的张量形状为2×2×3

dim = 0时,输出的张量形状为3×2×3

dim = 1时,输出的张量形状为2×3×3

dim = 2时,输出的张量形状为2×2×3

注意输出张量维度的变化与index大小的关系,结合输出的张量与原始张量来分析index_select()函数的作用

来源:https://blog.csdn.net/g_blink/article/details/102854188

标签:pytorch,index,select()
0
投稿

猜你喜欢

  • Python多进程fork()函数详解

    2023-06-08 19:41:37
  • Oracle数据库编写有效事务指导方针

    2009-03-19 17:41:00
  • Python自动化之批量处理工作簿和工作表

    2023-02-16 08:07:30
  • ASP 日期的加减运算实现代码

    2011-03-08 10:47:00
  • 在Python中使用SimpleParse模块进行解析的教程

    2021-04-11 12:17:53
  • 在Python中操作文件之seek()方法的使用教程

    2023-08-01 14:58:01
  • asp中文件与文件夹常用处理函数(文件后缀、创建文件等)

    2011-02-20 11:00:00
  • SQL Server 2000安装图解教程

    2009-09-09 19:59:00
  • ASP获取网页内容(解决乱码问题)

    2009-07-26 10:44:00
  • python 利用turtle模块画出没有角的方格

    2022-03-09 04:25:04
  • Oracle数据表分区的策略

    2010-07-28 12:59:00
  • Django与DRF结合的全局异常处理方案详解

    2021-05-19 22:53:16
  • 用IE浏览器UTF-8页面是一片空白

    2009-06-14 19:55:00
  • Python教程之类型转换详解

    2021-03-23 02:48:17
  • 客齐集社区头像显示效果代码

    2008-04-03 13:15:00
  • 从 msxml6.dll 中获取 DOMDocument 对象的方法与属性

    2009-02-22 18:46:00
  • 使用typescript快速开发一个cli的实现示例

    2023-08-30 07:25:25
  • 分析SQL Server中数据库的快照工作原理

    2009-01-19 14:03:00
  • ACCESS入门教程:用向导建立数据库

    2008-01-17 12:46:00
  • 为你的有序列表添加个性样式

    2009-02-23 13:12:00
  • asp之家 网络编程 m.aspxhome.com