其他分享
首页 > 其他分享> > 轻轻松松使用StyleGAN2(六):StyleGAN2 Encoder是怎样加载训练数据的?源代码+中文注释,dataset_tool.py和dataset.py

轻轻松松使用StyleGAN2(六):StyleGAN2 Encoder是怎样加载训练数据的?源代码+中文注释,dataset_tool.py和dataset.py

作者:互联网

上一篇文章里,我们简单介绍了StyleGAN2 Encoder的一部分源代码,即:projector.py和project_images.py,内容请参考:

轻轻松松使用StyleGAN2(五):StyleGAN2 Encoder源代码初探+中文注释,projector.py和project_images.py

其中有两个函数,涉及到加载训练数据的功能,在这篇文章里我们花一点时间来看一下。

这两个函数都在project_images.py里,分别是:

(1)dataset_tool.create_from_images(tfrecord_dir, image_dir, shuffle=0)

(2)dataset_obj = dataset.load_dataset(
        data_dir=data_dir, tfrecord_dir='tfrecords',
        max_label_size=0, repeat=False, shuffle_mb=0
    )

下面分别对这两个函数进行介绍:

(一)第一个函数dataset_tool.create_from_images(),在./dataset_tool.py里定义,是TFRecordExporter类中的函数,其中:

create_from_images()的功能是按不同的lod(levels of detail),将对应的shape和图像信息序列化存入文件,其源代码如下(含中文注释):

def create_from_images(tfrecord_dir, image_dir, shuffle):
    print('Loading images from "%s"' % image_dir)
    image_filenames = sorted(glob.glob(os.path.join(image_dir, '*')))
    if len(image_filenames) == 0:
        error('No input images found')

    img = np.asarray(PIL.Image.open(image_filenames[0]))
    resolution = img.shape[0]
    channels = img.shape[2] if img.ndim == 3 else 1 # img.ndim是指图像数组的维度,灰度图像是2,彩色图像是3
    # 举例,某个图像的shape是(1024, 1024, 3)
    if img.shape[1] != resolution:
        error('Input images must have the same width and height')
    if resolution != 2 ** int(np.floor(np.log2(resolution))):
        error('Input image resolution must be a power-of-two')
    if channels not in [1, 3]:
        error('Input images must be stored as RGB or grayscale')

    with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:
        # 乱序或顺序
        order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames))
        for idx in range(order.size):
            img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))
            if channels == 1:
                img = img[np.newaxis, :, :] # HW => CHW,高宽=>通道数·高·宽
            else:
                img = img.transpose([2, 0, 1]) # HWC => CHW,高宽通道数=>通道数·高·宽
            # 按不同的lod,将对应的shape和图像信息序列化存入文件
            tfr.add_image(img)

