pytorch中torch.topk()函数的快速理解

作者:Neo很努力 时间:2023-09-07 10:21:52 

函数作用:

pytorch中torch.topk()函数的快速理解

pytorch中torch.topk()函数的快速理解

该函数的作用即按字面意思理解,topk:取数组的前k个元素进行排序。

通常该函数返回2个值,第一个值为排序的数组,第二个值为该数组中获取到的元素在原数组中的位置标号。

举个栗子:

import numpy as np
import torch
import torch.utils.data.dataset as Dataset
from torch.utils.data import Dataset,DataLoader

####################准备一个数组#########################
tensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10],
            [3,4,5,1,1,1,1,1,1,1,1],
            [7,8,9,1,1,1,1,1,1,1,1],
            [1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32)

####################打印这个原数组#########################
print('tensor1:')
print(tensor1)

#################使用torch.topk()这个函数##################
print('使用torch.topk()这个函数得到:')

'''k=3代表从原数组中取得3个元素,dim=1表示从原数组中的第一维获取元素
(在本例中是分别从[10,1,2,1,1,1,1,1,1,1,10]、[3,4,5,1,1,1,1,1,1,1,1]、
 [7,8,9,1,1,1,1,1,1,1,1]、[1,4,7,1,1,1,1,1,1,1,1]这四个数组中获取3个元素)
其中largest=True表示从大到小取元素'''
print(torch.topk(tensor1, k=3, dim=1, largest=True))

#################打印这个函数第一个返回值####################
print('函数第一个返回值topk[0]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[0])

#################打印这个函数第二个返回值####################
print('函数第二个返回值topk[1]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[1])
'''

#######################运行结果##########################
tensor1:
tensor([[10.,  1.,  2.,  1.,  1.,  1.,  1.,  1.,  1.,  1., 10.],
       [ 3.,  4.,  5.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
       [ 7.,  8.,  9.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
       [ 1.,  4.,  7.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.]])

使用torch.topk()这个函数得到:

'得到的values是原数组dim=1的四组从大到小的三个元素值;
得到的indices是获取到的元素值在原数组dim=1中的位置。'

torch.return_types.topk(
values=tensor([[10., 10.,  2.],
       [ 5.,  4.,  3.],
       [ 9.,  8.,  7.],
       [ 7.,  4.,  1.]]),
indices=tensor([[ 0, 10,  2],
       [ 2,  1,  0],
       [ 2,  1,  0],
       [ 2,  1,  0]]))

函数第一个返回值topk[0]如下
tensor([[10., 10.,  2.],
       [ 5.,  4.,  3.],
       [ 9.,  8.,  7.],
       [ 7.,  4.,  1.]])

函数第二个返回值topk[1]如下
tensor([[ 0, 10,  2],
       [ 2,  1,  0],
       [ 2,  1,  0],
       [ 2,  1,  0]])
'''

该函数功能经常用来获取张量或者数组中最大或者最小的元素以及索引位置,是一个经常用到的基本函数。

实例演示

任务一:

取top1(最大值):

pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],
       [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],
       [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],
       [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])
print(pred)
values, indices = pred.topk(1, dim=0, largest=True, sorted=True)
print(indices)
print(values)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=0, keepdim=True)
print(indices_max)
print(indices_max == indices)
输出:
tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],
       [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],
       [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],
       [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])
tensor([[1, 1, 1, 1, 1]])
tensor([[0.7265, 1.4164, 1.3443, 1.2035, 1.8823]])
tensor([[1, 1, 1, 1, 1]])
tensor([[True, True, True, True, True]])

任务二:

按行取出topk,将小于topk的置为inf:

pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],
       [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],
       [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],
       [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])
print(pred)
top_k = 2  # 按行求出每一行的最大的前两个值
filter_value=-float('Inf')
indices_to_remove = pred < torch.topk(pred, top_k)[0][..., -1, None]
print(indices_to_remove)
pred[indices_to_remove] = filter_value  # 对于topk之外的其他元素的logits值设为负无穷
print(pred)

输出:
tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],
       [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],
       [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],
       [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])
tensor([[4],
       [4],
       [4],
       [3]])
tensor([[0.4053],
       [1.8823],
       [1.7255],
       [0.3849]])
