Skip to content

orbax to reg checkpointer conversion #1246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions axlearn/common/checkpointer_orbax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import jax
import orbax.checkpoint as ocp
import tensorflow as tf
from contextlib import contextmanager
from absl import logging

from axlearn.common import utils
Expand Down Expand Up @@ -45,6 +46,15 @@
_GRAIN_INSTALLED = False


@contextmanager
def setup(spec: str):
"""Setups any required values as required by Orbax.


"""

yield

class _TfIteratorHandler(ocp.type_handlers.TypeHandler):
"""Serializes tf.data.Iterator.

Expand Down Expand Up @@ -237,7 +247,7 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool:
options=ocp.CheckpointManagerOptions(
create=True,
max_to_keep=cfg.keep_last_n,
enable_async_checkpointing=True,
enable_async_checkpointing=False,
step_name_format=self._name_format,
should_save_fn=save_fn_with_summaries,
enable_background_delete=True,
Expand Down Expand Up @@ -345,8 +355,8 @@ def _restore_args(x: Any) -> ocp.RestoreArgs:
)
except FileNotFoundError as e:
# Orbax raises FileNotFoundError if there are no checkpoints.
if step is not None:
raise ValueError(f"Failed to restore at step {step}.") from e
# if step is not None:
# raise ValueError(f"Failed to restore at step {step}.") from e
logging.info("Could not find any completed checkpoints under %s: %s", cfg.dir, e)
return None, state # Return the input state.

Expand Down
9 changes: 9 additions & 0 deletions axlearn/common/checkpointer_orbax_emergency.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@

FLAGS = flags.FLAGS

def save_axlearn_checkpoint(step: int, state, directory: str, name: str):
cfg = Checkpointer.default_config().set(name=name, dir=directory)
ckpt = cfg.instantiate(parent=None)
ckpt.save(step=step, state=state)
ckpt.wait_until_finished()

@contextmanager
def setup(spec: str):
Expand Down Expand Up @@ -819,6 +824,10 @@ def restore(
)
time_diff = time.perf_counter() - start_t
logging.info("Took %ss to restore emergency checkpoint from %s.", time_diff, cfg.dir)

logging.info("Saving an AXLearn tensorstore from the restored Orbax state...")
save_axlearn_checkpoint(step, restored_state, cfg.dir, cfg.name)
Comment on lines +828 to +829
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand it correctly, here we're still using an online approach to do the checkpoint conversion. This means that we have to allocate the same resource (or at least a slice) of training stage for each checkpoint conversion.

I'm wondering if we can do the conversion offline on a CPU only node with a large memory.


return step, restored_state

def wait_until_finished(self):
Expand Down
56 changes: 50 additions & 6 deletions axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import jax.numpy as jnp
import tensorflow as tf
from jax.sharding import PartitionSpec
from absl import logging

from axlearn.common import (
base_model,
Expand Down Expand Up @@ -643,6 +644,7 @@ def get_trainer_config_fn(
keep_every_n_steps: int = 50_000,
save_every_n_steps: Optional[int] = None,
init_state_builder: Optional[state_builder.Builder.Config] = None,
checkpointer: str = "",
) -> TrainerConfigFn:
"""Builds a TrainerConfigFn according to the model and input specs.

Expand Down Expand Up @@ -710,12 +712,54 @@ def config_fn() -> InstantiableConfig:
)
cfg.evalers[name] = evaler_cfg
# Summaries and checkpoints.
cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set(
n=save_every_n_steps or min(eval_every_n_steps, 5_000),
max_step=max_step,
)
cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps)
cfg.checkpointer.keep_last_n = 3
calculated_save_every_n_steps = save_every_n_steps or min(eval_every_n_steps, 500)
logging.info("checkpointer: %s",checkpointer)
if not checkpointer:
logging.info("In no checkpointer")
cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set(
n=calculated_save_every_n_steps,
max_step=max_step,
)
cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps)
cfg.checkpointer.keep_last_n = 3
elif checkpointer == "OrbaxCheckpointer":
logging.info("In orbax checkpointer")
from axlearn.common.checkpointer_orbax import OrbaxCheckpointer

ckpt_config: OrbaxCheckpointer.Config = (
OrbaxCheckpointer.default_config()
)
ckpt_config.save_policy = config_for_function(every_n_steps_and_last_policy).set(
n=calculated_save_every_n_steps,
max_step=max_step,
)
ckpt_config.keep_last_n = 3
cfg.checkpointer = ckpt_config
elif checkpointer == "OrbaxEmergencyCheckpointer":
# Prevent global dependency on Orbax.
# pylint: disable-next=import-outside-toplevel
from axlearn.common.checkpointer_orbax_emergency import OrbaxEmergencyCheckpointer

