深度学习TextRNN的tensorflow1.14实现示例

作者:我是王大你是谁 时间:2023-12-31 18:59:23 

实现对下一个单词的预测

RNN 原理自己找,这里只给出简单例子的实现代码

import tensorflow as tf
import numpy as np
tf.reset_default_graph()
sentences = ['i love damao','i like mengjun','we love all']
words = list(set(" ".join(sentences).split()))
word2idx = {v:k for k,v in enumerate(words)}
idx2word = {k:v for k,v in enumerate(words)}
V = len(words)   # 词典大小
step = 2   # 时间序列长度
hidden = 5   # 隐层大小
dim = 50   # 词向量维度
# 制作输入和标签
def make_batch(sentences):
   input_batch = []
   target_batch = []
   for sentence in sentences:
       words = sentence.split()
       input = [word2idx[word] for word in words[:-1]]
       target = word2idx[words[-1]]
       input_batch.append(input)
       target_batch.append(np.eye(V)[target])   # 这里将标签改为 one-hot 编码,之后计算交叉熵的时候会用到
   return input_batch, target_batch
# 初始化词向量
embedding = tf.get_variable(shape=[V, dim], initializer=tf.random_normal_initializer(), name="embedding")
X = tf.placeholder(tf.int32, [None, step])
XX = tf.nn.embedding_lookup(embedding,  X)
Y = tf.placeholder(tf.int32, [None, V])
# 定义 cell
cell = tf.nn.rnn_cell.BasicRNNCell(hidden)
# 计算各个时间点的输出和隐层输出的结果
outputs, hiddens = tf.nn.dynamic_rnn(cell, XX, dtype=tf.float32)     # outputs: [batch_size, step, hidden] hiddens: [batch_size, hidden]
# 这里将所有时间点的状态向量都作为了后续分类器的输入(也可以只将最后时间节点的状态向量作为后续分类器的输入)
W = tf.Variable(tf.random_normal([step*hidden, V]))
b = tf.Variable(tf.random_normal([V]))
L = tf.matmul(tf.reshape(outputs,[-1, step*hidden]), W) + b
# 计算损失并进行优化
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y, logits=L))
optimizer = tf.train.AdamOptimizer(0.001).minimize(cost)
# 预测
prediction = tf.argmax(L, 1)
# 初始化 tf
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
# 喂训练数据
input_batch, target_batch = make_batch(sentences)
for epoch in range(5000):
   _, loss = sess.run([optimizer, cost], feed_dict={X:input_batch, Y:target_batch})
   if (epoch+1)%1000 == 0:
       print("epoch: ", '%04d'%(epoch+1), 'cost= ', '%04f'%(loss))
# 预测数据
predict = sess.run([prediction], feed_dict={X: input_batch})
print([sentence.split()[:2] for sentence in sentences], '->', [idx2word[n] for n in predict[0]])

结果打印

epoch:  1000 cost=  0.008979
epoch:  2000 cost=  0.002754
epoch:  3000 cost=  0.001283
epoch:  4000 cost=  0.000697
epoch:  5000 cost=  0.000406
[['i', 'love'], ['i', 'like'], ['we', 'love']] -> ['damao', 'mengjun', 'all'] 

来源:https://juejin.cn/post/6949412624215834638

标签:tensorflow,深度学习,TextRNN
0
投稿

猜你喜欢

  • 零基础学python应该从哪里入手

    2023-04-27 20:44:56
  • python 如何查看pytorch版本

    2021-02-05 05:38:59
  • Python实战之实现康威生命游戏

    2022-06-30 14:21:12
  • 用python实现文件备份

    2022-04-19 06:28:19
  • Python三元运算实现方法

    2021-12-27 06:02:52
  • 微信小程序开发之获取用户手机号码(php接口解密)

    2023-11-15 03:34:59
  • pytorch 如何把图像数据集进行划分成train,test和val

    2023-12-26 15:28:10
  • js获取css的各种样式并且设置他们的方法

    2024-04-18 10:10:33
  • 多表关联同时更新多条不同的记录方法分享

    2011-11-03 17:34:25
  • Apache下禁止特定目录执行PHP 提高服务器安全性

    2023-10-25 20:10:50
  • python中的Elasticsearch操作汇总

    2022-01-29 10:44:45
  • 探究MySQL中索引和提交频率对InnoDB表写入速度的影响

    2024-01-26 08:03:22
  • Python编码类型转换方法详解

    2022-02-19 07:13:54
  • 如何利用insert into values插入多条数据

    2024-01-24 04:39:54
  • MyBatis实现Mysql数据库分库分表操作和总结(推荐)

    2024-01-24 07:19:11
  • Python中logging日志库实例详解

    2023-10-04 13:26:25
  • MySQL操作数据库和表的常用命令新手教程

    2024-01-23 23:18:36
  • Javascript实现图片懒加载插件的方法

    2024-04-19 10:16:44
  • 总结网络IO模型与select模型的Python实例讲解

    2021-10-16 22:09:41
  • SQL server 2005中设置自动编号字段的方法

    2024-01-12 13:55:47
  • asp之家 网络编程 m.aspxhome.com