Skip to content

Commit

Permalink
internal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 605069831
  • Loading branch information
gauravmishra authored and t5-copybara committed Feb 7, 2024
1 parent cd94b76 commit 927cda6
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
4 changes: 2 additions & 2 deletions t5x/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import inspect
from typing import Callable, Optional, TypeVar

from absl import app
from absl import flags
from absl import logging
from clu import metric_writers
Expand All @@ -29,6 +28,7 @@
from fiddle.experimental import serialization
import jax
from t5x import gin_utils
from t5x import utils


FLAGS = flags.FLAGS
Expand Down Expand Up @@ -210,6 +210,6 @@ def flags_parser(args):

jax.config.parse_flags_with_absl()
if using_fdl():
app.run(main, flags_parser=flags_parser)
utils.run_main(main, flags_parser=flags_parser)
else:
gin_utils.run(main)
3 changes: 2 additions & 1 deletion t5x/gin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from clu import metric_writers
import gin
import jax
from t5x import utils
import tensorflow as tf


Expand Down Expand Up @@ -131,7 +132,7 @@ def summarize_gin_config(

def run(main):
"""Wrapper for app.run that rewrites gin args before parsing."""
app.run(
utils.run_main(
main,
flags_parser=lambda a: app.parse_flags_with_usage(
list(rewrite_gin_args(a))
Expand Down
5 changes: 5 additions & 0 deletions t5x/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Tuple, Type, Union
import warnings

from absl import app # pylint: disable=unused-import
from absl import flags
from absl import logging
import airio
Expand Down Expand Up @@ -2372,3 +2373,7 @@ def find_first_checkpoint_step(
return checkpoint_steps_index




def run_main(main, flags_parser):
app.run(main, flags_parser=flags_parser)

0 comments on commit 927cda6

Please sign in to comment.