ckpt_config: OrbaxEmergencyCheckpointer.Config = (
OrbaxEmergencyCheckpointer.default_config()
)
ckpt_config.save_policy = config_for_function(every_n_steps_and_last_policy).set(
# n=calculated_save_every_n_steps,
# Every 15 minures ore more recommended
n=200,
max_step=max_step,
)
ckpt_config.local_save_policy = config_for_function(every_n_steps_and_last_policy).set(
# n=calculated_save_every_n_steps,
# Every 2 minutes or more generally recommended
n=30,
max_step=max_step,
)
ckpt_config.local_dir = "/host-tmp/checkpoints"
ckpt_config.keep_every_n_steps = min(max_step, keep_every_n_steps)
ckpt_config.keep_last_n = 3
ckpt_config.replica_axis_index = 1
cfg.checkpointer = ckpt_config
cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100)
cfg.summary_writer.max_queue = 1000
if len(mesh_axis_names) != len(mesh_shape):
Expand Down
36 changes: 31 additions & 5 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
RoFormerQKVLinear,
StackedTransformerLayer,
)
from absl import logging
from axlearn.common.base_layer import RematSpec
from axlearn.common.config import config_for_function
from axlearn.common.decoder import LmHead
Expand Down Expand Up @@ -366,6 +367,9 @@ def get_trainer_kwargs(
),
)
elif model_size == "7B":
import jax

gbs = len(jax.devices())
trainer_kwargs = dict(
model_kwargs=dict(
num_layers=32,
Expand All @@ -378,7 +382,7 @@ def get_trainer_kwargs(
),
learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
train_batch_size=gbs,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
mesh_rules=(
Expand Down Expand Up @@ -633,6 +637,9 @@ def get_trainer_kwargs(
),
)
elif model_size == "70B":
import jax

devices = len(jax.devices())
trainer_kwargs = dict(
model_kwargs=dict(
num_layers=80,
Expand All @@ -648,7 +655,7 @@ def get_trainer_kwargs(
),
learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
train_batch_size=devices*1,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(fsdp=-1),
mesh_rules=(
Expand Down Expand Up @@ -914,22 +921,40 @@ def trainer_configs(
"""
arch = "fuji"
config_map = {}
for version, model_size, flash_attention in itertools.product(
Version, MODEL_SIZES, [True, False]
orbax_options = [
(True, False), # use_orbax_emergency_ckpt = True, use_orbax_ckpt = False
(False, True), # use_orbax_emergency_ckpt = False, use_orbax_ckpt = True
(False, False), # Neither is used
]
for version, model_size, flash_attention, (use_orbax_emergency_ckpt, use_orbax_ckpt) in itertools.product(
Version, MODEL_SIZES, [True, False], orbax_options
):
if model_size not in TOTAL_TOKENS[version]: # This combination does not exist.
continue
vocab_size = VOCAB_SIZE[version]

current_suffix_parts = []
if flash_attention:
current_suffix_parts.append("-flash")
if use_orbax_emergency_ckpt:
current_suffix_parts.append("-orbaxem")
elif use_orbax_ckpt:
current_suffix_parts.append("-orbax")

current_suffix = "".join(current_suffix_parts)
logging.info(current_suffix)
config_name = make_config_name(
arch=arch,
model_size=model_size,
version=f"v{version.value}",
suffix="-flash" if flash_attention else "",
suffix=current_suffix,
)
kwargs = get_trainer_kwargs(
model_size, vocab_size=vocab_size, version=version, flash_attention=flash_attention
)
max_sequence_length = kwargs.pop("max_sequence_length")
checkpointer_str = "OrbaxEmergencyCheckpointer" if use_orbax_emergency_ckpt else ""
checkpointer_str = "OrbaxCheckpointer" if use_orbax_ckpt else ""
# pylint: disable-next=unexpected-keyword-arg,missing-kwoa
config_map[config_name] = get_trainer_config_fn(
train_input_source=train_input_source(
Expand All @@ -939,6 +964,7 @@ def trainer_configs(
evalers=evaler_config_dict(
eval_input_sources(vocab_size=vocab_size, max_sequence_length=max_sequence_length),
),
checkpointer=checkpointer_str,
**kwargs,
)

Expand Down