其他分享
首页 > 其他分享> > TFRecord 数据格式的转换和读取

TFRecord 数据格式的转换和读取

作者:互联网

什么TFRecord格式的数据?

       Tensorflow支持的一种数据格式,内部使用了“Protocol Buffer”二进制数据编码方案,方便我们模型训练,验证,测试数据集的输入。

为什么提出TFRecord格式的数据?

       通常情况下,我们使用Tensorflow搭建好网络模型之后,要输入数据进行训练,验证,测试,其对应的文件夹经常为 train,val, test文件夹,这些文件夹内部往往会存着上百万的数据文件,这些文件散列存放在磁盘上,并且读取时候非常慢,繁琐,会有大量的I/O操作。同时,占用大量内存空间)。而TFRecord格式的文件存储形式会很合理的帮我们存储数据,其内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率。

如何生成TFRecord格式的数据?

 首先数据文件目录如下图:

------data

----------train

---------------dog

---------------cat

----------validation

---------------dog

---------------cat

上图为我们此次处理数据目录data为根目录,其下有两个文件夹train和validation,在train和validation下分别有dog和cat两个文件夹,其中存放对应图片数据。具体TFRecord格式数据转换如下代码:

import os
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile

# 定义函数转化变量类型
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# 将数据转化为tf.train.Example格式
def _make_example(label, image):
    image_raw = image.tostring()
    example = tf.train.Example(features=tf.train.Features(feature={
		'label': _int64_feature(label),
		'image_raw': _bytes_feature(image_raw)}))

    return example


# 读取图片
def read_images(sess,path,flag):
    
    # 获取path下所有目录,同时包括path目录
    sub_dirs = [x[0] for x in os.walk(path)]
    
    # 去除path目录
    is_root_dir = True
    
    设置当前label标记为:0
    current_label = 0
    
    print("开始处理训练数据")
    
    #开始生成TFRecord格式数据
    with tf.python_io.TFRecordWriter("./data/dogsVScats_%s_.tfrecord" % flag) as writer:
       
       # 读取所有的子目录
       for sub_dir in sub_dirs:
           if is_root_dir:
              is_root_dir = False
              continue
           
           # 定义图像类型
	   extensions = ['jpg','png']
           # 存放图像数据
	   file_list = []
	
           # 获取文件的名字
	   dir_name = os.path.basename(sub_dir)
	   for extension in extensions:
               # 文件匹配,类似正则表达式
               file_glob = os.path.join(path, dir_name, '*.' + extension)
               
               #将匹配数据加入列表
               file_list.extend(glob.glob(file_glob))
	       if not file_list:
		   continue
	       print("processing:", dir_name)
	       i = 0
	       # 处理图片数据。
	       for file_name in file_list:
                   i += 1
                   image_raw_data = gfile.FastGFile(file_name, 'rb').read()
		   image = tf.image.decode_jpeg(image_raw_data)
		   if image.dtype != tf.float32:
                       image = tf.image.convert_image_dtype(image, dtype=tf.float32)
		   image = tf.image.resize(image, [299, 299])
		   image_value = sess.run(image)
		   example = _make_example(current_label, image_value)
		   writer.write(example.SerializeToString())
		   print("正在处理{}中的第{}张图片".format(dir_name,i))
	        
               current_label += 1
        
        print("TFRecord %s 文件已保存" % flag)


# 执行产生tfrecord文件
with tf.Session() as sess:
    read_images(sess,"./data/train","train")
    read_images(sess,"./data/validation","validation")

 

标签:TFRecord,读取,image,train,file,tf,数据格式,dir
来源: https://blog.csdn.net/weixin_44402973/article/details/95009945