Skip to content

Commit 0c22f3d

Browse files
committed
lm workload with linear model
1 parent 4189ae0 commit 0c22f3d

File tree

9 files changed

+187
-118
lines changed

9 files changed

+187
-118
lines changed

algoperf/workloads/cifar/cifar_jax/workload.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,17 +71,6 @@ def _build_input_queue(
7171
cache,
7272
repeat_final_dataset)
7373

74-
def sync_batch_stats(
75-
self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState:
76-
"""Sync the batch statistics across replicas."""
77-
# An axis_name is passed to pmap which can then be used by pmean.
78-
# In this case each device has its own version of the batch statistics
79-
# and we average them.
80-
avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')
81-
new_model_state = model_state.copy()
82-
new_model_state['batch_stats'] = avg_fn(model_state['batch_stats'])
83-
return new_model_state
84-
8574
def init_model_fn(
8675
self,
8776
rng: spec.RandomState,

algoperf/workloads/lm/input_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def batch_iterator():
8787
tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size)
8888
inputs, targets = tokens[:, :-1], tokens[:, 1:]
8989
inputs, targets = jax.device_put(inputs), jax.device_put(targets)
90-
yield inputs, targets
90+
yield {'inputs': inputs, 'targets': targets}
9191

9292
return batch_iterator()
9393

algoperf/workloads/lm/lm_jax/models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@ class LinearModel(nn.Module):
77
@nn.compact
88
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
99
x = nn.Dense(
10-
512,
10+
10,
1111
kernel_init=nn.initializers.normal(0.02),
1212
bias_init=nn.initializers.zeros
1313
)(inputs)
1414
return nn.Dense(
1515
self.vocab_size,
1616
kernel_init=nn.initializers.normal(0.02),
17-
bias_init=nn.initializers.zeros
17+
bias_init=nn.initializers.zeros,
18+
name="output"
1819
)(x)

algoperf/workloads/lm/lm_jax/workload.py

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,57 @@
22

33
from typing import Dict, Optional, Tuple
44

5+
import jax
56
import jax.numpy as jnp
7+
import optax
68
from flax import jax_utils
79
from algoperf import param_utils
10+
from algoperf import sharding_utils
811
from algoperf import spec
912
from algoperf.workloads.lm.workload import BaseLmWorkload
1013
from algoperf.workloads.lm.lm_jax.models import LinearModel
14+
from algoperf.workloads.lm.input_pipeline import get_hf_dataloader
1115

1216

1317
class LmWorkload(BaseLmWorkload):
1418
"""LM JAX workload."""
1519

20+
def _build_input_queue(self,
21+
data_rng: jax.random.PRNGKey,
22+
split: str,
23+
data_dir: str,
24+
global_batch_size: int,
25+
num_batches: Optional[int] = None,
26+
repeat_final_dataset: bool = False):
27+
"""Build an input queue using HuggingFace FineWeb dataset."""
28+
del num_batches
29+
del repeat_final_dataset
30+
loader = get_hf_dataloader(
31+
cache_dir=data_dir,
32+
data_rng=data_rng,
33+
batch_size=global_batch_size,
34+
seq_len=self._seq_len,
35+
framework="jax",
36+
split=split)
37+
return loader
38+
1639
def init_model_fn(
1740
self,
1841
rng: spec.RandomState,
1942
dropout_rate: Optional[float] = None,
2043
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
2144

22-
model = LinearModel(vocab_size=self._vocab_size)
45+
self._model = LinearModel(vocab_size=self._vocab_size)
2346
input_shape = (1, self._seq_len, self._vocab_size)
24-
variables = model.init(rng, jnp.ones(input_shape, jnp.float32))
25-
model_state, params = variables.pop('params')
26-
47+
params_rng, init_rng = jax.random.split(rng)
48+
print(params_rng)
49+
# variables = model.init(init_rng, jnp.ones(input_shape, jnp.float32))
50+
variables = jax.jit(self._model.init)({'params': params_rng}, jnp.ones(input_shape, jnp.float32))
51+
params = variables['params']
2752
self._param_shapes = param_utils.jax_param_shapes(params)
2853
self._param_types = param_utils.jax_param_types(self._param_shapes)
29-
model_state = jax_utils.replicate(model_state)
30-
params = jax_utils.replicate(params)
31-
54+
params = sharding_utils.shard_replicated(params)
55+
model_state = None
3256
return params, model_state
3357

3458
def model_fn(
@@ -40,15 +64,51 @@ def model_fn(
4064
rng: spec.RandomState,
4165
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
4266

43-
del mode, rng, update_batch_norm # Not used for linear model
67+
del mode, rng, update_batch_norm, model_state
4468
inputs = batch['inputs']
45-
logits = self._model.apply({'params': params, **model_state}, inputs)
46-
return logits, model_state
69+
logits = self._model.apply({'params': params}, inputs)
70+
return logits, None
71+
72+
def loss_fn(
73+
self,
74+
label_batch: spec.Tensor,
75+
logits_batch: spec.Tensor,
76+
mask_batch: Optional[spec.Tensor] = None,
77+
label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]:
78+
"""Compute cross-entropy loss for language modeling in JAX."""
79+
vocab_size = logits_batch.shape[-1]
80+
81+
if len(label_batch.shape) == len(logits_batch.shape):
82+
# One-hot labels
83+
loss = -jnp.sum(label_batch * jax.nn.log_softmax(logits_batch, axis=-1))
84+
else:
85+
# Dense labels
86+
loss = -jax.nn.log_softmax(logits_batch)[jnp.arange(label_batch.shape[0]), label_batch]
87+
88+
if mask_batch is not None:
89+
loss = loss * mask_batch
90+
91+
n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0]
92+
return {
93+
'summed': loss.sum(),
94+
'n_valid_examples': n_valid,
95+
'per_example': loss
96+
}
4797

98+
def is_output_params(self, param_name: str) -> bool:
99+
"""Return whether the given parameter is an output parameter."""
100+
return param_name.contains('output')
101+
48102
def _eval_batch(self,
49103
params: spec.ParameterContainer,
50104
batch: Dict[str, spec.Tensor],
51105
model_state: spec.ModelAuxiliaryState,
52106
rng: spec.RandomState) -> spec.Tensor:
53107
"""Evaluate the model on a single batch."""
54-
pass
108+
logits, _ = self.model_fn(
109+
params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False)
110+
targets = batch['targets']
111+
112+
# Calculate cross-entropy loss
113+
loss = -jnp.sum(targets * jax.nn.log_softmax(logits, axis=-1))
114+
return loss

algoperf/workloads/lm/lm_pytorch/workload.py

Lines changed: 70 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -66,68 +66,38 @@ def _build_input_queue(
6666
global_batch_size: int,
6767
num_batches: Optional[int] = None,
6868
repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]:
69-
not_train = split != 'train'
70-
per_device_batch_size = int(global_batch_size / N_GPUS)
71-
72-
seq_len = self._seq_len # TODO: define it somewehere else?
73-
dtype = torch.int32 # TODO: decide between int32 and int64.
74-
75-
# Only create and iterate over tf input pipeline in one Python process to
76-
# avoid creating too many threads.
77-
if RANK == 0:
78-
np_iter = super()._build_input_queue(
79-
data_rng=data_rng,
80-
split=split,
81-
data_dir=data_dir,
82-
global_batch_size=global_batch_size,
83-
num_batches=num_batches,
84-
repeat_final_dataset=repeat_final_dataset)
69+
"""Build an input queue for the given split."""
70+
from algoperf.workloads.lm.input_pipeline import get_hf_dataloader
71+
72+
loader = get_hf_dataloader(
73+
cache_dir=data_dir,
74+
data_rng=data_rng,
75+
batch_size=global_batch_size,
76+
seq_len=self._seq_len,
77+
framework="torch",
78+
split=split)
79+
seq_len = self._seq_len
8580
weights = None
86-
87-
while True:
88-
# Only iterate over tf input pipeline in one Python process to
89-
# avoid creating too many threads.
90-
if RANK == 0:
91-
batch = next(np_iter) # pylint: disable=stop-iteration-return
92-
inputs = torch.as_tensor(
93-
batch['inputs'], dtype=dtype,
94-
device=DEVICE) # (N_GPUS, global_batch_size, seq_len)
95-
targets = torch.as_tensor(
96-
batch['targets'], dtype=dtype,
97-
device=DEVICE) # (N_GPUS, global_batch_size, seq_len)
98-
99-
# Send batch to other devices when using DDP.
100-
if USE_PYTORCH_DDP:
101-
if not_train:
102-
# During eval, the batch size of the remainder might be different.
103-
per_device_batch_size = torch.tensor(
104-
len(targets[0]), dtype=dtype, device=DEVICE)
105-
dist.broadcast(per_device_batch_size, src=0)
106-
# We don't broadcast the shard for RANK 0.
107-
dist.broadcast(inputs[1:], src=0)
108-
dist.broadcast(targets[1:], src=0)
109-
110-
# RANK 0 extracts his shard. If not DDP, this just flattens.
111-
inputs, targets = inputs[0], targets[0]
112-
113-
else:
114-
# Receive batch from rank 0.
115-
if not_train:
116-
# During eval, the batch size of the remainder might be different.
117-
per_device_batch_size = torch.empty((1,), dtype=dtype, device=DEVICE)
81+
82+
dtype = torch.long
83+
is_train = split == 'train'
84+
85+
for batch in loader:
86+
inputs, targets = batch
87+
88+
if USE_PYTORCH_DDP:
89+
if not is_train:
90+
# During eval, the batch size of the remainder might be different
91+
per_device_batch_size = torch.tensor(
92+
len(targets[0]), dtype=dtype, device=DEVICE)
11893
dist.broadcast(per_device_batch_size, src=0)
119-
120-
# N_GPUS - 1 since we don't broadcast the shard for RANK 0.
121-
inputs = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len),
122-
dtype=dtype,
123-
device=DEVICE)
124-
targets = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len),
125-
dtype=dtype,
126-
device=DEVICE)
94+
95+
# Broadcast to all devices
12796
dist.broadcast(inputs, src=0)
12897
dist.broadcast(targets, src=0)
129-
# RANK - 1 since we don't broadcast the shard for RANK 0.
130-
inputs, targets = inputs[RANK - 1], targets[RANK - 1]
98+
99+
if weights is None:
100+
weights = torch.ones(inputs.shape[0], device=DEVICE)
131101

132102
if weights is None:
133103
weights = torch.ones(per_device_batch_size, device=DEVICE)
@@ -138,10 +108,51 @@ def _build_input_queue(
138108
}
139109
yield batch
140110

111+
def is_output_params(self, param_name: str) -> bool:
112+
"""Return whether the given parameter is an output parameter."""
113+
return 'output.weight' in param_name or 'output.bias' in param_name
114+
141115
def _eval_batch(self,
142116
params: spec.ParameterContainer,
143117
batch: Dict[str, spec.Tensor],
144118
model_state: spec.ModelAuxiliaryState,
145119
rng: spec.RandomState) -> spec.Tensor:
146120
"""Evaluate the model on a single batch."""
147-
pass
121+
model = params
122+
logits, _ = self.model_fn(
123+
model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False)
124+
targets = batch['targets']
125+
126+
# Calculate cross-entropy loss
127+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
128+
loss = -torch.sum(targets * log_probs)
129+
return loss
130+
def loss_fn(
131+
self,
132+
label_batch: spec.Tensor,
133+
logits_batch: spec.Tensor,
134+
mask_batch: Optional[spec.Tensor] = None,
135+
label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]:
136+
"""Compute cross-entropy loss for language modeling in PyTorch."""
137+
vocab_size = logits_batch.shape[-1]
138+
139+
if len(label_batch.shape) == len(logits_batch.shape):
140+
# One-hot labels
141+
log_probs = torch.nn.functional.log_softmax(logits_batch, dim=-1)
142+
loss = -torch.sum(label_batch * log_probs, dim=-1)
143+
else:
144+
# Dense labels
145+
loss = torch.nn.functional.cross_entropy(
146+
logits_batch,
147+
label_batch,
148+
reduction='none')
149+
150+
if mask_batch is not None:
151+
loss = loss * mask_batch
152+
153+
n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0]
154+
return {
155+
'summed': loss.sum(),
156+
'n_valid_examples': n_valid,
157+
'per_example': loss
158+
}

0 commit comments

Comments
 (0)