python神经网络Keras实现LSTM及其参数量详解

作者:Bubbliiiing 时间:2023-02-09 14:02:22 

什么是LSTM

1、LSTM的结构

python神经网络Keras实现LSTM及其参数量详解

我们可以看出,在n时刻,LSTM的输入有三个:

  • 当前时刻网络的输入值Xt;

  • 上一时刻LSTM的输出值ht-1;

  • 上一时刻的单元状态Ct-1。

LSTM的输出有两个:

  • 当前时刻LSTM输出值ht;

  • 当前时刻的单元状态Ct。

2、LSTM独特的门结构

LSTM用两个门来控制单元状态cn的内容:

  • 遗忘门(forget gate),它决定了上一时刻的单元状态cn-1有多少保留到当前时刻;

  • 输入门(input gate),它决定了当前时刻网络的输入c’n有多少保存到新的单元状态cn中。

LSTM用一个门来控制当前输出值hn的内容:

输出门(output gate),它利用当前时刻单元状态cn对hn的输出进行控制。

python神经网络Keras实现LSTM及其参数量详解

3、LSTM参数量计算

a、遗忘门

python神经网络Keras实现LSTM及其参数量详解

遗忘门这里需要结合ht-1和Xt来决定上一时刻的单元状态cn-1有多少保留到当前时刻;

由图我们可以得到,我们在这一环节需要计一个参数ft。

python神经网络Keras实现LSTM及其参数量详解

python神经网络Keras实现LSTM及其参数量详解

b、输入门

python神经网络Keras实现LSTM及其参数量详解

输入门这里需要结合ht-1和Xt来决定当前时刻网络的输入c’n有多少保存到单元状态cn中。

由图我们可以得到,我们在这一环节需要计算两个参数,分别是it。

python神经网络Keras实现LSTM及其参数量详解

和C’t

python神经网络Keras实现LSTM及其参数量详解

里面需要训练的参数分别是Wi、bi、WC和bC。

在定义LSTM的时候我们会使用到一个参数叫做units,其实就是神经元的个数,也就是LSTM的输出——ht的维度。

所以:

python神经网络Keras实现LSTM及其参数量详解

c、输出门

python神经网络Keras实现LSTM及其参数量详解

输出门利用当前时刻单元状态cn对hn的输出进行控制;

由图我们可以得到,我们在这一环节需要计一个参数ot。

python神经网络Keras实现LSTM及其参数量详解

里面需要训练的参数分别是Wo和bo。在定义LSTM的时候我们会使用到一个参数叫做units,其实就是神经元的个数,也就是LSTM的输出——ht的维度。所以:

python神经网络Keras实现LSTM及其参数量详解

d、全部参数量

所以所有的门总参数量为:

python神经网络Keras实现LSTM及其参数量详解

在Keras中实现LSTM

LSTM一般需要输入两个参数。

一个是unit、一个是input_shape。

LSTM(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))

unit用于指定神经元的数量。

input_shape用于指定输入的shape,分别指定TIME_STEPS和INPUT_SIZE。

实现代码

import numpy as np
from keras.models import Sequential
from keras.layers import Input,Activation,Dense
from keras.models import Model
from keras.datasets import mnist
from keras.layers.recurrent import LSTM
from keras.utils import np_utils
from keras.optimizers import Adam
TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
inputs = Input(shape=[TIME_STEPS,INPUT_SIZE])
x = LSTM(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))(inputs)
x = Dense(OUTPUT_SIZE)(x)
x = Activation("softmax")(x)
model = Model(inputs,x)
adam = Adam(LR)
model.summary()
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])
for i in range(50000):
   X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
   Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
   index_start += BATCH_SIZE
   cost = model.train_on_batch(X_batch,Y_batch)
   if index_start >= X_train.shape[0]:
       index_start = 0
   if i%100 == 0:
       cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
       print("accuracy:",accuracy)

实现效果:

10000/10000 [==============================] - 3s 340us/step
accuracy: 0.14040000014007092
10000/10000 [==============================] - 3s 310us/step
accuracy: 0.6507000041007995
10000/10000 [==============================] - 3s 320us/step
accuracy: 0.7740999992191792
10000/10000 [==============================] - 3s 305us/step
accuracy: 0.8516999959945679
10000/10000 [==============================] - 3s 322us/step
accuracy: 0.8669999945163727
10000/10000 [==============================] - 3s 324us/step
accuracy: 0.889699995815754
10000/10000 [==============================] - 3s 307us/step

来源:https://blog.csdn.net/weixin_44791964/article/details/103896884

标签:python,神经网络,Keras,LSTM,参数量
0
投稿

猜你喜欢

  • Pytorch中DataLoader的使用方法详解

    2023-07-19 04:45:39
  • asp如何做一个密码“生成器”?

    2010-07-12 18:51:00
  • 通过python爬虫赚钱的方法

    2023-04-27 11:48:17
  • php strstr查找字符串中是否包含某些字符的查找函数

    2023-11-17 01:42:23
  • 谈非线性任务流程的窗口打开方式

    2008-08-28 12:47:00
  • Python 数字转化成列表详情

    2023-09-24 06:53:25
  • Python实现密钥密码(加解密)实例详解

    2022-09-10 12:03:37
  • javascript 时间脚本收集

    2013-07-17 19:52:50
  • Python操作word文档插入图片和表格的实例演示

    2023-09-20 08:21:09
  • python Pexpect模块的使用

    2023-01-23 20:54:58
  • python递归删除指定目录及其所有内容的方法

    2022-12-12 02:42:51
  • 柳永法:vbs或asp采集文章时网页编码问题

    2009-02-04 10:50:00
  • python自制简易mysql连接池的实现示例

    2023-04-14 20:23:55
  • Ajax学习小贴士

    2007-10-24 23:21:00
  • PHP:微信小程序 微信支付服务端集成实例详解及源码下载

    2023-11-14 13:37:55
  • Python机器学习NLP自然语言处理基本操作词袋模型

    2023-08-20 06:23:30
  • YUI学习笔记(3)

    2009-01-21 16:24:00
  • 在前女友婚礼上用python把婚礼现场的WIFI名称改成了

    2023-05-26 15:15:49
  • Zend Studio去除编辑器的语法警告设置方法

    2023-10-11 17:10:15
  • layer弹窗插件操作方法详解

    2023-08-09 14:30:14
  • asp之家 网络编程 m.aspxhome.com