numpy实现RNN原理实现
作者:J k l 时间:2023-09-21 23:47:33
首先说明代码只是帮助理解,并未写出梯度下降部分,默认参数已经被固定,不影响理解。代码主要实现RNN原理,只使用numpy库,不可用于GPU加速。
import numpy as np
class Rnn():
def __init__(self, input_size, hidden_size, num_layers, bidirectional=False):
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirectional = bidirectional
def feed(self, x):
'''
:param x: [seq, batch_size, embedding]
:return: out, hidden
'''
# x.shape [sep, batch, feature]
# hidden.shape [hidden_size, batch]
# Whh0.shape [hidden_size, hidden_size] Wih0.shape [hidden_size, feature]
# Whh1.shape [hidden_size, hidden_size] Wih1.size [hidden_size, hidden_size]
out = []
x, hidden = np.array(x), [np.zeros((self.hidden_size, x.shape[1])) for i in range(self.num_layers)]
Wih = [np.random.random((self.hidden_size, self.hidden_size)) for i in range(1, self.num_layers)]
Wih.insert(0, np.random.random((self.hidden_size, x.shape[2])))
Whh = [np.random.random((self.hidden_size, self.hidden_size)) for i in range(self.num_layers)]
time = x.shape[0]
for i in range(time):
hidden[0] = np.tanh((np.dot(Wih[0], np.transpose(x[i, ...], (1, 0))) +
np.dot(Whh[0], hidden[0])
))
for i in range(1, self.num_layers):
hidden[i] = np.tanh((np.dot(Wih[i], hidden[i-1]) +
np.dot(Whh[i], hidden[i])
))
out.append(hidden[self.num_layers-1])
return np.array(out), np.array(hidden)
def sigmoid(x):
return 1.0/(1.0 + 1.0/np.exp(x))
if __name__ == '__main__':
rnn = Rnn(1, 5, 4)
input = np.random.random((6, 2, 1))
out, h = rnn.feed(input)
print(f'seq is {input.shape[0]}, batch_size is {input.shape[1]} ', 'out.shape ', out.shape, ' h.shape ', h.shape)
# print(sigmoid(np.random.random((2, 3))))
#
# element-wise multiplication
# print(np.array([1, 2])*np.array([2, 1]))
来源:https://blog.csdn.net/qq_43056256/article/details/114272542
标签:numpy,RNN
0
投稿
猜你喜欢
oracle 服务启动,关闭脚本(windows系统下)
2009-07-26 08:57:00
python中PyQuery库用法分享
2023-12-05 03:08:31
python opencv肤色检测的实现示例
2023-06-13 20:31:58
python语言线程标准库threading.local解读总结
2023-12-22 18:18:07
两侧背景自动延伸的CSS实现方法
2010-02-24 09:42:00
关于Django显示时间你应该知道的一些问题
2023-10-23 06:26:21
JavaScript 使用技巧精萃(.net html
2023-07-02 05:18:45
Tensorflow加载Vgg预训练模型操作
2023-10-13 10:56:23
python 串行执行和并行执行实例
2022-07-12 07:32:58
python geopandas读取、创建shapefile文件的方法
2022-09-23 16:57:19
如何在页面中快捷地添加翻页按钮?
2010-06-26 12:33:00
ASP如何输出字符
2007-09-22 18:41:00
ASP使用fso遍历文件及文件夹列出文件名
2008-10-27 19:32:00
详解Python Flask框架的安装及应用
2022-06-20 11:12:50
深入Oracle字符集的查看与修改详解
2023-06-25 22:13:15
Django CBV与FBV原理及实例详解
2023-02-14 20:39:01
pandas的排序、分组groupby及cumsum累计求和方式
2023-07-20 07:00:39
利用Echarts如何实现多段圆环图
2024-04-28 09:36:22
pytest自动化测试数据驱动yaml/excel/csv/json
2023-06-18 14:19:47
python版百度语音识别功能
2023-02-28 09:56:34