orbax to reg checkpointer conversion #1246
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Command
Orbax Emergency test:
`axlearn gcp bundle --name=$NAME
--bundler_spec=allow_dirty=True
--bundler_type=artifactregistry
--bundler_spec=dockerfile=Dockerfile
--bundler_spec=image=tpu
--bundler_spec=target=tpu
axlearn gcp launch run --cluster=stoelinga-axlearn
--runner_name gke_tpu_single
--name=$NAME
--instance_type=tpu-v6e-16
--host_mount_spec=name=tmp,host_path=/tmp,mount_path=/host-tmp
--num_replicas=3
--bundler_spec=allow_dirty=True
--bundler_type=artifactregistry --bundler_spec=image=tpu
--bundler_spec=dockerfile=Dockerfile --bundler_spec=target=tpu
-- python3 -m axlearn.common.launch_trainer_main
--init_module=axlearn.common.checkpointer_orbax_emergency:local_ckpt_dir=/host-tmp/checkpoints
--module=text.gpt.c4_trainer
--config=fuji-7B-v2-flash-orbaxem
--trainer_dir=$OUTPUT_DIR
--data_dir=gs://axlearn-public/tensorflow_datasets
--jax_backend=tpu
--mesh_selector=tpu-v6e-16
--trace_at_steps=3`
Orbax test
`axlearn gcp bundle --name=$NAME
--bundler_spec=allow_dirty=True
--bundler_type=artifactregistry
--bundler_spec=dockerfile=Dockerfile
--bundler_spec=image=tpu
--bundler_spec=target=tpu
axlearn gcp launch run --cluster=lkolluru-axlearn
--runner_name=gke_tpu_pathways
--name=$NAME
--instance_type=tpu-v6e-16
--num_replicas=1
--bundler_spec=allow_dirty=True
--bundler_type=artifactregistry --bundler_spec=image=tpu
--bundler_spec=dockerfile=Dockerfile --bundler_spec=target=tpu
-- python3 -m axlearn.common.launch_trainer_main
--init_module=axlearn.common.checkpointer_orbax
--module=text.gpt.c4_trainer
--config=fuji-7B-v2-flash-orbax
--trainer_dir=$OUTPUT_DIR
--data_dir=gs://axlearn-public/tensorflow_datasets
--jax_backend=proxy
--mesh_selector=tpu-v6e-16
--trace_at_steps=3`