Pytorch中index_select() 函数的实现理解

作者:清晨的光明 时间:2023-11-26 16:24:32 

函数形式:


index_select(
dim,
index
)

参数:

  • dim:表示从第几维挑选数据,类型为int值;

  • index:表示从第一个参数维度中的哪个位置挑选数据,类型为torch.Tensor类的实例;

刚开始学习pytorch,遇到了index_select(),一开始不太明白几个参数的意思,后来查了一下资料,算是明白了一点。


a = torch.linspace(1, 12, steps=12).view(3, 4)
print(a)
b = torch.index_select(a, 0, torch.tensor([0, 2]))
print(b)
print(a.index_select(0, torch.tensor([0, 2])))
c = torch.index_select(a, 1, torch.tensor([1, 3]))
print(c)

先定义了一个tensor,这里用到了linspace和view方法。

第一个参数是索引的对象,第二个参数0表示按行索引,1表示按列进行索引,第三个参数是一个tensor,就是索引的序号,比如b里面tensor[0, 2]表示第0行和第2行,c里面tensor[1, 3]表示第1列和第3列。

输出结果如下:

tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.]])
tensor([[ 1.,  2.,  3.,  4.],
        [ 9., 10., 11., 12.]])
tensor([[ 1.,  2.,  3.,  4.],
        [ 9., 10., 11., 12.]])
tensor([[ 2.,  4.],
        [ 6.,  8.],
        [10., 12.]])

功能:从张量的某个维度的指定位置选取数据。

代码实例:


t = torch.arange(24).reshape(2, 3, 4) # 初始化一个tensor,从0到23,形状为(2,3,4)
print("t--->", t)

index = torch.tensor([1, 2]) # 要选取数据的位置
print("index--->", index)

data1 = t.index_select(1, index) # 第一个参数:从第1维挑选, 第二个参数:从该维中挑选的位置
print("data1--->", data1)

data2 = t.index_select(2, index) # 第一个参数:从第2维挑选, 第二个参数:从该维中挑选的位置
print("data2--->", data2)

运行结果: 

t---> tensor([[[ 0,  1,  2,  3],
               [ 4,  5,  6,  7],
               [ 8,  9, 10, 11]],
 
              [[12, 13, 14, 15],
               [16, 17, 18, 19],
               [20, 21, 22, 23]]])
 
index---> tensor([1, 2])
 
data1---> tensor([[[ 4,  5,  6,  7],
                   [ 8,  9, 10, 11]],
 
                  [[16, 17, 18, 19],
                   [20, 21, 22, 23]]])
 
data2---> tensor([[[ 1,  2],
                   [ 5,  6],
                   [ 9, 10]],
 
                  [[13, 14],
                   [17, 18],
                   [21, 22]]])

来源:https://blog.csdn.net/kdongyi/article/details/103099589

标签:Pytorch,index,select
0
投稿

猜你喜欢

  • ORACLE常见错误代码的分析与解决(三)

    2024-01-25 12:26:01
  • Python基于Pymssql模块实现连接SQL Server数据库的方法详解

    2024-01-15 03:13:17
  • PHP使用flock实现文件加锁的方法

    2023-10-29 21:26:59
  • 最新Pygame zero最全集合

    2022-07-18 13:14:16
  • Django使用Mysql数据库已经存在的数据表方法

    2023-07-21 15:24:59
  • Flask模板引擎之Jinja2语法介绍

    2021-11-15 21:08:11
  • 详解使用vue脚手架工具搭建vue-webpack项目

    2024-05-21 10:29:19
  • Python3环境安装Scrapy爬虫框架过程及常见错误

    2021-10-19 00:01:05
  • Tornado高并发处理方法实例代码

    2022-10-13 15:30:07
  • webpack 打包压缩js和css的方法示例

    2023-07-02 05:18:32
  • 使用python检查yaml配置文件是否符合要求

    2021-06-23 05:27:53
  • Python如何使用PIL Image制作GIF图片

    2023-08-24 22:42:17
  • SQLSERVER聚集索引和主键(Primary Key)的误区认识

    2024-01-14 07:49:56
  • Python写一个简单的api接口的实现

    2023-07-23 20:20:53
  • mysql数据库远程访问设置方法

    2024-01-14 11:25:34
  • Oracle静态注册与动态注册详解

    2024-01-19 22:31:31
  • django实现前后台交互实例

    2022-04-12 20:53:33
  • Go语言的Windows下环境配置以及简单的程序结构讲解

    2023-08-26 16:04:10
  • PHP之CI框架学习讲解

    2023-07-03 21:25:13
  • ASP访问SQL Server内置对象

    2008-04-05 06:49:00
  • asp之家 网络编程 m.aspxhome.com