Skip to content

[WIP] LM Workload #860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 62 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
ae48ccd
Use jax.jit for sharding initial steps
rka97 Nov 21, 2024
eb5cac7
Use jax.jit for adamw
rka97 Nov 21, 2024
82977da
Pass yapf checks
rka97 Dec 9, 2024
99545d4
CIFAR workload sharding
rka97 Dec 9, 2024
018711a
librispeech_conformer now running
rka97 Jan 7, 2025
fbeb5f1
fix formatting
rka97 Feb 5, 2025
6e4e7b0
shard default
rka97 Feb 5, 2025
4a2c02d
start imagenet
rka97 Feb 5, 2025
47beba1
remove bn sync in imagenet (jit handles it automatically)
rka97 Feb 5, 2025
3a18f19
ImageNet-ViT also works
rka97 Feb 6, 2025
bd0f565
Start working on WMT. OOM error
rka97 Feb 20, 2025
3044efb
post-rebase, still on wmt
rka97 Feb 20, 2025
e301c49
cache sharding fix
rka97 Feb 20, 2025
e5ed97a
Merge branch 'dev' into jit_switch
priyakasimbeg Feb 21, 2025
4fcf984
target_setting_algorithms sharding, compilation caching
rka97 Feb 21, 2025
d147e39
Update tests to correct batch size
rka97 Feb 21, 2025
a2b61be
yapf and isort checks..
rka97 Feb 21, 2025
be11c23
Merge branch 'jit_switch' of https://github.com/mlcommons/algorithmic…
priyakasimbeg Mar 6, 2025
e2a3b5f
Merge branch 'dev' into jit_switch
priyakasimbeg Mar 6, 2025
a80f4ec
switch fastmri from pmap to jit
priyakasimbeg Mar 7, 2025
c39ca51
migrate criteo workload
priyakasimbeg Mar 7, 2025
06377d9
update utils function used for sharding conformer
priyakasimbeg Mar 7, 2025
9cbe7d9
update conformer and deepspeech
priyakasimbeg Mar 8, 2025
c6ecd67
debugging
priyakasimbeg Mar 11, 2025
f35690d
debuging
priyakasimbeg Mar 12, 2025
da5f85a
first LM commit
Niccolo-Ajroldi Mar 11, 2025
a12a364
lm data pipeline
Niccolo-Ajroldi Mar 12, 2025
ca83ab8
testing
Niccolo-Ajroldi Mar 14, 2025
e3e78dc
LM workload tested torch pipeline
Niccolo-Ajroldi Mar 17, 2025
e619495
LM workload - fix torch tests
Niccolo-Ajroldi Mar 17, 2025
d8e9c56
add LM tests, remove dev files
Niccolo-Ajroldi Mar 18, 2025
6b4ff12
add LM tests, remove dev files
Niccolo-Ajroldi Mar 18, 2025
3c5c847
Stop tracking .gitignore
Niccolo-Ajroldi Mar 18, 2025
20d841b
Remove dev/ from repo, keep locally
Niccolo-Ajroldi Mar 18, 2025
f3ba059
fix comments
Niccolo-Ajroldi Mar 18, 2025
381451f
add class specifications
Niccolo-Ajroldi Mar 18, 2025
f111d2e
add workload LM info
Niccolo-Ajroldi Mar 18, 2025
808d398
restore data_utils.py tree map
Niccolo-Ajroldi Mar 18, 2025
35f8f89
fixed NFS bug
Niccolo-Ajroldi Mar 18, 2025
cbb6ee6
train/val split before concat
Niccolo-Ajroldi Mar 18, 2025
848b50c
reformatting
priyakasimbeg Mar 18, 2025
fb62eae
reformatting
priyakasimbeg Mar 18, 2025
fe3f9f0
reformatting
priyakasimbeg Mar 18, 2025
004afbd
reformatting
priyakasimbeg Mar 18, 2025
f1db3d3
reformatting
priyakasimbeg Mar 18, 2025
868987c
renamed datasets to avoid conflict with HF
Niccolo-Ajroldi Mar 19, 2025
8191f6d
Merge remote-tracking branch 'upstream/lm_workload' into lm_workload
Niccolo-Ajroldi Mar 19, 2025
dd59ded
renamed datasets to dataset
Niccolo-Ajroldi Mar 19, 2025
c208cc7
sharding deepspeech
priyakasimbeg Mar 19, 2025
2e4cc9e
ogbg jit migration
priyakasimbeg Mar 19, 2025
496b9c3
fix style
Niccolo-Ajroldi Mar 20, 2025
50989eb
fix formatting
Niccolo-Ajroldi Mar 20, 2025
5af0fdc
fix style
Niccolo-Ajroldi Mar 20, 2025
2683099
fix style
Niccolo-Ajroldi Mar 20, 2025
6b7ee29
fix yapf
Niccolo-Ajroldi Mar 20, 2025
46b645b
fix style
Niccolo-Ajroldi Mar 20, 2025
b3ae647
HF datasets pipeline
rka97 Mar 27, 2025
f095d4b
Testing with linear model
rka97 Mar 27, 2025
4189ae0
Merge branch 'jit_switch' into lm_workload
rka97 Mar 27, 2025
0c22f3d
lm workload with linear model
rka97 Apr 3, 2025
99c7b9b
add nanodo model
rka97 Apr 3, 2025
706d9f7
torch model
rka97 Apr 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
pip install .[pytorch_cpu]
pip install .[full]
pip install -e .
python tests/reference_algorithm_tests.py --workload=wmt --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json
python tests/reference_algorithm_tests.py --workload=wmt --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json
wmt_pytorch:
runs-on: ubuntu-latest
steps:
Expand All @@ -54,7 +54,7 @@ jobs:
pip install .[pytorch_cpu]
pip install .[full]
pip install -e .
python tests/reference_algorithm_tests.py --workload=wmt --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json
python tests/reference_algorithm_tests.py --workload=wmt --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json
imagenet_jax:
runs-on: ubuntu-latest
steps:
Expand All @@ -71,8 +71,8 @@ jobs:
pip install .[pytorch_cpu]
pip install .[full]
pip install -e .
python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json
python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json
python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json
python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json
imagenet_pytorch:
runs-on: ubuntu-latest
steps:
Expand All @@ -89,8 +89,8 @@ jobs:
pip install .[pytorch_cpu]
pip install .[full]
pip install -e .
python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json
python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json
python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json
python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json
# uncomment when https://github.com/mlcommons/algorithmic-efficiency/issues/339 is resolved.
criteo_jax:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -142,8 +142,8 @@ jobs:
pip install .[pytorch_cpu]
pip install .[full]
pip install -e .
python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json
python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json
speech_pytorch:
runs-on: ubuntu-latest
steps:
Expand All @@ -160,8 +160,8 @@ jobs:
pip install .[pytorch_cpu]
pip install .[full]
pip install -e .
python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json
python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json
python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
ogbg:
runs-on: ubuntu-latest
steps:
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ scoring/plots/
!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv
!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv

