编程语言
首页 > 编程语言> > python – Tensorflow:如何查找tf.data.Dataset API对象的大小

python – Tensorflow:如何查找tf.data.Dataset API对象的大小

作者:互联网

我理解Dataset API是一种迭代器,它不会将整个数据集加载到内存中,因此无法找到数据集的大小.我正在谈论存储在文本文件或tfRecord文件中的大型数据语料库.通常使用tf.data.TextLineDataset或类似的东西来读取这些文件.使用tf.data.Dataset.from_tensor_slices找到加载的数据集的大小是微不足道的.

我问数据集大小的原因如下:
假设我的数据集大小为1000个元素.批量大小= 50个元素.然后训练步骤/批次(假设1个纪元)= 20.在这20个步骤中,我想将我的学习率从0.1到0.01指数衰减为

tf.train.exponential_decay(
    learning_rate = 0.1,
    global_step = global_step,
    decay_steps = 20,
    decay_rate = 0.1,
    staircase=False,
    name=None
)

在上面的代码中,我有“和”想要设置decay_steps =每个epoch的步数/批次数= num_elements / batch_size.仅当预先知道数据集中的元素数量时,才能计算此值.

提前知道大小的另一个原因是使用tf.data.Dataset.take(),tf.data.Dataset.skip()方法将数据拆分为训练集和测试集.

PS:我不是在寻找蛮力的方法,比如遍历整个数据集并更新计数器以计算元素数量或putting a very large batch size and then finding the size of the resultant dataset等.

解决方法:

您可以选择手动指定数据集的大小吗?

我如何加载我的数据:

sample_id_hldr = tf.placeholder(dtype=tf.int64, shape=(None,), name="samples")

sample_ids = tf.Variable(sample_id_hldr, validate_shape=False, name="samples_cache")
num_samples = tf.size(sample_ids)

data = tf.data.Dataset.from_tensor_slices(sample_ids)
# "load" data by id:
# return (id, data) for each id
data = data.map(
    lambda id: (id, some_load_op(id))
)

在这里,您可以通过使用占位符初始化sample_ids一次来指定所有样本ID.
您的样本ID可以是例如文件路径或简单数字(np.arange(num_elems))

然后在num_samples中提供元素数量.

标签:python,tensorflow,tensorflow-datasets
来源: https://codeday.me/bug/20190910/1800796.html