其他分享
首页 > 其他分享> > flags等config的代码

flags等config的代码

作者:互联网

hi各位大佬好,这个flags我之前处理过,应该是在tf-NCF中见识过了,这里又重新看这个框架,其实这个玩意就是个参数输入,环境(日志)打印,bug处理,没有啥高大上的东西,然而就是这个玩意耽误纯粹的模型理解,这就是屏蔽小白的手段,没意思。

For Recommendation in Deep learning QQ Group 277356808

For deep learning QQ Second Group 629530787

I'm here waiting for you

注册关键词flag

def register_key_flags_in_core(f):
  def core_fn(*args, **kwargs):
    key_flags = f(*args, **kwargs)
    [flags.declare_key_flag(fl) for fl in key_flags]  # pylint: disable=expression-not-assigned
  return core_fn

可以看到其中f是一个函数,这个就是修饰函数的作用,为了给其中的key进行注册,如下示例,

define_base = register_key_flags_in_core(_base.define_base)

def define_base(data_dir=True, model_dir=True, clean=True, train_epochs=True,
                epochs_between_evals=True, stop_threshold=True, batch_size=True,
                num_gpu=True, hooks=True, export_dir=True,
                distribution_strategy=True):
  key_flags = []

  if data_dir:
    flags.DEFINE_string(
        name="data_dir", short_name="dd", default="/tmp",
        help=help_wrap("The location of the input data."))
    key_flags.append("data_dir")

  if model_dir:
    flags.DEFINE_string(
        name="model_dir", short_name="md", default="/tmp",
        help=help_wrap("The location of the model checkpoint files."))
    key_flags.append("model_dir")

  if clean:
    flags.DEFINE_boolean(
        name="clean", default=False,
        help=help_wrap("If set, model_dir will be removed if it exists."))
    key_flags.append("clean")

  if train_epochs:
    flags.DEFINE_integer(
        name="train_epochs", short_name="te", default=1,
        help=help_wrap("The number of epochs used to train."))
    key_flags.append("train_epochs")

  if epochs_between_evals:
    flags.DEFINE_integer(
        name="epochs_between_evals", short_name="ebe", default=1,
        help=help_wrap("The number of training epochs to run between "
                       "evaluations."))
    key_flags.append("epochs_between_evals")

  if stop_threshold:
    flags.DEFINE_float(
        name="stop_threshold", short_name="st",
        default=None,
        help=help_wrap("If passed, training will stop at the earlier of "
                       "train_epochs and when the evaluation metric is  "
                       "greater than or equal to stop_threshold."))

  if batch_size:
    flags.DEFINE_integer(
        name="batch_size", short_name="bs", default=32,
        help=help_wrap("Batch size for training and evaluation. When using "
                       "multiple gpus, this is the global batch size for "
                       "all devices. For example, if the batch size is 32 "
                       "and there are 4 GPUs, each GPU will get 8 examples on "
                       "each step."))
    key_flags.append("batch_size")

  if num_gpu:
    flags.DEFINE_integer(
        name="num_gpus", short_name="ng",
        default=1 if tf.test.is_gpu_available() else 0,
        help=help_wrap(
            "How many GPUs to use at each worker with the "
            "DistributionStrategies API. The default is 1 if TensorFlow can "
            "detect a GPU, and 0 otherwise."))

  if hooks:
    # Construct a pretty summary of hooks.
    hook_list_str = (
        u"\ufeff  Hook:\n" + u"\n".join([u"\ufeff    {}".format(key) for key
                                         in hooks_helper.HOOKS]))
    flags.DEFINE_list(
        name="hooks", short_name="hk", default="LoggingTensorHook",
        help=help_wrap(
            u"A list of (case insensitive) strings to specify the names of "
            u"training hooks.\n{}\n\ufeff  Example: `--hooks ProfilerHook,"
            u"ExamplesPerSecondHook`\n See official.utils.logs.hooks_helper "
            u"for details.".format(hook_list_str))
    )
    key_flags.append("hooks")

  if export_dir:
    flags.DEFINE_string(
        name="export_dir", short_name="ed", default=None,
        help=help_wrap("If set, a SavedModel serialization of the model will "
                       "be exported to this directory at the end of training. "
                       "See the README for more details and relevant links.")
    )
    key_flags.append("export_dir")

  if distribution_strategy:
    flags.DEFINE_string(
        name="distribution_strategy", short_name="ds", default="default",
        help=help_wrap("The Distribution Strategy to use for training. "
                       "Accepted values are 'off', 'default', 'one_device', "
                       "'mirrored', 'parameter_server', 'collective', "
                       "case insensitive. 'off' means not to use "
                       "Distribution Strategy; 'default' means to choose "
                       "from `MirroredStrategy` or `OneDeviceStrategy` "
                       "according to the number of GPUs.")
    )

  return key_flags

其中的help参数传入的是一个解释说明,至于为啥又定义一个函数(这个其实是为了将其中的解释说明进行分行处理,当超过80后就会分行,没啥高科技,如下代码),没有必要了解。可以直接去掉这个解释说明的注释。

>>> _help_wrap("abcjlllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllll")
'\nabcjlllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllll\nllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllll\nllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllll'

上面说的每个worker最多使用一个GPU?咋理解?挖个坑,毕竟对多机多卡训练可能有用。

这就是个参数注册,还要定义flags.adopt_module_key_flags当前模块,一层包一层没有意思

在main函数中还要使用app.run进行实现,如果没有这个就不起作用。差评!!全部改成简单的赋值,根本无需这些垃圾玩意。坑爹。

愿我们终有重逢之时,

而你还记得我们曾经讨论的话题。

标签:help,default,代码,flags,key,config,dir,name
来源: https://blog.csdn.net/SPESEG/article/details/123592718