将自己的数据集制作成TFRecord格式教程

作者:v1_vivian 时间:2022-02-01 14:49:37 

在使用TensorFlow训练神经网络时,首先面临的问题是:网络的输入

此篇文章,教大家将自己的数据集制作成TFRecord格式,feed进网络,除了TFRecord格式,TensorFlow也支持其他格

式的数据,此处就不再介绍了。建议大家使用TFRecord格式,在后面可以通过api进行多线程的读取文件队列。

1. 原本的数据集

此时,我有两类图片,分别是xiansu100,xiansu60,每一类中有10张图片。

将自己的数据集制作成TFRecord格式教程

2.制作成TFRecord格式

tfrecord会根据你选择输入文件的类,自动给每一类打上同样的标签。如在本例中,只有0,1 两类,想知道文件夹名与label关系的,可以自己保存起来。


#生成整数型的属性
def _int64_feature(value):
return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))

#生成字符串类型的属性
def _bytes_feature(value):
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))

#制作TFRecord格式
def createTFRecord(filename,mapfile):
class_map = {}
data_dir = '/home/wc/DataSet/traffic/testTFRecord/'
classes = {'xiansu60','xiansu100'}
#输出TFRecord文件的地址

writer = tf.python_io.TFRecordWriter(filename)

for index,name in enumerate(classes):
 class_path=data_dir+name+'/'
 class_map[index] = name
 for img_name in os.listdir(class_path):
  img_path = class_path + img_name #每个图片的地址
  img = Image.open(img_path)
  img= img.resize((224,224))
  img_raw = img.tobytes()   #将图片转化成二进制格式
  example = tf.train.Example(features = tf.train.Features(feature = {
   'label':_int64_feature(index),
   'image_raw': _bytes_feature(img_raw)
  }))
  writer.write(example.SerializeToString())
writer.close()

txtfile = open(mapfile,'w+')
for key in class_map.keys():
 txtfile.writelines(str(key)+":"+class_map[key]+"\n")
txtfile.close()

此段代码,运行完后会产生生成的.tfrecord文件。

3. 读取TFRecord的数据,进行解析,此时使用了文件队列以及多线程


#读取train.tfrecord中的数据
def read_and_decode(filename):
#创建一个reader来读取TFRecord文件中的样例
reader = tf.TFRecordReader()
#创建一个队列来维护输入文件列表
filename_queue = tf.train.string_input_producer([filename], shuffle=False,num_epochs = 1)
#从文件中读出一个样例,也可以使用read_up_to一次读取多个样例
_,serialized_example = reader.read(filename_queue)
#  print _,serialized_example

#解析读入的一个样例,如果需要解析多个,可以用parse_example
features = tf.parse_single_example(
serialized_example,
features = {'label':tf.FixedLenFeature([], tf.int64),
   'image_raw': tf.FixedLenFeature([], tf.string),})
#将字符串解析成图像对应的像素数组
img = tf.decode_raw(features['image_raw'], tf.uint8)
img = tf.reshape(img,[224, 224, 3]) #reshape为128*128*3通道图片
img = tf.image.per_image_standardization(img)
labels = tf.cast(features['label'], tf.int32)
return img, labels

4. 将图片几个一打包,形成batch


def createBatch(filename,batchsize):
images,labels = read_and_decode(filename)

min_after_dequeue = 10
capacity = min_after_dequeue + 3 * batchsize

image_batch, label_batch = tf.train.shuffle_batch([images, labels],
             batch_size=batchsize,
             capacity=capacity,
             min_after_dequeue=min_after_dequeue
             )

label_batch = tf.one_hot(label_batch,depth=2)
return image_batch, label_batch

5.主函数


if __name__ =="__main__":
#训练图片两张为一个batch,进行训练,测试图片一起进行测试
mapfile = "/home/wc/DataSet/traffic/testTFRecord/classmap.txt"
train_filename = "/home/wc/DataSet/traffic/testTFRecord/train.tfrecords"
#  createTFRecord(train_filename,mapfile)
test_filename = "/home/wc/DataSet/traffic/testTFRecord/test.tfrecords"
#  createTFRecord(test_filename,mapfile)
image_batch, label_batch = createBatch(filename = train_filename,batchsize = 2)
test_images,test_labels = createBatch(filename = test_filename,batchsize = 20)
with tf.Session() as sess:
 initop = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
 sess.run(initop)
 coord = tf.train.Coordinator()
 threads = tf.train.start_queue_runners(sess = sess, coord = coord)

try:
  step = 0
  while 1:
   _image_batch,_label_batch = sess.run([image_batch,label_batch])
   step += 1
   print step
   print (_label_batch)
 except tf.errors.OutOfRangeError:
  print (" trainData done!")

try:
  step = 0
  while 1:
   _test_images,_test_labels = sess.run([test_images,test_labels])
   step += 1
   print step
#     print _image_batch.shape
   print (_test_labels)
 except tf.errors.OutOfRangeError:
  print (" TEST done!")
 coord.request_stop()
 coord.join(threads)

此时,生成的batch,就可以feed进网络了。

来源:https://blog.csdn.net/v1_vivian/article/details/77898414

标签:数据集,TFRecord
0
投稿

猜你喜欢

  • 使用SQL2000将现有代码作为Web服务提供

    2009-02-19 17:20:00
  • 用Python登录好友QQ空间点赞的示例代码

    2023-08-08 09:29:40
  • SQL查询入门(上篇) 推荐收藏

    2011-09-30 11:47:11
  • 如何将txt文本中的数据轻松导入MySQL表中

    2009-03-06 17:35:00
  • 简单解析PHP程序的运行流程

    2023-06-22 07:35:41
  • 利用Python正则表达式过滤敏感词的方法

    2023-05-07 05:05:18
  • 使用Golang的Context管理上下文的方法

    2023-06-29 06:37:23
  • 轻设计,让网站灵敏轻便的6个技巧

    2009-12-07 21:26:00
  • GIt在pyCharm的详细使用教程记录

    2021-11-21 02:21:05
  • Pandas实现自定义Excel格式并导出多个sheet表

    2022-10-04 18:46:34
  • python爬虫中采集中遇到的问题整理

    2022-10-17 03:32:23
  • Qt5 实现主窗口状态栏显示时间

    2022-05-29 23:54:45
  • 22个HTML5的初级技巧

    2010-12-17 12:39:00
  • 使用Pytorch来拟合函数方式

    2021-06-22 18:10:45
  • 一个ACCESS数据库数据传递的方法

    2008-03-05 11:58:00
  • 后端开发使用pycharm的技巧(推荐)

    2021-11-16 14:50:07
  • python命令行参数解析OptionParser类用法实例

    2022-06-21 17:57:24
  • python GUI库图形界面开发之PyQt5布局控件QVBoxLayout详细使用方法与实例

    2022-10-12 11:37:27
  • Go语言的IO库那么多纠结该如何选择

    2023-10-08 07:16:46
  • 快速配置PHPMyAdmin方法

    2023-07-16 07:05:20
  • asp之家 网络编程 m.aspxhome.com