其他分享
首页 > 其他分享> > tensorflow2 加载数据方法总结

tensorflow2 加载数据方法总结

作者:互联网

tensorflow2 加载数据方法总结

1.tfrecord

tfrecord 是将训练数据和label数据打包成二进制文件,然后在训练的时候可以快速的读取,节省io操作

1.1 tfrecord 打包

tfrecord_file = "train.tfrecord"
train_write = tf.io.TFRecordWriter(tfrecord_file)
image = cv2.imread("test.jpg")
label = 1
feature = {
            "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image.tobytes()])),
            "label": tf.train.Feature(float_list=tf.train.Int64List(value=[label]))
        }
features = tf.train.Features(feature=feature)
example = tf.train.Example(features=features)
train_write.write(example.SerializeToString())
train_write.close()

1.2 tfrecord 读取

def load_tfrecord(path, image_size, batch):
    """
    加载 tfrecord 数据
    :param path: tfrecord 文件路径
    :param image_size: 输入图片尺寸
    :param batch: batch 大小
    :return: dataset 迭代器
    """
    raw_dataset = tf.data.TFRecordDataset(path)

    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.FixedLenFeature([], tf.int64)
    }

    def _parse_example(example_string):
        feature_dict = tf.io.parse_example(example_string, feature_description)
        # 在打包图片数据的时候对图片数据进行了归一化在0-1之间,因此需要用 tf.float64
        feature_dict["image"] = tf.io.decode_raw(feature_dict["image"], tf.float64)
        # 编码后的图像要进行reshape
        feature_dict["image"] = tf.cast(tf.reshape(feature_dict["image"], (image_size, image_size, 3)), tf.float32)
        # 还可以在这里进行图片的增强处理
        # tf.image.
        return feature_dict["image"], feature_dict["label"]

    dataset = raw_dataset.map(_parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)

    dataset = dataset.shuffle(buffer_size=batch * 1000)
    dataset = dataset.batch(batch)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    return dataset

2.tf.data.Dataset

这里按照猫狗图片为例

def load_image(path, batch, size):
	"""
	path:图片目录
	batch:batch size
	size:模型输入图片尺寸
	"""
    def _decode_and_resize(filename, label):
    	# 读取图片
        image_string = tf.io.read_file(filename)
        # 解码
        image = tf.image.decode_jpeg(image_string)
        # 随机镜像
        image = tf.image.random_flip_left_right(image)
        # 随机对比度
        image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
        image = tf.image.random_saturation(image, lower=0.8, upper=1.2)
        # 通过pad来对图片进行resize,不改变图片原来宽高比例
        image_resized = tf.image.resize_with_pad(image, size, size) / 255.0
        return image_resized, label

    cats_dir = path + "/cat/"
    dogs_dir = path + "/dog/"

	# 组装路径列表
    cat_filenames = tf.constant([cats_dir + filename for filename in os.listdir(cats_dir)])
    dog_filenames = tf.constant([dogs_dir + filename for filename in os.listdir(dogs_dir)])
    filenames = tf.concat([cat_filenames, dog_filenames], axis=-1)
    # 标签
    labels = tf.concat([
        tf.zeros(cat_filenames.shape, dtype=tf.int32),
        tf.ones(dog_filenames.shape, dtype=tf.int32)],
        axis=-1)

    datasets = tf.data.Dataset.from_tensor_slices((filenames, labels))
    datasets = datasets.map(
        map_func=_decode_and_resize,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    datasets = datasets.shuffle(buffer_size=batch * 100)
    datasets = datasets.batch(batch)
    datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)

    return datasets

3.tf.keras.utils.Sequence

4.tf.keras.preprocessing.image.ImageDataGenerator

主要用于图片分类,自带数据增强功能

标签:tensorflow2,总结,tfrecord,image,batch,feature,train,tf,加载
来源: https://blog.csdn.net/xf8964/article/details/110286003