Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A handler for JAX arrays that uses colocated python."""

from __future__ import annotations

import asyncio
from typing import Sequence

from absl import logging
import jax
from jax.experimental import colocated_python
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.serialization.type_handlers import ParamInfo
from orbax.checkpoint._src.serialization.type_handlers import RestoreArgs
from orbax.checkpoint._src.serialization.type_handlers import SaveArgs


@colocated_python.colocated_python
async def _serialize(
info: ParamInfo,
value: jax.Array,
args: SaveArgs | None,
) -> None:
"""Function to be run on a remote host to serialize a single array."""
# TODO(b/283161063): remove this logging.
logging.info('Beginning serialization for %s.', info.name)
logging.info(
'[Colocated] _serialize started for param: %s on device: %s',
info.name,
value.device(),
)
# Must create a new handler on the remote host.
logging.info('[Colocated] Creating remote ArrayHandler for: %s.', info.name)
handler = type_handlers.ArrayHandler()
logging.info(
'[Colocated] Calling handler.serialize for: %s with args: %s.',
info.name,
args,
)
commit_futures = await handler.serialize([value], infos=[info], args=[args])
logging.info(
'[Colocated] Received %d commit futures for: %s.',
len(commit_futures),
info.name,
)
# All futures should be awaited.
for i, f in enumerate(commit_futures):
logging.info('[Colocated] Awaiting commit future %d for: %s.', i, info.name)
f.result()
logging.info('[Colocated] Commit future %d for %s completed.', i, info.name)
logging.info('[Colocated] _serialize finished for param: %s.', info.name)


@colocated_python.colocated_python
async def _deserialize(
info: ParamInfo,
args: RestoreArgs | None,
) -> jax.Array:
"""Function to be run on a remote host to deserialize a single array."""
logging.info(
'[Colocated] _deserialize started for param: %s with args: %s.',
info.name,
args,
)
# Must create a new handler on the remote host.
logging.info('[Colocated] Creating remote ArrayHandler for: %s.', info.name)
handler = type_handlers.ArrayHandler()
logging.info('[Colocated] Calling handler.deserialize for: %s.', info.name)
restored = await handler.deserialize([info], args=[args])
logging.info(
'[Colocated] Deserialization complete for: %s. Result count: %d.',
info.name,
len(restored),
)
if not restored:
raise ValueError(f'Failed to deserialize {info.name}.')
logging.info('[Colocated] _deserialize finished for param: %s.', info.name)
return restored[0]


class ColocatedPythonArrayHandler(type_handlers.ArrayHandler):
"""An implementation of TypeHandler for jax.Array on Pathways."""

async def serialize(
self,
values: Sequence[jax.Array],
infos: Sequence[ParamInfo],
args: Sequence[SaveArgs] | None = None,
) -> Sequence[future.Future]:
"""Serializes a jax.Array using colocated python."""
logging.info(
'ColocatedPythonArrayHandler.serialize called for %d values.',
len(values),
)
args = args or ([SaveArgs()] * len(values))
type_handlers.check_input_arguments(values, infos, args)
logging.info('Input arguments checked successfully.')

async def _serialize_all():
logging.info(
'Dispatching %d colocated serialization tasks.', len(values)
)
tasks = [
_serialize(info, v, arg) for info, v, arg in zip(infos, values, args)
]
await asyncio.gather(*tasks)
logging.info(
'All %d colocated serialization tasks completed.', len(values)
)

await _serialize_all()
logging.info('ColocatedPythonArrayHandler.serialize finished.')
# Returning a single future that is already complete.
# The actual work is awaited above.
return [future.Future()]

async def deserialize(
self,
infos: Sequence[ParamInfo],
args: Sequence[RestoreArgs] | None = None,
) -> Sequence[jax.Array]:
"""See superclass documentation."""
logging.info(
'ColocatedPythonArrayHandler.deserialize called for %d infos.',
len(infos),
)
args = args or ([RestoreArgs()] * len(infos))
logging.info(
'Dispatching %d colocated deserialization tasks.', len(infos)
)
tasks = [_deserialize(info, arg) for info, arg in zip(infos, args)]
results = await asyncio.gather(*tasks)
logging.info(
'All %d colocated deserialization tasks completed. Returning results.',
len(results),
)
return results
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from absl.testing import parameterized
from etils import epath
import jax
import numpy as np
from orbax.checkpoint import test_utils
from orbax.checkpoint._src.serialization import cloud_pathways_type_handlers
from orbax.checkpoint._src.serialization import type_handlers

from .learning.brain.research.jax.tests.multiprocess import multiprocess_test


class ColocatedPythonArrayHandlerTest(
unittest.IsolatedAsyncioTestCase, parameterized.TestCase
):
"""Test class."""

def setUp(self):
super().setUp()
self.directory = epath.Path(
self.create_tempdir(name='colocated_test').full_path
)
# TODO(b/364139319): Support more devices and processes.
assert jax.device_count() == 8
assert jax.process_count() == 1

test_utils.set_tensorstore_driver_for_test()

test_utils.sync_global_processes(
'ColocatedPythonArrayHandlerTest:setup_complete'
)

def tearDown(self):
test_utils.sync_global_processes(
'ColocatedPythonArrayHandlerTest:tests_complete'
)
super().tearDown()

@parameterized.product(
use_ocdbt=(True, False),
use_zarr3=(True, False),
)
async def test_serialize_deserialize(self, use_ocdbt, use_zarr3):
handler = cloud_pathways_type_handlers.ColocatedPythonArrayHandler()
sharding = jax.sharding.NamedSharding(
mesh=jax.sharding.Mesh(
devices=np.asarray(jax.devices()).reshape(2, 4),
axis_names=('x', 'y'),
),
spec=jax.sharding.PartitionSpec('y'),
)
arr = jax.device_put(np.arange(32), sharding)
info = test_utils.get_param_info('a', self.directory, is_ocdbt=use_ocdbt)
info.use_zarr3 = use_zarr3

futures = await handler.serialize([arr], [info])
for f in futures:
f.result()
test_utils.sync_global_processes(
'ColocatedPythonArrayHandlerTest:serialized'
)

restored = await handler.deserialize(
[info], [type_handlers.ArrayRestoreArgs(sharding=sharding)]
)
test_utils.assert_array_equal(self, arr, restored[0])

async def test_serialize_no_save_args(self):
handler = cloud_pathways_type_handlers.ColocatedPythonArrayHandler()
sharding = jax.sharding.NamedSharding(
mesh=jax.sharding.Mesh(
devices=np.asarray(jax.devices()).reshape(2, 4),
axis_names=('x', 'y'),
),
spec=jax.sharding.PartitionSpec('y'),
)
arr = jax.device_put(np.arange(32), sharding)
info = test_utils.get_param_info('a', self.directory)

with self.assertRaises(ValueError):
await handler.serialize([arr], [info], args=None)

async def test_serialize_deserialize_random_key(self):
handler = cloud_pathways_type_handlers.ColocatedPythonArrayHandler()
sharding = jax.sharding.NamedSharding(
mesh=jax.sharding.Mesh(
devices=np.asarray(jax.devices()).reshape(2, 4),
axis_names=('x', 'y'),
),
spec=jax.sharding.PartitionSpec('y'),
)
key = jax.random.key(0)
info = test_utils.get_param_info('a', self.directory)

futures = await handler.serialize([key], [info])
for f in futures:
f.result()
test_utils.sync_global_processes(
'ColocatedPythonArrayHandlerTest:serialized'
)

restored = await handler.deserialize(
[info], [type_handlers.ArrayRestoreArgs(sharding=sharding)]
)
test_utils.assert_array_equal(self, key, restored[0])


if __name__ == '__main__':
multiprocess_test.main()
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from orbax.checkpoint._src.serialization.type_handlers import SingleReplicaArrayRestoreArgs
from orbax.checkpoint._src.serialization.type_handlers import StringHandler
from orbax.checkpoint._src.serialization.type_handlers import TypeHandler
from orbax.checkpoint._src.serialization.cloud_pathways_type_handlers import ColocatedPythonArrayHandler

# TypeHandler Registry
from orbax.checkpoint._src.serialization.type_handlers import TypeHandlerRegistry
Expand Down
2 changes: 1 addition & 1 deletion checkpoint/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dependencies = [
'etils[epath,epy]',
'typing_extensions',
'msgpack',
'jax >= 0.5.0',
'jax >= 0.6.2',
'numpy',
'pyyaml',
'tensorstore >= 0.1.71',
Expand Down
73 changes: 73 additions & 0 deletions docs/guides/checkpoint/colocated_checkpointing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@

Colocated Python Checkpointing
================================

.. warning::
This is an experimental feature and requires a Pathways environment to run.

Design
------

The ColocatedPythonArrayHandler is an experimental feature that provides a more performant way to save and restore checkpoints in a multi-host, multi-device setting. It leverages Pathways-specific infrastructure to ensure that the checkpointing operations happen on the same devices as the data, avoiding unnecessary data transfer.

How to use the colocated_python annotation
------------------------------------------

To use the ColocatedPythonArrayHandler, you need to annotate your JAX arrays with ``ocp.args.StandardSave(..., colocated=True)``. This tells Orbax to use the ColocatedPythonArrayHandler for that specific array.

.. code-block:: python

import orbax.checkpoint as ocp
import jax.numpy as jnp

# ...

checkpointer = ocp.PyTreeCheckpointer()
data = {'my_array': jnp.ones((10,))}
save_args = ocp.args.StandardSave(colocated=True)
checkpointer.save(path, args=ocp.args.PyTreeSave(tree=data, save_args=save_args))

Example usage in memory
-----------------------

Here's a complete in-memory example of how to use the ColocatedPythonArrayHandler:

.. code-block:: python

import orbax.checkpoint as ocp
from orbax.checkpoint.experimental import ColocatedPythonArrayHandler
import jax
import jax.numpy as jnp
import numpy as np
from etils import epath

# Register the handler
ocp.register_type_handler(jax.Array, ColocatedPythonArrayHandler(), override=True)

class Example:
def __init__(self):
self.checkpointer = ocp.PyTreeCheckpointer()
self.path = epath.Path('/tmp/orbax-checkpoint')

def save(self, data):
save_args = ocp.args.StandardSave(colocated=True)
self.checkpointer.save(self.path, args=ocp.args.PyTreeSave(tree=data, save_args=save_args))

def restore(self):
return self.checkpointer.restore(self.path)

# Create some data
data = {'x': jnp.ones((10,)), 'y': np.arange(5)}

# Save the data
example = Example()
example.save(data)

# Restore the data
restored_data = example.restore()

# Verify the data
np.testing.assert_array_equal(restored_data['x'], data['x'])
np.testing.assert_array_equal(restored_data['y'], data['y'])

print("Checkpoint saved and restored successfully!")
3 changes: 2 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ Install from GitHub using the following.
guides/checkpoint/optimized_checkpointing
guides/checkpoint/transformations
guides/checkpoint/preemption_checkpointing
guides/checkpoint/async_checkpointing
guides/checkpoint/async_checkpointing
guides/checkpoint/colocated_checkpointing
guides/checkpoint/debug_guide
api_reference/checkpoint

Expand Down
Loading