Skip to content

Commit

Permalink
Support Irregular/nondivisible batch sizes by padding (#882)
Browse files Browse the repository at this point in the history
Also add support for padding the final batch

We use zeros for this padding, which works out b/c a 0's loss_mask means
no contribution to the loss
  • Loading branch information
dlwh authored Feb 12, 2025
1 parent b6dc7d4 commit 7b91583
Show file tree
Hide file tree
Showing 14 changed files with 265 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_entry_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ jobs:
pip install soundfile librosa
- name: Run entry tests with pytest
run: |
JAX_PLATFORMS="cpu" XLA_FLAGS=--xla_force_host_platform_device_count=8 PYTHONPATH=$(pwd)/tests:$(pwd)/src:$(pwd):. pytest tests -m entry
JAX_PLATFORMS="cpu" PYTHONPATH=$(pwd)/tests:$(pwd)/src:$(pwd):. pytest -s tests -m entry
5 changes: 5 additions & 0 deletions config/gpt2_small_fast_batch_schedule.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@ trainer:
per_device_parallelism: -1

train_batch_size:
- until: 10
value: 7
- until: 50
value: 39
- until: 1000
value: 128
- until: -1
value: 256
num_train_steps: 20000
allow_nondivisible_batch_size: true

# tensor_parallel_axes: ["position", "key_position"]
# tensor_parallel_axes: ["heads", "mlp"]
Expand Down
171 changes: 135 additions & 36 deletions src/levanter/data/loader.py

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions src/levanter/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ def validate_schedule_sorted(schedule: Sequence[ScheduleStep[T]]):
raise ValueError(f"Schedule is not sorted at index {i}")


def distinct_values(schedule: Sequence[ScheduleStep[T]] | T) -> set[T]:
if not isinstance(schedule, Sequence) or (schedule and not isinstance(schedule[0], ScheduleStep)):
return {schedule} # type: ignore
return set(step.value for step in schedule)


@dataclass
class BatchSegment:
start: int # The training step at which this batch size starts.
Expand Down Expand Up @@ -121,3 +127,6 @@ def batch_indices_at_step(self, bn):
last = self.segments[-1]
base = last.offset + (bn - last.start) * last.value
return range(base, base + last.value)

def unique_batch_sizes(self):
return set(seg.value for seg in self.segments)
15 changes: 13 additions & 2 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from levanter.checkpoint import CheckpointerConfig, load_checkpoint_or_initialize
from levanter.config import JsonAtom
from levanter.data import AsyncDataset, DataLoader
from levanter.data.loader import _round_to_nearest_multiple
from levanter.distributed import DistributedConfig, RayConfig
from levanter.grad_accum import microbatched
from levanter.optim.model_averaging import ModelAveragingConfig
Expand Down Expand Up @@ -515,6 +516,7 @@ def data_loader(self, dataset: AsyncDataset[X], batch: Optional[hax.Axis] = None
axis_resources=self.compute_axis_mapping,
prefetch_size=32,
batch_axis_name=batch_name,
allow_nondivisible_batch_size=self.config.allow_nondivisible_batch_size,
)

@cached_property
Expand Down Expand Up @@ -636,6 +638,13 @@ class TrainerConfig:
per_device_eval_parallelism: int = -1
"""how many examples to process in parallel on each device. -1 (default) means same as per_device_parallelism"""

allow_nondivisible_batch_size: bool = False
"""
Allow batch sizes to be non-divisible by the number of devices (or data axis size).
This is typically used when you want a specific batch size but have a weird number of devices.
"""

# Config related to duration
num_train_steps: int = 400_000 # number of training steps
steps_per_eval: int = 1_000 # how often to evaluate
Expand Down Expand Up @@ -879,8 +888,10 @@ def _validate_and_set_defaults(self):

if self.per_device_eval_parallelism == -1:
if self.per_device_parallelism == -1:
initial_train_batch_size = value_at_step(self.train_batch_size, 0)
self.per_device_eval_parallelism = initial_train_batch_size // self.data_axis_size
tbs = max(levanter.schedule.distinct_values(self.train_batch_size))
self.per_device_eval_parallelism = (
_round_to_nearest_multiple(tbs, self.data_axis_size) // self.data_axis_size
)
else:
self.per_device_eval_parallelism = self.per_device_parallelism

Expand Down
3 changes: 2 additions & 1 deletion src/levanter/trainer_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class TrainerState(eqx.Module, Generic[M]):
is_trainable: FilterTree = eqx.field(static=True)
mp: jmp.Policy = eqx.field(static=True)

model_averaging: ModelAveraging[M] | None = None
model_averaging: ModelAveraging[M]

@property
def int_step(self) -> int:
Expand Down Expand Up @@ -224,6 +224,7 @@ def take_train_step(
train_grads = trainables_only(grads, is_trainable)
overwrites, train_grads = partition_for_grad_overwrite(train_grads)
trainable_model = trainables_only(model, is_trainable)
_, trainable_model = partition_for_grad_overwrite(trainable_model)
updates, opt_state = optimizer.update(train_grads, opt_state, params=trainable_model, obj_fn=obj_fun)
model = apply_updates(model, updates, overwrites)

Expand Down
12 changes: 12 additions & 0 deletions tests/test_audio.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
import tempfile

import pytest
import ray
from datasets import load_dataset
from transformers import AutoProcessor, AutoTokenizer

from levanter.data.audio import AudioDatasetSourceConfig, AudioIODatasetConfig, BatchAudioProcessor
from levanter.store.cache import SerialCacheWriter
from levanter.utils.py_utils import logical_cpu_core_count
from test_utils import skip_if_hf_model_not_accessible, skip_if_no_soundlibs


def setup_module(module):
ray.init(
"local", num_cpus=max(2 * logical_cpu_core_count(), 8), ignore_reinit_error=True
) # 2x cpu count is faster on my m1


def teardown_module(module):
ray.shutdown()


@skip_if_no_soundlibs
@skip_if_hf_model_not_accessible("openai/whisper-tiny")
def test_whisper_batch_processor():
Expand Down
3 changes: 2 additions & 1 deletion tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def _dummy_step_info(step):
training_key=jax.random.PRNGKey(0),
is_trainable=True,
mp=None,
model_averaging=None,
),
loss=0.0,
step_duration=0.0,
Expand Down Expand Up @@ -167,7 +168,7 @@ def _make_state(step, key):
optim = optax.adam(1e-4)
opt_state = optim.init(arrays_only(model))

return TrainerState(step, model, optim, opt_state, key, is_trainable=True, mp=None)
return TrainerState(step, model, optim, opt_state, key, is_trainable=True, mp=None, model_averaging=None)


def test_checkpoint_simple():
Expand Down
4 changes: 2 additions & 2 deletions tests/test_doremi.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def test_estimate_mixture_weights():
# 2. x is not predictive of y at all, y is highly random (y ~ N(0, 1))
# 3. x is highly predictive of y, but it's very easy (y = sigmoid([1, 0, 0] x > 0.5)

Dim = hax.Axis("Dim", 5)
Batch = hax.Axis("Batch", 32)
Dim = hax.Axis("dim", 5)
Batch = hax.Axis("batch", 32)

# data loading needs to take place on CPU
with local_cpu_mesh():
Expand Down
3 changes: 1 addition & 2 deletions tests/test_eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_eval_lm_from_hf():
num_layers=2,
num_heads=2,
seq_len=1024,
hidden_dim=2,
hidden_dim=32,
use_flash_attention=True,
)

