From 6a20f0d98948fefbdd66c5d3d3f04981f89f12b6 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 27 Jan 2025 21:48:34 -0800 Subject: [PATCH 01/10] fix bad docs in Getting-Started (#867) Fixes #847 --- docs/Getting-Started-TPU-VM.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/Getting-Started-TPU-VM.md b/docs/Getting-Started-TPU-VM.md index 53aaed218..a60ba5b93 100644 --- a/docs/Getting-Started-TPU-VM.md +++ b/docs/Getting-Started-TPU-VM.md @@ -101,8 +101,8 @@ accel_env: docker_repository: levanter # default zone: us-west4-a # if not set, will use your default zone -tpu_name: test-spin-up-32 -tpu_type: "v5litepod-16" +tpu: test-spin-up-32 # name of the TPU +tpu_type: "v4-16" capacity_type: "preemptible" subnetwork: "default" # default EOF @@ -319,11 +319,12 @@ This will show you a list of files and directories in the repo, sorted by size, ### Automatic Setup +!!! warning + This approach is deprecated and will be removed in the future. Please use `launch.py` or `launch_on_ray.py` instead. + You can use `infra/spin-up-vm.sh` to create a TPU VM instance. In addition to creating the instance, it will set up the venv on each worker, and it will clone the repo to `~/levanter/`. -**For Public Users**: - ```bash bash infra/spin-up-vm.sh -z -t -n [--preemptible] [--use-alpha] ``` From 04f434292b3c5c6bbda4a053fc1305cf850ea707 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 27 Jan 2025 21:52:28 -0800 Subject: [PATCH 02/10] increase timeout to bypass initialization issues --- src/levanter/distributed.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/levanter/distributed.py b/src/levanter/distributed.py index 6efd9f0cb..8537620fc 100644 --- a/src/levanter/distributed.py +++ b/src/levanter/distributed.py @@ -320,7 +320,9 @@ def initialize(self): if coordinator_address is None: coordinator_address = LevanterSlurmCluster.get_coordinator_address() - jax.distributed.initialize(coordinator_address, self.num_processes, self.process_id, device_ids) + jax.distributed.initialize( + coordinator_address, self.num_processes, self.process_id, device_ids, initialization_timeout=30 * 60 + ) logger.info( f"Initialized jax.distributed with {jax.device_count()} devices, {jax.process_count()} processes," f" coordinator_address={coordinator_address}, process_id={self.process_id}, my" From d9a0d57cf14c5beb6f76601ec49bef90a4554927 Mon Sep 17 00:00:00 2001 From: Suhas Kotha <38450656+kothasuhas@users.noreply.github.com> Date: Tue, 28 Jan 2025 22:36:19 -0800 Subject: [PATCH 03/10] Supporting varied mixtures over training (#868) ## Description Currently, LM mixture dataset can only handle a static mixture over the course of training. This PR enables varying this mixture over datasets over the course of training. The user can now specify a list of stages and the sequence index at which each should start. Internally, we identify a training block to its stage, which defines its mixing weights. To efficiently translate a data point's index within a block to a respective source dataset, we precompute prefix sums that track how many data points are seen by previous stages. ## Fixes Issues https://github.com/stanford-crfm/marin/issues/81 ## Unit test coverage There are new unit tests in `test_varying_mixture.py` to ensure that the varying mixture behaves as expected. ## Known breaking changes/behaviors The design enables traditional usage of the MixtureDataset class. However, some of the private quantities are different (i.e. the expected counts per block now depends on the block and is not a member variable). To my knowledge, these variables are not accessed outside of tests. ## Additional context I have some changes I want to make Marin to enable usage of this new functionality, though these updates are modular and can be seperate PR's. I have spot-checked that training proceeds as expected with this test. This is my first PR so feedback is appreciated :)) --- src/levanter/data/mixture.py | 100 +++++++++++++++++++----- src/levanter/data/text.py | 45 +++++++---- tests/test_mixture.py | 4 +- tests/test_varying_mixture.py | 142 ++++++++++++++++++++++++++++++++++ 4 files changed, 256 insertions(+), 35 deletions(-) create mode 100644 tests/test_varying_mixture.py diff --git a/src/levanter/data/mixture.py b/src/levanter/data/mixture.py index 63c623e4b..188e5e426 100644 --- a/src/levanter/data/mixture.py +++ b/src/levanter/data/mixture.py @@ -1,6 +1,6 @@ import asyncio import warnings -from typing import Mapping, Optional, Sequence, TypeVar +from typing import List, Mapping, Optional, Sequence, Tuple, TypeVar import jax import numpy as np @@ -30,12 +30,12 @@ class MixtureDataset(AsyncDataset[T]): according to the weights. Creating a random-access MixtureDataset is challenging because we need to keep track of the current index of each - dataset. So solve this, we instead use "block-deterministic" mixtures, where the number of samples from each dataset - in each block is always identical (and we shuffle the order of the dataset ids in each block). + dataset. To solve this, we instead use "block-deterministic" mixtures, where the number of samples from each dataset + in each block is always identical (and we shuffle the order of the dataset ids in each block). To handle the case where the dataset mixture changes over time, we use a list of stages and precompute statistics to accurately compute the index of each dataset in each block. Args: datasets: A dict of datasets, where the key is the name of the dataset and the value is the dataset itself - weights: weights for each dataset + weights: Weights for each dataset. This can be provided in a list of stages, where each stage is a tuple of (start_seq_index, weights). Note that start_seq_index corresponds to the sequence index at which the weights should change, not the training batch index. stop_strategy: strategy for stopping the iteration, by default RESTART_STRATEGY. (Currently only RESTART_STRATEGY is supported) - FIRST_STOP_STRATEGY: stop when one dataset has been exhausted - ALL_STOP_STRATEGY: stop when all datasets have been exhausted @@ -46,7 +46,7 @@ class MixtureDataset(AsyncDataset[T]): def __init__( self, datasets: Mapping[str, AsyncDataset[T]], - weights: dict[str, float], + weights: dict[str, float] | List[Tuple[int, dict[str, float]]], block_size: int, *, randomize_blocks: bool = True, @@ -54,8 +54,30 @@ def __init__( stop_strategy: str = StopStrategy.RESTART_STRATEGY, ): super().__init__() - self.weights = MixtureDataset._normalize_weights(weights) - self.datasets = {name: dataset for name, dataset in datasets.items() if self.weights.get(name, 0) > 0} + if isinstance(weights, dict): + weight_stages = [(0, weights)] + else: + weight_stages = weights + + # assert that steps are in sorted order and that the start index of each stage is a multiple of block_size + for i, (start_seq_index, _) in enumerate(weight_stages): + if i == 0: + assert start_seq_index == 0 + else: + assert start_seq_index % block_size == 0, ( + f"start_seq_index for a stage must be a multiple of block_size, got {start_seq_index=} and" + f" {block_size=}" + ) + assert start_seq_index > weight_stages[i - 1][0], f"Weights list must be sorted, got {weight_stages}" + + self.weight_stages = [ + (start_seq_index, self._normalize_weights(weights)) for start_seq_index, weights in weight_stages + ] + self.datasets = { + name: dataset + for name, dataset in datasets.items() + if any(weights.get(name, 0) > 0 for _, weights in self.weight_stages) + } self.dataset_index = Index(self.datasets.keys()) self.block_size = block_size # we pack index and ds id into a single 32 bit, so block size must be at most 2^16 @@ -78,15 +100,38 @@ def __init__( self.stop_strategy = stop_strategy - self._counts_per_block = self._compute_expected_counts_per_block(block_size) - # precompute a list of ids for each block - # the ids contain both the dataset index and the index within the dataset - self._unpermuted_ids = self._compute_unpermuted_ids(self._counts_per_block) + # Initialize stage-related counts and IDs + ( + self._counts_per_block_per_stage, + self._counts_after_stage, + self._unpermuted_ids_per_stage, + ) = self._initialize_stage_counts() + + def _initialize_stage_counts(self): + counts_per_block_per_stage = [] + counts_after_stage = [] + unpermuted_ids_per_stage = [] + + cumulative_counts = np.zeros(len(self.datasets), dtype=np.int32) + + for stage_idx, (start_seq_index, stage_weights) in enumerate(self.weight_stages): + counts_this_stage = self._compute_expected_counts_per_block(stage_weights, self.block_size) + counts_per_block_per_stage.append(counts_this_stage) + unpermuted_ids_per_stage.append(self._compute_unpermuted_ids(counts_this_stage)) + + if stage_idx < len(self.weight_stages) - 1: + next_start = self.weight_stages[stage_idx + 1][0] + num_blocks_in_stage = (next_start - start_seq_index) // self.block_size + stage_total_counts = counts_this_stage * num_blocks_in_stage + cumulative_counts += stage_total_counts + counts_after_stage.append(cumulative_counts.copy()) - def _compute_expected_counts_per_block(self, block_size): + return counts_per_block_per_stage, counts_after_stage, unpermuted_ids_per_stage + + def _compute_expected_counts_per_block(self, weights: dict[str, float], block_size: int): _expected_values_per_block = np.zeros(len(self.datasets), dtype=np.int32) for i, dsname in enumerate(self.dataset_index): - _expected_values_per_block[i] = self.weights[dsname] * block_size + _expected_values_per_block[i] = weights.get(dsname, 0) * block_size # handle remainder by adding to the largest dataset largest_dataset = np.argmax(_expected_values_per_block) @@ -94,9 +139,9 @@ def _compute_expected_counts_per_block(self, block_size): # check if any dataset has 0 samples (and nonzero weight) for i, dsname in enumerate(self.dataset_index): - if _expected_values_per_block[i] == 0 and self.weights[dsname] > 0: + if _expected_values_per_block[i] == 0 and weights.get(dsname, 0) > 0: warnings.warn( - f"Dataset {dsname} has 0 samples in the block, but weight of {self.weights[dsname]}." + f"Dataset {dsname} has 0 samples in the block, but weight of {weights[dsname]}." " Recommend increasing block size." ) @@ -143,20 +188,35 @@ async def current_len(self) -> Optional[int]: raise NotImplementedError("Length is not known for other strategies") - @alru_cache + def _get_stage_for_block(self, block_id: int) -> int: + block_start = block_id * self.block_size + stage_starts = np.array([start for start, _ in self.weight_stages]) + return max(0, np.searchsorted(stage_starts, block_start, side="right") - 1) + + @alru_cache(maxsize=32) async def _get_block(self, index: int) -> Optional[np.ndarray]: + stage = self._get_stage_for_block(index) if not self.randomize_blocks: - return self._unpermuted_ids + return self._unpermuted_ids_per_stage[stage] - return np.array(_compute_block_assignment(self._unpermuted_ids, index, self.key)) + return np.array(_compute_block_assignment(self._unpermuted_ids_per_stage[stage], index, self.key)) - def _index_into_dataset_for_id(self, id: int, block_id) -> tuple[int, int]: + def _index_into_dataset_for_id(self, id: int, block_id: int) -> tuple[int, int]: + stage = self._get_stage_for_block(block_id) dataset_id = id >> 16 dataset_index = id & 0xFFFF - return dataset_id, dataset_index + block_id * self._counts_per_block[dataset_id] + + # Get the base offset from previous stages + base_offset = self._counts_after_stage[stage - 1][dataset_id] if stage > 0 else 0 + # Add offset within current stage + offset_in_stage = (block_id * self.block_size - self.weight_stages[stage][0]) // self.block_size + current_stage_offset = offset_in_stage * self._counts_per_block_per_stage[stage][dataset_id] + + return dataset_id, dataset_index + base_offset + current_stage_offset async def get_batch(self, indices: Sequence[int]) -> Sequence[T]: block_ids = np.array([idx // self.block_size for idx in indices]) + blocks = [self._get_block(block_id) for block_id in block_ids] blocks = await asyncio.gather(*blocks) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 13c7ea44b..8af05698c 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -1154,28 +1154,41 @@ def _convert_id_to_token(self, index: int) -> str: @dataclass class LMMixtureDatasetConfig(LMTaskConfig): - """This class represents a mixture of datasets with their associated weights.""" + """A mixture of language model datasets that supports dynamic weight changes during training. + + Weights can be specified either as a single dictionary for constant mixing ratios, + or as a list of (step, weights) tuples to change mixing ratios during training. + """ cache_dir: Optional[str] = "cache/" - # data source configs and weights configs: Dict[str, LMDatasetSourceConfig] = field(default_factory=dict) - """ configuration of each dataset source (urls, hf dataset id, etc.) """ - train_weights: Dict[str, float] = field(default_factory=dict) - """ weights for each dataset source. They will be normalized to sum to 1. """ + """ Configuration of each dataset source (urls, hf dataset id, etc.) """ + + train_weights: Union[Dict[str, float], List[Tuple[int, Dict[str, float]]]] = field(default_factory=dict) + """ Dataset mixing weights. Either a constant dict[name->weight] or list of (step, weights) tuples """ + stop_strategy: str = field(default=StopStrategy.RESTART_STRATEGY) mixture_block_size: int = 2048 - """ block size for the mixture dataset.""" + """ Block size for deterministic mixing """ def __post_init__(self): if len(self.configs) == 0: raise ValueError("At least one dataset must be provided") - if set(self.configs.keys()) != set(self.train_weights.keys()): - raise ValueError( - f"The keys in configs and weights must be the same;got {self.configs.keys()} and" - f" {self.train_weights.keys()}" - ) + if isinstance(self.train_weights, dict): + if not all(name in self.configs for name in self.train_weights): + raise ValueError( + f"Weight keys {self.train_weights.keys()} must be subset of config keys {self.configs.keys()}" + ) + elif isinstance(self.train_weights, list): + for step, weights in self.train_weights: + if not all(name in self.configs for name in weights): + raise ValueError( + f"Weight keys {weights.keys()} must be subset of config keys {self.configs.keys()}" + ) + else: + raise ValueError(f"Invalid train_weights type: {type(self.train_weights)}") def train_set( self, @@ -1218,7 +1231,7 @@ def shuffle_ds(ds, key): weights=self.train_weights, stop_strategy=self.stop_strategy, key=mix_key, - block_size=2048, + block_size=self.mixture_block_size, ) return mixture @@ -1248,9 +1261,13 @@ def build_caches( caches = {} for name, source_config in self.configs.items(): - weight = self.train_weights.get(name, 0) + # Skip datasets with zero weight in all stages + if isinstance(self.train_weights, dict): + has_nonzero_weight = self.train_weights.get(name, 0) > 0 + elif isinstance(self.train_weights, list): + has_nonzero_weight = any(weights.get(name, 0) > 0 for _, weights in self.train_weights) - if weight == 0 and split == "train": + if not has_nonzero_weight and split == "train": continue source_config_dict = dict(**source_config.__dict__) diff --git a/tests/test_mixture.py b/tests/test_mixture.py index e8821e24f..8ae6dbb1b 100644 --- a/tests/test_mixture.py +++ b/tests/test_mixture.py @@ -87,7 +87,9 @@ async def test_mixture_dataset_normalized_weights(): async def test_mixture_dataset_unpermuted_ids(): mixture_ds = MixtureDataset(datasets(), weights(), block_size=10, key=key()) - unpermuted_ids = mixture_ds._compute_unpermuted_ids(mixture_ds._counts_per_block) + unpermuted_ids = mixture_ds._compute_unpermuted_ids( + mixture_ds._compute_expected_counts_per_block(weights(), block_size()) + ) assert len(unpermuted_ids) == 10 assert unpermuted_ids[0] >> 32 in range(3) # Ensure the dataset ID is valid diff --git a/tests/test_varying_mixture.py b/tests/test_varying_mixture.py new file mode 100644 index 000000000..564225b4e --- /dev/null +++ b/tests/test_varying_mixture.py @@ -0,0 +1,142 @@ +import jax +import pytest + +from levanter.data import ListAsyncDataset, MixtureDataset + + +def create_datasets(): + ds1 = ListAsyncDataset([1, 2, 3, 4, 5]) + ds2 = ListAsyncDataset([10, 20, 30, 40, 50]) + ds3 = ListAsyncDataset([100, 200, 300, 400, 500]) + ds1.finalize() + ds2.finalize() + ds3.finalize() + return {"ds1": ds1, "ds2": ds2, "ds3": ds3} + + +@pytest.mark.asyncio +async def test_mixture_dataset_stage_transitions(): + datasets = create_datasets() + # Define three stages with different weights + stages = [ + (0, {"ds1": 1.0, "ds2": 0.0, "ds3": 0.0}), # Stage 1: only ds1 + (20, {"ds1": 0.0, "ds2": 1.0, "ds3": 0.0}), # Stage 2: only ds2 + (40, {"ds1": 0.0, "ds2": 0.0, "ds3": 1.0}), # Stage 3: only ds3 + ] + + mixture_ds = MixtureDataset(datasets, stages, block_size=10, key=jax.random.PRNGKey(42), randomize_blocks=False) + + # Test first stage (should only get values from ds1) + batch1 = await mixture_ds.get_batch(list(range(10))) + assert all(x in [1, 2, 3, 4, 5] for x in batch1), f"Unexpected values in first stage: {batch1}" + + # Test second stage (should only get values from ds2) + batch2 = await mixture_ds.get_batch(list(range(20, 30))) + assert all(x in [10, 20, 30, 40, 50] for x in batch2), f"Unexpected values in second stage: {batch2}" + + # Test third stage (should only get values from ds3) + batch3 = await mixture_ds.get_batch(list(range(40, 50))) + assert all(x in [100, 200, 300, 400, 500] for x in batch3), f"Unexpected values in third stage: {batch3}" + + +@pytest.mark.asyncio +async def test_mixture_dataset_gradual_transition(): + datasets = create_datasets() + # Define stages with gradual transitions + stages = [ + (0, {"ds1": 0.8, "ds2": 0.2, "ds3": 0.0}), # Mostly ds1 + (20, {"ds1": 0.2, "ds2": 0.6, "ds3": 0.2}), # Mostly ds2 + (40, {"ds1": 0.0, "ds2": 0.2, "ds3": 0.8}), # Mostly ds3 + ] + + mixture_ds = MixtureDataset(datasets, stages, block_size=10, key=jax.random.PRNGKey(42)) + + # Sample a large batch from each stage and verify proportions + def count_sources(batch): + counts = {"ds1": 0, "ds2": 0, "ds3": 0} + for x in batch: + if x < 10: + counts["ds1"] += 1 + elif x < 100: + counts["ds2"] += 1 + else: + counts["ds3"] += 1 + return counts + + # Test first stage + batch1 = await mixture_ds.get_batch(list(range(20))) + counts1 = count_sources(batch1) + assert counts1["ds1"] > counts1["ds2"] > counts1["ds3"], f"Unexpected distribution in first stage: {counts1}" + + # Test second stage + batch2 = await mixture_ds.get_batch(list(range(20, 40))) + counts2 = count_sources(batch2) + assert ( + counts2["ds2"] > counts2["ds1"] and counts2["ds2"] > counts2["ds3"] + ), f"Unexpected distribution in second stage: {counts2}" + + # Test third stage + batch3 = await mixture_ds.get_batch(list(range(40, 60))) + counts3 = count_sources(batch3) + assert counts3["ds3"] > counts3["ds2"] > counts3["ds1"], f"Unexpected distribution in third stage: {counts3}" + + +@pytest.mark.asyncio +async def test_mixture_dataset_invalid_stage_configurations(): + datasets = create_datasets() + + # Test stages that don't start at 0 + with pytest.raises(AssertionError): + MixtureDataset(datasets, [(10, {"ds1": 1.0})], block_size=10, key=jax.random.PRNGKey(42)) + + # Test stages with start indices not multiple of block_size + with pytest.raises(AssertionError): + MixtureDataset(datasets, [(0, {"ds1": 1.0}), (15, {"ds2": 1.0})], block_size=10, key=jax.random.PRNGKey(42)) + + # Test stages in wrong order + with pytest.raises(AssertionError): + MixtureDataset( + datasets, + [(0, {"ds1": 1.0}), (20, {"ds2": 1.0}), (10, {"ds3": 1.0})], + block_size=10, + key=jax.random.PRNGKey(42), + ) + + +@pytest.mark.asyncio +async def test_mixture_dataset_zero_weight_handling(): + datasets = create_datasets() + # Define stages where some datasets have zero weight + stages = [ + (0, {"ds1": 1.0, "ds2": 0.0, "ds3": 0.0}), + (20, {"ds1": 0.0, "ds2": 1.0, "ds3": 0.0}), + ] + + mixture_ds = MixtureDataset(datasets, stages, block_size=10, key=jax.random.PRNGKey(42), randomize_blocks=False) + + # Verify that zero-weight datasets are not sampled + batch1 = await mixture_ds.get_batch(list(range(10))) + assert all(x < 10 for x in batch1), f"Found samples from zero-weight datasets in first stage: {batch1}" + + batch2 = await mixture_ds.get_batch(list(range(20, 30))) + assert all(10 <= x < 100 for x in batch2), f"Found samples from zero-weight datasets in second stage: {batch2}" + + +@pytest.mark.asyncio +async def test_mixture_dataset_block_boundaries(): + datasets = create_datasets() + # Define stages with transition at block boundary + stages = [ + (0, {"ds1": 1.0, "ds2": 0.0, "ds3": 0.0}), + (10, {"ds1": 0.0, "ds2": 1.0, "ds3": 0.0}), + ] + + mixture_ds = MixtureDataset(datasets, stages, block_size=10, key=jax.random.PRNGKey(42), randomize_blocks=False) + + # Test the boundary between stages + batch = await mixture_ds.get_batch(list(range(5, 15))) # Should span both stages + first_half = batch[:5] + second_half = batch[5:] + + assert all(x < 10 for x in first_half), f"Unexpected values at end of first stage: {first_half}" + assert all(10 <= x < 100 for x in second_half), f"Unexpected values at start of second stage: {second_half}" From 0ad8c54ee0fd1da26edcd8f88fe4e059117c0e7c Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 29 Jan 2025 13:12:51 -0800 Subject: [PATCH 04/10] you kill, not cancel, actors (#871) --- src/levanter/infra/ray_tpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index 86ce4223a..fa8970e4a 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -172,7 +172,7 @@ def do_run(self, remote_fn, coordinator_ip, slice_id, num_slices) -> _TpuRunResu logger.exception(e) for actor in actors: try: - ray.cancel(actor) + ray.kill(actor) except Exception: logger.exception("Failed to kill actor after primary failure") return futures From 04a81ca2957a1a02ee3bec97f129c09cb442a2ca Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 29 Jan 2025 15:46:40 -0800 Subject: [PATCH 05/10] Use int64 for prp (#870) fixes #869 --- src/levanter/data/_prp.py | 48 +++++++++++++++----------------- src/levanter/data/permutation.py | 12 +++++--- tests/test_prp.py | 14 ++++++++-- 3 files changed, 42 insertions(+), 32 deletions(-) diff --git a/src/levanter/data/_prp.py b/src/levanter/data/_prp.py index 2abc43b80..e30f86b6d 100644 --- a/src/levanter/data/_prp.py +++ b/src/levanter/data/_prp.py @@ -1,18 +1,16 @@ import typing -import jax.lax import jax.numpy as jnp import jax.random as jrandom import numpy as np -# TODO: do we make this a pytree class Permutation: # Pseudo-Random Permutation Code """A stateless pseudo-random permutation. This class generates a pseudo-random permutation of a given length. The permutation is generated using a PRNG - with a fixed key. The permutation is generated by finding a random `a` and `b` such that `gcd(a, length) != 1` and + with a fixed key. The permutation is generated by finding a random `a` and `b` such that `gcd(a, length) == 1` and then computing the permutation as `p(x) = (a * x + b) % length`. This is not a very good PRP, but it is probably good enough for our purposes. @@ -21,40 +19,40 @@ class Permutation: def __init__(self, length, prng_key): self.length = length - self.prng_key = prng_key - a_key, b_key = jrandom.split(prng_key) - self._a = jrandom.randint(a_key, (), 1, length) - self._b = jrandom.randint(b_key, (), 0, length) + # Convert jax.random.PRNGKey to numpy.random.Generator + self.rng = np.random.Generator(np.random.PCG64(jrandom.randint(prng_key, (), 0, 2**30).item())) + self.a, self.b = self._generate_permutation_params() # Generate a and b in init - cond = lambda a_and_key: jnp.all(jnp.gcd(a_and_key[0], length) != 1) + def _generate_permutation_params(self): + length = self.length + rng = self.rng - def loop_body(a_and_key): - a, key = a_and_key - this_key, key = jrandom.split(key) - a = jrandom.randint(this_key, (), 1, length) - return a, key + if length == 1: + return 1, 0 - self._a, key = jax.lax.while_loop(cond, loop_body, (self._a, a_key)) + while True: + a = rng.integers(1, length) + if np.gcd(a, length) == 1: + break - self._a = int(self._a) - self._b = int(self._b) + b = rng.integers(0, length) # b can be in [0, length-1] + return a, b @typing.overload def __call__(self, indices: int) -> int: ... @typing.overload - def __call__(self, indices: jnp.ndarray) -> jnp.ndarray: + def __call__(self, indices: np.ndarray) -> np.ndarray: ... def __call__(self, indices): + a = self.a + b = self.b + length = self.length + was_int = False - if isinstance(indices, jnp.ndarray): - # TODO: use error_if? - # import equinox as eqx - if jnp.any(indices < 0) or jnp.any(indices >= self.length): - raise IndexError(f"index {indices} is out of bounds for length {self.length}") - elif isinstance(indices, np.ndarray): + if isinstance(indices, np.ndarray | jnp.ndarray): if np.any(indices < 0) or np.any(indices >= self.length): raise IndexError(f"index {indices} is out of bounds for length {self.length}") else: @@ -64,9 +62,7 @@ def __call__(self, indices): indices = np.array(indices) was_int = True - old_settings = np.seterr(over="raise") - out = (self._a * indices + self._b) % self.length - np.seterr(**old_settings) + out = (a * indices + b) % length # Compute permutation on-the-fly if was_int: return int(out) diff --git a/src/levanter/data/permutation.py b/src/levanter/data/permutation.py index 66a1887fd..ac20a3581 100644 --- a/src/levanter/data/permutation.py +++ b/src/levanter/data/permutation.py @@ -41,7 +41,9 @@ async def getitem_async(self, index: int) -> T_co: async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: permutation = await self._get_permutation() - return await self.dataset.get_batch([permutation(i) for i in indices]) + return await self.dataset.get_batch( + [int(permutation(i)) for i in indices] + ) # cast to int to be sure it's python int async def _get_permutation(self): if self._permutation is None: @@ -83,10 +85,10 @@ async def gen_era_permutation(era: int) -> Permutation: # TODO: support epochs # edge case: final era may be shorter than era_length current_len = await self.dataset.wait_until_len_at_least((era + 1) * self.era_length) - era_length = min(self.era_length, current_len - era * self.era_length) + era_length_val = min(self.era_length, current_len - era * self.era_length) mix_key = jax.random.fold_in(key, era) - return Permutation(era_length, mix_key) + return Permutation(era_length_val, mix_key) self.gen_era_permutation = gen_era_permutation @@ -95,7 +97,9 @@ async def _get_index(self, idx: int) -> int: raise ValueError("Negative indices are not supported") era = idx // self.era_length permutation = await self.gen_era_permutation(era) - return permutation(idx - era * self.era_length) + era * self.era_length + out = permutation(idx - era * self.era_length) + era * self.era_length + + return out async def async_len(self) -> int: return await self.dataset.async_len() diff --git a/tests/test_prp.py b/tests/test_prp.py index 6c549eabf..8a0bcac94 100644 --- a/tests/test_prp.py +++ b/tests/test_prp.py @@ -10,8 +10,8 @@ def test_permutation_creates_valid_instance(): prng_key = jrandom.PRNGKey(0) permutation = Permutation(length, prng_key) assert permutation.length == length - assert permutation._a > 0 and permutation._a < length - assert permutation._b >= 0 and permutation._b < length + assert 0 < permutation.a < length + assert 0 <= permutation.b < length def test_permutation_with_single_index_returns_correct_value(): @@ -85,3 +85,13 @@ def test_permutation_is_deterministic1(): permutation = Permutation(length, prng_key) results2 = permutation(indices) assert jnp.all(results == results2) + + +def test_permutation_handles_large_length_no_overflow(): + large_length = 2**34 + prng_key = jrandom.PRNGKey(0) + permutation = Permutation(large_length, prng_key) + index = 2**32 # A large index within the range + result = permutation(index) + assert isinstance(result, int) + assert 0 <= result < large_length From 1d216d1f05db4ce5b3a62da232bb140c3098df7f Mon Sep 17 00:00:00 2001 From: William Held Date: Sat, 1 Feb 2025 02:45:17 -0600 Subject: [PATCH 06/10] Adds Ability to Sub-Sample Data for Data Constrained Scaling Law Experiments (#872) ![image](https://github.com/user-attachments/assets/cc381e88-79b4-4810-bb66-351ddf7c3b04) Allows mixture datasets to specify a target budget and a experiment budget. This then computes what percentage of the data to sample overall in order to enable data constrained experiments like the above figure. --- src/levanter/data/dataset.py | 54 ++++++++++++++++++++++++++++++++++++ src/levanter/data/text.py | 22 +++++++++++++++ tests/test_mixture.py | 30 ++++++++++++++++++++ 3 files changed, 106 insertions(+) diff --git a/src/levanter/data/dataset.py b/src/levanter/data/dataset.py index f448ed83b..86e8c78d6 100644 --- a/src/levanter/data/dataset.py +++ b/src/levanter/data/dataset.py @@ -122,6 +122,9 @@ def map(self, fn: MapFunction[U], *extra_args, **extra_kwargs) -> "MappedAsyncDa def map_batches(self, fn: MapFunction[Sequence[U]], *extra_args, **extra_kwargs) -> "BatchMappedAsyncDataset[U]": return BatchMappedAsyncDataset(self, fn, *extra_args, **extra_kwargs) + def slice_dataset(self, start_index: Optional[int] = None, end_index: Optional[int] = None): + return SlicedAsyncDataset(self, start_index, end_index) + def shuffle(self, key: PRNGKey): import levanter.data.permutation as permutation @@ -375,6 +378,57 @@ def _call_fn(self, index, item): return self.fn(item, *self._extra_args, **kwargs) +class SlicedAsyncDataset(AsyncDataset[U]): + def __init__( + self, + dataset: AsyncDataset[U], + start_index: Optional[int] = None, + end_index: Optional[int] = None, + ): + super().__init__() + if start_index is None: + start_index = 0 + if end_index is not None and start_index > end_index: + raise ValueError("End index must come after start index.") + + self.start_index = start_index + self.end_index = end_index + self.dataset = dataset + self._min_known_len = dataset._min_known_len if end_index is None else (end_index - start_index) + + async def get_batch(self, indices: Sequence[int]) -> Sequence[U]: + shifted_indices = [(index + self.start_index) for index in indices] + max_index = max(shifted_indices) + + if self.end_index is not None and max_index > self.end_index: + raise ValueError("Requested indices beyond the end of the dataset") + + return await self.dataset.get_batch(shifted_indices) + + async def async_len(self) -> int: + underlying_length = await self.dataset.async_len() + if self.end_index is None: + return underlying_length - self.start_index + else: + return self.end_index - self.start_index + + async def final_length_is_known(self) -> bool: + underlying_is_known = await self.dataset.final_length_is_known() + return underlying_is_known and self.end_index is not None + + def is_finite(self) -> bool: + return self.dataset.is_finite() and self.end_index is not None + + async def current_len(self) -> Optional[int]: + underlying_length = await self.dataset.current_len() + if self.end_index is not None: + return self.end_index - self.start_index + elif underlying_length is not None: + return underlying_length - self.start_index + else: + return underlying_length + + class BatchMappedAsyncDataset(AsyncDataset[U]): """ A dataset that applies a function to each batch of items in the dataset. diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 8af05698c..6446ad45f 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -1169,6 +1169,11 @@ class LMMixtureDatasetConfig(LMTaskConfig): """ Dataset mixing weights. Either a constant dict[name->weight] or list of (step, weights) tuples """ stop_strategy: str = field(default=StopStrategy.RESTART_STRATEGY) + + # Configuration for Simulated Epoching + target_budget: Optional[int] = None + experiment_budget: Optional[int] = None + mixture_block_size: int = 2048 """ Block size for deterministic mixing """ @@ -1226,6 +1231,23 @@ def shuffle_ds(ds, key): out_token_datasets[name] = shuffle_ds(ds, next(key_iter)) token_datasets = out_token_datasets + if ( + self.experiment_budget is not None and self.target_budget is not None + ) and self.experiment_budget > self.target_budget: + raise ValueError( + f"Experiment budget should be smaller than target budget, got {self.experiment_budget} >" + f" {self.target_budget}" + ) + if self.experiment_budget is not None and self.target_budget is not None: + simulated_data_ratio = self.experiment_budget / self.target_budget + sliced_token_datasets: Dict[str, TokenSeqDataset] = {} + for name, ds in token_datasets.items(): + # Note(Will): This blocks on datasets being fully processed even for small simulated runs making simulating data size slightly latency inducing but I think that's ok + true_length_of_dataset = len(ds.as_sync_dataset()) + simulated_length_of_dataset = int(true_length_of_dataset * simulated_data_ratio) + sliced_token_datasets[name] = ds.slice_dataset(end_index=simulated_length_of_dataset) + token_datasets = sliced_token_datasets + mixture = MixtureDataset( datasets=token_datasets, weights=self.train_weights, diff --git a/tests/test_mixture.py b/tests/test_mixture.py index 8ae6dbb1b..52652f380 100644 --- a/tests/test_mixture.py +++ b/tests/test_mixture.py @@ -73,6 +73,36 @@ async def test_mixture_dataset_stop_strategy_restart(): await mixture_ds.async_len() +@pytest.mark.asyncio +async def test_mixture_dataset_simulated_data_size(): + weights = {"ds1": 1 / 3, "ds2": 1 / 3, "ds3": 1 / 3} + mixture_ds = MixtureDataset( + {name: dataset.slice_dataset(end_index=1) for name, dataset in datasets().items()}, + weights, + block_size=10, + key=key(), + randomize_blocks=False, + stop_strategy=StopStrategy.RESTART_STRATEGY, + ) + for _ in range(10): + batch = await mixture_ds.get_batch([0, 1, 2]) + assert len(batch) == 3 + assert all(item in [1, 10, 100] for item in batch) + + mixture_ds = MixtureDataset( + {name: dataset.slice_dataset(end_index=2) for name, dataset in datasets().items()}, + weights, + block_size=10, + key=key(), + randomize_blocks=False, + stop_strategy=StopStrategy.RESTART_STRATEGY, + ) + for _ in range(10): + batch = await mixture_ds.get_batch([0, 1, 2]) + assert len(batch) == 3 + assert all(item in [1, 2, 10, 20, 100, 200] for item in batch) + + @pytest.mark.asyncio async def test_mixture_dataset_normalized_weights(): weights = {"ds1": 0, "ds2": 0.5, "ds3": 0.5} From 59d21380a5790b00e3fb34b3029f2b36fae84e12 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 3 Feb 2025 14:07:35 -0800 Subject: [PATCH 07/10] ray_tpu: Catch preemption better, don't hang if there's a c level abort (#877) --- src/levanter/infra/ray_tpu.py | 58 ++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index fa8970e4a..d220ec982 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -8,6 +8,7 @@ import tempfile import time from dataclasses import dataclass +from queue import Empty as QueueEmpty from typing import Callable, Optional, Sequence import draccus @@ -89,24 +90,24 @@ def do_run(remote_fn) -> _TpuRunResult: logger.info("TPU job finished") return TpuSuccess(info, out) except RayError as e: - for f in futures: - try: - ray.cancel(f) - except Exception: - logger.exception("Failed to kill job after primary failure") + _cancel_all_futures(futures) return _handle_ray_error(info, e) except Exception as e: - for f in futures: - try: - ray.cancel(f) - except Exception: - logger.exception("Failed to kill job after primary failure") + _cancel_all_futures(futures) return TpuFailed(info, e) return do_run.remote(remote_fn) -def run_on_pod_multislice(remote_fn: RemoteFunction | Callable, tpu_type: str, num_slices: int) -> ray.ObjectRef: +def _cancel_all_futures(futures): + for f in futures: + try: + ray.cancel(f) + except Exception: + logger.exception("Failed to kill job after primary failure") + + +def run_on_pod_multislice(remote_fn: RemoteFunction | Callable, tpu_type: str, num_slices: int) -> list[ray.ObjectRef]: """ Run a remote function on multiple TPU slices. @@ -147,18 +148,12 @@ def do_run(self, remote_fn, coordinator_ip, slice_id, num_slices) -> _TpuRunResu logger.info("TPU job finished") return TpuSuccess(info, out) except RayError as e: - for f in futures: - try: - ray.cancel(f) - except Exception: - logger.exception("Failed to kill job after primary failure") + logger.exception(f"Ray error {e}. Killing futures for this slice") + _cancel_all_futures(futures) return _handle_ray_error(info, e) except Exception as e: - for f in futures: - try: - ray.cancel(f) - except Exception: - logger.exception("Failed to kill job after primary failure") + logger.exception(f"Exception {e}") + _cancel_all_futures(futures) return TpuFailed(info, e) actors = [MultisliceActor.remote() for _ in range(num_slices)] # type: ignore @@ -310,6 +305,16 @@ def run_on_pod_multislice_resumable( futures = run_on_pod_multislice(remote_fn, tpu_type, num_slices) try: outs = ray.get(futures) + except ray.exceptions.ActorUnavailableError as e: + problem = e + num_preemptions += 1 + logger.warning(f"Preempted {num_preemptions} times, {e}") + continue + except ray.exceptions.ActorDiedError as e: + problem = e + num_preemptions += 1 + logger.warning(f"Preempted {num_preemptions} times, {e}") + continue except ray.exceptions.RayTaskError as e: for f in futures: try: @@ -425,6 +430,9 @@ def _handle_ray_error(tpu_info: _TpuInfo, e: RayError): if isinstance(e, NodeDiedError): logger.exception("Node died", exc_info=e) return TpuPreempted(tpu_info, e) + elif isinstance(e, ray.exceptions.ActorUnavailableError | ray.exceptions.ActorDiedError): + logger.exception("Actor died", exc_info=e) + return TpuPreempted(tpu_info, e) elif isinstance(e, WorkerCrashedError): logger.exception("Worker crashed", exc_info=e) return TpuPreempted(tpu_info, e) @@ -506,7 +514,13 @@ def target_fn(queue, args, kwargs): process.join() # Retrieve the result or error from the queue - success, value = queue.get() + logger.info("Process finished") + try: + success, value = queue.get(timeout=10) + except QueueEmpty: + logger.error("Process timed out") + process.terminate() + raise RuntimeError("Process timed out") if success: return value From 86543ca40ff53688b9debc99a32c732356eb2fd3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 3 Feb 2025 14:20:01 -0800 Subject: [PATCH 08/10] Update gcsfs requirement from <2025,>=2024.2 to >=2024.2,<2026 (#878) Updates the requirements on [gcsfs](https://github.com/fsspec/gcsfs) to permit the latest version.
Commits

Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 861b222f3..23fc44553 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "pyarrow>=11.0.0", "zstandard>=0.20.0", "datasets>=3.1.0,<4.0", - "gcsfs>=2024.2,<2025", + "gcsfs>=2024.2,<2026", "braceexpand>=0.1.7", "jmp>=0.0.3", "fsspec[http]>=2024.2,<2025", From e6e300e4bcd1b6c0881f2845510a098aa2105259 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 3 Feb 2025 14:20:13 -0800 Subject: [PATCH 09/10] Update fsspec[http] requirement from <2025,>=2024.2 to >=2024.2,<2026 (#879) Updates the requirements on [fsspec[http]](https://github.com/fsspec/filesystem_spec) to permit the latest version.
Commits

Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 23fc44553..7de2989cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ "gcsfs>=2024.2,<2026", "braceexpand>=0.1.7", "jmp>=0.0.3", - "fsspec[http]>=2024.2,<2025", + "fsspec[http]>=2024.2,<2026", "tensorstore>=0.1.65", "pytimeparse>=1.1.8", "humanfriendly==10.0", From e49702bc8fbf276d36549c1fbcd7e306bec6254d Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Mon, 3 Feb 2025 14:20:41 -0800 Subject: [PATCH 10/10] change docs to work with yaml (#875) Fix minor bug in yaml config --- docs/Getting-Started-TPU-VM.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/Getting-Started-TPU-VM.md b/docs/Getting-Started-TPU-VM.md index a60ba5b93..8ae148692 100644 --- a/docs/Getting-Started-TPU-VM.md +++ b/docs/Getting-Started-TPU-VM.md @@ -102,7 +102,6 @@ accel_env: docker_repository: levanter # default zone: us-west4-a # if not set, will use your default zone tpu: test-spin-up-32 # name of the TPU -tpu_type: "v4-16" capacity_type: "preemptible" subnetwork: "default" # default EOF