Skip to content

Commit c335e34

Browse files
committed
lm workload dataset integration in jax
1 parent 4189ae0 commit c335e34

File tree

8 files changed

+261
-209
lines changed

8 files changed

+261
-209
lines changed

algoperf/workloads/cifar/cifar_jax/workload.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,17 +71,6 @@ def _build_input_queue(
7171
cache,
7272
repeat_final_dataset)
7373

74-
def sync_batch_stats(
75-
self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState:
76-
"""Sync the batch statistics across replicas."""
77-
# An axis_name is passed to pmap which can then be used by pmean.
78-
# In this case each device has its own version of the batch statistics
79-
# and we average them.
80-
avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')
81-
new_model_state = model_state.copy()
82-
new_model_state['batch_stats'] = avg_fn(model_state['batch_stats'])
83-
return new_model_state
84-
8574
def init_model_fn(
8675
self,
8776
rng: spec.RandomState,

algoperf/workloads/lm/input_pipeline.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,19 @@ def batch_iterator():
8787
tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size)
8888
inputs, targets = tokens[:, :-1], tokens[:, 1:]
8989
inputs, targets = jax.device_put(inputs), jax.device_put(targets)
90-
yield inputs, targets
91-
90+
batch = {
91+
"inputs": inputs,
92+
"targets": targets,
93+
}
94+
yield batch
9295
return batch_iterator()
9396

9497

9598
def get_lm_dataset(data_rng: jax.random.PRNGKey,
9699
split: str,
97100
data_dir: str,
98-
vocab_size: int,
99101
global_batch_size: int,
100-
num_batches: Optional[int] = None,
101-
repeat_final_dataset: bool = False,
102-
vocab_path: Optional[str] = None):
102+
num_batches: Optional[int] = None):
103103
"""Load HF dataset and return a TF dataset."""
104104

105105
dataset_path = os.path.join(data_dir, split)

algoperf/workloads/lm/lm_jax/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
1414
return nn.Dense(
1515
self.vocab_size,
1616
kernel_init=nn.initializers.normal(0.02),
17-
bias_init=nn.initializers.zeros
17+
bias_init=nn.initializers.zeros,
18+
name="output"
1819
)(x)

algoperf/workloads/lm/lm_jax/workload.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,36 @@
22

33
from typing import Dict, Optional, Tuple
44

5+
import jax
56
import jax.numpy as jnp
7+
import optax
68
from flax import jax_utils
79
from algoperf import param_utils
10+
from algoperf import sharding_utils
811
from algoperf import spec
912
from algoperf.workloads.lm.workload import BaseLmWorkload
1013
from algoperf.workloads.lm.lm_jax.models import LinearModel
14+
from algoperf.workloads.lm.input_pipeline import get_hf_dataloader, get_lm_dataset
1115

1216

1317
class LmWorkload(BaseLmWorkload):
1418
"""LM JAX workload."""
19+
def _build_input_queue(self,
20+
data_rng: jax.random.PRNGKey,
21+
split: str,
22+
data_dir: str,
23+
global_batch_size: int,
24+
num_batches: Optional[int] = None,
25+
repeat_final_dataset: bool = False):
26+
"""Build an input queue using pre-cached FineWeb dataset."""
27+
del num_batches
28+
del repeat_final_dataset
29+
loader = get_lm_dataset(
30+
data_rng=data_rng,
31+
split=split,
32+
data_dir=data_dir,
33+
global_batch_size=global_batch_size)
34+
return loader
1535