tensor([[ True, False,  True,  True, False],
       [ True, False,  True,  True, False],
       [ True,  True, False,  True, False],
       [ True, False,  True, False,  True]])
tensor([[   -inf, -0.3873,    -inf,    -inf,  0.4053],
       [   -inf,  1.4164,    -inf,    -inf,  1.8823],
       [   -inf,    -inf,  1.2590,    -inf,  1.7255],
       [   -inf,  0.3041,    -inf,  0.3849,    -inf]])

任务三:

import numpy as np
import torch
import torch.utils.data.dataset as Dataset
from torch.utils.data import Dataset,DataLoader
tensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10],
            [3,4,5,1,1,1,1,1,1,1,1],
            [7,8,9,1,1,1,1,1,1,1,1],
            [1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32)
# tensor2=torch.tensor([[3,2,1],
#                       [6,5,4],
#                       [1,4,7],
#                       [9,8,7]],dtype=torch.float32)
#
print('tensor1:')
print(tensor1)
print('直接输出topk,会得到两个东西,我们需要的是第二个indices')
print(torch.topk(tensor1, k=3, dim=1, largest=True))
print('topk[0]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[0])
print('topk[1]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[1])
'''
tensor1:
tensor([[10.,  1.,  2.,  1.,  1.,  1.,  1.,  1.,  1.,  1., 10.],
       [ 3.,  4.,  5.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
       [ 7.,  8.,  9.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
       [ 1.,  4.,  7.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.]])
直接输出topk,会得到两个东西,我们需要的是第二个indices
torch.return_types.topk(
values=tensor([[10., 10.,  2.],
       [ 5.,  4.,  3.],
       [ 9.,  8.,  7.],
       [ 7.,  4.,  1.]]),
indices=tensor([[ 0, 10,  2],
       [ 2,  1,  0],
       [ 2,  1,  0],
       [ 2,  1,  0]]))
topk[0]如下
tensor([[10., 10.,  2.],
       [ 5.,  4.,  3.],
       [ 9.,  8.,  7.],
       [ 7.,  4.,  1.]])
topk[1]如下
tensor([[ 0, 10,  2],
       [ 2,  1,  0],
       [ 2,  1,  0],
       [ 2,  1,  0]])
'''

来源:https://blog.csdn.net/qq_45193872/article/details/119878804

标签:pytorch,torch.topk(),函数
0
投稿

猜你喜欢

  • C#编程实现连接ACCESS数据库实例详解

    2024-01-19 10:25:15
  • 5招带你轻松优化MySQL count(*)查询性能

    2024-01-28 04:40:58
  • vue实现动态控制el-table表格列的展示与隐藏

    2024-05-05 09:07:26
  • Mac在python3环境下安装virtualwrapper遇到的问题及解决方法

    2021-06-22 17:32:38
  • Python之Web框架Django项目搭建全过程

    2022-03-21 03:06:08
  • ASP编程中使用数组方法实例讲解

    2008-02-15 08:25:00
  • numpy.float32的典型用法

    2022-04-30 04:36:30
  • 通过备份记录获取数据库的增长情况

    2024-01-17 20:21:26
  • Vue中keep-alive组件的深入理解

    2023-07-02 17:00:38
  • MySql .frm数据库文件导入的问题

    2024-01-18 19:08:09
  • 关于python中remove的一些坑小结

    2022-09-22 02:14:31
  • 基于pytorch padding=SAME的解决方式

    2023-03-03 05:05:28
  • Python Logging 日志记录入门学习

    2022-05-17 14:48:39
  • python使用matplotlib库生成随机漫步图

    2021-03-01 15:11:37
  • Python函数中的可变长参数详解

    2022-08-01 06:04:57
  • MySQL事件与触发器专题精炼

    2024-01-22 09:08:21
  • 基于mpvue的小程序项目搭建的步骤

    2024-05-13 09:12:13
  • 浅谈Python中range与Numpy中arange的比较

    2021-05-21 00:15:48
  • Flask接口签名sign原理与实例代码浅析

    2023-06-21 23:18:13
  • Python下opencv图像阈值处理的使用笔记

    2023-08-24 21:28:27
  • asp之家 网络编程 m.aspxhome.com