diff --git a/t5x/config_utils.py b/t5x/config_utils.py index 94c2b4f8c..23355bed0 100644 --- a/t5x/config_utils.py +++ b/t5x/config_utils.py @@ -18,7 +18,6 @@ import inspect from typing import Callable, Optional, TypeVar -from absl import app from absl import flags from absl import logging from clu import metric_writers @@ -29,6 +28,7 @@ from fiddle.experimental import serialization import jax from t5x import gin_utils +from t5x import utils FLAGS = flags.FLAGS @@ -210,6 +210,6 @@ def flags_parser(args): jax.config.parse_flags_with_absl() if using_fdl(): - app.run(main, flags_parser=flags_parser) + utils.run_main(main, flags_parser=flags_parser) else: gin_utils.run(main) diff --git a/t5x/gin_utils.py b/t5x/gin_utils.py index ff2050599..64f118a21 100644 --- a/t5x/gin_utils.py +++ b/t5x/gin_utils.py @@ -22,6 +22,7 @@ from clu import metric_writers import gin import jax +from t5x import utils import tensorflow as tf @@ -131,7 +132,7 @@ def summarize_gin_config( def run(main): """Wrapper for app.run that rewrites gin args before parsing.""" - app.run( + utils.run_main( main, flags_parser=lambda a: app.parse_flags_with_usage( list(rewrite_gin_args(a)) diff --git a/t5x/utils.py b/t5x/utils.py index b79826039..78df37db9 100644 --- a/t5x/utils.py +++ b/t5x/utils.py @@ -29,6 +29,7 @@ from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Tuple, Type, Union import warnings +from absl import app # pylint: disable=unused-import from absl import flags from absl import logging import airio @@ -2372,3 +2373,7 @@ def find_first_checkpoint_step( return checkpoint_steps_index + + +def run_main(main, flags_parser): + app.run(main, flags_parser=flags_parser)