Skip to content

[WIP] Jit switch #869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 69 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
ae48ccd
Use jax.jit for sharding initial steps
rka97 Nov 21, 2024
eb5cac7
Use jax.jit for adamw
rka97 Nov 21, 2024
82977da
Pass yapf checks
rka97 Dec 9, 2024
99545d4
CIFAR workload sharding
rka97 Dec 9, 2024
018711a
librispeech_conformer now running
rka97 Jan 7, 2025
fbeb5f1
fix formatting
rka97 Feb 5, 2025
6e4e7b0
shard default
rka97 Feb 5, 2025
4a2c02d
start imagenet
rka97 Feb 5, 2025
47beba1
remove bn sync in imagenet (jit handles it automatically)
rka97 Feb 5, 2025
3a18f19
ImageNet-ViT also works
rka97 Feb 6, 2025
bd0f565
Start working on WMT. OOM error
rka97 Feb 20, 2025
3044efb
post-rebase, still on wmt
rka97 Feb 20, 2025
e301c49
cache sharding fix
rka97 Feb 20, 2025
e5ed97a
Merge branch 'dev' into jit_switch
priyakasimbeg Feb 21, 2025
4fcf984
target_setting_algorithms sharding, compilation caching
rka97 Feb 21, 2025
d147e39
Update tests to correct batch size
rka97 Feb 21, 2025
a2b61be
yapf and isort checks..
rka97 Feb 21, 2025
be11c23
Merge branch 'jit_switch' of https://github.com/mlcommons/algorithmic…
priyakasimbeg Mar 6, 2025
e2a3b5f
Merge branch 'dev' into jit_switch
priyakasimbeg Mar 6, 2025
a80f4ec
switch fastmri from pmap to jit
priyakasimbeg Mar 7, 2025
c39ca51
migrate criteo workload
priyakasimbeg Mar 7, 2025
06377d9
update utils function used for sharding conformer
priyakasimbeg Mar 7, 2025
9cbe7d9
update conformer and deepspeech
priyakasimbeg Mar 8, 2025
c6ecd67
debugging
priyakasimbeg Mar 11, 2025
f35690d
debuging
priyakasimbeg Mar 12, 2025
848b50c
reformatting
priyakasimbeg Mar 18, 2025
fb62eae
reformatting
priyakasimbeg Mar 18, 2025
fe3f9f0
reformatting
priyakasimbeg Mar 18, 2025
004afbd
reformatting
priyakasimbeg Mar 18, 2025
f1db3d3
reformatting
priyakasimbeg Mar 18, 2025
c208cc7
sharding deepspeech
priyakasimbeg Mar 19, 2025
2e4cc9e
ogbg jit migration
priyakasimbeg Mar 19, 2025
d3a06fc
deepspeech jit changes
priyakasimbeg Mar 20, 2025
2cfa2a9
set jax to 0.5.1
priyakasimbeg Mar 20, 2025
70705a7
merge
priyakasimbeg Mar 20, 2025
75d6315
upgrade jax to 0.5.3
priyakasimbeg Apr 1, 2025
1df0690
change bsz back
priyakasimbeg Apr 1, 2025
c1d0c66
formatting
priyakasimbeg Apr 3, 2025
1b9466c
remove debugging statements from submission_runner.py
priyakasimbeg Apr 3, 2025
7a71cf0
pyproject.toml
priyakasimbeg Apr 3, 2025
9e1f337
clean up ogbg
priyakasimbeg Apr 3, 2025
a1d0abd
clean up ogbg
priyakasimbeg Apr 3, 2025
adb2b7e
Merge branch 'jit_switch' of github.com:mlcommons/algorithmic-efficie…
priyakasimbeg Apr 3, 2025
99caa03
clean up mnist workload.py
priyakasimbeg Apr 3, 2025
b14174b
refactoring & clean up
priyakasimbeg Apr 3, 2025
a3a9b9f
simplify changes in cifar jax
priyakasimbeg Apr 3, 2025
0a340a2
small fix
priyakasimbeg Apr 3, 2025
60c1cce
rename sharding utils
priyakasimbeg Apr 3, 2025
1edb724
fix sharding rename
priyakasimbeg Apr 3, 2025
49864fb
refactoring
priyakasimbeg Apr 3, 2025
7820ac6
modifications to cifar
priyakasimbeg Apr 4, 2025
0a2043c
fix
priyakasimbeg Apr 5, 2025
95037bf
clean up and small fixes
priyakasimbeg Apr 5, 2025
e79c761
add test for sharding invariance
priyakasimbeg Apr 5, 2025
110e792
fix
priyakasimbeg Apr 8, 2025
9c91c65
Update pyproject.toml
priyakasimbeg Apr 14, 2025
21bb997
Update workload.py
priyakasimbeg Apr 14, 2025
eb56919
Update workload.py
priyakasimbeg Apr 14, 2025
c489749
Merge branch 'jit_switch' of github.com:mlcommons/algorithmic-efficie…
priyakasimbeg Apr 14, 2025
1277cc2
upgrade jax
priyakasimbeg May 19, 2025
def4ac5
update dockerfile
priyakasimbeg May 19, 2025
450cbee
remove extra installs
priyakasimbeg May 19, 2025
89718e7
update jax version
priyakasimbeg May 20, 2025
7dcf5af
update install commands for pytorch cpu only
priyakasimbeg May 20, 2025
4335688
update dockerfile
priyakasimbeg May 20, 2025
8d1fe7e
update dockerfile
priyakasimbeg May 20, 2025
240e2e5
update dockerfile
priyakasimbeg May 20, 2025
cc8d604
update dockerfile
priyakasimbeg May 20, 2025
fe56eaf
update dockerfile
priyakasimbeg May 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,16 @@ Both options are described in detail in the [**Getting Started**](/docs/GETTING_
*TL;DR to install the Jax version for GPU run:*

```bash
pip3 install -e '.[pytorch_cpu]'
pip3 install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html'
pip3 install -e '.[pytorch_cpu]' -f https://download.pytorch.org/whl/cpu
pip3 install -e '.[jax_gpu]'
pip3 install -e '.[full]'
```

*TL;DR to install the PyTorch version for GPU run:*

```bash
pip3 install -e '.[jax_cpu]'
pip3 install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'
pip3 install -e '.[pytorch_gpu]'
pip3 install -e '.[full]'
```

Expand Down
4 changes: 0 additions & 4 deletions algoperf/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from flax import jax_utils
from flax.training import checkpoints as flax_checkpoints
from flax.training.checkpoints import latest_checkpoint
import jax
import numpy as np
from tensorflow.io import gfile # pytype: disable=import-error
import torch
Expand Down Expand Up @@ -193,10 +192,7 @@ def save_checkpoint(framework: str,
train_state, eval_results, global_step, preemption_count).
"""
if framework == 'jax':
model_params = jax.device_get(jax_utils.unreplicate(model_params))
opt_state, _ = optimizer_state
opt_state = jax.device_get(jax_utils.unreplicate(opt_state))
model_state = jax.device_get(jax_utils.unreplicate(model_state))
else:
if isinstance(
model_params,
Expand Down
6 changes: 2 additions & 4 deletions algoperf/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch.utils.data import DistributedSampler
from torch.utils.data import Sampler

from algoperf import jax_sharding_utils
from algoperf import spec


Expand Down Expand Up @@ -60,10 +61,7 @@ def _prepare(x):
if remainder_size != 0 or pad_to_global_batch_size:
x = pad(x, pad_size, padding_value=padding_value)

# Reshape (global_batch_size, ...) to
# (local_device_count, per_device_batch_size, ...).
# Assumes that `global_batch_size % local_device_count == 0`.
return x.reshape((local_device_count, -1, *x.shape[1:]))
return jax.device_put(x, jax_sharding_utils.get_batch_dim_sharding())

return jax.tree.map(_prepare, batch)

Expand Down
37 changes: 37 additions & 0 deletions algoperf/jax_sharding_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Utilities for dealing with sharding in JAX."""

import jax
from jax.sharding import NamedSharding, PartitionSpec as P


def get_replicate_sharding():
"""Returns a sharding spec that replicates data across all devices."""
mesh = jax.sharding.Mesh(jax.devices(), ('batch',))
return NamedSharding(mesh, P())


def get_batch_dim_sharding():
"""Returns a sharding spec that shards data along the first axis."""
mesh = jax.sharding.Mesh(jax.devices(), ('batch',))
return NamedSharding(mesh, P('batch'))


def shard_along_batch_dim(x):
"""Shards a tensor across all devices."""
mesh = jax.sharding.Mesh(jax.devices(), ('batch',))
return jax.tree.map(
lambda x: jax.device_put(x, NamedSharding(mesh, P('batch'))), x)


def replicate(x):
"""Replicates tensor across all devices."""
mesh = jax.sharding.Mesh(jax.devices(), ('batch',))
return jax.tree.map(
lambda x: jax.device_put(x, NamedSharding(mesh, P())), x)


def display_shard_info(x: jax.Array):
"""Displays shard info of a jax array."""
for shard in x.addressable_shards:
print(f"shard.device: {shard.device}, index: {shard.index}, replica_id:"
f" {shard.replica_id}.\n")
2 changes: 1 addition & 1 deletion algoperf/workloads/cifar/cifar_jax/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,5 +171,5 @@ def create_input_iter(
functools.partial(
shard_and_maybe_pad_np, global_batch_size=global_batch_size),
ds)
it = jax_utils.prefetch_to_device(it, 2)

return it
69 changes: 43 additions & 26 deletions algoperf/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import functools
from typing import Any, Dict, Iterator, Optional, Tuple

from flax import jax_utils
from flax import linen as nn
from flax.core import pop
import jax
Expand All @@ -13,6 +12,7 @@
import tensorflow_datasets as tfds

from algoperf import param_utils
from algoperf import jax_sharding_utils
from algoperf import spec
from algoperf.workloads.cifar.cifar_jax import models
from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter
Expand All @@ -31,6 +31,7 @@ def _build_cifar_dataset(
repeat_final_dataset: Optional[bool] = None
) -> Iterator[Dict[str, spec.Tensor]]:
ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir)
ds_builder.download_and_prepare()
train = split == 'train'
assert self.num_train_examples + self.num_validation_examples == 50000
if split in ['train', 'eval_train']:
Expand Down Expand Up @@ -96,8 +97,8 @@ def init_model_fn(
model_state, params = pop(variables, 'params')
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
model_state = jax_utils.replicate(model_state)
params = jax_utils.replicate(params)
model_state = jax_sharding_utils.replicate(params)
params = jax_sharding_utils.replicate(params)
return params, model_state

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
Expand Down Expand Up @@ -175,35 +176,51 @@ def _compute_metrics(self,
'loss': summed_loss,
'accuracy': accuracy,
}
metrics = lax.psum(metrics, axis_name='batch')
return metrics

@functools.partial(
jax.pmap,
axis_name='batch',
in_axes=(None, 0, 0, 0, None),
static_broadcasted_argnums=(0,))
def _eval_model(
self,
params: spec.ParameterContainer,
batch: Dict[str, spec.Tensor],
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
self,
params: spec.ParameterContainer,
batch: Dict[str, spec.Tensor],
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
"""Return the mean accuracy and loss as a dict."""
logits, _ = self.model_fn(
params,
batch,
model_state,
spec.ForwardPassMode.EVAL,
rng,
update_batch_norm=False)
weights = batch.get('weights')
if weights is None:
weights = jnp.ones(len(logits))
return self._compute_metrics(logits, batch['targets'], weights)

@functools.partial(
jax.jit,
in_shardings=(
jax_sharding_utils.get_replicate_sharding(), # params
jax_sharding_utils.get_batch_dim_sharding(), # batch
jax_sharding_utils.get_replicate_sharding(), # model_state
jax_sharding_utils.get_batch_dim_sharding(), # rng
),
)
def _eval_model_jitted(
params: spec.ParameterContainer,
batch: Dict[str, spec.Tensor],
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
"""Return the mean accuracy and loss as a dict."""
logits, _ = self.model_fn(
params,
batch,
model_state,
spec.ForwardPassMode.EVAL,
rng,
update_batch_norm=False)
weights = batch.get('weights')
if weights is None:
weights = jnp.ones(len(logits))
return self._compute_metrics(logits, batch['targets'], weights)

metrics = _eval_model_jitted(params,
batch,
model_state,
rng)
return jax.tree.map(lambda x: x.item(), metrics)

def _normalize_eval_metrics(
self, num_examples: int, total_metrics: Dict[str,
Any]) -> Dict[str, float]:
"""Normalize eval metrics."""
return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics)
return jax.tree_map(lambda x: x / num_examples, total_metrics)
23 changes: 13 additions & 10 deletions algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from algoperf import param_utils
from algoperf import spec
from algoperf.workloads.criteo1tb.criteo1tb_jax import models
from algoperf import jax_sharding_utils
from algoperf.workloads.criteo1tb.workload import \
BaseCriteo1TbDlrmSmallWorkload

Expand Down Expand Up @@ -105,7 +106,7 @@ def init_model_fn(
initial_params = initial_variables['params']
self._param_shapes = param_utils.jax_param_shapes(initial_params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
return jax_utils.replicate(initial_params), None
return jax_sharding_utils.replicate(initial_params), None

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_7'
Expand All @@ -129,13 +130,16 @@ def model_fn(
return logits_batch, None

@functools.partial(
jax.pmap,
axis_name='batch',
in_axes=(None, 0, 0),
static_broadcasted_argnums=(0,))
def _eval_batch_pmapped(self,
params: spec.ParameterContainer,
batch: Dict[str, spec.Tensor]) -> spec.Tensor:
jax.jit,
in_shardings=(
jax_sharding_utils.get_replicate_sharding(),
jax_sharding_utils.get_batch_dim_sharding(),
),
static_argnums=(0,),
out_shardings=jax_sharding_utils.get_replicate_sharding())
def _eval_batch_jitted(self,
params: spec.ParameterContainer,
batch: Dict[str, spec.Tensor]) -> spec.Tensor:
logits, _ = self.model_fn(
params,
batch,
Expand All @@ -156,8 +160,7 @@ def _eval_batch(self,
batch: Dict[str, spec.Tensor]) -> spec.Tensor:
# We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of
# shape (local_device_count,) will all be different values.
return np.array(
self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64)
return np.array(self._eval_batch_jitted(params, batch), dtype=np.float64)


class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):
Expand Down
19 changes: 10 additions & 9 deletions algoperf/workloads/fastmri/fastmri_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from algoperf import param_utils
from algoperf import spec
from algoperf import jax_sharding_utils
import algoperf.random_utils as prng
from algoperf.workloads.fastmri.fastmri_jax.models import UNet
from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim
Expand Down Expand Up @@ -39,7 +40,7 @@ def init_model_fn(
params = variables['params']
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
params = jax_utils.replicate(params)
params = jax_sharding_utils.replicate(params)
return params, None

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
Expand Down Expand Up @@ -94,10 +95,12 @@ def loss_fn(
}

@functools.partial(
jax.pmap,
axis_name='batch',
in_axes=(None, 0, 0, 0),
static_broadcasted_argnums=(0,))
jax.jit,
in_shardings=(jax_sharding_utils.get_replicate_sharding(),
jax_sharding_utils.get_batch_dim_sharding(),
jax_sharding_utils.get_replicate_sharding()),
static_argnums=(0,),
out_shardings=jax_sharding_utils.get_replicate_sharding())
def _eval_model(self,
params: spec.Tensor,
batch: Dict[str, spec.Tensor],
Expand Down Expand Up @@ -126,7 +129,6 @@ def _eval_model(self,
'ssim': ssim_sum,
'loss': summed_loss,
}
metrics = jax.lax.psum(metrics, axis_name='batch')
return metrics

def _eval_model_on_split(self,
Expand Down Expand Up @@ -154,13 +156,12 @@ def _eval_model_on_split(self,
num_batches=num_batches)

total_metrics = {'ssim': 0., 'loss': 0.}
eval_rngs = prng.split(model_rng, jax.local_device_count())
for _ in range(num_batches):
batch = next(self._eval_iters[split])
# We already sum these metrics across devices inside _eval_model.
synced_metrics = self._eval_model(params, batch, eval_rngs)
synced_metrics = self._eval_model(params, batch, model_rng)
total_metrics = {
k: v + synced_metrics[k][0] for k, v in total_metrics.items()
k: v + synced_metrics[k] for k, v in total_metrics.items()
}
return {k: float(v.item() / num_examples) for k, v in total_metrics.items()}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ def create_input_iter(split: str,
ds)

# Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%.
it = jax_utils.prefetch_to_device(it, 2)
# TODO (kasimbeg): put on device
# it = jax_utils.prefetch_to_device(it, 2)

return iter(it)
Loading
Loading