create_from_images()函数的核心功能由add_image()完成,以源图像尺寸为1024x1024为例,它主要是按照lod从0--8,按不同尺寸生成大小各不相同的shape和图像数据,分别序列化存储到tfrecords-r10.tfrecords到tfrecords-r02.tfrecords等9个临时文件中,其源代码(含中文注释)如下:

    def add_image(self, img):
        if self.print_progress and self.cur_images % self.progress_interval == 0:
            print('%d / %d\r' % (self.cur_images, self.expected_images), end='', flush=True)
        if self.shape is None:
            self.shape = img.shape
            self.resolution_log2 = int(np.log2(self.shape[1]))
            assert self.shape[0] in [1, 3]
            assert self.shape[1] == self.shape[2]
            assert self.shape[1] == 2**self.resolution_log2
            tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
            # 举例,如果图像是1024x1024,则lod从0--8,tfr_file命名从10--2,如:tfrecords-r05.tfrecords
            for lod in range(self.resolution_log2 - 1):
                tfr_file = self.tfr_prefix + '-r%02d.tfrecords' % (self.resolution_log2 - lod) # 10 - lod
                # TFRecordWriter是将记录写入TFRecords文件的类
                # 每个层级的lod(level of detail)对应一个TFRecordWriter,共同组成一个list
                self.tfr_writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt))
        assert img.shape == self.shape
        for lod, tfr_writer in enumerate(self.tfr_writers):
            # lod第一个值为0,因此保持原有图像大小不变,从第二个lod开始尺寸减半
            if lod:
                img = img.astype(np.float32)
                # 以2x2的方格为例,(0,0)(0,1)(1,0)(1,1)四个像素取平均值(求和除以4)
                # 图像的高和宽减半
                img = (img[:, 0::2, 0::2] + img[:, 0::2, 1::2] + img[:, 1::2, 0::2] + img[:, 1::2, 1::2]) * 0.25
            # img四舍五入,裁剪到(0, 255)区间内,转换为unint8
            quant = np.rint(img).clip(0, 255).astype(np.uint8)
            # tf.train.Example()将shape和data组装为二进制数据
            ex = tf.train.Example(features=tf.train.Features(feature={
                'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=quant.shape)),
                'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[quant.tostring()]))}))
            # SerializeToString()将ex序列化为字符串,写入tfr_writer文件
            tfr_writer.write(ex.SerializeToString())
        self.cur_images += 1

(二)第二个函数dataset.load_dataset(),在./training/dataset.py里定义,是TFRecordDataset类中的函数,它是一个helper(助手)函数,用于构建dataset对象(在TFRecordDataset类创建对象实例时完成)。

这部分代码的核心是TFRecordDataset类,TFRecordDataset类用于从tfrecords文件中加载不同lod(levels of detail,以1024x1024的图像为例,从0--8)的shape、labels(标签)和图像数据,并且定义了这些数据的预处理方法、shuffle缓冲区大小、批大小、预读取缓冲区大小、可复用的迭代器等,其中使用的方法包括:tf.data.Dataset.from_tensor_slices(),tf.data.TFRecordDataset(),tf.data.Iterator.from_structure(),dataset.map(),dataset.shuffle(),dataset.prefetch(),dataset.batch()等等,其源代码(含中文注释)如下:

import os
import glob
import numpy as np
import tensorflow as tf
import dnnlib
import dnnlib.tflib as tflib

#----------------------------------------------------------------------------
# Dataset class that loads data from tfrecords files.
# TFRecordDataset类用于从tfrecords文件中加载数据
# 使用的方法包括:tf.data.Dataset.from_tensor_slices(),tf.data.TFRecordDataset(),tf.data.Iterator.from_structure(),
# dataset.map(),dataset.shuffle(),dataset.prefetch(),dataset.batch()等

class TFRecordDataset:
    def __init__(self,
        tfrecord_dir,               # Directory containing a collection of tfrecords files.
        resolution      = None,     # Dataset resolution, None = autodetect.
        label_file      = None,     # Relative path of the labels file, None = autodetect.
        max_label_size  = 0,        # 0 = no labels, 'full' = full labels, <int> = N first label components.
        max_images      = None,     # Maximum number of images to use, None = use all images.
        repeat          = True,     # Repeat dataset indefinitely?
        shuffle_mb      = 4096,     # Shuffle data within specified window (megabytes), 0 = disable shuffling.
        prefetch_mb     = 2048,     # Amount of data to prefetch (megabytes), 0 = disable prefetching.
        buffer_mb       = 256,      # Read buffer size (megabytes).
        num_threads     = 2):       # Number of concurrent threads.并发线程数

        self.tfrecord_dir       = tfrecord_dir
        self.resolution         = None
        self.resolution_log2    = None
        self.shape              = []        # [channels, height, width]
        self.dtype              = 'uint8'
        self.dynamic_range      = [0, 255]
        self.label_file         = label_file
        self.label_size         = None      # components
        self.label_dtype        = None
        self._np_labels         = None
        self._tf_minibatch_in   = None
        self._tf_labels_var     = None
        self._tf_labels_dataset = None
        self._tf_datasets       = dict()
        self._tf_iterator       = None
        self._tf_init_ops       = dict()
        self._tf_minibatch_np   = None
        self._cur_minibatch     = -1
        self._cur_lod           = -1

        # List tfrecords files and inspect their shapes.
        # 获得tfrecord_dir目录下的所有.tfrecords文件的列表,赋值给tfr_files
        assert os.path.isdir(self.tfrecord_dir)
        tfr_files = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.tfrecords')))
        assert len(tfr_files) >= 1
        tfr_shapes = []
        # 遍历tfr_files文件列表
        for tfr_file in tfr_files:
            tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
            # 读取tfr_file文件
            for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt):
                # 从一个tfrecords文件中读取一个独立的图像数据,并把shape信息加入到tfr_shapes列表中,用于建立TF表达式
                # 为了节省开销,这里只读取了图像数据的第一个字节,真正读取全部图像数据的操作在下面建立TF表达式时完成
                tfr_shapes.append(self.parse_tfrecord_np(record).shape)
                break

        # Autodetect label filename.
        # 自动发现标签文件,标签文件用于图像文件的各个属性标记和说明
        if self.label_file is None:
            guess = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.labels')))
            if len(guess):
                self.label_file = guess[0]
        elif not os.path.isfile(self.label_file):
            guess = os.path.join(self.tfrecord_dir, self.label_file)
            if os.path.isfile(guess):
                self.label_file = guess

        # Determine shape and resolution.
        # 确定图片的shape和resolution(分辨率)
        max_shape = max(tfr_shapes, key=np.prod)
        self.resolution = resolution if resolution is not None else max_shape[1] # shape[1]:高(height)
        self.resolution_log2 = int(np.log2(self.resolution))
        self.shape = [max_shape[0], self.resolution, self.resolution] # shape[0]:通道数(channels)
        # tfr_lods列表用于建立TF表达式
        # 举例,1024x1024的resolution_log2是10,log2(shape[1])从10--2
        tfr_lods = [self.resolution_log2 - int(np.log2(shape[1])) for shape in tfr_shapes]
        assert all(shape[0] == max_shape[0] for shape in tfr_shapes)
        assert all(shape[1] == shape[2] for shape in tfr_shapes)
        # 下面这句话,是验证resolution和Lod的数学关系
        assert all(shape[1] == self.resolution // (2**lod) for shape, lod in zip(tfr_shapes, tfr_lods))
        assert all(lod in tfr_lods for lod in range(self.resolution_log2 - 1))

        # Load labels.
        # 加载labels(标签)
        assert max_label_size == 'full' or max_label_size >= 0
        # 赋值给_np_labels,1<<30创建一个很大的初始值2**30,即:至少包含30个属性,初始化为0
        self._np_labels = np.zeros([1<<30, 0], dtype=np.float32)
        if self.label_file is not None and max_label_size != 0:
            self._np_labels = np.load(self.label_file)
            assert self._np_labels.ndim == 2
        if max_label_size != 'full' and self._np_labels.shape[1] > max_label_size:
            self._np_labels = self._np_labels[:, :max_label_size]
        if max_images is not None and self._np_labels.shape[0] > max_images:
            self._np_labels = self._np_labels[:max_images]
        # 将读取到的标签个数赋值给label_size
        self.label_size = self._np_labels.shape[1]
        self.label_dtype = self._np_labels.dtype.name

        # Build TF expressions.
        # 建立TF表达式
        # 定义一块名为Dataset的区域,并在其中工作;使用CPU内存(即:不使用GPU,减轻显存的压力)
        with tf.name_scope('Dataset'), tf.device('/cpu:0'):
            # 设置输入项minibatch_in
            self._tf_minibatch_in = tf.placeholder(tf.int64, name='minibatch_in', shape=[])
            # 按照_np_lablels列表的shape创建变量labels_var,每个标签数据具有很大的初始值,即:至少包含30个属性
            self._tf_labels_var = tflib.create_var_with_large_initial_value(self._np_labels, name='labels_var')
            # 将_tf_labels_var列表读进内存,将array转化为tensor,赋值给_tf_labels_dataset
            self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(self._tf_labels_var)
            # 遍历tfr_files列表,读取tfr_file中保存的图像信息,将信息加载到dset
            for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes, tfr_lods):
                if tfr_lod < 0:
                    continue
                # 定义一个dataset,对应于tfr_file,数据读取的缓冲区大小为buffer_mb(MB)
                dset = tf.data.TFRecordDataset(tfr_file, compression_type='', buffer_size=buffer_mb<<20)
                # (最多)取max_images组数据
                if max_images is not None:
                    dset = dset.take(max_images)
                # 调用函数parse_tfrecord_tf对dset数据(即:对应的tfr_file)进行预处理,
                # 即:解析输入数据,把串行数据解析为shape和图像信息;并发操作,线程数=num_threads
                dset = dset.map(self.parse_tfrecord_tf, num_parallel_calls=num_threads)
                # 把dset数据和_tf_labels_dataset组合,赋值给dset
                dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))
                # 计算每个lod对应的图像数据item存储需要的空间,如:1024x1024x(uint8.itemsize=1)
                bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize
                if shuffle_mb > 0:
                    # 计算shuffle缓冲区的容量大小,即:shuffle缓冲区能存放多少个item,
                    # 从dset数据集中按顺序抽取容量大小数量的样本放在shuffle_buffer中,然后打乱shuffle_buffer中的样本顺序
                    # 当迭代消耗数据,使得shuffle_buffer中样本个数少于容量大小时,继续从dset数据集中按顺序填充至容量大小,再次打乱样本顺序
                    # shuffle缓冲区越大,乱序程度也越大;shuffle缓冲区大小=1时不打乱顺序
                    dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1)
                if repeat:
                    # 无限重复使用数据集,避免在迭代过程中出现数据不足的情况
                    dset = dset.repeat()
                if prefetch_mb > 0:
                    # 预读取若干条数据,确保下个批次的数据对GPU是可用的,避免GPU闲置和等待
                    dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1)
                # 按照顺序(从shuffle缓冲区中)取出_tf_minibatch条数据
                dset = dset.batch(self._tf_minibatch_in)
                # 将以上准备好的数据赋值给_tf_datasets[tfr_lod],逐一赋值,完成self._tf_datasets[]的构建
                self._tf_datasets[tfr_lod] = dset
            # 定义一个可复用的迭代器
            self._tf_iterator = tf.data.Iterator.from_structure(self._tf_datasets[0].output_types, self._tf_datasets[0].output_shapes)
            # 不同lod的dset共用这个迭代器,进行初始化
            self._tf_init_ops = {lod: self._tf_iterator.make_initializer(dset) for lod, dset in self._tf_datasets.items()}

    def close(self):
        pass

    # Use the given minibatch size and level-of-detail for the data returned by get_minibatch_tf().
    def configure(self, minibatch_size, lod=0):
        lod = int(np.floor(lod)) # 取不大于lod的最大整数
        assert minibatch_size >= 1 and lod in self._tf_datasets
        if self._cur_minibatch != minibatch_size or self._cur_lod != lod:
            self._tf_init_ops[lod].run({self._tf_minibatch_in: minibatch_size})
            self._cur_minibatch = minibatch_size
            self._cur_lod = lod

    # Get next minibatch as TensorFlow expressions.
    def get_minibatch_tf(self): # => images, labels
        return self._tf_iterator.get_next()

    # Get next minibatch as NumPy arrays.
    def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels
        self.configure(minibatch_size, lod)
        with tf.name_scope('Dataset'):
            if self._tf_minibatch_np is None:
                self._tf_minibatch_np = self.get_minibatch_tf()
            return tflib.run(self._tf_minibatch_np)

    # Get random labels as TensorFlow expression.
    def get_random_labels_tf(self, minibatch_size): # => labels
        with tf.name_scope('Dataset'):
            if self.label_size > 0:
                with tf.device('/cpu:0'):
                    return tf.gather(self._tf_labels_var, tf.random_uniform([minibatch_size], 0, self._np_labels.shape[0], dtype=tf.int32))
            return tf.zeros([minibatch_size, 0], self.label_dtype)

    # Get random labels as NumPy array.
    def get_random_labels_np(self, minibatch_size): # => labels
        if self.label_size > 0:
            return self._np_labels[np.random.randint(self._np_labels.shape[0], size=[minibatch_size])]
        return np.zeros([minibatch_size, 0], self.label_dtype)

    # Parse individual image from a tfrecords file into TensorFlow expression.
    @staticmethod
    def parse_tfrecord_tf(record):
        features = tf.parse_single_example(record, features={
            'shape': tf.FixedLenFeature([3], tf.int64),
            'data': tf.FixedLenFeature([], tf.string)})
        data = tf.decode_raw(features['data'], tf.uint8)
        return tf.reshape(data, features['shape'])

    # Parse individual image from a tfrecords file into NumPy array.
    @staticmethod
    def parse_tfrecord_np(record):
        ex = tf.train.Example()
        ex.ParseFromString(record)
        shape = ex.features.feature['shape'].int64_list.value # pylint: disable=no-member
        data = ex.features.feature['data'].bytes_list.value[0] # pylint: disable=no-member
        return np.fromstring(data, np.uint8).reshape(shape)

