tensorflow使用range_input_producer多线程读取数据实例

作者:lyg5623 时间:2022-10-19 16:43:21 

先放关键代码:


i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])

原理解析:

第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;

0,1,2,0,1,2

队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。

如果num_epochs不指定,则队列内容是这样子:

0,1,2,0,1,2,0,1,2,0,1,2...

队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。

下面是完整的演示代码。

数据文件test.txt内容:


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

main.py内容:


import tensorflow as tf
import codecs

BATCH_SIZE = 6
NUM_EXPOCHES = 5

def input_producer():
array = codecs.open("test.txt").readlines()
array = map(lambda line: line.strip(), array)
i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
return inputs

class Inputs(object):
def __init__(self):
 self.inputs = input_producer()

def main(*args, **kwargs):
inputs = Inputs()
init = tf.group(tf.initialize_all_variables(),
    tf.initialize_local_variables())
sess = tf.Session()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
sess.run(init)
try:
 index = 0
 while not coord.should_stop() and index<10:
  datalines = sess.run(inputs.inputs)
  index += 1
  print("step: %d, batch data: %s" % (index, str(datalines)))
except tf.errors.OutOfRangeError:
 print("Done traing:-------Epoch limit reached")
except KeyboardInterrupt:
 print("keyboard interrput detected, stop training")
finally:
 coord.request_stop()
coord.join(threads)
sess.close()
del sess

if __name__ == "__main__":
main()

输出:


step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
Done traing:-------Epoch limit reached

如果range_input_producer去掉参数num_epochs=1,则输出:


step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
step: 6, batch data: ['1' '2' '3' '4' '5' '6']
step: 7, batch data: ['7' '8' '9' '10' '11' '12']
step: 8, batch data: ['13' '14' '15' '16' '17' '18']
step: 9, batch data: ['19' '20' '21' '22' '23' '24']
step: 10, batch data: ['25' '26' '27' '28' '29' '30']

有一点需要注意,文件总共有35条数据,BATCH_SIZE = 6表示每个batch包含6条数据,NUM_EXPOCHES = 5表示产生5个batch,如果NUM_EXPOCHES =6,则总共需要36条数据,就会报如下错误:


InvalidArgumentError (see above for traceback): Expected size[0] in [0, 5], but got 6
[[Node: Slice = Slice[Index=DT_INT32, T=DT_STRING, _device="/job:localhost/replica:0/task:0/cpu:0"](Slice/input, Slice/begin/_5, Slice/size)]]

错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之间。

来源:https://blog.csdn.net/lyg5623/article/details/69387917

标签:tensorflow,多线程,读取,数据
0
投稿

猜你喜欢

  • MYSQL插入处理重复键值的几种方法

    2024-01-22 05:41:28
  • php文件类型MIME对照表(比较全)

    2023-06-08 07:24:10
  • Python md5与sha1加密算法用法分析

    2021-04-21 01:51:44
  • 如何安装控制器JavaScript生成插件详解

    2024-04-10 10:51:51
  • Python实现贪吃蛇小游戏(单人模式)

    2023-09-26 23:14:42
  • js中string转int把String类型转化成int类型

    2024-05-03 15:30:11
  • 详解Node.js 中使用 ECDSA 签名遇到的坑

    2024-05-08 09:36:01
  • Python遍历列表时删除元素案例

    2023-09-03 16:08:09
  • 一篇文章让你搞清楚JavaScript事件循环

    2024-04-19 09:53:02
  • 不是原型继承那么简单!prototype的深度探索

    2008-03-07 12:42:00
  • django query模块

    2021-12-01 09:16:22
  • Tornado 多进程实现分析详解

    2022-06-13 20:51:56
  • go语言实现sftp包上传文件和文件夹到远程服务器操作

    2024-05-08 10:22:18
  • python中matplotlib的颜色及线条控制的示例

    2023-11-04 08:11:50
  • python GUI库图形界面开发之PyQt5信号与槽多窗口数据传递详细使用方法与实例

    2022-06-05 14:11:58
  • ASP中取得图片宽度和高度的类

    2008-10-29 12:38:00
  • Python实现合并excel表格的方法分析

    2022-04-24 21:30:22
  • 分享9个好用的Python技巧

    2021-03-15 18:43:05
  • Vue computed 计算属性代码实例

    2024-05-09 15:14:39
  • 浅谈JavaScript 中的延迟加载属性模式

    2024-04-17 10:29:56
  • asp之家 网络编程 m.aspxhome.com