pytorch中使用LSTM详解
作者:qyhyzard 时间:2021-01-08 04:27:10
LSMT层
可以在troch.nn
模块中找到LSTM类
lstm = torch.nn.LSTM(*paramsters)
1、__init__方法
首先对nn.LSTM
类进行实例化,需要传入的参数如下图所示:
一般我们关注这4个:
input_size
表示输入的每个token的维度,也可以理解为一个word的embedding的维度。hidden_size
表示隐藏层也就是记忆单元C的维度,也可以理解为要将一个word的embedding维度转变成另一个大小的维度。除了C,在LSTM中输出的H的维度与C的维度是一致的。num_layers
表示有多少层LSTM,加深网络的深度,这个参数对LSTM的输出的维度是有影响的(后文会提到)。bidirectional
表示是否需要双向LSTM,这个参数也会对后面的输出有影响。
2、forward方法的输入
将数据input传入forward方法进行前向传播时有3个参数可以输入,见下图:
这里要注意的是
input
参数各个维度的意义,一般来说如果不在实例化时制定batch_first=True
,那么input
的第一个维度是输入句子的长度seq_len,第二个维度是批量的大小,第三个维度是输入句子的embedding维度也就是input_size,这个参数要与__init__
方法中的第一个参数对应。另外记忆细胞中的两个参数
h_0
和c_0
可以选择自己初始化传入也可以不传,系统默认是都初始化为0。传入的话注意维度[bidirectional * num_layers, batch_size, hidden_size]。
3、forward方法的输出
forward方法的输出如下图所示:
一般采用如下形式:
out,(h_n, c_n) = lstm(x)
out
表示在最后一层上,每一个时间步的输出,也就是句子有多长,这个out的输出就有多长;其维度为[seq_len, batch_size, hidden_size * bidirectional]。因为如果的双向LSTM,最后一层的输出会把正向的和反向的进行拼接,故需要hidden_size * bidirectional。h_n
表示的是每一层(双向算两层)在最后一个时间步上的输出;其维度为[bidirectional * num_layers, batch_size, hidden_size]
假设是双向的LSTM,且是3层LSTM,双向每个方向算一层,两个方向的组合起来叫一层LSTM,故共会有6层(3个正向,3个反向)。所以h_n是每层的输出,bidirectional * num_layers = 6。c_n
表示的是每一层(双向算两层)在最后一个时间步上的记忆单元,意义不同,但是其余均与 h_n
一样。
LSTMCell
可以在troch.nn
模块中找到LSTMCell类
lstm = torch.nn.LSTMCell(*paramsters)
它的__init__
方法的参数设置与LSTM类似,但是没有num_layers
参数,因为这就是一个细胞单元,谈不上多少层和是否双向。forward
的输入和输出与LSTM均有所不同:
其相比LSTM,输入没有了时间步的概念,因为只有一个Cell单元;输出 也没有out
参数,因为就一个Cell,out
就是h_1
,h_1
和c_1
也因为只有一个Cell单元,其没有层数上的意义,故只是一个Cell的输出的维度[batch_size, hidden_size].
代码演示如下:
rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
hx = torch.randn(3, 20) # (batch, hidden_size)
cx = torch.randn(3, 20)
output = []
# 从输入的第一个维度也就是seq_len上遍历,每循环一次,输入一个单词
for i in range(input.size()[0]):
# 更新细胞记忆单元
hx, cx = rnn(input[i], (hx, cx))
# 将每个word作为输入的输出存起来,相当于LSTM中的out
output.append(hx)
output = torch.stack(output, dim=0)
来源:https://blog.csdn.net/qq_42961603/article/details/119638341
![](/images/zang.png)
![](/images/jiucuo.png)
猜你喜欢
Javascript怎样使用SessionStorage和LocalStorage
python同时替换多个字符串方法示例
Mobile Web下的编码设计
![](https://img.aspxhome.com/file/UploadPic/20101/28/mobile-web-75s.jpg)
MYSQL初学者扫盲
Python+Qt身体特征识别人数统计源码窗体程序(使用步骤)
![](https://img.aspxhome.com/file/2023/5/66745_0s.jpg)
mysql5在rhel5下乱码问题及解决方法
![](https://img.aspxhome.com/file/UploadPic/201012/3/2010123165047481s.png)
php下pdo的mysql事务处理用法实例
一个免刷新页面的JavaScript日历
![](https://img.aspxhome.com/file/UploadPic/200712/26/20071226131443435s.jpg)
什么是XSLT,什么是XPath
如何使用pycharm连接Databricks的步骤详解
![](https://img.aspxhome.com/file/2023/3/72983_0s.png)
文字解说Golang Goroutine和线程的区别
24式加速你的Python(小结)
![](https://img.aspxhome.com/file/2023/2/62572_0s.png)
浅谈Python中range和xrange的区别
三分钟掌握PHP操作数据库
![](https://img.aspxhome.com/file/2023/9/55279_0s.png)
使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)
![](https://img.aspxhome.com/file/2023/9/78509_0s.jpg)
Dreamweaver表格布局经验谈
IE8 的 JSON 解析 Bug
通过实例了解python__slots__使用方法
PHP连接MSSQL方法汇总
Python下载ts文件视频且合并的操作方法
![](https://img.aspxhome.com/file/2023/5/76985_0s.jpg)