forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add checklist for official models. Remove file access from flag valid…
…ator (fix build) (tensorflow#4492) * Add checklist for official models. Remove file access from flag validator (causing issues with BUILD) * spelling * address PR comments
- Loading branch information
Showing
7 changed files
with
207 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Using flags in official models | ||
|
||
1. **All common flags must be incorporated in the models.** | ||
|
||
Common flags (i.e. batch_size, model_dir, etc.) are provided by various flag definition functions, | ||
and channeled through `official.utils.flags.core`. For instance to define common supervised | ||
learning parameters one could use the following code: | ||
|
||
```$xslt | ||
from absl import app as absl_app | ||
from absl import flags | ||
from official.utils.flags import core as flags_core | ||
def define_flags(): | ||
flags_core.define_base() | ||
flags.adopt_key_flags(flags_core) | ||
def main(_): | ||
flags_obj = flags.FLAGS | ||
print(flags_obj) | ||
if __name__ == "__main__" | ||
absl_app.run(main) | ||
``` | ||
2. **Validate flag values.** | ||
|
||
See the [Validators](#validators) section for implementation details. | ||
|
||
Validators in the official model repo should not access the file system, such as verifying | ||
that files exist, due to the strict ordering requirements. | ||
|
||
3. **Flag values should not be mutated.** | ||
|
||
Instead of mutating flag values, use getter functions to return the desired values. An example | ||
getter function is `get_loss_scale` function below: | ||
|
||
``` | ||
# Map string to (TensorFlow dtype, default loss scale) | ||
DTYPE_MAP = { | ||
"fp16": (tf.float16, 128), | ||
"fp32": (tf.float32, 1), | ||
} | ||
def get_loss_scale(flags_obj): | ||
if flags_obj.loss_scale is not None: | ||
return flags_obj.loss_scale | ||
return DTYPE_MAP[flags_obj.dtype][1] | ||
def main(_): | ||
flags_obj = flags.FLAGS() | ||
# Do not mutate flags_obj | ||
# if flags_obj.loss_scale is None: | ||
# flags_obj.loss_scale = DTYPE_MAP[flags_obj.dtype][1] # Don't do this | ||
print(get_loss_scale(flags_obj)) | ||
... | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Logging in official models | ||
|
||
This library adds logging functions that print or save tensor values. Official models should define all common hooks | ||
(using hooks helper) and a benchmark logger. | ||
|
||
1. **Training Hooks** | ||
|
||
Hooks are a TensorFlow concept that define specific actions at certain points of the execution. We use them to obtain and log | ||
tensor values during training. | ||
|
||
hooks_helper.py provides an easy way to create common hooks. The following hooks are currently defined: | ||
* LoggingTensorHook: Logs tensor values | ||
* ProfilerHook: Writes a timeline json that can be loaded into chrome://tracing. | ||
* ExamplesPerSecondHook: Logs the number of examples processed per second. | ||
* LoggingMetricHook: Similar to LoggingTensorHook, except that the tensors are logged in a format defined by our data | ||
anaylsis pipeline. | ||
|
||
|
||
2. **Benchmarks** | ||
|
||
The benchmark logger provides useful functions for logging environment information, and evaluation results. | ||
The module also contains a context which is used to update the status of the run. | ||
|
||
Example usage: | ||
|
||
``` | ||
from absl import app as absl_app | ||
from official.utils.logs import hooks_helper | ||
from official.utils.logs import logger | ||
def model_main(flags_obj): | ||
estimator = ... | ||
benchmark_logger = logger.get_benchmark_logger() | ||
benchmark_logger.log_run_info(...) | ||
train_hooks = hooks_helper.get_train_hooks(...) | ||
for epoch in range(10): | ||
estimator.train(..., hooks=train_hooks) | ||
eval_results = estimator.evaluate(...) | ||
# Log a dictionary of metrics | ||
benchmark_logger.log_evaluation_result(eval_results) | ||
# Log an individual metric | ||
benchmark_logger.log_metric(...) | ||
def main(_): | ||
with logger.benchmark_context(flags.FLAGS): | ||
model_main(flags.FLAGS) | ||
if __name__ == "__main__": | ||
# define flags | ||
absl_app.run(main) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters