pytorch对可变长度序列的处理方法详解

作者:深度学习1 时间:2022-11-11 23:19:39 

主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这三个函数的用法。

1、torch.nn.utils.rnn.PackedSequence()

NOTE: 这个类的实例不能手动创建。它们只能被 pack_padded_sequence() 实例化。

PackedSequence对象包括:

一个data对象:一个torch.Variable(令牌的总数,每个令牌的维度),在这个简单的例子中有五个令牌序列(用整数表示):(18,1)

一个batch_sizes对象:每个时间步长的令牌数列表,在这个例子中为:[6,5,2,4,1]

用pack_padded_sequence函数来构造这个对象非常的简单:

pytorch对可变长度序列的处理方法详解

如何构造一个PackedSequence对象(batch_first = True)

PackedSequence对象有一个很不错的特性,就是我们无需对序列解包(这一步操作非常慢)即可直接在PackedSequence数据变量上执行许多操作。特别是我们可以对令牌执行任何操作(即对令牌的顺序/上下文不敏感)。当然,我们也可以使用接受PackedSequence作为输入的任何一个pyTorch模块(pyTorch 0.2)。

2、torch.nn.utils.rnn.pack_padded_sequence()

这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

输入的形状可以是(T×B×* )。T是最长序列长度,B是batch size,*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)。

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

NOTE: 只要是维度大于等于2的input都可以作为这个函数的参数。你可以用它来打包labels,然后用RNN的输出和打包后的labels来计算loss。通过PackedSequence对象的.data属性可以获取 Variable。

参数说明:

input (Variable) – 变长序列 被填充后的 batch

lengths (list[int]) – Variable 中 每个序列的长度。

batch_first (bool, optional) – 如果是True,input的形状应该是B*T*size。

返回值:

一个PackedSequence 对象。

3、torch.nn.utils.rnn.pad_packed_sequence()

填充packed_sequence。

上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。

返回的Varaible的值的size是 T×B×*, T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*。

Batch中的元素将会以它们长度的逆序排列。

参数说明:

sequence (PackedSequence) – 将要被填充的 batch

batch_first (bool, optional) – 如果为True,返回的数据的格式为 B×T×*。

返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。

例子:


import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
batch_size = 2
max_length = 3
hidden_size = 2
n_layers =1

tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step

# pack it
pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)

# initialize
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))

#forward
out, _ = rnn(pack, h0)

# unpack
unpacked = nn_utils.rnn.pad_packed_sequence(out)
print('111',unpacked)

输出:


111 (Variable containing:
(0 ,.,.) =
0.5406 0.3584
-0.1403 0.0308

(1 ,.,.) =
-0.6855 -0.9307
0.0000 0.0000
[torch.FloatTensor of size 2x2x2]
, [2, 1])

来源:104.116.116.112.58.47.47.119.119.119.46.99.110.98.108.111.103.115.46.99.111.109.47.108.105.110.100.97.120.105.110.47.112.47.56.48.53.50.48.52.51.46.104.116.109.108.

标签:pytorch,变长,序列
0
投稿

猜你喜欢

  • python向量化与for循环耗时对比分析

    2023-12-21 14:14:59
  • 解决Oracle安装遇到Enterprise Manager配置失败问题

    2024-01-21 03:48:28
  • golang 如何通过反射创建新对象

    2024-04-27 15:24:38
  • Python实现MySQL操作的方法小结【安装,连接,增删改查等】

    2024-01-16 07:02:33
  • Python中使用OpenCV库来进行简单的气象学遥感影像计算

    2021-02-02 09:45:49
  • 浅谈Python3中print函数的换行

    2023-12-15 18:24:31
  • python sys模块使用方法介绍

    2021-02-03 09:19:16
  • Oracle 分析函数RANK(),ROW_NUMBER(),LAG()等的使用方法

    2009-11-05 21:45:00
  • 基于Python实现一个简单的学生管理系统

    2023-07-23 23:06:40
  • MYSQL初学者命令行使用指南

    2024-01-15 08:46:33
  • pandas中的DataFrame数据遍历解读

    2023-12-03 21:57:34
  • 通过python读取txt文件和绘制柱形图的实现代码

    2023-11-23 11:24:43
  • 定位?浮动?自适应?

    2008-06-30 14:20:00
  • javascript面向对象技术基础(二)

    2010-02-07 13:09:00
  • Node+OCR实现图像文字识别功能

    2024-04-22 13:01:41
  • 五个有趣的Python整蛊小程序合集

    2022-10-27 12:34:10
  • 详解Python判定IP地址合法性的三种方法

    2021-12-02 14:35:10
  • Python读写及备份oracle数据库操作示例

    2024-01-21 17:09:45
  • node.js用fs.rename强制重命名或移动文件夹的方法

    2024-05-13 10:05:37
  • 关于python如何生成exe文件

    2021-06-05 00:52:33
  • asp之家 网络编程 m.aspxhome.com