Tensorflow使用tfrecord输入数据格式

作者:ruyiweicas 时间:2022-06-18 22:55:40 

Tensorflow 提供了一种统一的格式来存储数据,这个格式就是TFRecord,上一篇文章中所提到的方法当数据的来源更复杂,每个样例中的信息更丰富的时候就很难有效的记录输入数据中的信息了,于是Tensorflow提供了TFRecord来统一存储数据,接下来我们就来介绍如何使用TFRecord来同意输入数据的格式。

1. TFRecord格式介绍

TFRecord文件中的数据是通过tf.train.Example Protocol Buffer的格式存储的,下面是tf.train.Example的定义


message Example {
Features features = 1;
};

message Features{
map<string,Feature> featrue = 1;
};

message Feature{
 oneof kind{
   BytesList bytes_list = 1;
   FloatList float_list = 2;
   Int64List int64_list = 3;
 }
};

从上述代码可以看到,ft.train.Example 的数据结构相对简洁。tf.train.Example中包含了一个从属性名称到取值的字典,其中属性名称为一个字符串,属性的取值可以为字符串(BytesList ),实数列表(FloatList )或整数列表(Int64List )。例如我们可以将解码前的图片作为字符串,图像对应的类别标号作为整数列表。

2. 将自己的数据转化为TFRecord格式

准备数据

在上一篇中,我们为了像伟大的MNIST致敬,所以选择图像的前缀来进行不同类别的分类依据,但是大多数的情况下,在进行分类任务的过程中,不同的类别都会放在不同的文件夹下,而且类别的个数往往浮动性又很大,所以针对这样的情况,我们现在利用不同类别在不同文件夹中的图像来生成TFRecord.

我们在Iris&Contact这个文件夹下有两个文件夹,分别为iris,contact。对于每个文件夹中存放的是对应的图片

转换数据

数据准备好以后,就开始准备生成TFRecord,具体代码如下:


import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt

cwd='/home/ruyiwei/Documents/Iris&Contact/'
classes={'iris','contact'}
writer= tf.python_io.TFRecordWriter("iris_contact.tfrecords")

for index,name in enumerate(classes):
 class_path=cwd+name+'/'
 for img_name in os.listdir(class_path):
   img_path=class_path+img_name
   img=Image.open(img_path)
   img= img.resize((512,80))
   img_raw=img.tobytes()
   #plt.imshow(img) # if you want to check you image,please delete '#'
   #plt.show()
   example = tf.train.Example(features=tf.train.Features(feature={
     "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
     'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
   }))
   writer.write(example.SerializeToString())

writer.close()

3. Tensorflow从TFRecord中读取数据


def read_and_decode(filename): # read iris_contact.tfrecords
 filename_queue = tf.train.string_input_producer([filename])# create a queue

reader = tf.TFRecordReader()
 _, serialized_example = reader.read(filename_queue)#return file_name and file
 features = tf.parse_single_example(serialized_example,
                   features={
                     'label': tf.FixedLenFeature([], tf.int64),
                     'img_raw' : tf.FixedLenFeature([], tf.string),
                   })#return image and label

img = tf.decode_raw(features['img_raw'], tf.uint8)
 img = tf.reshape(img, [512, 80, 3]) #reshape image to 512*80*3
 img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #throw img tensor
 label = tf.cast(features['label'], tf.int32) #throw label tensor
 return img, label

4. 将TFRecord中的数据保存为图片


filename_queue = tf.train.string_input_producer(["iris_contact.tfrecords"])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)  #return file and file_name
features = tf.parse_single_example(serialized_example,
                 features={
                   'label': tf.FixedLenFeature([], tf.int64),
                   'img_raw' : tf.FixedLenFeature([], tf.string),
                 })
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [512, 80, 3])
label = tf.cast(features['label'], tf.int32)
with tf.Session() as sess:
 init_op = tf.initialize_all_variables()
 sess.run(init_op)
 coord=tf.train.Coordinator()
 threads= tf.train.start_queue_runners(coord=coord)
 for i in range(20):
   example, l = sess.run([image,label])#take out image and label
   img=Image.fromarray(example, 'RGB')
   img.save(cwd+str(i)+'_''Label_'+str(l)+'.jpg')#save image
   print(example, l)
 coord.request_stop()
 coord.join(threads)

来源:https://blog.csdn.net/best_coder/article/details/70146441

标签:Tensorflow,tfrecord
0
投稿

猜你喜欢

  • Matplotlib绘制条形图的方法你知道吗

    2022-12-05 15:23:59
  • 在Python中处理日期和时间的基本知识点整理汇总

    2021-05-13 07:12:14
  • python如何定义带参数的装饰器

    2022-01-07 04:18:30
  • 如何远程连接SQL Server数据库

    2009-06-08 12:41:00
  • python中异常捕获方法详解

    2021-10-30 10:06:09
  • SQL Join的一些总结(实例)

    2012-08-21 10:19:29
  • 网站重构 CSS样式表的优化技巧

    2009-05-12 11:51:00
  • ASP初学者常犯的几个错误

    2007-09-07 10:19:00
  • Python实现翻转数组功能示例

    2022-02-28 09:03:09
  • Python3 pandas 操作列表实例详解

    2021-11-30 14:24:12
  • Bootstrap每天必学之响应式导航、轮播图

    2023-08-15 03:29:45
  • Python基于最小二乘法实现曲线拟合示例

    2021-08-06 15:47:07
  • Python 键盘事件详解

    2022-09-28 20:31:01
  • Python asyncore socket客户端开发基本使用教程

    2021-01-25 11:06:39
  • Firefox 3.5 新增加的支持(整理)

    2009-08-01 12:51:00
  • Python 聊聊socket中的listen()参数(数字)到底代表什么

    2022-10-17 00:49:25
  • Python 爬虫多线程详解及实例代码

    2021-01-25 14:05:03
  • 详解如何利用Python进行客户分群分析

    2023-04-25 16:47:09
  • python中用ggplot绘制画图实例讲解

    2023-07-04 07:25:16
  • 对python中数组的del,remove,pop区别详解

    2021-01-23 09:22:31
  • asp之家 网络编程 m.aspxhome.com