PyTorch深度学习LSTM从input输入到Linear输出

作者:Cyril_KI 时间:2022-04-03 23:11:32 

LSTM介绍

关于LSTM的具体原理,可以参考:

https://www.jb51.net/article/178582.htm

https://www.jb51.net/article/178423.htm

系列文章:

PyTorch搭建双向LSTM实现时间序列负荷预测

PyTorch搭建LSTM实现多变量多步长时序负荷预测

PyTorch搭建LSTM实现多变量时序负荷预测

PyTorch搭建LSTM实现时间序列负荷预测

LSTM参数

关于nn.LSTM的参数,官方文档给出的解释为:

PyTorch深度学习LSTM从input输入到Linear输出

总共有七个参数,其中只有前三个是必须的。由于大家普遍使用PyTorch的DataLoader来形成批量数据,因此batch_first也比较重要。LSTM的两个常见的应用场景为文本处理和时序预测,因此下面对每个参数我都会从这两个方面来进行具体解释。

  • input_size:在文本处理中,由于一个单词没法参与运算,因此我们得通过Word2Vec来对单词进行嵌入表示,将每一个单词表示成一个向量,此时input_size=embedding_size。

  • 比如每个句子中有五个单词,每个单词用一个100维向量来表示,那么这里input_size=100;

  • 在时间序列预测中,比如需要预测负荷,每一个负荷都是一个单独的值,都可以直接参与运算,因此并不需要将每一个负荷表示成一个向量,此时input_size=1。

  • 但如果我们使用多变量进行预测,比如我们利用前24小时每一时刻的[负荷、风速、温度、压强、湿度、天气、节假日信息]来预测下一时刻的负荷,那么此时input_size=7。

  • hidden_size:隐藏层节点个数。可以随意设置。

  • num_layers:层数。nn.LSTMCell与nn.LSTM相比,num_layers默认为1。

  • batch_first:默认为False,意义见后文。

Inputs

关于LSTM的输入,官方文档给出的定义为:

PyTorch深度学习LSTM从input输入到Linear输出

可以看到,输入由两部分组成:input、(初始的隐状态h_0,初始的单元状态c_0)

其中input:

input(seq_len, batch_size, input_size)
  • seq_len:在文本处理中,如果一个句子有7个单词,则seq_len=7;在时间序列预测中,假设我们用前24个小时的负荷来预测下一时刻负荷,则seq_len=24。

  • batch_size:一次性输入LSTM中的样本个数。在文本处理中,可以一次性输入很多个句子;在时间序列预测中,也可以一次性输入很多条数据。

  • input_size:见前文。

(h_0, c_0):

h_0(num_directions * num_layers, batch_size, hidden_size)
c_0(num_directions * num_layers, batch_size, hidden_size)

h_0和c_0的shape一致。

  • num_directions:如果是双向LSTM,则num_directions=2;否则num_directions=1。

  • num_layers:见前文。

  • batch_size:见前文。

  • hidden_size:见前文。

Outputs

关于LSTM的输出,官方文档给出的定义为:

PyTorch深度学习LSTM从input输入到Linear输出

可以看到,输出也由两部分组成:otput、(隐状态h_n,单元状态c_n)

其中output的shape为:

output(seq_len, batch_size, num_directions * hidden_size)

h_n和c_n的shape保持不变,参数解释见前文。

batch_first

如果在初始化LSTM时令batch_first=True,那么input和output的shape将由:

input(seq_len, batch_size, input_size)
output(seq_len, batch_size, num_directions * hidden_size)

变为:

input(batch_size, seq_len, input_size)
output(batch_size, seq_len, num_directions * hidden_size)

即batch_size提前。

案例

简单搭建一个LSTM如下所示:

class LSTM(nn.Module):
   def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size):
       super().__init__()
       self.input_size = input_size
       self.hidden_size = hidden_size
       self.num_layers = num_layers
       self.output_size = output_size
       self.num_directions = 1 # 单向LSTM
       self.batch_size = batch_size
       self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
       self.linear = nn.Linear(self.hidden_size, self.output_size)
   def forward(self, input_seq):
       h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
       c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
       seq_len = input_seq.shape[1] # (5, 30)
       # input(batch_size, seq_len, input_size)
       input_seq = input_seq.view(self.batch_size, seq_len, 1)  # (5, 30, 1)
       # output(batch_size, seq_len, num_directions * hidden_size)
       output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)
       output = output.contiguous().view(self.batch_size * seq_len, self.hidden_size) # (5 * 30, 64)
       pred = self.linear(output) # pred(150, 1)
       pred = pred.view(self.batch_size, seq_len, -1) # (5, 30, 1)
       pred = pred[:, -1, :]  # (5, 1)
       return pred

其中定义模型的代码为:

self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
self.linear = nn.Linear(self.hidden_size, self.output_size)

我们加上具体的数字:

self.lstm = nn.LSTM(self.input_size=1, self.hidden_size=64, self.num_layers=5, batch_first=True)
self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)

再看前向传播:

def forward(self, input_seq):
   h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
   c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
   seq_len = input_seq.shape[1]  # (5, 30)
   # input(batch_size, seq_len, input_size)
   input_seq = input_seq.view(self.batch_size, seq_len, 1)  # (5, 30, 1)
   # output(batch_size, seq_len, num_directions * hidden_size)
   output, _ = self.lstm(input_seq, (h_0, c_0))  # output(5, 30, 64)
   output = output.contiguous().view(self.batch_size * seq_len, self.hidden_size)  # (5 * 30, 64)
   pred = self.linear(output) # (150, 1)
   pred = pred.view(self.batch_size, seq_len, -1)  # (5, 30, 1)
   pred = pred[:, -1, :]  # (5, 1)
   return pred

假设用前30个预测下一个,则seq_len=30,batch_size=5,由于设置了batch_first=True,因此,输入到LSTM中的input的shape应该为:

input(batch_size, seq_len, input_size) = input(5, 30, 1)

但实际上,经过DataLoader处理后的input_seq为:

input_seq(batch_size, seq_len) = input_seq(5, 30)

(5, 30)表示一共5条数据,每条数据的维度都为30。为了匹配LSTM的输入,我们需要对input_seq的shape进行变换:

input_seq = input_seq.view(self.batch_size, seq_len, 1)  # (5, 30, 1)

然后将input_seq送入LSTM:

output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)

根据前文,output的shape为:

output(batch_size, seq_len, num_directions * hidden_size) = output(5, 30, 64)

全连接层的定义为:

self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)

因此,我们需要将output的第二维度变换为64(150, 64):

output = output.contiguous().view(self.batch_size * seq_len, self.hidden_size) # (5 * 30, 64)

然后将output送入全连接层:

pred = self.linear(output) # pred(150, 1)

得到的预测值shape为(150, 1)。我们需要将其进行还原,变成(5, 30, 1):

pred = pred.view(self.batch_size, seq_len, -1) # (5, 30, 1)

在用DataLoader处理了数据后,得到的input_seq和label的shape分别为:

input_seq(batch_size, seq_len) = input_seq(5, 30)label(batch_size, output_size) = label(5, 1)

由于输出是输入右移,我们只需要取pred第二维度(time)中的最后一个数据:

pred = pred[:, -1, :] # (5, 1)

这样,我们就得到了预测值,然后与label求loss,然后再反向更新参数即可。

时间序列预测的一个真实案例请见:PyTorch搭建LSTM实现时间序列预测(负荷预测)

来源:https://blog.csdn.net/Cyril_KI/article/details/122557880

标签:PyTorch,深度学习,LSTM,input,Linear
0
投稿

猜你喜欢

  • Python编程argparse入门浅析

    2023-11-05 09:53:01
  • 深入解析MySQL索引的原理与优化策略

    2024-01-19 02:29:02
  • Python模块学习 datetime介绍

    2023-08-15 16:01:18
  • 在matplotlib的图中设置中文标签的方法

    2023-10-10 07:17:53
  • 使用PHP实现微信摇一摇周边红包

    2023-11-14 12:04:22
  • 简单了解python列表和元组的区别

    2022-02-11 17:14:43
  • Python简单计算数组元素平均值的方法示例

    2021-02-22 17:52:20
  • vue使用百度地图报错BMap is not defined问题及解决

    2024-04-26 17:42:02
  • Python Print实现在输出中插入变量的例子

    2022-06-07 11:12:09
  • Python中安装库的常用方法介绍

    2022-04-03 08:13:17
  • JavaScript实现设计模式中的单例模式的一些技巧总结

    2024-05-02 16:21:11
  • Selenium定位浏览器弹窗方法实例总结

    2022-07-03 05:17:24
  • 利用Python实现翻译HTML中的文本字符串

    2022-02-20 23:03:13
  • 详解php中的类与对象(继承)

    2023-11-23 14:07:09
  • php输出全部gb2312编码内的汉字方法

    2023-10-04 05:56:31
  • Python 里最强的地图绘制神器

    2023-07-17 12:36:43
  • Python数组拼接np.concatenate实现过程

    2023-11-12 04:26:10
  • Python实现8个概率分布公式的方法详解

    2022-05-14 08:14:37
  • Python2.6版本中实现字典推导 PEP 274(Dict Comprehensions)

    2022-04-13 02:53:50
  • JavaScript实现网页动态生成表格

    2024-04-16 09:24:00
  • asp之家 网络编程 m.aspxhome.com