python神经网络使用Keras构建RNN训练

作者:Bubbliiiing 时间:2021-07-19 21:12:15 

Keras中构建RNN的重要函数

1、SimpleRNN

SimpleRNN用于在Keras中构建普通的简单RNN层,在使用前需要import。

from keras.layers import SimpleRNN

在实际使用时,需要用到几个参数。

model.add(
   SimpleRNN(
       batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
       output_dim = CELL_SIZE,
   )
)

其中,batch_input_shape代表RNN输入数据的shape,shape的内容分别是每一次训练使用的BATCH,TIME_STEPS表示这个RNN按顺序输入的时间点的数量,INPUT_SIZE表示每一个时间点的输入数据大小。
CELL_SIZE代表训练每一个时间点的神经元数量。

2、model.train_on_batch

与之前的训练CNN网络和普通分类网络不同,RNN网络在建立时就规定了batch_input_shape,所以训练的时候也需要一定量一定量的传入训练数据。
model.train_on_batch在使用前需要对数据进行处理。获取指定BATCH大小的训练集。

X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE

具体训练过程如下:

for i in range(500):
   X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
   Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
   index_start += BATCH_SIZE
   cost = model.train_on_batch(X_batch,Y_batch)
   if index_start >= X_train.shape[0]:
       index_start = 0
   if i%100 == 0:
       ## acc
       cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
       ## W,b = model.layers[0].get_weights()
       print("accuracy:",accuracy)
       x = X_test[1].reshape(1,28,28)

全部代码

这是一个RNN神经网络的例子,用于识别手写体。

import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam
TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
model = Sequential()
# conv1
model.add(
   SimpleRNN(
       batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
       output_dim = CELL_SIZE,
   )
)
model.add(Dense(OUTPUT_SIZE))
model.add(Activation("softmax"))
adam = Adam(LR)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])
## tarin
for i in range(500):
   X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
   Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
   index_start += BATCH_SIZE
   cost = model.train_on_batch(X_batch,Y_batch)
   if index_start >= X_train.shape[0]:
       index_start = 0
   if i%100 == 0:
       ## acc
       cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
       ## W,b = model.layers[0].get_weights()
       print("accuracy:",accuracy)

实验结果为:

10000/10000 [==============================] - 1s 147us/step
accuracy: 0.09329999938607215
…………………………
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9395000022649765
10000/10000 [==============================] - 1s 109us/step
accuracy: 0.9422999995946885
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9534000000357628
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9566000008583069
10000/10000 [==============================] - 1s 113us/step
accuracy: 0.950799999833107
10000/10000 [==============================] - 1s 116us/step
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9474999988079071
10000/10000 [==============================] - 1s 111us/step
accuracy: 0.9515000003576278
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9288999977707862
10000/10000 [==============================] - 1s 115us/step
accuracy: 0.9487999993562698

来源:https://blog.csdn.net/weixin_44791964/article/details/101609556

标签:python,神经网络,Keras,RNN,训练
0
投稿

猜你喜欢

  • python中的列表与元组的使用

    2023-07-23 08:25:12
  • php数组转换js数组操作及json_encode的用法详解

    2024-05-03 15:13:44
  • 详解Python函数print用法

    2023-06-10 03:47:34
  • python实现字符串和字典的转换

    2023-03-02 02:57:18
  • java 中JDBC连接数据库代码和步骤详解及实例代码

    2024-01-27 16:35:14
  • 深入了解Python enumerate和zip

    2021-11-15 12:08:23
  • django中只使用ModleForm的表单验证

    2021-02-03 15:41:50
  • 浅探express路由和中间件的实现

    2024-05-11 10:17:08
  • 使用字符串建立查询能加快服务器的解析速度吗?

    2010-07-14 21:03:00
  • Python编程之微信推送模板消息功能示例

    2022-11-15 03:45:04
  • mysql5.58的编译安装

    2011-01-29 16:26:00
  • 如何使数据库中取出的数据保持原有格式

    2008-11-27 16:16:00
  • python logging模块书写日志以及日志分割详解

    2023-02-23 12:52:16
  • python用moviepy对视频进行简单的处理

    2023-08-03 07:02:15
  • 使用matplotlib库实现图形局部数据放大显示的实践

    2021-01-13 18:47:13
  • Python 内置高阶函数详细

    2022-07-26 11:02:07
  • mysql community server 8.0.12安装配置方法图文教程

    2024-01-21 19:28:04
  • 收缩后对数据库的使用有影响吗?

    2024-01-21 09:41:48
  • Python编程基础之构造方法和析构方法详解

    2022-02-26 02:38:03
  • js处理自己不能定义二维数组的方法详解

    2023-09-06 21:25:12
  • asp之家 网络编程 m.aspxhome.com