#----------------------------------------------------------------------------
# Helper func for constructing a dataset object using the given options.
# 这是一个helper函数,用于构建dataset对象(在TFRecordDataset类创建对象实例时完成)
# dnnlib.util.get_obj_by_name(class_name)找到并加载名为“TFRecordDataset”的类和对象

def load_dataset(class_name=None, data_dir=None, verbose=False, **kwargs):
    kwargs = dict(kwargs)
    if 'tfrecord_dir' in kwargs:
        if class_name is None:
            class_name = __name__ + '.TFRecordDataset' # 作为模块被调用时__name__就是模块的名字
        if data_dir is not None:
            kwargs['tfrecord_dir'] = os.path.join(data_dir, kwargs['tfrecord_dir'])

    assert class_name is not None
    if verbose:
        print('Streaming data using %s...' % class_name)
    dataset = dnnlib.util.get_obj_by_name(class_name)(**kwargs)
    if verbose:
        print('Dataset shape =', np.int32(dataset.shape).tolist())
        print('Dynamic range =', dataset.dynamic_range)
        print('Label size    =', dataset.label_size)
    return dataset

亲爱的读者,如果你真的愿意花时间看到这里,我一定要称赞你:真是一个爱学习的好同学!

 (完)

 

amao93 发布了32 篇原创文章 · 获赞 75 · 访问量 3万+ 私信 关注

标签:labels,self,py,dataset,np,shape,tf,StyleGAN2,tfr
来源: https://blog.csdn.net/weixin_41943311/article/details/104444349