TFRecord 数据格式的转换和读取
作者:互联网
什么TFRecord格式的数据?
Tensorflow支持的一种数据格式,内部使用了“Protocol Buffer”二进制数据编码方案,方便我们模型训练,验证,测试数据集的输入。
为什么提出TFRecord格式的数据?
通常情况下,我们使用Tensorflow搭建好网络模型之后,要输入数据进行训练,验证,测试,其对应的文件夹经常为 train,val, test文件夹,这些文件夹内部往往会存着上百万的数据文件,这些文件散列存放在磁盘上,并且读取时候非常慢,繁琐,会有大量的I/O操作。同时,占用大量内存空间)。而TFRecord格式的文件存储形式会很合理的帮我们存储数据,其内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率。
如何生成TFRecord格式的数据?
首先数据文件目录如下图:
------data
----------train
---------------dog
---------------cat
----------validation
---------------dog
---------------cat
上图为我们此次处理数据目录data为根目录,其下有两个文件夹train和validation,在train和validation下分别有dog和cat两个文件夹,其中存放对应图片数据。具体TFRecord格式数据转换如下代码:
import os
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
# 定义函数转化变量类型
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 将数据转化为tf.train.Example格式
def _make_example(label, image):
image_raw = image.tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'label': _int64_feature(label),
'image_raw': _bytes_feature(image_raw)}))
return example
# 读取图片
def read_images(sess,path,flag):
# 获取path下所有目录,同时包括path目录
sub_dirs = [x[0] for x in os.walk(path)]
# 去除path目录
is_root_dir = True
设置当前label标记为:0
current_label = 0
print("开始处理训练数据")
#开始生成TFRecord格式数据
with tf.python_io.TFRecordWriter("./data/dogsVScats_%s_.tfrecord" % flag) as writer:
# 读取所有的子目录
for sub_dir in sub_dirs:
if is_root_dir:
is_root_dir = False
continue
# 定义图像类型
extensions = ['jpg','png']
# 存放图像数据
file_list = []
# 获取文件的名字
dir_name = os.path.basename(sub_dir)
for extension in extensions:
# 文件匹配,类似正则表达式
file_glob = os.path.join(path, dir_name, '*.' + extension)
#将匹配数据加入列表
file_list.extend(glob.glob(file_glob))
if not file_list:
continue
print("processing:", dir_name)
i = 0
# 处理图片数据。
for file_name in file_list:
i += 1
image_raw_data = gfile.FastGFile(file_name, 'rb').read()
image = tf.image.decode_jpeg(image_raw_data)
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.image.resize(image, [299, 299])
image_value = sess.run(image)
example = _make_example(current_label, image_value)
writer.write(example.SerializeToString())
print("正在处理{}中的第{}张图片".format(dir_name,i))
current_label += 1
print("TFRecord %s 文件已保存" % flag)
# 执行产生tfrecord文件
with tf.Session() as sess:
read_images(sess,"./data/train","train")
read_images(sess,"./data/validation","validation")
标签:TFRecord,读取,image,train,file,tf,数据格式,dir 来源: https://blog.csdn.net/weixin_44402973/article/details/95009945