flags等config的代码
作者:互联网
hi各位大佬好,这个flags我之前处理过,应该是在tf-NCF中见识过了,这里又重新看这个框架,其实这个玩意就是个参数输入,环境(日志)打印,bug处理,没有啥高大上的东西,然而就是这个玩意耽误纯粹的模型理解,这就是屏蔽小白的手段,没意思。
For Recommendation in Deep learning QQ Group 277356808
For deep learning QQ Second Group 629530787
I'm here waiting for you
注册关键词flag
def register_key_flags_in_core(f):
def core_fn(*args, **kwargs):
key_flags = f(*args, **kwargs)
[flags.declare_key_flag(fl) for fl in key_flags] # pylint: disable=expression-not-assigned
return core_fn
可以看到其中f是一个函数,这个就是修饰函数的作用,为了给其中的key进行注册,如下示例,
define_base = register_key_flags_in_core(_base.define_base)
def define_base(data_dir=True, model_dir=True, clean=True, train_epochs=True,
epochs_between_evals=True, stop_threshold=True, batch_size=True,
num_gpu=True, hooks=True, export_dir=True,
distribution_strategy=True):
key_flags = []
if data_dir:
flags.DEFINE_string(
name="data_dir", short_name="dd", default="/tmp",
help=help_wrap("The location of the input data."))
key_flags.append("data_dir")
if model_dir:
flags.DEFINE_string(
name="model_dir", short_name="md", default="/tmp",
help=help_wrap("The location of the model checkpoint files."))
key_flags.append("model_dir")
if clean:
flags.DEFINE_boolean(
name="clean", default=False,
help=help_wrap("If set, model_dir will be removed if it exists."))
key_flags.append("clean")
if train_epochs:
flags.DEFINE_integer(
name="train_epochs", short_name="te", default=1,
help=help_wrap("The number of epochs used to train."))
key_flags.append("train_epochs")
if epochs_between_evals:
flags.DEFINE_integer(
name="epochs_between_evals", short_name="ebe", default=1,
help=help_wrap("The number of training epochs to run between "
"evaluations."))
key_flags.append("epochs_between_evals")
if stop_threshold:
flags.DEFINE_float(
name="stop_threshold", short_name="st",
default=None,
help=help_wrap("If passed, training will stop at the earlier of "
"train_epochs and when the evaluation metric is "
"greater than or equal to stop_threshold."))
if batch_size:
flags.DEFINE_integer(
name="batch_size", short_name="bs", default=32,
help=help_wrap("Batch size for training and evaluation. When using "
"multiple gpus, this is the global batch size for "
"all devices. For example, if the batch size is 32 "
"and there are 4 GPUs, each GPU will get 8 examples on "
"each step."))
key_flags.append("batch_size")
if num_gpu:
flags.DEFINE_integer(
name="num_gpus", short_name="ng",
default=1 if tf.test.is_gpu_available() else 0,
help=help_wrap(
"How many GPUs to use at each worker with the "
"DistributionStrategies API. The default is 1 if TensorFlow can "
"detect a GPU, and 0 otherwise."))
if hooks:
# Construct a pretty summary of hooks.
hook_list_str = (
u"\ufeff Hook:\n" + u"\n".join([u"\ufeff {}".format(key) for key
in hooks_helper.HOOKS]))
flags.DEFINE_list(
name="hooks", short_name="hk", default="LoggingTensorHook",
help=help_wrap(
u"A list of (case insensitive) strings to specify the names of "
u"training hooks.\n{}\n\ufeff Example: `--hooks ProfilerHook,"
u"ExamplesPerSecondHook`\n See official.utils.logs.hooks_helper "
u"for details.".format(hook_list_str))
)
key_flags.append("hooks")
if export_dir:
flags.DEFINE_string(
name="export_dir", short_name="ed", default=None,
help=help_wrap("If set, a SavedModel serialization of the model will "
"be exported to this directory at the end of training. "
"See the README for more details and relevant links.")
)
key_flags.append("export_dir")
if distribution_strategy:
flags.DEFINE_string(
name="distribution_strategy", short_name="ds", default="default",
help=help_wrap("The Distribution Strategy to use for training. "
"Accepted values are 'off', 'default', 'one_device', "
"'mirrored', 'parameter_server', 'collective', "
"case insensitive. 'off' means not to use "
"Distribution Strategy; 'default' means to choose "
"from `MirroredStrategy` or `OneDeviceStrategy` "
"according to the number of GPUs.")
)
return key_flags
其中的help参数传入的是一个解释说明,至于为啥又定义一个函数(这个其实是为了将其中的解释说明进行分行处理,当超过80后就会分行,没啥高科技,如下代码),没有必要了解。可以直接去掉这个解释说明的注释。
>>> _help_wrap("abcjlllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllll")
'\nabcjlllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllll\nllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllll\nllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllllll'
上面说的每个worker最多使用一个GPU?咋理解?挖个坑,毕竟对多机多卡训练可能有用。
这就是个参数注册,还要定义flags.adopt_module_key_flags当前模块,一层包一层没有意思
在main函数中还要使用app.run进行实现,如果没有这个就不起作用。差评!!全部改成简单的赋值,根本无需这些垃圾玩意。坑爹。
愿我们终有重逢之时,
而你还记得我们曾经讨论的话题。
标签:help,default,代码,flags,key,config,dir,name 来源: https://blog.csdn.net/SPESEG/article/details/123592718