其他分享
首页 > 其他分享> > lingvo/core/base_input_generator分析(一)

lingvo/core/base_input_generator分析(一)

作者:互联网

2021SC@SDUSC

lingvo/core/base_input_generator中的DefineInfeedParams和Params方法

lingvo中的baseInputGenerator分析

简介

base_input_generator是lingvo中的重要部分,其具有庞大的代码体系,其中包含的类有BaseInputGenerator、BaseInputGeneratorFromFiles、BaseSequenceInputGenerator、BaseTinyDatasetInput、TFDataSequenceInputGenerator、BaseDataExampleInputGenerato。
其中很重要的一个概念是设备拆分批量大小,其由Params()定义,表示的是每个设备或者TPU上的批量大小。BaseInputGenerator.params.batch_size 和BaseSequenceInputGenerator.params.bucket_batch_limit确定拆分的大小。

BaseInputGenerator

本次分析我们先从BaseInputGenerator类入手,其中包含众多方法,如下是部分方法截图:
在这里插入图片描述

1.DefineInfeedParams方法:

方法的代码如下:

  @classmethod
  def DefineInfeedParams(cls, p):
    p.Define('use_per_host_infeed', False,
             'Whether run infeed op on each host.')
    p.Define('use_per_core_infeed', False,
             'Whether to shard the infeed per TPU core instead of per replica')
    p.Define('tpu_infeed_parallelism', 1,
             'Uses these many python threads to drive infeed concurrently.')
    p.Define('use_partitioned_infeed_queue', False, 'Use partitioned infeed')
    p.Define(
        'num_partitions', None,
        'Number of partitions to split the model graph into. Used with '
        'model parallelism. When >1, it specifies the number of devices '
        'used to place one replica of the model graph nodes.')

从@classmethod中我们可以看出该函数不需要实例化,其不需要self参数,但是函数的第一个参数必须是cls,表示自身类,用来调用的类的属性、方法和实例化方法等。方法的第二个参数是p,其连续调用五个Define函数,该函数来自lingvo/core/hyperparams.py,该文件中define方法部分的代码如下:

  def Define(self, name: str, default_value: Any, description: str) -> None:
    if self._immutable:
      raise TypeError('This Params instance is immutable.')
    assert name is not None and isinstance(name, str) and (re.match(
        '^[a-z][a-z0-9_]*$', name) is not None)
    if name in self._params:
      raise AttributeError('Parameter %s is already defined' % name)
    self._params[name] = _Param(name, default_value, description)

不难看出Define函数是用来定义参数的,其输入有三个参数:name、default_value和description。让我们先来对这个函数进行分析,其中:
①name:参数名称,其类型为str,只能包含小写字母、下划线和数字,且只能以小写字母开头。
②default_value:参数的默认值,可以为none
③description:参数的描述,str类型。
同样我们可以看出该方法中手动设置了两个异常,TypeError当参数实例不可变时引用,AttributeError当参数名称已经被定义时调用。

现在我们再看DefineInfeedParams方法,其中对define进行了5次不同的调用,其参数分别是use_per_host_infeed、use_per_core_infeed、tpu_infeed_parallelism、use_partitioned_infeed_queue和num_partitions,我们可以根据主机、核心或者TPU的情况进行设置。

2.Params方法

input generators的默认参数,其也是用@classmethod描述,用p来调用多个Define函数,关于每个调用的用法我们可以从其调用时传入的参数得出。以调用参数变量名为eval_samples_per_summary为例,其对于支持 samples_per_summary == 0 以指示使用整个数据集的输入生成器,他必须(1)是可重置的,(2)要抛出tf.errors.OutOfRangeError异常。

   @classmethod
  def Params(cls):
    p = super().Params()
    p.name = 'input'
    p.Define(
        'file_datasource', None,
        'The DataSource that produces input batches for this input generator.')
    p.Define(
        'batch_size', 0, 'Batch size for a device split. This will be '
        'scaled to match the accelarator hardware topology.')
    p.Define(
        'num_samples', 0,
        'If non-zero, the dataset contains these many samples. '
        'For test/eval dataset, if we want the test/evel job evaluate '
        'the whole dataset, this param must be set precisely. Otherwise, '
        'this param is optional.')
    p.Define('resettable', False,
             'If True, the input generator must implement Reset().')
    p.Define(
        'eval_samples_per_summary', None, 'If not None, overrides '
        'task_p.eval.samples_per_summary directly. Allowed to be 0, which '
        'means to use the entire dataset.')
    p.Define(
        'decoder_samples_per_summary', None, 'If not None, overrides '
        'task_p.eval.decoder_samples_per_summary directly. Allowed to be 0, '
        'which means to use the entire dataset.')
    p.Define(
        'filter_sparse_tensors', False,
        'If true, filter out SparseTensors in input_batch before enqueuing '
        'onto TPU.')
    cls.DefineInfeedParams(p)

    p.Define('remote', hyperparams.Params(),
             'Params to configure remote input policy.')
    p.remote.Define(
        'max_inflights_per_target', 32, 'The maximum number of '
        'concurrent inflight remote input fetches per remote target.')

    p.Define(
        'input_stats_summary_interval_steps', 10,
        'Number of steps in between logging of TF scalar summaries for '
        'training related input data stats.')

    p.Define(
        'tpu_embedding_mode', 'train',
        'The mode used to enqueue TPU embedding ids. Valid values are: {'
        'None: no TPU embedding enqueue ops will be generated; '
        '"inference": enqueue ops will be generated, but backprop will be '
        'disabled (i.e. no gradient will be generated and the embedding '
        'tables are freezed); '
        '"train": both enqueue ops and gradient will be generated when '
        'do_eval is False, otherwise fallback to "inference" mode; }.')
    p.Define('cpu_passthrough_keys', [],
             'A list of keys in the input batch to not send to TPU device.')

    return p

小结

本次分析了lingvo/core/base_input_generator.py文件,其中的DefineInfeedParams和Params方法都与参数相关,且都调用了lingvo/core/hyperparams.py中的Define方法,为我们对参数的操作提供了方法。

标签:core,name,generator,per,lingvo,Params,input,infeed,Define
来源: https://blog.csdn.net/hewei000good/article/details/121322486