tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用

作者:yeqiustu 时间:2022-06-29 16:23:40 

1.创建tfrecord

tfrecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList、tf.train.Int64List、tf.train.FloatList写入tf.train.Feature,如下所示:


tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()])) #feature一般是多维数组,要先转为list
tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape))) #tostring函数后feature的形状信息会丢失,把shape也写入
tf.train.Feature(float_list=tf.train.FloatList(value=[label]))

通过上述操作,以dict的形式把要写入的数据汇总,并构建tf.train.Features,然后构建tf.train.Example,如下:


def get_tfrecords_example(feature, label):
tfrecords_features = {}
feat_shape = feature.shape
tfrecords_features['feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))

把创建的tf.train.Example序列化下,便可通过tf.python_io.TFRecordWriter写入tfrecord文件,如下:


tfrecord_wrt = tf.python_io.TFRecordWriter('xxx.tfrecord') #创建tfrecord的writer,文件名为xxx
exmp = get_tfrecords_example(feats[inx], labels[inx]) #把数据写入Example
exmp_serial = exmp.SerializeToString()  #Example序列化
tfrecord_wrt.write(exmp_serial)  #写入tfrecord文件
tfrecord_wrt.close()  #写完后关闭tfrecord的writer

代码汇总:


import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

mnist = read_data_sets("MNIST_data/", one_hot=True)
#把数据写入Example
def get_tfrecords_example(feature, label):
tfrecords_features = {}
feat_shape = feature.shape
tfrecords_features['feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
#把所有数据写入tfrecord文件
def make_tfrecord(data, outf_nm='mnist-train'):
feats, labels = data
outf_nm += '.tfrecord'
tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
ndatas = len(labels)
for inx in range(ndatas):
exmp = get_tfrecords_example(feats[inx], labels[inx])
exmp_serial = exmp.SerializeToString()
tfrecord_wrt.write(exmp_serial)
tfrecord_wrt.close()

import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)

# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
[mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')

# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
[mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')

# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

2.tfrecord文件的使用:tf.data.TFRecordDataset

从tfrecord文件创建TFRecordDataset:


dataset = tf.data.TFRecordDataset('xxx.tfrecord')

解析tfrecord文件的每条记录,即序列化后的tf.train.Example;使用tf.parse_single_example来解析:


feats = tf.parse_single_example(serial_exmp, features=data_dict)

其中,data_dict是一个dict,包含的key是写入tfrecord文件时用的key,相应的value则是tf.FixedLenFeature([], tf.string)、tf.FixedLenFeature([], tf.int64)、tf.FixedLenFeature([], tf.float32),分别对应不同的数据类型,汇总即有:


def parse_exmp(serial_exmp):  #label中[10]是因为一个label是一个有10个元素的列表,shape中的[x]为shape的长度
feats = tf.parse_single_example(serial_exmp, features={'feature':tf.FixedLenFeature([], tf.string),\
'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([x], tf.int64)})
image = tf.decode_raw(feats['feature'], tf.float32)
label = feats['label']
shape = tf.cast(feats['shape'], tf.int32)
return image, label, shape

解析tfrecord文件中的所有记录,使用dataset的map方法,如下:


dataset = dataset.map(parse_exmp)

map方法可以接受任意函数以对dataset中的数据进行处理;另外,可使用repeat、shuffle、batch方法对dataset进行重复、混洗、分批;用repeat复制dataset以进行多个epoch;如下:


dataset = dataset.repeat(epochs).shuffle(buffer_size).batch(batch_size)

解析完数据后,便可以取出数据进行使用,通过创建iterator来进行,如下:


iterator = dataset.make_one_shot_iterator()
batch_image, batch_label, batch_shape = iterator.get_next()

要把不同dataset的数据feed进行模型,则需要先创建iterator handle,即iterator placeholder,如下:


handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, \
dataset_train.output_types, dataset_train.output_shapes)
image, label, shape = iterator.get_next()

然后为各个dataset创建handle,以feed_dict传入placeholder,如下:


with tf.Session() as sess:
handle_train, handle_val, handle_test = sess.run(\
[x.string_handle() for x in [iter_train, iter_val, iter_test]])
   sess.run([loss, train_op], feed_dict={handle: handle_train}

汇总:


import tensorflow as tf

train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]

def parse_exmp(serial_exmp):
feats = tf.parse_single_example(serial_exmp, features={'feature':tf.FixedLenFeature([], tf.string),\
'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
image = tf.decode_raw(feats['feature'], tf.float32)
label = feats['label']
shape = tf.cast(feats['shape'], tf.int32)
return image, label, shape

def get_dataset(fname):
dataset = tf.data.TFRecordDataset(fname)
return dataset.map(parse_exmp) # use padded_batch method if padding needed

epochs = 16
batch_size = 50 # when batch_size can't be divided by nDatas, like 56,
# there will be a batch data with nums less than batch_size

# training dataset
nDatasTrain = 46750
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).batch(batch_size) # make sure repeat is ahead batch
 # this is different from dataset.shuffle(1000).batch(batch_size).repeat(epochs)
 # the latter means that there will be a batch data with nums less than batch_size for each epoch
 # if when batch_size can't be divided by nDatas.
nBatchs = nDatasTrain*epochs//batch_size

# evalation dataset
nDatasVal = 8250
dataset_val = get_dataset(val_f)
dataset_val = dataset_val.batch(nDatasVal).repeat(nBatchs//100*2)

# test dataset
nDatasTest = 10000
dataset_test = get_dataset(test_f)
dataset_test = dataset_test.batch(nDatasTest)

# make dataset iterator
iter_train = dataset_train.make_one_shot_iterator()
iter_val  = dataset_val.make_one_shot_iterator()
iter_test  = dataset_test.make_one_shot_iterator()

# make feedable iterator
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, \
dataset_train.output_types, dataset_train.output_shapes)
x, y_, _ = iterator.get_next()
train_op, loss, eval_op = model(x, y_)
init = tf.initialize_all_variables()

# summary
logdir = './logs/m4d2a'
def summary_op(datapart='train'):
tf.summary.scalar(datapart + '-loss', loss)
tf.summary.scalar(datapart + '-eval', eval_op)
return tf.summary.merge_all()
summary_op_train = summary_op()
summary_op_test = summary_op('val')

with tf.Session() as sess:
sess.run(init)
handle_train, handle_val, handle_test = sess.run(\
[x.string_handle() for x in [iter_train, iter_val, iter_test]])
   _, cur_loss, cur_train_eval, summary = sess.run([train_op, loss, eval_op, summary_op_train], \
 feed_dict={handle: handle_train, keep_prob: 0.5} )
   cur_val_loss, cur_val_eval, summary = sess.run([loss, eval_op, summary_op_test], \
 feed_dict={handle: handle_val, keep_prob: 1.0})

3.mnist实验


import tensorflow as tf

train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]

def parse_exmp(serial_exmp):
feats = tf.parse_single_example(serial_exmp, features={'feature':tf.FixedLenFeature([], tf.string),\
'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
image = tf.decode_raw(feats['feature'], tf.float32)
label = feats['label']
shape = tf.cast(feats['shape'], tf.int32)
return image, label, shape

def get_dataset(fname):
dataset = tf.data.TFRecordDataset(fname)
return dataset.map(parse_exmp) # use padded_batch method if padding needed

epochs = 16
batch_size = 50 # when batch_size can't be divided by nDatas, like 56,
# there will be a batch data with nums less than batch_size

# training dataset
nDatasTrain = 46750
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).batch(batch_size) # make sure repeat is ahead batch
 # this is different from dataset.shuffle(1000).batch(batch_size).repeat(epochs)
 # the latter means that there will be a batch data with nums less than batch_size for each epoch
 # if when batch_size can't be divided by nDatas.
nBatchs = nDatasTrain*epochs//batch_size

# evalation dataset
nDatasVal = 8250
dataset_val = get_dataset(val_f)
dataset_val = dataset_val.batch(nDatasVal).repeat(nBatchs//100*2)

# test dataset
nDatasTest = 10000
dataset_test = get_dataset(test_f)
dataset_test = dataset_test.batch(nDatasTest)

# make dataset iterator
iter_train = dataset_train.make_one_shot_iterator()
iter_val  = dataset_val.make_one_shot_iterator()
iter_test  = dataset_test.make_one_shot_iterator()

# make feedable iterator, i.e. iterator placeholder
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, \
dataset_train.output_types, dataset_train.output_shapes)
x, y_, _ = iterator.get_next()

# cnn
x_image = tf.reshape(x, [-1,28,28,1])
w_init = tf.truncated_normal_initializer(stddev=0.1, seed=9)
b_init = tf.constant_initializer(0.1)
cnn1 = tf.layers.conv2d(x_image, 32, (5,5), padding='same', activation=tf.nn.relu, \
kernel_initializer=w_init, bias_initializer=b_init)
mxpl1 = tf.layers.max_pooling2d(cnn1, 2, strides=2, padding='same')
cnn2 = tf.layers.conv2d(mxpl1, 64, (5,5), padding='same', activation=tf.nn.relu, \
kernel_initializer=w_init, bias_initializer=b_init)
mxpl2 = tf.layers.max_pooling2d(cnn2, 2, strides=2, padding='same')
mxpl2_flat = tf.reshape(mxpl2, [-1,7*7*64])
fc1 = tf.layers.dense(mxpl2_flat, 1024, activation=tf.nn.relu, \
kernel_initializer=w_init, bias_initializer=b_init)
keep_prob = tf.placeholder('float')
fc1_drop = tf.nn.dropout(fc1, keep_prob)
logits = tf.layers.dense(fc1_drop, 10, kernel_initializer=w_init, bias_initializer=b_init)

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_))
optmz = tf.train.AdamOptimizer(1e-4)
train_op = optmz.minimize(loss)

def get_eval_op(logits, labels):
corr_prd = tf.equal(tf.argmax(logits,1), tf.argmax(labels,1))
return tf.reduce_mean(tf.cast(corr_prd, 'float'))
eval_op = get_eval_op(logits, y_)

init = tf.initialize_all_variables()

# summary
logdir = './logs/m4d2a'
def summary_op(datapart='train'):
tf.summary.scalar(datapart + '-loss', loss)
tf.summary.scalar(datapart + '-eval', eval_op)
return tf.summary.merge_all()
summary_op_train = summary_op()
summary_op_val = summary_op('val')

# whether to restore or not
ckpts_dir = 'ckpts/'
ckpt_nm = 'cnn-ckpt'
saver = tf.train.Saver(max_to_keep=50) # defaults to save all variables, using dict {'x':x,...} to save specified ones.
restore_step = ''
start_step = 0
train_steps = nBatchs
best_loss = 1e6
best_step = 0

# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# config = tf.ConfigProto()
# config.gpu_options.per_process_gpu_memory_fraction = 0.9
# config.gpu_options.allow_growth=True # allocate when needed
# with tf.Session(config=config) as sess:
with tf.Session() as sess:
sess.run(init)
handle_train, handle_val, handle_test = sess.run(\
[x.string_handle() for x in [iter_train, iter_val, iter_test]])
if restore_step:
ckpt = tf.train.get_checkpoint_state(ckpts_dir)
if ckpt and ckpt.model_checkpoint_path: # ckpt.model_checkpoint_path means the latest ckpt
 if restore_step == 'latest':
 ckpt_f = tf.train.latest_checkpoint(ckpts_dir)
 start_step = int(ckpt_f.split('-')[-1]) + 1
 else:
 ckpt_f = ckpts_dir+ckpt_nm+'-'+restore_step
 print('loading wgt file: '+ ckpt_f)
 saver.restore(sess, ckpt_f)
summary_wrt = tf.summary.FileWriter(logdir,sess.graph)
if restore_step in ['', 'latest']:
for i in range(start_step, train_steps):
 _, cur_loss, cur_train_eval, summary = sess.run([train_op, loss, eval_op, summary_op_train], \
  feed_dict={handle: handle_train, keep_prob: 0.5} )
 # log to stdout and eval validation set
 if i % 100 == 0 or i == train_steps-1:
 saver.save(sess, ckpts_dir+ckpt_nm, global_step=i) # save variables
 summary_wrt.add_summary(summary, global_step=i)
 cur_val_loss, cur_val_eval, summary = sess.run([loss, eval_op, summary_op_val], \
  feed_dict={handle: handle_val, keep_prob: 1.0})
 if cur_val_loss < best_loss:
  best_loss = cur_val_loss
  best_step = i
 summary_wrt.add_summary(summary, global_step=i)
 print 'step %5d: loss %.5f, acc %.5f --- loss val %0.5f, acc val %.5f'%(i, \
  cur_loss, cur_train_eval, cur_val_loss, cur_val_eval)
 # sess.run(init_train)
with open(ckpts_dir+'best.step','w') as f:
 f.write('best step is %d\n'%best_step)
print 'best step is %d'%best_step
# eval test set
test_loss, test_eval = sess.run([loss, eval_op], feed_dict={handle: handle_test, keep_prob: 1.0})
print 'eval test: loss %.5f, acc %.5f'%(test_loss, test_eval)

实验结果:

tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用

来源:https://blog.csdn.net/yeqiustu/article/details/79793454

标签:tensorflow,tfrecord,tf.data.TFRecordDataset
0
投稿

猜你喜欢

  • 如何去除点击链接时出现的虚线框

    2007-12-02 17:38:00
  • GO语言入门学习之基本数据类型字符串

    2023-07-16 08:26:31
  • python3使用pandas获取股票数据的方法

    2023-01-04 15:01:15
  • Python注释详解

    2023-07-17 06:48:17
  • 高效率的GetRows()的使用方法

    2008-09-23 18:29:00
  • MYSQL教程:如何选择正确的数据列类型

    2009-02-27 16:05:00
  • Ethnique公司logo设计过程和思路

    2009-09-19 17:04:00
  • Javascript语法检查插件 jsLint for Vim

    2009-03-11 16:37:00
  • ASP生成html的新方法

    2011-04-02 11:04:00
  • 优化 SQL Server 索引的小技巧

    2012-10-07 11:00:07
  • python tkinter界面居中显示的方法

    2023-10-14 11:16:18
  • django-rest-framework 自定义swagger过程详解

    2023-01-01 22:05:34
  • asp javascript值的互相传递方法

    2011-03-30 10:37:00
  • 使用PyQt4 设置TextEdit背景的方法

    2021-09-01 14:41:43
  • Python中几种导入模块的方式总结

    2021-09-20 08:17:14
  • python使用celery实现异步任务执行的例子

    2021-01-06 03:27:20
  • 查看ASP详细错误提示信息的图文设置方法

    2011-02-05 11:02:00
  • 好用的asp防SQL注入代码

    2008-10-24 08:36:00
  • Python设计模式之工厂模式简单示例

    2022-10-06 13:22:38
  • Keras构建神经网络踩坑(解决model.predict预测值全为0.0的问题)

    2023-03-28 04:05:22
  • asp之家 网络编程 m.aspxhome.com