Expand All @@ -86,7 +86,6 @@ def test_eval_lm_from_hf():
config = eval_lm.EvalLmConfig(
data=data_config,
model=model_config,
hf_checkpoint="sshleifer/tiny-gpt2",
trainer=eval_lm.TrainerConfig(
per_device_eval_parallelism=len(jax.devices()),
max_eval_batches=1,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def test_flash_attention_acausal():
QPos = hax.Axis("QPos", BLOCK_SIZE * 2)
KPos = hax.Axis("KPos", BLOCK_SIZE * 2)

q = hax.random.normal(jrandom.PRNGKey(0), (QPos, Key))
k = hax.random.normal(jrandom.PRNGKey(1), (KPos, Key))
v = hax.random.normal(jrandom.PRNGKey(2), (KPos, Key))
q = hax.random.normal(jrandom.PRNGKey(0), (QPos, Key)) * 0.2
k = hax.random.normal(jrandom.PRNGKey(1), (KPos, Key)) * 0.2
v = hax.random.normal(jrandom.PRNGKey(2), (KPos, Key)) * 0.2

flash_out = flash_attention(QPos, KPos, Key, q, k, v, inference=True, block_size=BLOCK_SIZE)
hax_out = hnn.attention.dot_product_attention(KPos, Key, q, k, v)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def __call__(self, x):
@staticmethod
def init(*, key):
k1, k2 = jax.random.split(key)
first = hnn.Linear.init(In, Mid, key=k1)
second = hnn.Linear.init(Mid, In, key=k2)
first = hnn.Linear.init(In, Mid, key=k1, init_scale=0.02)
second = hnn.Linear.init(Mid, In, key=k2, init_scale=0.02)
return Module(first, second)

Layers = hax.Axis("Layers", 2)
Expand Down
78 changes: 76 additions & 2 deletions tests/test_new_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,36 @@ def test_structured_batches_model_axis_1_with_names():
assert len(batches) == 10


@skip_if_not_enough_devices(2)
def test_structured_batches_model_axis_1_non_divisible_batch_size():
devices = jax.devices()
model_axis_size = 1

mesh = Mesh(
np.array(devices).reshape(-1, model_axis_size),
(ResourceAxis.DATA, ResourceAxis.MODEL),
)
with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA}):
Height = Axis("Height", 16)
Width = Axis("Width", 16)
dataset = StructuredDatasetWithNames(Height, Width, 0, 10, 1)
loader = DataLoader(
dataset, 1, max_buffered_batches=0, mesh=mesh, axis_resources=None, allow_nondivisible_batch_size=True
)

