Skip to content

Commit 4189ae0

Browse files
committed
Merge branch 'jit_switch' into lm_workload
2 parents f095d4b + 2e4cc9e commit 4189ae0

File tree

31 files changed

+949
-318
lines changed

31 files changed

+949
-318
lines changed

.github/workflows/CI.yml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
pip install .[pytorch_cpu]
3838
pip install .[full]
3939
pip install -e .
40-
python tests/reference_algorithm_tests.py --workload=wmt --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json
40+
python tests/reference_algorithm_tests.py --workload=wmt --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json
4141
wmt_pytorch:
4242
runs-on: ubuntu-latest
4343
steps:
@@ -54,7 +54,7 @@ jobs:
5454
pip install .[pytorch_cpu]
5555
pip install .[full]
5656
pip install -e .
57-
python tests/reference_algorithm_tests.py --workload=wmt --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json
57+
python tests/reference_algorithm_tests.py --workload=wmt --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json
5858
imagenet_jax:
5959
runs-on: ubuntu-latest
6060
steps:
@@ -71,8 +71,8 @@ jobs:
7171
pip install .[pytorch_cpu]
7272
pip install .[full]
7373
pip install -e .
74-
python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json
75-
python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json
74+
python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json
75+
python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json
7676
imagenet_pytorch:
7777
runs-on: ubuntu-latest
7878
steps:
@@ -89,8 +89,8 @@ jobs:
8989
pip install .[pytorch_cpu]
9090
pip install .[full]
9191
pip install -e .
92-
python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json
93-
python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json
92+
python tests/reference_algorithm_tests.py --workload=imagenet_resnet --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_momentum.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json
93+
python tests/reference_algorithm_tests.py --workload=imagenet_vit --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json
9494
# uncomment when https://github.com/mlcommons/algorithmic-efficiency/issues/339 is resolved.
9595
criteo_jax:
9696
runs-on: ubuntu-latest
@@ -142,8 +142,8 @@ jobs:
142142
pip install .[pytorch_cpu]
143143
pip install .[full]
144144
pip install -e .
145-
python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
146-
python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=jax --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json
145+
python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
146+
python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json
147147
speech_pytorch:
148148
runs-on: ubuntu-latest
149149
steps:
@@ -160,8 +160,8 @@ jobs:
160160
pip install .[pytorch_cpu]
161161
pip install .[full]
162162
pip install -e .
163-
python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json
164-
python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=pytorch --global_batch_size=2 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
163+
python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json
164+
python tests/reference_algorithm_tests.py --workload=librispeech_conformer --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
165165
ogbg:
166166
runs-on: ubuntu-latest
167167
steps:

algoperf/checkpoint_utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from flax import jax_utils
1212
from flax.training import checkpoints as flax_checkpoints
1313
from flax.training.checkpoints import latest_checkpoint
14-
import jax
1514
import numpy as np
1615
from tensorflow.io import gfile # pytype: disable=import-error
1716
import torch
@@ -193,10 +192,7 @@ def save_checkpoint(framework: str,
193192
train_state, eval_results, global_step, preemption_count).
194193
"""
195194
if framework == 'jax':
196-
model_params = jax.device_get(jax_utils.unreplicate(model_params))
197195
opt_state, _ = optimizer_state
198-
opt_state = jax.device_get(jax_utils.unreplicate(opt_state))
199-
model_state = jax.device_get(jax_utils.unreplicate(model_state))
200196
else:
201197
if isinstance(
202198
model_params,

algoperf/data_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch.utils.data import DistributedSampler
1212
from torch.utils.data import Sampler
1313

14+
from algoperf import sharding_utils
1415
from algoperf import spec
1516

1617

@@ -50,6 +51,7 @@ def shard_and_maybe_pad_np(
5051
weights = batch.get('weights')
5152
# The weights will also be padded.
5253
batch['weights'] = np.ones(mask_shape) if weights is None else weights
54+
naive_sharding_spec = sharding_utils.get_naive_sharding_spec()
5355

5456
def _prepare(x):
5557
# Use _numpy() for zero-copy conversion between TF and NumPy.
@@ -60,10 +62,7 @@ def _prepare(x):
6062
if remainder_size != 0 or pad_to_global_batch_size:
6163
x = pad(x, pad_size, padding_value=padding_value)
6264

63-
# Reshape (global_batch_size, ...) to
64-
# (local_device_count, per_device_batch_size, ...).
65-
# Assumes that `global_batch_size % local_device_count == 0`.
66-
return x.reshape((local_device_count, -1, *x.shape[1:]))
65+
return jax.device_put(x, naive_sharding_spec)
6766

6867
return jax.tree.map(_prepare, batch)
6968

algoperf/sharding_utils.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""Utilities for dealing with sharding in JAX."""
2+
3+
import jax
4+
from jax.sharding import NamedSharding
5+
from jax.sharding import PartitionSpec
6+
7+
8+
def get_mesh() -> jax.sharding.Mesh:
9+
"""Creates a mesh from all available GPUs.
10+
Here, we simply create a one-dimensional mesh."""
11+
return jax.sharding.Mesh(jax.devices(), ("batch",))
12+
13+
14+
def get_replicated_sharding(mesh=None):
15+
"""Returns a sharding spec that replicates data across all devices."""
16+
if mesh is None:
17+
mesh = get_mesh()
18+
return NamedSharding(mesh, PartitionSpec())
19+
20+
21+
def shard_replicated(x, mesh=None):
22+
"""Shards a tensor across all devices."""
23+
if mesh is None:
24+
mesh = get_mesh()
25+
return jax.tree.map(
26+
lambda x: jax.device_put(x, get_replicated_sharding(mesh)), x)
27+
28+
29+
def get_naive_sharding_spec(mesh=None):
30+
"""Returns a sharding spec that shards data along the first axis."""
31+
if mesh is None:
32+
mesh = get_mesh()
33+
return NamedSharding(mesh, PartitionSpec("batch"))
34+
35+
36+
def get_naive_sharding(x, mesh=None):
37+
"""Given a 1D mesh and a tensor, try to shard along the appropriate axis."""
38+
if mesh is None:
39+
mesh = get_mesh()
40+
grid_size = mesh.shape["batch"]
41+
if len(x.shape) > 0 and x.shape[0] % grid_size == 0:
42+
return NamedSharding(mesh, PartitionSpec("batch"))
43+
else:
44+
return NamedSharding(mesh, PartitionSpec())
45+
46+
47+
def shard_params(params, mesh=None):
48+
"""Shards a parameter tree across all devices
49+
with naive sharding (see get_naive_sharding)."""
50+
if mesh is None:
51+
mesh = get_mesh()
52+
return jax.tree.map(lambda x: jax.device_put(x, get_naive_sharding(x)),
53+
params)
54+
55+
56+
def shard_naive(x, mesh=None):
57+
return shard_params(x, mesh)
58+
59+
60+
def get_naive_sharding_tree(input_tree, mesh=None):
61+
if mesh is None:
62+
mesh = get_mesh()
63+
return jax.tree.map(lambda x: get_naive_sharding(x, mesh), input_tree)
64+
65+
66+
def get_sharding_tree(params, mesh=None):
67+
"""Returns a sharding tree for a parameter tree."""
68+
return jax.tree.map(lambda x: get_naive_sharding(x, mesh), params)
69+
70+
71+
def get_empty_sharding(mesh=None):
72+
"""Returns a sharding spec that replicates data across all devices."""
73+
if mesh is None:
74+
mesh = get_mesh()
75+
return NamedSharding(mesh, PartitionSpec())
76+
77+
78+
def disp_shard_info(x: jax.Array):
79+
"""Displays shard info of a jax array."""
80+
for shard in x.addressable_shards:
81+
print(f"shard.device: {shard.device}, index: {shard.index}, replica_id:"
82+
f" {shard.replica_id}.\n")

algoperf/workloads/cifar/cifar_jax/input_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import functools
99
from typing import Dict, Iterator, Tuple
1010

11-
from flax import jax_utils
1211
import jax
1312
import tensorflow as tf
1413
import tensorflow_datasets as tfds
@@ -171,5 +170,6 @@ def create_input_iter(
171170
functools.partial(
172171
shard_and_maybe_pad_np, global_batch_size=global_batch_size),
173172
ds)
174-
it = jax_utils.prefetch_to_device(it, 2)
173+
# FIXME(rka97): Figure out how to do prefetching+sharding.
174+
# it = jax_utils.prefetch_to_device(it, 2)
175175
return it

0 commit comments

Comments
 (0)