tensorflow2 加载数据方法总结
作者:互联网
tensorflow2 加载数据方法总结
- 1.tfrecord
- 2.tf.data.Dataset
- 3.tf.keras.utils.Sequence
- 4.tf.keras.preprocessing.image.ImageDataGenerator
1.tfrecord
tfrecord 是将训练数据和label数据打包成二进制文件,然后在训练的时候可以快速的读取,节省io操作
1.1 tfrecord 打包
tfrecord_file = "train.tfrecord"
train_write = tf.io.TFRecordWriter(tfrecord_file)
image = cv2.imread("test.jpg")
label = 1
feature = {
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image.tobytes()])),
"label": tf.train.Feature(float_list=tf.train.Int64List(value=[label]))
}
features = tf.train.Features(feature=feature)
example = tf.train.Example(features=features)
train_write.write(example.SerializeToString())
train_write.close()
- tf.train.BytesList:字符串或者原始Byte文件(比如图片类型文件)
- tf.train.FloatList:浮点
- tf.train.Int64List:64位整型
1.2 tfrecord 读取
def load_tfrecord(path, image_size, batch):
"""
加载 tfrecord 数据
:param path: tfrecord 文件路径
:param image_size: 输入图片尺寸
:param batch: batch 大小
:return: dataset 迭代器
"""
raw_dataset = tf.data.TFRecordDataset(path)
feature_description = {
"image": tf.io.FixedLenFeature([], tf.string),
"label": tf.io.FixedLenFeature([], tf.int64)
}
def _parse_example(example_string):
feature_dict = tf.io.parse_example(example_string, feature_description)
# 在打包图片数据的时候对图片数据进行了归一化在0-1之间,因此需要用 tf.float64
feature_dict["image"] = tf.io.decode_raw(feature_dict["image"], tf.float64)
# 编码后的图像要进行reshape
feature_dict["image"] = tf.cast(tf.reshape(feature_dict["image"], (image_size, image_size, 3)), tf.float32)
# 还可以在这里进行图片的增强处理
# tf.image.
return feature_dict["image"], feature_dict["label"]
dataset = raw_dataset.map(_parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.shuffle(buffer_size=batch * 1000)
dataset = dataset.batch(batch)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
2.tf.data.Dataset
这里按照猫狗图片为例
def load_image(path, batch, size):
"""
path:图片目录
batch:batch size
size:模型输入图片尺寸
"""
def _decode_and_resize(filename, label):
# 读取图片
image_string = tf.io.read_file(filename)
# 解码
image = tf.image.decode_jpeg(image_string)
# 随机镜像
image = tf.image.random_flip_left_right(image)
# 随机对比度
image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
image = tf.image.random_saturation(image, lower=0.8, upper=1.2)
# 通过pad来对图片进行resize,不改变图片原来宽高比例
image_resized = tf.image.resize_with_pad(image, size, size) / 255.0
return image_resized, label
cats_dir = path + "/cat/"
dogs_dir = path + "/dog/"
# 组装路径列表
cat_filenames = tf.constant([cats_dir + filename for filename in os.listdir(cats_dir)])
dog_filenames = tf.constant([dogs_dir + filename for filename in os.listdir(dogs_dir)])
filenames = tf.concat([cat_filenames, dog_filenames], axis=-1)
# 标签
labels = tf.concat([
tf.zeros(cat_filenames.shape, dtype=tf.int32),
tf.ones(dog_filenames.shape, dtype=tf.int32)],
axis=-1)
datasets = tf.data.Dataset.from_tensor_slices((filenames, labels))
datasets = datasets.map(
map_func=_decode_and_resize,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
datasets = datasets.shuffle(buffer_size=batch * 100)
datasets = datasets.batch(batch)
datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)
return datasets
3.tf.keras.utils.Sequence
4.tf.keras.preprocessing.image.ImageDataGenerator
主要用于图片分类,自带数据增强功能
标签:tensorflow2,总结,tfrecord,image,batch,feature,train,tf,加载 来源: https://blog.csdn.net/xf8964/article/details/110286003