python神经网络使用Keras构建RNN训练
作者:Bubbliiiing 时间:2021-07-19 21:12:15
Keras中构建RNN的重要函数
1、SimpleRNN
SimpleRNN用于在Keras中构建普通的简单RNN层,在使用前需要import。
from keras.layers import SimpleRNN
在实际使用时,需要用到几个参数。
model.add(
SimpleRNN(
batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
output_dim = CELL_SIZE,
)
)
其中,batch_input_shape代表RNN输入数据的shape,shape的内容分别是每一次训练使用的BATCH,TIME_STEPS表示这个RNN按顺序输入的时间点的数量,INPUT_SIZE表示每一个时间点的输入数据大小。
CELL_SIZE代表训练每一个时间点的神经元数量。
2、model.train_on_batch
与之前的训练CNN网络和普通分类网络不同,RNN网络在建立时就规定了batch_input_shape,所以训练的时候也需要一定量一定量的传入训练数据。
model.train_on_batch在使用前需要对数据进行处理。获取指定BATCH大小的训练集。
X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE
具体训练过程如下:
for i in range(500):
X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE
cost = model.train_on_batch(X_batch,Y_batch)
if index_start >= X_train.shape[0]:
index_start = 0
if i%100 == 0:
## acc
cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
## W,b = model.layers[0].get_weights()
print("accuracy:",accuracy)
x = X_test[1].reshape(1,28,28)
全部代码
这是一个RNN神经网络的例子,用于识别手写体。
import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam
TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
model = Sequential()
# conv1
model.add(
SimpleRNN(
batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
output_dim = CELL_SIZE,
)
)
model.add(Dense(OUTPUT_SIZE))
model.add(Activation("softmax"))
adam = Adam(LR)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])
## tarin
for i in range(500):
X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE
cost = model.train_on_batch(X_batch,Y_batch)
if index_start >= X_train.shape[0]:
index_start = 0
if i%100 == 0:
## acc
cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
## W,b = model.layers[0].get_weights()
print("accuracy:",accuracy)
实验结果为:
10000/10000 [==============================] - 1s 147us/step
accuracy: 0.09329999938607215
…………………………
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9395000022649765
10000/10000 [==============================] - 1s 109us/step
accuracy: 0.9422999995946885
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9534000000357628
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9566000008583069
10000/10000 [==============================] - 1s 113us/step
accuracy: 0.950799999833107
10000/10000 [==============================] - 1s 116us/step
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9474999988079071
10000/10000 [==============================] - 1s 111us/step
accuracy: 0.9515000003576278
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9288999977707862
10000/10000 [==============================] - 1s 115us/step
accuracy: 0.9487999993562698
来源:https://blog.csdn.net/weixin_44791964/article/details/101609556
标签:python,神经网络,Keras,RNN,训练
![](/images/zang.png)
![](/images/jiucuo.png)
猜你喜欢
Mootools 1.2教程(4)——函数
2008-11-18 15:36:00
一个jquery日期选取插件源码
2009-12-23 19:15:00
![](https://img.aspxhome.com/file/UploadPic/20101/11/2009126213743-92s.png)
ajax返回中文乱码问题解决
2009-04-13 16:07:00
比较详细PHP生成静态页面教程
2023-10-14 18:54:31
让你的空间支持域名绑定子目录的解决办法
2010-09-15 10:03:00
事件触发列表与解说
2013-07-19 11:17:12
如何利用python正确地为图像添加高斯噪声
2023-08-03 08:26:22
![](https://img.aspxhome.com/file/2023/9/59459_0s.png)
将数据从MySQL迁移到 Oracle的注意事项
2008-12-03 15:41:00
剖析SQL Server 事务日志的收缩和截断
2009-01-15 13:04:00
JS数组方法汇总
2009-08-03 14:06:00
最令人蛋疼的10种用户体验设计师
2011-08-05 18:51:07
ASP连接MSSQL2005 数据库
2009-03-08 19:20:00
浅析数据完整性问题
2007-10-07 12:44:00
开发心得--写给想学Javascript朋友的一点经验之谈
2009-02-25 11:42:00
列举Python中吸引人的一些特性
2023-12-17 03:25:57
PHP文件上传功能实现逻辑分析
2023-05-25 02:28:30
asp中InstrRev的语法
2008-01-22 18:14:00
python GUI库图形界面开发之PyQt5 Qt Designer工具(Qt设计师)详细使用方法及Designer ui文件转py文件方法
2023-05-17 00:32:46
![](https://img.aspxhome.com/file/2023/4/64364_0s.png)
掀起抛弃IE6的高潮吧
2009-02-26 12:44:00
用FrontPage制作缩略图和图片重叠效果
2007-11-18 14:45:00