轻轻松松使用StyleGAN2(六):StyleGAN2 Encoder是怎样加载训练数据的?源代码+中文注释,dataset_tool.py和dataset.py
作者:互联网
上一篇文章里,我们简单介绍了StyleGAN2 Encoder的一部分源代码,即:projector.py和project_images.py,内容请参考:
轻轻松松使用StyleGAN2(五):StyleGAN2 Encoder源代码初探+中文注释,projector.py和project_images.py
其中有两个函数,涉及到加载训练数据的功能,在这篇文章里我们花一点时间来看一下。
这两个函数都在project_images.py里,分别是:
(1)dataset_tool.create_from_images(tfrecord_dir, image_dir, shuffle=0)
(2)dataset_obj = dataset.load_dataset(
data_dir=data_dir, tfrecord_dir='tfrecords',
max_label_size=0, repeat=False, shuffle_mb=0
)
下面分别对这两个函数进行介绍:
(一)第一个函数dataset_tool.create_from_images(),在./dataset_tool.py里定义,是TFRecordExporter类中的函数,其中:
create_from_images()的功能是按不同的lod(levels of detail),将对应的shape和图像信息序列化存入文件,其源代码如下(含中文注释):
def create_from_images(tfrecord_dir, image_dir, shuffle):
print('Loading images from "%s"' % image_dir)
image_filenames = sorted(glob.glob(os.path.join(image_dir, '*')))
if len(image_filenames) == 0:
error('No input images found')
img = np.asarray(PIL.Image.open(image_filenames[0]))
resolution = img.shape[0]
channels = img.shape[2] if img.ndim == 3 else 1 # img.ndim是指图像数组的维度,灰度图像是2,彩色图像是3
# 举例,某个图像的shape是(1024, 1024, 3)
if img.shape[1] != resolution:
error('Input images must have the same width and height')
if resolution != 2 ** int(np.floor(np.log2(resolution))):
error('Input image resolution must be a power-of-two')
if channels not in [1, 3]:
error('Input images must be stored as RGB or grayscale')
with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:
# 乱序或顺序
order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames))
for idx in range(order.size):
img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))
if channels == 1:
img = img[np.newaxis, :, :] # HW => CHW,高宽=>通道数·高·宽
else:
img = img.transpose([2, 0, 1]) # HWC => CHW,高宽通道数=>通道数·高·宽
# 按不同的lod,将对应的shape和图像信息序列化存入文件
tfr.add_image(img)
create_from_images()函数的核心功能由add_image()完成,以源图像尺寸为1024x1024为例,它主要是按照lod从0--8,按不同尺寸生成大小各不相同的shape和图像数据,分别序列化存储到tfrecords-r10.tfrecords到tfrecords-r02.tfrecords等9个临时文件中,其源代码(含中文注释)如下:
def add_image(self, img):
if self.print_progress and self.cur_images % self.progress_interval == 0:
print('%d / %d\r' % (self.cur_images, self.expected_images), end='', flush=True)
if self.shape is None:
self.shape = img.shape
self.resolution_log2 = int(np.log2(self.shape[1]))
assert self.shape[0] in [1, 3]
assert self.shape[1] == self.shape[2]
assert self.shape[1] == 2**self.resolution_log2
tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
# 举例,如果图像是1024x1024,则lod从0--8,tfr_file命名从10--2,如:tfrecords-r05.tfrecords
for lod in range(self.resolution_log2 - 1):
tfr_file = self.tfr_prefix + '-r%02d.tfrecords' % (self.resolution_log2 - lod) # 10 - lod
# TFRecordWriter是将记录写入TFRecords文件的类
# 每个层级的lod(level of detail)对应一个TFRecordWriter,共同组成一个list
self.tfr_writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt))
assert img.shape == self.shape
for lod, tfr_writer in enumerate(self.tfr_writers):
# lod第一个值为0,因此保持原有图像大小不变,从第二个lod开始尺寸减半
if lod:
img = img.astype(np.float32)
# 以2x2的方格为例,(0,0)(0,1)(1,0)(1,1)四个像素取平均值(求和除以4)
# 图像的高和宽减半
img = (img[:, 0::2, 0::2] + img[:, 0::2, 1::2] + img[:, 1::2, 0::2] + img[:, 1::2, 1::2]) * 0.25
# img四舍五入,裁剪到(0, 255)区间内,转换为unint8
quant = np.rint(img).clip(0, 255).astype(np.uint8)
# tf.train.Example()将shape和data组装为二进制数据
ex = tf.train.Example(features=tf.train.Features(feature={
'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=quant.shape)),
'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[quant.tostring()]))}))
# SerializeToString()将ex序列化为字符串,写入tfr_writer文件
tfr_writer.write(ex.SerializeToString())
self.cur_images += 1
(二)第二个函数dataset.load_dataset(),在./training/dataset.py里定义,是TFRecordDataset类中的函数,它是一个helper(助手)函数,用于构建dataset对象(在TFRecordDataset类创建对象实例时完成)。
这部分代码的核心是TFRecordDataset类,TFRecordDataset类用于从tfrecords文件中加载不同lod(levels of detail,以1024x1024的图像为例,从0--8)的shape、labels(标签)和图像数据,并且定义了这些数据的预处理方法、shuffle缓冲区大小、批大小、预读取缓冲区大小、可复用的迭代器等,其中使用的方法包括:tf.data.Dataset.from_tensor_slices(),tf.data.TFRecordDataset(),tf.data.Iterator.from_structure(),dataset.map(),dataset.shuffle(),dataset.prefetch(),dataset.batch()等等,其源代码(含中文注释)如下:
import os
import glob
import numpy as np
import tensorflow as tf
import dnnlib
import dnnlib.tflib as tflib
#----------------------------------------------------------------------------
# Dataset class that loads data from tfrecords files.
# TFRecordDataset类用于从tfrecords文件中加载数据
# 使用的方法包括:tf.data.Dataset.from_tensor_slices(),tf.data.TFRecordDataset(),tf.data.Iterator.from_structure(),
# dataset.map(),dataset.shuffle(),dataset.prefetch(),dataset.batch()等
class TFRecordDataset:
def __init__(self,
tfrecord_dir, # Directory containing a collection of tfrecords files.
resolution = None, # Dataset resolution, None = autodetect.
label_file = None, # Relative path of the labels file, None = autodetect.
max_label_size = 0, # 0 = no labels, 'full' = full labels, <int> = N first label components.
max_images = None, # Maximum number of images to use, None = use all images.
repeat = True, # Repeat dataset indefinitely?
shuffle_mb = 4096, # Shuffle data within specified window (megabytes), 0 = disable shuffling.
prefetch_mb = 2048, # Amount of data to prefetch (megabytes), 0 = disable prefetching.
buffer_mb = 256, # Read buffer size (megabytes).
num_threads = 2): # Number of concurrent threads.并发线程数
self.tfrecord_dir = tfrecord_dir
self.resolution = None
self.resolution_log2 = None
self.shape = [] # [channels, height, width]
self.dtype = 'uint8'
self.dynamic_range = [0, 255]
self.label_file = label_file
self.label_size = None # components
self.label_dtype = None
self._np_labels = None
self._tf_minibatch_in = None
self._tf_labels_var = None
self._tf_labels_dataset = None
self._tf_datasets = dict()
self._tf_iterator = None
self._tf_init_ops = dict()
self._tf_minibatch_np = None
self._cur_minibatch = -1
self._cur_lod = -1
# List tfrecords files and inspect their shapes.
# 获得tfrecord_dir目录下的所有.tfrecords文件的列表,赋值给tfr_files
assert os.path.isdir(self.tfrecord_dir)
tfr_files = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.tfrecords')))
assert len(tfr_files) >= 1
tfr_shapes = []
# 遍历tfr_files文件列表
for tfr_file in tfr_files:
tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
# 读取tfr_file文件
for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt):
# 从一个tfrecords文件中读取一个独立的图像数据,并把shape信息加入到tfr_shapes列表中,用于建立TF表达式
# 为了节省开销,这里只读取了图像数据的第一个字节,真正读取全部图像数据的操作在下面建立TF表达式时完成
tfr_shapes.append(self.parse_tfrecord_np(record).shape)
break
# Autodetect label filename.
# 自动发现标签文件,标签文件用于图像文件的各个属性标记和说明
if self.label_file is None:
guess = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.labels')))
if len(guess):
self.label_file = guess[0]
elif not os.path.isfile(self.label_file):
guess = os.path.join(self.tfrecord_dir, self.label_file)
if os.path.isfile(guess):
self.label_file = guess
# Determine shape and resolution.
# 确定图片的shape和resolution(分辨率)
max_shape = max(tfr_shapes, key=np.prod)
self.resolution = resolution if resolution is not None else max_shape[1] # shape[1]:高(height)
self.resolution_log2 = int(np.log2(self.resolution))
self.shape = [max_shape[0], self.resolution, self.resolution] # shape[0]:通道数(channels)
# tfr_lods列表用于建立TF表达式
# 举例,1024x1024的resolution_log2是10,log2(shape[1])从10--2
tfr_lods = [self.resolution_log2 - int(np.log2(shape[1])) for shape in tfr_shapes]
assert all(shape[0] == max_shape[0] for shape in tfr_shapes)
assert all(shape[1] == shape[2] for shape in tfr_shapes)
# 下面这句话,是验证resolution和Lod的数学关系
assert all(shape[1] == self.resolution // (2**lod) for shape, lod in zip(tfr_shapes, tfr_lods))
assert all(lod in tfr_lods for lod in range(self.resolution_log2 - 1))
# Load labels.
# 加载labels(标签)
assert max_label_size == 'full' or max_label_size >= 0
# 赋值给_np_labels,1<<30创建一个很大的初始值2**30,即:至少包含30个属性,初始化为0
self._np_labels = np.zeros([1<<30, 0], dtype=np.float32)
if self.label_file is not None and max_label_size != 0:
self._np_labels = np.load(self.label_file)
assert self._np_labels.ndim == 2
if max_label_size != 'full' and self._np_labels.shape[1] > max_label_size:
self._np_labels = self._np_labels[:, :max_label_size]
if max_images is not None and self._np_labels.shape[0] > max_images:
self._np_labels = self._np_labels[:max_images]
# 将读取到的标签个数赋值给label_size
self.label_size = self._np_labels.shape[1]
self.label_dtype = self._np_labels.dtype.name
# Build TF expressions.
# 建立TF表达式
# 定义一块名为Dataset的区域,并在其中工作;使用CPU内存(即:不使用GPU,减轻显存的压力)
with tf.name_scope('Dataset'), tf.device('/cpu:0'):
# 设置输入项minibatch_in
self._tf_minibatch_in = tf.placeholder(tf.int64, name='minibatch_in', shape=[])
# 按照_np_lablels列表的shape创建变量labels_var,每个标签数据具有很大的初始值,即:至少包含30个属性
self._tf_labels_var = tflib.create_var_with_large_initial_value(self._np_labels, name='labels_var')
# 将_tf_labels_var列表读进内存,将array转化为tensor,赋值给_tf_labels_dataset
self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(self._tf_labels_var)
# 遍历tfr_files列表,读取tfr_file中保存的图像信息,将信息加载到dset
for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes, tfr_lods):
if tfr_lod < 0:
continue
# 定义一个dataset,对应于tfr_file,数据读取的缓冲区大小为buffer_mb(MB)
dset = tf.data.TFRecordDataset(tfr_file, compression_type='', buffer_size=buffer_mb<<20)
# (最多)取max_images组数据
if max_images is not None:
dset = dset.take(max_images)
# 调用函数parse_tfrecord_tf对dset数据(即:对应的tfr_file)进行预处理,
# 即:解析输入数据,把串行数据解析为shape和图像信息;并发操作,线程数=num_threads
dset = dset.map(self.parse_tfrecord_tf, num_parallel_calls=num_threads)
# 把dset数据和_tf_labels_dataset组合,赋值给dset
dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))
# 计算每个lod对应的图像数据item存储需要的空间,如:1024x1024x(uint8.itemsize=1)
bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize
if shuffle_mb > 0:
# 计算shuffle缓冲区的容量大小,即:shuffle缓冲区能存放多少个item,
# 从dset数据集中按顺序抽取容量大小数量的样本放在shuffle_buffer中,然后打乱shuffle_buffer中的样本顺序
# 当迭代消耗数据,使得shuffle_buffer中样本个数少于容量大小时,继续从dset数据集中按顺序填充至容量大小,再次打乱样本顺序
# shuffle缓冲区越大,乱序程度也越大;shuffle缓冲区大小=1时不打乱顺序
dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1)
if repeat:
# 无限重复使用数据集,避免在迭代过程中出现数据不足的情况
dset = dset.repeat()
if prefetch_mb > 0:
# 预读取若干条数据,确保下个批次的数据对GPU是可用的,避免GPU闲置和等待
dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1)
# 按照顺序(从shuffle缓冲区中)取出_tf_minibatch条数据
dset = dset.batch(self._tf_minibatch_in)
# 将以上准备好的数据赋值给_tf_datasets[tfr_lod],逐一赋值,完成self._tf_datasets[]的构建
self._tf_datasets[tfr_lod] = dset
# 定义一个可复用的迭代器
self._tf_iterator = tf.data.Iterator.from_structure(self._tf_datasets[0].output_types, self._tf_datasets[0].output_shapes)
# 不同lod的dset共用这个迭代器,进行初始化
self._tf_init_ops = {lod: self._tf_iterator.make_initializer(dset) for lod, dset in self._tf_datasets.items()}
def close(self):
pass
# Use the given minibatch size and level-of-detail for the data returned by get_minibatch_tf().
def configure(self, minibatch_size, lod=0):
lod = int(np.floor(lod)) # 取不大于lod的最大整数
assert minibatch_size >= 1 and lod in self._tf_datasets
if self._cur_minibatch != minibatch_size or self._cur_lod != lod:
self._tf_init_ops[lod].run({self._tf_minibatch_in: minibatch_size})
self._cur_minibatch = minibatch_size
self._cur_lod = lod
# Get next minibatch as TensorFlow expressions.
def get_minibatch_tf(self): # => images, labels
return self._tf_iterator.get_next()
# Get next minibatch as NumPy arrays.
def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels
self.configure(minibatch_size, lod)
with tf.name_scope('Dataset'):
if self._tf_minibatch_np is None:
self._tf_minibatch_np = self.get_minibatch_tf()
return tflib.run(self._tf_minibatch_np)
# Get random labels as TensorFlow expression.
def get_random_labels_tf(self, minibatch_size): # => labels
with tf.name_scope('Dataset'):
if self.label_size > 0:
with tf.device('/cpu:0'):
return tf.gather(self._tf_labels_var, tf.random_uniform([minibatch_size], 0, self._np_labels.shape[0], dtype=tf.int32))
return tf.zeros([minibatch_size, 0], self.label_dtype)
# Get random labels as NumPy array.
def get_random_labels_np(self, minibatch_size): # => labels
if self.label_size > 0:
return self._np_labels[np.random.randint(self._np_labels.shape[0], size=[minibatch_size])]
return np.zeros([minibatch_size, 0], self.label_dtype)
# Parse individual image from a tfrecords file into TensorFlow expression.
@staticmethod
def parse_tfrecord_tf(record):
features = tf.parse_single_example(record, features={
'shape': tf.FixedLenFeature([3], tf.int64),
'data': tf.FixedLenFeature([], tf.string)})
data = tf.decode_raw(features['data'], tf.uint8)
return tf.reshape(data, features['shape'])
# Parse individual image from a tfrecords file into NumPy array.
@staticmethod
def parse_tfrecord_np(record):
ex = tf.train.Example()
ex.ParseFromString(record)
shape = ex.features.feature['shape'].int64_list.value # pylint: disable=no-member
data = ex.features.feature['data'].bytes_list.value[0] # pylint: disable=no-member
return np.fromstring(data, np.uint8).reshape(shape)
#----------------------------------------------------------------------------
# Helper func for constructing a dataset object using the given options.
# 这是一个helper函数,用于构建dataset对象(在TFRecordDataset类创建对象实例时完成)
# dnnlib.util.get_obj_by_name(class_name)找到并加载名为“TFRecordDataset”的类和对象
def load_dataset(class_name=None, data_dir=None, verbose=False, **kwargs):
kwargs = dict(kwargs)
if 'tfrecord_dir' in kwargs:
if class_name is None:
class_name = __name__ + '.TFRecordDataset' # 作为模块被调用时__name__就是模块的名字
if data_dir is not None:
kwargs['tfrecord_dir'] = os.path.join(data_dir, kwargs['tfrecord_dir'])
assert class_name is not None
if verbose:
print('Streaming data using %s...' % class_name)
dataset = dnnlib.util.get_obj_by_name(class_name)(**kwargs)
if verbose:
print('Dataset shape =', np.int32(dataset.shape).tolist())
print('Dynamic range =', dataset.dynamic_range)
print('Label size =', dataset.label_size)
return dataset
亲爱的读者,如果你真的愿意花时间看到这里,我一定要称赞你:真是一个爱学习的好同学!
(完)
amao93 发布了32 篇原创文章 · 获赞 75 · 访问量 3万+ 私信 关注
标签:labels,self,py,dataset,np,shape,tf,StyleGAN2,tfr 来源: https://blog.csdn.net/weixin_41943311/article/details/104444349