algoperf/_version.py
algoperf/_version.py
4 changes: 0 additions & 4 deletions algoperf/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from flax import jax_utils
from flax.training import checkpoints as flax_checkpoints
from flax.training.checkpoints import latest_checkpoint
import jax
import numpy as np
from tensorflow.io import gfile # pytype: disable=import-error
import torch
Expand Down Expand Up @@ -193,10 +192,7 @@ def save_checkpoint(framework: str,
train_state, eval_results, global_step, preemption_count).
"""
if framework == 'jax':
model_params = jax.device_get(jax_utils.unreplicate(model_params))
opt_state, _ = optimizer_state
opt_state = jax.device_get(jax_utils.unreplicate(opt_state))
model_state = jax.device_get(jax_utils.unreplicate(model_state))
else:
if isinstance(
model_params,
Expand Down
7 changes: 3 additions & 4 deletions algoperf/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch.utils.data import DistributedSampler
from torch.utils.data import Sampler

from algoperf import sharding_utils
from algoperf import spec


Expand Down Expand Up @@ -50,6 +51,7 @@ def shard_and_maybe_pad_np(
weights = batch.get('weights')
# The weights will also be padded.
batch['weights'] = np.ones(mask_shape) if weights is None else weights
naive_sharding_spec = sharding_utils.get_naive_sharding_spec()

def _prepare(x):
# Use _numpy() for zero-copy conversion between TF and NumPy.
Expand All @@ -60,10 +62,7 @@ def _prepare(x):
if remainder_size != 0 or pad_to_global_batch_size:
x = pad(x, pad_size, padding_value=padding_value)

# Reshape (global_batch_size, ...) to
# (local_device_count, per_device_batch_size, ...).
# Assumes that `global_batch_size % local_device_count == 0`.
return x.reshape((local_device_count, -1, *x.shape[1:]))
return jax.device_put(x, naive_sharding_spec)

return jax.tree.map(_prepare, batch)

Expand Down
2 changes: 2 additions & 0 deletions algoperf/param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def pytorch_param_types(
param_types[name] = spec.ParameterType.ATTENTION_BIAS
elif 'in_proj' in name:
param_types[name] = spec.ParameterType.ATTENTION_QKV
elif 'qkv' in name:
param_types[name] = spec.ParameterType.ATTENTION_QKV
elif 'kv_proj' in name:
param_types[name] = spec.ParameterType.ATTENTION_KV
elif 'k_proj' in name or 'key' in name:
Expand Down
82 changes: 82 additions & 0 deletions algoperf/sharding_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Utilities for dealing with sharding in JAX."""

