numpy实现RNN原理实现

作者:J k l 时间:2023-09-21 23:47:33 

首先说明代码只是帮助理解,并未写出梯度下降部分,默认参数已经被固定,不影响理解。代码主要实现RNN原理,只使用numpy库,不可用于GPU加速。


import numpy as np

class Rnn():

def __init__(self, input_size, hidden_size, num_layers, bidirectional=False):
   self.input_size = input_size
   self.hidden_size = hidden_size
   self.num_layers = num_layers
   self.bidirectional = bidirectional

def feed(self, x):
   '''

:param x: [seq, batch_size, embedding]
   :return: out, hidden
   '''

# x.shape [sep, batch, feature]
   # hidden.shape [hidden_size, batch]
   # Whh0.shape [hidden_size, hidden_size] Wih0.shape [hidden_size, feature]
   # Whh1.shape [hidden_size, hidden_size] Wih1.size [hidden_size, hidden_size]

out = []
   x, hidden = np.array(x), [np.zeros((self.hidden_size, x.shape[1])) for i in range(self.num_layers)]
   Wih = [np.random.random((self.hidden_size, self.hidden_size)) for i in range(1, self.num_layers)]
   Wih.insert(0, np.random.random((self.hidden_size, x.shape[2])))
   Whh = [np.random.random((self.hidden_size, self.hidden_size)) for i in range(self.num_layers)]

time = x.shape[0]
   for i in range(time):
     hidden[0] = np.tanh((np.dot(Wih[0], np.transpose(x[i, ...], (1, 0))) +
              np.dot(Whh[0], hidden[0])
              ))

for i in range(1, self.num_layers):
       hidden[i] = np.tanh((np.dot(Wih[i], hidden[i-1]) +
                  np.dot(Whh[i], hidden[i])
                  ))

out.append(hidden[self.num_layers-1])

return np.array(out), np.array(hidden)

def sigmoid(x):
 return 1.0/(1.0 + 1.0/np.exp(x))

if __name__ == '__main__':
 rnn = Rnn(1, 5, 4)
 input = np.random.random((6, 2, 1))
 out, h = rnn.feed(input)
 print(f'seq is {input.shape[0]}, batch_size is {input.shape[1]} ', 'out.shape ', out.shape, ' h.shape ', h.shape)
 # print(sigmoid(np.random.random((2, 3))))
 #
 # element-wise multiplication
 # print(np.array([1, 2])*np.array([2, 1]))

来源:https://blog.csdn.net/qq_43056256/article/details/114272542

标签:numpy,RNN
0
投稿

猜你喜欢

  • oracle 服务启动,关闭脚本(windows系统下)

    2009-07-26 08:57:00
  • python中PyQuery库用法分享

    2023-12-05 03:08:31
  • python opencv肤色检测的实现示例

    2023-06-13 20:31:58
  • python语言线程标准库threading.local解读总结

    2023-12-22 18:18:07
  • 两侧背景自动延伸的CSS实现方法

    2010-02-24 09:42:00
  • 关于Django显示时间你应该知道的一些问题

    2023-10-23 06:26:21
  • JavaScript 使用技巧精萃(.net html

    2023-07-02 05:18:45
  • Tensorflow加载Vgg预训练模型操作

    2023-10-13 10:56:23
  • python 串行执行和并行执行实例

    2022-07-12 07:32:58
  • python geopandas读取、创建shapefile文件的方法

    2022-09-23 16:57:19
  • 如何在页面中快捷地添加翻页按钮?

    2010-06-26 12:33:00
  • ASP如何输出字符

    2007-09-22 18:41:00
  • ASP使用fso遍历文件及文件夹列出文件名

    2008-10-27 19:32:00
  • 详解Python Flask框架的安装及应用

    2022-06-20 11:12:50
  • 深入Oracle字符集的查看与修改详解

    2023-06-25 22:13:15
  • Django CBV与FBV原理及实例详解

    2023-02-14 20:39:01
  • pandas的排序、分组groupby及cumsum累计求和方式

    2023-07-20 07:00:39
  • 利用Echarts如何实现多段圆环图

    2024-04-28 09:36:22
  • pytest自动化测试数据驱动yaml/excel/csv/json

    2023-06-18 14:19:47
  • python版百度语音识别功能

    2023-02-28 09:56:34
  • asp之家 网络编程 m.aspxhome.com