Skip to content

Commit 804b888

Browse files
Hilly12recml authors
authored andcommitted
Add Keras trainer.
PiperOrigin-RevId: 744903627
1 parent bdeb41c commit 804b888

File tree

10 files changed

+1081
-19
lines changed

10 files changed

+1081
-19
lines changed

mlrx/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@
2727
from mlrx.training.core import Experiment
2828
from mlrx.training.core import run_experiment
2929
from mlrx.training.core import Trainer
30-
from mlrx.training.jax import JaxState
31-
from mlrx.training.jax import JaxTask
32-
from mlrx.training.jax import JaxTrainer
33-
from mlrx.training.jax import KerasState
30+
from mlrx.training.jax_trainer import JaxState
31+
from mlrx.training.jax_trainer import JaxTask
32+
from mlrx.training.jax_trainer import JaxTrainer
33+
from mlrx.training.jax_trainer import KerasState
34+
from mlrx.training.keras_trainer import KerasTask
35+
from mlrx.training.keras_trainer import KerasTrainer
3436
from mlrx.training.optax_factory import AdagradFactory
3537
from mlrx.training.optax_factory import AdamFactory
3638
from mlrx.training.optax_factory import OptimizerFactory

mlrx/training/core.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""Core training library for Jax."""
1515

1616
import abc
17-
from collections.abc import Mapping
17+
from collections.abc import Mapping, Sequence
1818
import dataclasses
1919
import enum
2020
from typing import Any, Generic, TypeVar
@@ -33,6 +33,8 @@
3333
TRAINING_COMPLETE_MARKER_FILE = "marker.txt"
3434
TRAIN_LOG_DIRNAME = "train"
3535
EVAL_LOG_DIRNAME = "val"
36+
KERAS_MODEL_SAVEFILE = "model.keras"
37+
ORBAX_CHECKPOINT_DEFAULT_KEY = "default"
3638

3739
DEFAULT_RNG_SEED = 0
3840
IN_TRAINER_CONTEXT = False # Set to true when run from the main trainer.
@@ -171,6 +173,15 @@ def get_iterators(
171173
return train_dataset, eval_datasets # pytype: disable=bad-return-type
172174

173175

176+
def get_shape(
177+
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor,
178+
) -> Sequence[int | None]:
179+
"""Gets the shape of a dense / sparse / ragged tensor."""
180+
if isinstance(x, tf.SparseTensor):
181+
return [x.shape[0]] + [None for _ in x.shape[1:]]
182+
return x.shape.as_list()
183+
184+
174185
def in_tracing_context() -> bool:
175186
"""Returns whether the current context is a tracing context."""
176187
return isinstance(jnp.ones(()), jax.core.Tracer)
File renamed without changes.

mlrx/training/jax_quality_test.py renamed to mlrx/training/jax_trainer_quality_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525
import jax.numpy as jnp
2626
import jaxtyping as jt
2727
import optax
28-
from mlrx.training import jax as jax_lib
28+
from mlrx.training import jax_trainer
2929
from mlrx.training import partitioning
3030
import tensorflow as tf
3131
import tensorflow_datasets as tfds
3232

3333

34-
class _MNISTTask(jax_lib.JaxTask):
34+
class _MNISTTask(jax_trainer.JaxTask):
3535
"""Task for fitting a CNN on MNIST."""
3636

3737
def create_datasets(self) -> tuple[tf.data.Dataset, tf.data.Dataset]:
@@ -126,7 +126,7 @@ def setUp(self):
126126
def test_mnist_e2e(self):
127127
model_dir = self.create_tempdir().full_path
128128
task = _MNISTTask()
129-
trainer = jax_lib.JaxTrainer(
129+
trainer = jax_trainer.JaxTrainer(
130130
partitioner=partitioning.DataParallelPartitioner(),
131131
train_steps=1000,
132132
steps_per_eval=50,

mlrx/training/jax_test.py renamed to mlrx/training/jax_trainer_test.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import optax
3131
import orbax.checkpoint as ocp
3232
from mlrx.training import core
33-
from mlrx.training import jax as jax_lib
33+
from mlrx.training import jax_trainer
3434
from mlrx.training import partitioning
3535
import tensorflow as tf
3636

@@ -42,7 +42,7 @@ def __call__(self, inputs: jax.Array) -> jax.Array:
4242
return nn.Dense(1, kernel_init=nn.initializers.constant(-1.0))(inputs)
4343

4444

45-
class _JaxTask(jax_lib.JaxTask):
45+
class _JaxTask(jax_trainer.JaxTask):
4646

4747
def create_datasets(
4848
self,
@@ -90,7 +90,7 @@ def eval_step(
9090
return {"loss": clu_metrics.Average.from_model_output(loss)}
9191

9292

93-
class _KerasJaxTask(jax_lib.JaxTask):
93+
class _KerasJaxTask(jax_trainer.JaxTask):
9494

9595
def create_datasets(self) -> tf.data.Dataset:
9696
def _map_fn(x: int):
@@ -106,7 +106,7 @@ def _map_fn(x: int):
106106

107107
def create_state(
108108
self, batch: jt.PyTree, rng: jax.Array
109-
) -> jax_lib.KerasState:
109+
) -> jax_trainer.KerasState:
110110
x, _ = batch
111111

112112
model = keras.Sequential(
@@ -122,11 +122,11 @@ def create_state(
122122
model.build(x.shape)
123123

124124
optimizer = optax.adagrad(0.1)
125-
return jax_lib.KerasState.create(model=model, tx=optimizer)
125+
return jax_trainer.KerasState.create(model=model, tx=optimizer)
126126

127127
def train_step(
128-
self, batch: jt.PyTree, state: jax_lib.KerasState, rng: jax.Array
129-
) -> tuple[jax_lib.KerasState, Mapping[str, clu_metrics.Metric]]:
128+
self, batch: jt.PyTree, state: jax_trainer.KerasState, rng: jax.Array
129+
) -> tuple[jax_trainer.KerasState, Mapping[str, clu_metrics.Metric]]:
130130
x, y = batch
131131

132132
def _loss_fn(tvars):
@@ -140,7 +140,7 @@ def _loss_fn(tvars):
140140
return state, {"loss": clu_metrics.Average.from_model_output(loss)}
141141

142142
def eval_step(
143-
self, batch: jt.PyTree, state: jax_lib.KerasState
143+
self, batch: jt.PyTree, state: jax_trainer.KerasState
144144
) -> Mapping[str, clu_metrics.Metric]:
145145
x, y = batch
146146
y_pred, _ = state.model.stateless_call(state.tvars, state.ntvars, x)
@@ -208,13 +208,13 @@ def setUp(self):
208208
)
209209
def test_jax_trainer(
210210
self,
211-
task_cls: type[jax_lib.JaxTask],
211+
task_cls: type[jax_trainer.JaxTask],
212212
mode: str,
213213
expected_keys: Sequence[str],
214214
):
215215
model_dir = self.create_tempdir().full_path
216216
task = task_cls()
217-
trainer = jax_lib.JaxTrainer(
217+
trainer = jax_trainer.JaxTrainer(
218218
partitioner=partitioning.DataParallelPartitioner(data_axis="batch"),
219219
train_steps=12,
220220
steps_per_eval=3,
@@ -258,7 +258,7 @@ class State:
258258
),
259259
)
260260
state = State(step=10, opt_state=tx.init({"a": jnp.ones((10, 10))}))
261-
metrics = jax_lib._state_metrics(state)
261+
metrics = jax_trainer._state_metrics(state)
262262
self.assertIn("optimizer/learning_rate", metrics)
263263
self.assertEqual(metrics["optimizer/learning_rate"].compute(), 0.1)
264264

0 commit comments

Comments
 (0)