import jax
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec


def get_mesh() -> jax.sharding.Mesh:
"""Creates a mesh from all available GPUs.
Here, we simply create a one-dimensional mesh."""
return jax.sharding.Mesh(jax.devices(), ("batch",))


def get_replicated_sharding(mesh=None):
"""Returns a sharding spec that replicates data across all devices."""
if mesh is None:
mesh = get_mesh()
return NamedSharding(mesh, PartitionSpec())


def shard_replicated(x, mesh=None):
"""Shards a tensor across all devices."""
if mesh is None:
mesh = get_mesh()
return jax.tree.map(
lambda x: jax.device_put(x, get_replicated_sharding(mesh)), x)


def get_naive_sharding_spec(mesh=None):
"""Returns a sharding spec that shards data along the first axis."""
if mesh is None:
mesh = get_mesh()
return NamedSharding(mesh, PartitionSpec("batch"))


def get_naive_sharding(x, mesh=None):
"""Given a 1D mesh and a tensor, try to shard along the appropriate axis."""
if mesh is None:
mesh = get_mesh()
grid_size = mesh.shape["batch"]
if len(x.shape) > 0 and x.shape[0] % grid_size == 0:
return NamedSharding(mesh, PartitionSpec("batch"))
else:
return NamedSharding(mesh, PartitionSpec())


def shard_params(params, mesh=None):
"""Shards a parameter tree across all devices
with naive sharding (see get_naive_sharding)."""
if mesh is None:
mesh = get_mesh()
return jax.tree.map(lambda x: jax.device_put(x, get_naive_sharding(x)),
params)


def shard_naive(x, mesh=None):
return shard_params(x, mesh)


def get_naive_sharding_tree(input_tree, mesh=None):
if mesh is None:
mesh = get_mesh()
return jax.tree.map(lambda x: get_naive_sharding(x, mesh), input_tree)


def get_sharding_tree(params, mesh=None):
"""Returns a sharding tree for a parameter tree."""
return jax.tree.map(lambda x: get_naive_sharding(x, mesh), params)


def get_empty_sharding(mesh=None):
"""Returns a sharding spec that replicates data across all devices."""
if mesh is None:
mesh = get_mesh()
return NamedSharding(mesh, PartitionSpec())


def disp_shard_info(x: jax.Array):
"""Displays shard info of a jax array."""
for shard in x.addressable_shards:
print(f"shard.device: {shard.device}, index: {shard.index}, replica_id:"
f" {shard.replica_id}.\n")
4 changes: 2 additions & 2 deletions algoperf/workloads/cifar/cifar_jax/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import functools
from typing import Dict, Iterator, Tuple

from flax import jax_utils
import jax
import tensorflow as tf
import tensorflow_datasets as tfds
Expand Down Expand Up @@ -171,5 +170,6 @@ def create_input_iter(
functools.partial(
shard_and_maybe_pad_np, global_batch_size=global_batch_size),
ds)
it = jax_utils.prefetch_to_device(it, 2)
# FIXME(rka97): Figure out how to do prefetching+sharding.
# it = jax_utils.prefetch_to_device(it, 2)
return it
Loading
Loading