1636
def init_model_fn(
1737
self,
@@ -21,14 +41,15 @@ def init_model_fn(
2141

2242
model = LinearModel(vocab_size=self._vocab_size)
2343
input_shape = (1, self._seq_len, self._vocab_size)
24-
variables = model.init(rng, jnp.ones(input_shape, jnp.float32))
25-
model_state, params = variables.pop('params')
26-
44+
params_rng, init_rng = jax.random.split(rng)
45+
variables = jax.jit(model.init)({'params': params_rng},
46+
jnp.ones(input_shape, jnp.float32))
47+
params = variables['params']
2748
self._param_shapes = param_utils.jax_param_shapes(params)
2849
self._param_types = param_utils.jax_param_types(self._param_shapes)
29-
model_state = jax_utils.replicate(model_state)
30-
params = jax_utils.replicate(params)
31-
50+
params = sharding_utils.shard_replicated(params)
51+
model_state = None
52+
self._model = model
3253
return params, model_state
3354

3455
def model_fn(
@@ -40,15 +61,40 @@ def model_fn(
4061
rng: spec.RandomState,
4162
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
4263

43-
del mode, rng, update_batch_norm # Not used for linear model
44-
inputs = batch['inputs']
45-
logits = self._model.apply({'params': params, **model_state}, inputs)
46-
return logits, model_state
64+
del mode, rng, update_batch_norm, model_state
65+
inputs = jax.nn.one_hot(batch['inputs'], self._vocab_size, axis=-1)
66+
logits = self._model.apply({'params': params}, inputs)
67+
return logits, None
68+
69+
def loss_fn(
70+
self,
71+
label_batch: spec.Tensor, # One-hot labels.
72+
logits_batch: spec.Tensor, # Dense logits.
73+
mask_batch: Optional[spec.Tensor] = None,
74+
label_smoothing: Optional[float] = 0.0) -> Dict[str, spec.Tensor]:
75+
del mask_batch, label_smoothing
76+
logits_flat = logits_batch.reshape(-1, self._vocab_size)
77+
targets = jax.nn.one_hot(label_batch, self._vocab_size, axis=-1)
78+
targets_flat = targets.reshape(-1, self._vocab_size)
79+
# Cross-entropy loss
80+
loss = -jnp.sum(targets_flat * jax.nn.log_softmax(logits_flat, axis=-1))
81+
n_valid_examples = logits_flat.shape[0]
82+
return {'summed': loss, 'n_valid_examples': n_valid_examples}
4783

84+
def is_output_params(self, param_name: str) -> bool:
85+
"""Return whether the given parameter is an output parameter."""
86+
return param_name.contains('output')
87+
4888
def _eval_batch(self,
4989
params: spec.ParameterContainer,
5090
batch: Dict[str, spec.Tensor],
5191
model_state: spec.ModelAuxiliaryState,
5292
rng: spec.RandomState) -> spec.Tensor:
5393
"""Evaluate the model on a single batch."""
54-
pass
94+
logits, _ = self.model_fn(
95+
params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False)
96+
targets = batch['targets']
97+
98+
# Calculate cross-entropy loss
99+
loss = -jnp.sum(targets * jax.nn.log_softmax(logits, axis=-1))
100+
return loss

algoperf/workloads/lm/lm_pytorch/workload.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -66,35 +66,30 @@ def _build_input_queue(
6666
global_batch_size: int,
6767
num_batches: Optional[int] = None,
6868
repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]:
69-
not_train = split != 'train'
70-
per_device_batch_size = int(global_batch_size / N_GPUS)
71-
72-
seq_len = self._seq_len # TODO: define it somewehere else?
73-
dtype = torch.int32 # TODO: decide between int32 and int64.
74-
75-
# Only create and iterate over tf input pipeline in one Python process to
76-
# avoid creating too many threads.
77-
if RANK == 0:
78-
np_iter = super()._build_input_queue(
79-
data_rng=data_rng,
80-
split=split,
81-
data_dir=data_dir,
82-
global_batch_size=global_batch_size,
83-
num_batches=num_batches,
84-
repeat_final_dataset=repeat_final_dataset)
69+
"""Build an input queue for the given split."""
70+
from algoperf.workloads.lm.input_pipeline import get_hf_dataloader
71+
72+
loader = get_hf_dataloader(
73+
cache_dir=data_dir,
74+
data_rng=data_rng,
75+
batch_size=global_batch_size,
76+
seq_len=self._seq_len,
77+
framework="torch",
78+
split=split)
79+
seq_len = self._seq_len
8580
weights = None
86-
81+
8782
while True:
8883
# Only iterate over tf input pipeline in one Python process to
8984
# avoid creating too many threads.
9085
if RANK == 0:
91-
batch = next(np_iter) # pylint: disable=stop-iteration-return
86+
batch = next(dataset_iter) # pylint: disable=stop-iteration-return
9287
inputs = torch.as_tensor(
9388
batch['inputs'], dtype=dtype,
94-
device=DEVICE) # (N_GPUS, global_batch_size, seq_len)
89+
device=DEVICE) # (N_GPUS, per_device_batch_size, seq_len)
9590
targets = torch.as_tensor(
9691
batch['targets'], dtype=dtype,
97-
device=DEVICE) # (N_GPUS, global_batch_size, seq_len)
92+
device=DEVICE) # (N_GPUS, per_device_batch_size, seq_len)
9893

9994
# Send batch to other devices when using DDP.
10095
if USE_PYTORCH_DDP:
@@ -138,10 +133,22 @@ def _build_input_queue(
138133
}
139134
yield batch
140135

136+
def is_output_params(self, param_name: str) -> bool:
137+
"""Return whether the given parameter is an output parameter."""
138+
return 'output.weight' in param_name or 'output.bias' in param_name
139+
141140
def _eval_batch(self,
142141
params: spec.ParameterContainer,
143142
batch: Dict[str, spec.Tensor],
144143
model_state: spec.ModelAuxiliaryState,
145144
rng: spec.RandomState) -> spec.Tensor:
146145
"""Evaluate the model on a single batch."""
147-
pass
146+
model = params
147+
logits, _ = self.model_fn(
148+
model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False)
149+
targets = batch['targets']
150+
151+
# Calculate cross-entropy loss
152+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
153+
loss = -torch.sum(targets * log_probs)
154+
return loss

0 commit comments

Comments
 (0)