This repository has been archived by the owner on Dec 9, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 635
/
Copy pathflags.py
84 lines (63 loc) · 3.44 KB
/
flags.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains functions to define flags and params.
Calling a DEFINE_* function will add a ParamSpec namedtuple to the param_spec
dict. The DEFINE_* arguments match those in absl. Calling define_flags() creates
a command-line flag for every ParamSpec defined by a DEFINE_* functions.
The reason we don't use absl flags directly is that we want to be able to use
tf_cnn_benchmarks as a library. When using it as a library, we don't want to
define any flags, but instead pass parameters to the BenchmarkCNN constructor.
"""
from collections import namedtuple
from absl import flags as absl_flags
# ParamSpec describes one of benchmark_cnn.BenchmarkCNN's parameters.
ParamSpec = namedtuple('_ParamSpec',
['flag_type', 'default_value', 'description',
'kwargs'])
# Maps from parameter name to its ParamSpec.
param_specs = {}
def DEFINE_string(name, default, help): # pylint: disable=invalid-name,redefined-builtin
param_specs[name] = ParamSpec('string', default, help, {})
def DEFINE_boolean(name, default, help): # pylint: disable=invalid-name,redefined-builtin
param_specs[name] = ParamSpec('boolean', default, help, {})
def DEFINE_integer(name, default, help, lower_bound=None, upper_bound=None): # pylint: disable=invalid-name,redefined-builtin
kwargs = {'lower_bound': lower_bound, 'upper_bound': upper_bound}
param_specs[name] = ParamSpec('integer', default, help, kwargs)
def DEFINE_float(name, default, help, lower_bound=None, upper_bound=None): # pylint: disable=invalid-name,redefined-builtin
kwargs = {'lower_bound': lower_bound, 'upper_bound': upper_bound}
param_specs[name] = ParamSpec('float', default, help, kwargs)
def DEFINE_enum(name, default, enum_values, help): # pylint: disable=invalid-name,redefined-builtin
kwargs = {'enum_values': enum_values}
param_specs[name] = ParamSpec('enum', default, help, kwargs)
def DEFINE_list(name, default, help): # pylint: disable=invalid-name,redefined-builtin
param_specs[name] = ParamSpec('list', default, help, {})
def define_flags(specs=None):
"""Define a command line flag for each ParamSpec in flags.param_specs."""
specs = specs or param_specs
define_flag = {
'boolean': absl_flags.DEFINE_boolean,
'float': absl_flags.DEFINE_float,
'integer': absl_flags.DEFINE_integer,
'string': absl_flags.DEFINE_string,
'enum': absl_flags.DEFINE_enum,
'list': absl_flags.DEFINE_list
}
for name, param_spec in specs.items():
if param_spec.flag_type not in define_flag:
raise ValueError('Unknown flag_type %s' % param_spec.flag_type)
else:
define_flag[param_spec.flag_type](name, param_spec.default_value,
help=param_spec.description,
**param_spec.kwargs)