batches = list(loader)
for batch in batches:
check_sharded_consistency(batch, check_disjoint_indices_are_different=False)

# check that it crashes if allow_non_divisible_batch_size is False
with pytest.raises(ValueError):
DataLoader(
dataset, 1, max_buffered_batches=0, mesh=mesh, axis_resources=None, allow_nondivisible_batch_size=False
)

assert len(batches) == 10


@skip_if_not_enough_devices(2)
def test_structured_batches_model_axis_2_with_names():
devices = jax.devices()
Expand Down Expand Up @@ -265,15 +295,59 @@ def test_loader_with_batch_scheduler(model_axis_size):
with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA}):
seq_len = 128
cache = _small_dataset(seq_len, num_sequences=1000)
loader = DataLoader(cache, schedule, max_buffered_batches=10, mesh=mesh, axis_resources=None)
loader = DataLoader(
cache, schedule, max_buffered_batches=10, mesh=mesh, axis_resources=None, pad_final_batch=False
)

for step, batch in enumerate(loader):
if step < 10:
assert len(batch) == 8
elif step < 20:
assert len(batch) == 16
else:
assert len(batch) == 32
assert len(batch) == 32, f"step: {step} len: {len(batch)}"

# total steps: 10 * 8 + 10 * 16 = 240, (1000 - 240) // 32 = 23
assert step == 20 + 22


@pytest.mark.parametrize("model_axis_size", [1, 2])
def test_padded_final_batch(model_axis_size):
schedule = [ScheduleStep(until=10, value=8), ScheduleStep(until=20, value=16), ScheduleStep(until=-1, value=32)]

if len(jax.devices()) % model_axis_size != 0:
pytest.skip("This test requires the number of devices to divide model_axis_size")

if 32 % (len(jax.devices()) // model_axis_size) != 0:
pytest.skip("This test requires the number of devices to divide 32")

devices = jax.devices()

mesh = Mesh(
np.array(devices).reshape(-1, model_axis_size),
(ResourceAxis.DATA, ResourceAxis.MODEL),
)

with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA}):
seq_len = 128
cache = _small_dataset(seq_len, num_sequences=1007)
loader = DataLoader(
cache, schedule, max_buffered_batches=10, mesh=mesh, axis_resources=None, pad_final_batch=True
)

for step, batch in enumerate(loader):
if step < 10:
assert len(batch) == 8
elif step < 20:
assert len(batch) == 16
else:
assert len(batch) == 32

# total steps: 10 * 8 + 10 * 16 = 240, (1007 - 240) // 32 = 23
assert step == 20 + 23

# last batch should be padded
assert len(batch) == 32
# ensure all the padded examples are all 0's
num_padding = 32 - (1007 - 240) % 32
assert np.all(batch[-num_padding:] == 0)
3 changes: 2 additions & 1 deletion tests/test_train_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from test_utils import skip_if_no_soundlibs


@pytest.mark.skip
@pytest.mark.entry
@skip_if_no_soundlibs
def test_train_asr():
# just testing if train_lm has a pulse
# just testing if train_asr has a pulse
with tempfile.TemporaryDirectory() as tmpdir:
data_config = tiny_test_corpus.tiny_asr_corpus_config(tmpdir)
try:
Expand Down

0 comments on commit 7b91583

Please sign in to comment.