diff --git a/docs/Getting-Started-TPU-VM.md b/docs/Getting-Started-TPU-VM.md index 53aaed218..8ae148692 100644 --- a/docs/Getting-Started-TPU-VM.md +++ b/docs/Getting-Started-TPU-VM.md @@ -101,8 +101,7 @@ accel_env: docker_repository: levanter # default zone: us-west4-a # if not set, will use your default zone -tpu_name: test-spin-up-32 -tpu_type: "v5litepod-16" +tpu: test-spin-up-32 # name of the TPU capacity_type: "preemptible" subnetwork: "default" # default EOF @@ -319,11 +318,12 @@ This will show you a list of files and directories in the repo, sorted by size, ### Automatic Setup +!!! warning + This approach is deprecated and will be removed in the future. Please use `launch.py` or `launch_on_ray.py` instead. + You can use `infra/spin-up-vm.sh` to create a TPU VM instance. In addition to creating the instance, it will set up the venv on each worker, and it will clone the repo to `~/levanter/`. -**For Public Users**: - ```bash bash infra/spin-up-vm.sh -z -t -n [--preemptible] [--use-alpha] ``` diff --git a/pyproject.toml b/pyproject.toml index 861b222f3..7de2989cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,10 +36,10 @@ dependencies = [ "pyarrow>=11.0.0", "zstandard>=0.20.0", "datasets>=3.1.0,<4.0", - "gcsfs>=2024.2,<2025", + "gcsfs>=2024.2,<2026", "braceexpand>=0.1.7", "jmp>=0.0.3", - "fsspec[http]>=2024.2,<2025", + "fsspec[http]>=2024.2,<2026", "tensorstore>=0.1.65", "pytimeparse>=1.1.8", "humanfriendly==10.0", diff --git a/src/levanter/data/_prp.py b/src/levanter/data/_prp.py index 2abc43b80..e30f86b6d 100644 --- a/src/levanter/data/_prp.py +++ b/src/levanter/data/_prp.py @@ -1,18 +1,16 @@ import typing -import jax.lax import jax.numpy as jnp import jax.random as jrandom import numpy as np -# TODO: do we make this a pytree class Permutation: # Pseudo-Random Permutation Code """A stateless pseudo-random permutation. This class generates a pseudo-random permutation of a given length. The permutation is generated using a PRNG - with a fixed key. The permutation is generated by finding a random `a` and `b` such that `gcd(a, length) != 1` and + with a fixed key. The permutation is generated by finding a random `a` and `b` such that `gcd(a, length) == 1` and then computing the permutation as `p(x) = (a * x + b) % length`. This is not a very good PRP, but it is probably good enough for our purposes. @@ -21,40 +19,40 @@ class Permutation: def __init__(self, length, prng_key): self.length = length - self.prng_key = prng_key - a_key, b_key = jrandom.split(prng_key) - self._a = jrandom.randint(a_key, (), 1, length) - self._b = jrandom.randint(b_key, (), 0, length) + # Convert jax.random.PRNGKey to numpy.random.Generator + self.rng = np.random.Generator(np.random.PCG64(jrandom.randint(prng_key, (), 0, 2**30).item())) + self.a, self.b = self._generate_permutation_params() # Generate a and b in init - cond = lambda a_and_key: jnp.all(jnp.gcd(a_and_key[0], length) != 1) + def _generate_permutation_params(self): + length = self.length + rng = self.rng - def loop_body(a_and_key): - a, key = a_and_key - this_key, key = jrandom.split(key) - a = jrandom.randint(this_key, (), 1, length) - return a, key + if length == 1: + return 1, 0 - self._a, key = jax.lax.while_loop(cond, loop_body, (self._a, a_key)) + while True: + a = rng.integers(1, length) + if np.gcd(a, length) == 1: + break - self._a = int(self._a) - self._b = int(self._b) + b = rng.integers(0, length) # b can be in [0, length-1] + return a, b @typing.overload def __call__(self, indices: int) -> int: ... @typing.overload - def __call__(self, indices: jnp.ndarray) -> jnp.ndarray: + def __call__(self, indices: np.ndarray) -> np.ndarray: ... def __call__(self, indices): + a = self.a + b = self.b + length = self.length + was_int = False - if isinstance(indices, jnp.ndarray): - # TODO: use error_if? - # import equinox as eqx - if jnp.any(indices < 0) or jnp.any(indices >= self.length): - raise IndexError(f"index {indices} is out of bounds for length {self.length}") - elif isinstance(indices, np.ndarray): + if isinstance(indices, np.ndarray | jnp.ndarray): if np.any(indices < 0) or np.any(indices >= self.length): raise IndexError(f"index {indices} is out of bounds for length {self.length}") else: @@ -64,9 +62,7 @@ def __call__(self, indices): indices = np.array(indices) was_int = True - old_settings = np.seterr(over="raise") - out = (self._a * indices + self._b) % self.length - np.seterr(**old_settings) + out = (a * indices + b) % length # Compute permutation on-the-fly if was_int: return int(out) diff --git a/src/levanter/data/dataset.py b/src/levanter/data/dataset.py index f448ed83b..86e8c78d6 100644 --- a/src/levanter/data/dataset.py +++ b/src/levanter/data/dataset.py @@ -122,6 +122,9 @@ def map(self, fn: MapFunction[U], *extra_args, **extra_kwargs) -> "MappedAsyncDa def map_batches(self, fn: MapFunction[Sequence[U]], *extra_args, **extra_kwargs) -> "BatchMappedAsyncDataset[U]": return BatchMappedAsyncDataset(self, fn, *extra_args, **extra_kwargs) + def slice_dataset(self, start_index: Optional[int] = None, end_index: Optional[int] = None): + return SlicedAsyncDataset(self, start_index, end_index) + def shuffle(self, key: PRNGKey): import levanter.data.permutation as permutation @@ -375,6 +378,57 @@ def _call_fn(self, index, item): return self.fn(item, *self._extra_args, **kwargs) +class SlicedAsyncDataset(AsyncDataset[U]): + def __init__( + self, + dataset: AsyncDataset[U], + start_index: Optional[int] = None, + end_index: Optional[int] = None, + ): + super().__init__() + if start_index is None: + start_index = 0 + if end_index is not None and start_index > end_index: + raise ValueError("End index must come after start index.") + + self.start_index = start_index + self.end_index = end_index + self.dataset = dataset + self._min_known_len = dataset._min_known_len if end_index is None else (end_index - start_index) + + async def get_batch(self, indices: Sequence[int]) -> Sequence[U]: + shifted_indices = [(index + self.start_index) for index in indices] + max_index = max(shifted_indices) + + if self.end_index is not None and max_index > self.end_index: + raise ValueError("Requested indices beyond the end of the dataset") + + return await self.dataset.get_batch(shifted_indices) + + async def async_len(self) -> int: + underlying_length = await self.dataset.async_len() + if self.end_index is None: + return underlying_length - self.start_index + else: + return self.end_index - self.start_index + + async def final_length_is_known(self) -> bool: + underlying_is_known = await self.dataset.final_length_is_known() + return underlying_is_known and self.end_index is not None + + def is_finite(self) -> bool: + return self.dataset.is_finite() and self.end_index is not None + + async def current_len(self) -> Optional[int]: + underlying_length = await self.dataset.current_len() + if self.end_index is not None: + return self.end_index - self.start_index + elif underlying_length is not None: + return underlying_length - self.start_index + else: + return underlying_length + + class BatchMappedAsyncDataset(AsyncDataset[U]): """ A dataset that applies a function to each batch of items in the dataset. diff --git a/src/levanter/data/mixture.py b/src/levanter/data/mixture.py index 63c623e4b..188e5e426 100644 --- a/src/levanter/data/mixture.py +++ b/src/levanter/data/mixture.py @@ -1,6 +1,6 @@ import asyncio import warnings -from typing import Mapping, Optional, Sequence, TypeVar +from typing import List, Mapping, Optional, Sequence, Tuple, TypeVar import jax import numpy as np @@ -30,12 +30,12 @@ class MixtureDataset(AsyncDataset[T]): according to the weights. Creating a random-access MixtureDataset is challenging because we need to keep track of the current index of each - dataset. So solve this, we instead use "block-deterministic" mixtures, where the number of samples from each dataset - in each block is always identical (and we shuffle the order of the dataset ids in each block). + dataset. To solve this, we instead use "block-deterministic" mixtures, where the number of samples from each dataset + in each block is always identical (and we shuffle the order of the dataset ids in each block). To handle the case where the dataset mixture changes over time, we use a list of stages and precompute statistics to accurately compute the index of each dataset in each block. Args: datasets: A dict of datasets, where the key is the name of the dataset and the value is the dataset itself - weights: weights for each dataset + weights: Weights for each dataset. This can be provided in a list of stages, where each stage is a tuple of (start_seq_index, weights). Note that start_seq_index corresponds to the sequence index at which the weights should change, not the training batch index. stop_strategy: strategy for stopping the iteration, by default RESTART_STRATEGY. (Currently only RESTART_STRATEGY is supported) - FIRST_STOP_STRATEGY: stop when one dataset has been exhausted - ALL_STOP_STRATEGY: stop when all datasets have been exhausted @@ -46,7 +46,7 @@ class MixtureDataset(AsyncDataset[T]): def __init__( self, datasets: Mapping[str, AsyncDataset[T]], - weights: dict[str, float], + weights: dict[str, float] | List[Tuple[int, dict[str, float]]], block_size: int, *, randomize_blocks: bool = True, @@ -54,8 +54,30 @@ def __init__( stop_strategy: str = StopStrategy.RESTART_STRATEGY, ): super().__init__() - self.weights = MixtureDataset._normalize_weights(weights) - self.datasets = {name: dataset for name, dataset in datasets.items() if self.weights.get(name, 0) > 0} + if isinstance(weights, dict): + weight_stages = [(0, weights)] + else: + weight_stages = weights + + # assert that steps are in sorted order and that the start index of each stage is a multiple of block_size + for i, (start_seq_index, _) in enumerate(weight_stages): + if i == 0: + assert start_seq_index == 0 + else: + assert start_seq_index % block_size == 0, ( + f"start_seq_index for a stage must be a multiple of block_size, got {start_seq_index=} and" + f" {block_size=}" + ) + assert start_seq_index > weight_stages[i - 1][0], f"Weights list must be sorted, got {weight_stages}" + + self.weight_stages = [ + (start_seq_index, self._normalize_weights(weights)) for start_seq_index, weights in weight_stages + ] + self.datasets = { + name: dataset + for name, dataset in datasets.items() + if any(weights.get(name, 0) > 0 for _, weights in self.weight_stages) + } self.dataset_index = Index(self.datasets.keys()) self.block_size = block_size # we pack index and ds id into a single 32 bit, so block size must be at most 2^16 @@ -78,15 +100,38 @@ def __init__( self.stop_strategy = stop_strategy - self._counts_per_block = self._compute_expected_counts_per_block(block_size) - # precompute a list of ids for each block - # the ids contain both the dataset index and the index within the dataset - self._unpermuted_ids = self._compute_unpermuted_ids(self._counts_per_block) + # Initialize stage-related counts and IDs + ( + self._counts_per_block_per_stage, + self._counts_after_stage, + self._unpermuted_ids_per_stage, + ) = self._initialize_stage_counts() + + def _initialize_stage_counts(self): + counts_per_block_per_stage = [] + counts_after_stage = [] + unpermuted_ids_per_stage = [] + + cumulative_counts = np.zeros(len(self.datasets), dtype=np.int32) + + for stage_idx, (start_seq_index, stage_weights) in enumerate(self.weight_stages): + counts_this_stage = self._compute_expected_counts_per_block(stage_weights, self.block_size) + counts_per_block_per_stage.append(counts_this_stage) + unpermuted_ids_per_stage.append(self._compute_unpermuted_ids(counts_this_stage)) + + if stage_idx < len(self.weight_stages) - 1: + next_start = self.weight_stages[stage_idx + 1][0] + num_blocks_in_stage = (next_start - start_seq_index) // self.block_size + stage_total_counts = counts_this_stage * num_blocks_in_stage + cumulative_counts += stage_total_counts + counts_after_stage.append(cumulative_counts.copy()) - def _compute_expected_counts_per_block(self, block_size): + return counts_per_block_per_stage, counts_after_stage, unpermuted_ids_per_stage + + def _compute_expected_counts_per_block(self, weights: dict[str, float], block_size: int): _expected_values_per_block = np.zeros(len(self.datasets), dtype=np.int32) for i, dsname in enumerate(self.dataset_index): - _expected_values_per_block[i] = self.weights[dsname] * block_size + _expected_values_per_block[i] = weights.get(dsname, 0) * block_size # handle remainder by adding to the largest dataset largest_dataset = np.argmax(_expected_values_per_block) @@ -94,9 +139,9 @@ def _compute_expected_counts_per_block(self, block_size): # check if any dataset has 0 samples (and nonzero weight) for i, dsname in enumerate(self.dataset_index): - if _expected_values_per_block[i] == 0 and self.weights[dsname] > 0: + if _expected_values_per_block[i] == 0 and weights.get(dsname, 0) > 0: warnings.warn( - f"Dataset {dsname} has 0 samples in the block, but weight of {self.weights[dsname]}." + f"Dataset {dsname} has 0 samples in the block, but weight of {weights[dsname]}." " Recommend increasing block size." ) @@ -143,20 +188,35 @@ async def current_len(self) -> Optional[int]: raise NotImplementedError("Length is not known for other strategies") - @alru_cache + def _get_stage_for_block(self, block_id: int) -> int: + block_start = block_id * self.block_size + stage_starts = np.array([start for start, _ in self.weight_stages]) + return max(0, np.searchsorted(stage_starts, block_start, side="right") - 1) + + @alru_cache(maxsize=32) async def _get_block(self, index: int) -> Optional[np.ndarray]: + stage = self._get_stage_for_block(index) if not self.randomize_blocks: - return self._unpermuted_ids + return self._unpermuted_ids_per_stage[stage] - return np.array(_compute_block_assignment(self._unpermuted_ids, index, self.key)) + return np.array(_compute_block_assignment(self._unpermuted_ids_per_stage[stage], index, self.key)) - def _index_into_dataset_for_id(self, id: int, block_id) -> tuple[int, int]: + def _index_into_dataset_for_id(self, id: int, block_id: int) -> tuple[int, int]: + stage = self._get_stage_for_block(block_id) dataset_id = id >> 16 dataset_index = id & 0xFFFF - return dataset_id, dataset_index + block_id * self._counts_per_block[dataset_id] + + # Get the base offset from previous stages + base_offset = self._counts_after_stage[stage - 1][dataset_id] if stage > 0 else 0 + # Add offset within current stage + offset_in_stage = (block_id * self.block_size - self.weight_stages[stage][0]) // self.block_size + current_stage_offset = offset_in_stage * self._counts_per_block_per_stage[stage][dataset_id] + + return dataset_id, dataset_index + base_offset + current_stage_offset async def get_batch(self, indices: Sequence[int]) -> Sequence[T]: block_ids = np.array([idx // self.block_size for idx in indices]) + blocks = [self._get_block(block_id) for block_id in block_ids] blocks = await asyncio.gather(*blocks) diff --git a/src/levanter/data/permutation.py b/src/levanter/data/permutation.py index 66a1887fd..ac20a3581 100644 --- a/src/levanter/data/permutation.py +++ b/src/levanter/data/permutation.py @@ -41,7 +41,9 @@ async def getitem_async(self, index: int) -> T_co: async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: permutation = await self._get_permutation() - return await self.dataset.get_batch([permutation(i) for i in indices]) + return await self.dataset.get_batch( + [int(permutation(i)) for i in indices] + ) # cast to int to be sure it's python int async def _get_permutation(self): if self._permutation is None: @@ -83,10 +85,10 @@ async def gen_era_permutation(era: int) -> Permutation: # TODO: support epochs # edge case: final era may be shorter than era_length current_len = await self.dataset.wait_until_len_at_least((era + 1) * self.era_length) - era_length = min(self.era_length, current_len - era * self.era_length) + era_length_val = min(self.era_length, current_len - era * self.era_length) mix_key = jax.random.fold_in(key, era) - return Permutation(era_length, mix_key) + return Permutation(era_length_val, mix_key) self.gen_era_permutation = gen_era_permutation @@ -95,7 +97,9 @@ async def _get_index(self, idx: int) -> int: raise ValueError("Negative indices are not supported") era = idx // self.era_length permutation = await self.gen_era_permutation(era) - return permutation(idx - era * self.era_length) + era * self.era_length + out = permutation(idx - era * self.era_length) + era * self.era_length + + return out async def async_len(self) -> int: return await self.dataset.async_len() diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index b48d558de..c35bf2f52 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -1185,28 +1185,46 @@ def _convert_id_to_token(self, index: int) -> str: @dataclass class LMMixtureDatasetConfig(LMTaskConfig): - """This class represents a mixture of datasets with their associated weights.""" + """A mixture of language model datasets that supports dynamic weight changes during training. + + Weights can be specified either as a single dictionary for constant mixing ratios, + or as a list of (step, weights) tuples to change mixing ratios during training. + """ cache_dir: Optional[str] = "cache/" - # data source configs and weights configs: Dict[str, LMDatasetSourceConfig] = field(default_factory=dict) - """ configuration of each dataset source (urls, hf dataset id, etc.) """ - train_weights: Dict[str, float] = field(default_factory=dict) - """ weights for each dataset source. They will be normalized to sum to 1. """ + """ Configuration of each dataset source (urls, hf dataset id, etc.) """ + + train_weights: Union[Dict[str, float], List[Tuple[int, Dict[str, float]]]] = field(default_factory=dict) + """ Dataset mixing weights. Either a constant dict[name->weight] or list of (step, weights) tuples """ + stop_strategy: str = field(default=StopStrategy.RESTART_STRATEGY) + + # Configuration for Simulated Epoching + target_budget: Optional[int] = None + experiment_budget: Optional[int] = None + mixture_block_size: int = 2048 - """ block size for the mixture dataset.""" + """ Block size for deterministic mixing """ def __post_init__(self): if len(self.configs) == 0: raise ValueError("At least one dataset must be provided") - if set(self.configs.keys()) != set(self.train_weights.keys()): - raise ValueError( - f"The keys in configs and weights must be the same;got {self.configs.keys()} and" - f" {self.train_weights.keys()}" - ) + if isinstance(self.train_weights, dict): + if not all(name in self.configs for name in self.train_weights): + raise ValueError( + f"Weight keys {self.train_weights.keys()} must be subset of config keys {self.configs.keys()}" + ) + elif isinstance(self.train_weights, list): + for step, weights in self.train_weights: + if not all(name in self.configs for name in weights): + raise ValueError( + f"Weight keys {weights.keys()} must be subset of config keys {self.configs.keys()}" + ) + else: + raise ValueError(f"Invalid train_weights type: {type(self.train_weights)}") def train_set( self, @@ -1244,12 +1262,29 @@ def shuffle_ds(ds, key): out_token_datasets[name] = shuffle_ds(ds, next(key_iter)) token_datasets = out_token_datasets + if ( + self.experiment_budget is not None and self.target_budget is not None + ) and self.experiment_budget > self.target_budget: + raise ValueError( + f"Experiment budget should be smaller than target budget, got {self.experiment_budget} >" + f" {self.target_budget}" + ) + if self.experiment_budget is not None and self.target_budget is not None: + simulated_data_ratio = self.experiment_budget / self.target_budget + sliced_token_datasets: Dict[str, TokenSeqDataset] = {} + for name, ds in token_datasets.items(): + # Note(Will): This blocks on datasets being fully processed even for small simulated runs making simulating data size slightly latency inducing but I think that's ok + true_length_of_dataset = len(ds.as_sync_dataset()) + simulated_length_of_dataset = int(true_length_of_dataset * simulated_data_ratio) + sliced_token_datasets[name] = ds.slice_dataset(end_index=simulated_length_of_dataset) + token_datasets = sliced_token_datasets + mixture = MixtureDataset( datasets=token_datasets, weights=self.train_weights, stop_strategy=self.stop_strategy, key=mix_key, - block_size=2048, + block_size=self.mixture_block_size, ) return mixture @@ -1279,9 +1314,13 @@ def build_caches( caches = {} for name, source_config in self.configs.items(): - weight = self.train_weights.get(name, 0) + # Skip datasets with zero weight in all stages + if isinstance(self.train_weights, dict): + has_nonzero_weight = self.train_weights.get(name, 0) > 0 + elif isinstance(self.train_weights, list): + has_nonzero_weight = any(weights.get(name, 0) > 0 for _, weights in self.train_weights) - if weight == 0 and split == "train": + if not has_nonzero_weight and split == "train": continue source_config_dict = dict(**source_config.__dict__) diff --git a/src/levanter/distributed.py b/src/levanter/distributed.py index 6efd9f0cb..8537620fc 100644 --- a/src/levanter/distributed.py +++ b/src/levanter/distributed.py @@ -320,7 +320,9 @@ def initialize(self): if coordinator_address is None: coordinator_address = LevanterSlurmCluster.get_coordinator_address() - jax.distributed.initialize(coordinator_address, self.num_processes, self.process_id, device_ids) + jax.distributed.initialize( + coordinator_address, self.num_processes, self.process_id, device_ids, initialization_timeout=30 * 60 + ) logger.info( f"Initialized jax.distributed with {jax.device_count()} devices, {jax.process_count()} processes," f" coordinator_address={coordinator_address}, process_id={self.process_id}, my" diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index 86ce4223a..d220ec982 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -8,6 +8,7 @@ import tempfile import time from dataclasses import dataclass +from queue import Empty as QueueEmpty from typing import Callable, Optional, Sequence import draccus @@ -89,24 +90,24 @@ def do_run(remote_fn) -> _TpuRunResult: logger.info("TPU job finished") return TpuSuccess(info, out) except RayError as e: - for f in futures: - try: - ray.cancel(f) - except Exception: - logger.exception("Failed to kill job after primary failure") + _cancel_all_futures(futures) return _handle_ray_error(info, e) except Exception as e: - for f in futures: - try: - ray.cancel(f) - except Exception: - logger.exception("Failed to kill job after primary failure") + _cancel_all_futures(futures) return TpuFailed(info, e) return do_run.remote(remote_fn) -def run_on_pod_multislice(remote_fn: RemoteFunction | Callable, tpu_type: str, num_slices: int) -> ray.ObjectRef: +def _cancel_all_futures(futures): + for f in futures: + try: + ray.cancel(f) + except Exception: + logger.exception("Failed to kill job after primary failure") + + +def run_on_pod_multislice(remote_fn: RemoteFunction | Callable, tpu_type: str, num_slices: int) -> list[ray.ObjectRef]: """ Run a remote function on multiple TPU slices. @@ -147,18 +148,12 @@ def do_run(self, remote_fn, coordinator_ip, slice_id, num_slices) -> _TpuRunResu logger.info("TPU job finished") return TpuSuccess(info, out) except RayError as e: - for f in futures: - try: - ray.cancel(f) - except Exception: - logger.exception("Failed to kill job after primary failure") + logger.exception(f"Ray error {e}. Killing futures for this slice") + _cancel_all_futures(futures) return _handle_ray_error(info, e) except Exception as e: - for f in futures: - try: - ray.cancel(f) - except Exception: - logger.exception("Failed to kill job after primary failure") + logger.exception(f"Exception {e}") + _cancel_all_futures(futures) return TpuFailed(info, e) actors = [MultisliceActor.remote() for _ in range(num_slices)] # type: ignore @@ -172,7 +167,7 @@ def do_run(self, remote_fn, coordinator_ip, slice_id, num_slices) -> _TpuRunResu logger.exception(e) for actor in actors: try: - ray.cancel(actor) + ray.kill(actor) except Exception: logger.exception("Failed to kill actor after primary failure") return futures @@ -310,6 +305,16 @@ def run_on_pod_multislice_resumable( futures = run_on_pod_multislice(remote_fn, tpu_type, num_slices) try: outs = ray.get(futures) + except ray.exceptions.ActorUnavailableError as e: + problem = e + num_preemptions += 1 + logger.warning(f"Preempted {num_preemptions} times, {e}") + continue + except ray.exceptions.ActorDiedError as e: + problem = e + num_preemptions += 1 + logger.warning(f"Preempted {num_preemptions} times, {e}") + continue except ray.exceptions.RayTaskError as e: for f in futures: try: @@ -425,6 +430,9 @@ def _handle_ray_error(tpu_info: _TpuInfo, e: RayError): if isinstance(e, NodeDiedError): logger.exception("Node died", exc_info=e) return TpuPreempted(tpu_info, e) + elif isinstance(e, ray.exceptions.ActorUnavailableError | ray.exceptions.ActorDiedError): + logger.exception("Actor died", exc_info=e) + return TpuPreempted(tpu_info, e) elif isinstance(e, WorkerCrashedError): logger.exception("Worker crashed", exc_info=e) return TpuPreempted(tpu_info, e) @@ -506,7 +514,13 @@ def target_fn(queue, args, kwargs): process.join() # Retrieve the result or error from the queue - success, value = queue.get() + logger.info("Process finished") + try: + success, value = queue.get(timeout=10) + except QueueEmpty: + logger.error("Process timed out") + process.terminate() + raise RuntimeError("Process timed out") if success: return value diff --git a/tests/test_mixture.py b/tests/test_mixture.py index e8821e24f..52652f380 100644 --- a/tests/test_mixture.py +++ b/tests/test_mixture.py @@ -73,6 +73,36 @@ async def test_mixture_dataset_stop_strategy_restart(): await mixture_ds.async_len() +@pytest.mark.asyncio +async def test_mixture_dataset_simulated_data_size(): + weights = {"ds1": 1 / 3, "ds2": 1 / 3, "ds3": 1 / 3} + mixture_ds = MixtureDataset( + {name: dataset.slice_dataset(end_index=1) for name, dataset in datasets().items()}, + weights, + block_size=10, + key=key(), + randomize_blocks=False, + stop_strategy=StopStrategy.RESTART_STRATEGY, + ) + for _ in range(10): + batch = await mixture_ds.get_batch([0, 1, 2]) + assert len(batch) == 3 + assert all(item in [1, 10, 100] for item in batch) + + mixture_ds = MixtureDataset( + {name: dataset.slice_dataset(end_index=2) for name, dataset in datasets().items()}, + weights, + block_size=10, + key=key(), + randomize_blocks=False, + stop_strategy=StopStrategy.RESTART_STRATEGY, + ) + for _ in range(10): + batch = await mixture_ds.get_batch([0, 1, 2]) + assert len(batch) == 3 + assert all(item in [1, 2, 10, 20, 100, 200] for item in batch) + + @pytest.mark.asyncio async def test_mixture_dataset_normalized_weights(): weights = {"ds1": 0, "ds2": 0.5, "ds3": 0.5} @@ -87,7 +117,9 @@ async def test_mixture_dataset_normalized_weights(): async def test_mixture_dataset_unpermuted_ids(): mixture_ds = MixtureDataset(datasets(), weights(), block_size=10, key=key()) - unpermuted_ids = mixture_ds._compute_unpermuted_ids(mixture_ds._counts_per_block) + unpermuted_ids = mixture_ds._compute_unpermuted_ids( + mixture_ds._compute_expected_counts_per_block(weights(), block_size()) + ) assert len(unpermuted_ids) == 10 assert unpermuted_ids[0] >> 32 in range(3) # Ensure the dataset ID is valid diff --git a/tests/test_prp.py b/tests/test_prp.py index 6c549eabf..8a0bcac94 100644 --- a/tests/test_prp.py +++ b/tests/test_prp.py @@ -10,8 +10,8 @@ def test_permutation_creates_valid_instance(): prng_key = jrandom.PRNGKey(0) permutation = Permutation(length, prng_key) assert permutation.length == length - assert permutation._a > 0 and permutation._a < length - assert permutation._b >= 0 and permutation._b < length + assert 0 < permutation.a < length + assert 0 <= permutation.b < length def test_permutation_with_single_index_returns_correct_value(): @@ -85,3 +85,13 @@ def test_permutation_is_deterministic1(): permutation = Permutation(length, prng_key) results2 = permutation(indices) assert jnp.all(results == results2) + + +def test_permutation_handles_large_length_no_overflow(): + large_length = 2**34 + prng_key = jrandom.PRNGKey(0) + permutation = Permutation(large_length, prng_key) + index = 2**32 # A large index within the range + result = permutation(index) + assert isinstance(result, int) + assert 0 <= result < large_length diff --git a/tests/test_varying_mixture.py b/tests/test_varying_mixture.py new file mode 100644 index 000000000..564225b4e --- /dev/null +++ b/tests/test_varying_mixture.py @@ -0,0 +1,142 @@ +import jax +import pytest + +from levanter.data import ListAsyncDataset, MixtureDataset + + +def create_datasets(): + ds1 = ListAsyncDataset([1, 2, 3, 4, 5]) + ds2 = ListAsyncDataset([10, 20, 30, 40, 50]) + ds3 = ListAsyncDataset([100, 200, 300, 400, 500]) + ds1.finalize() + ds2.finalize() + ds3.finalize() + return {"ds1": ds1, "ds2": ds2, "ds3": ds3} + + +@pytest.mark.asyncio +async def test_mixture_dataset_stage_transitions(): + datasets = create_datasets() + # Define three stages with different weights + stages = [ + (0, {"ds1": 1.0, "ds2": 0.0, "ds3": 0.0}), # Stage 1: only ds1 + (20, {"ds1": 0.0, "ds2": 1.0, "ds3": 0.0}), # Stage 2: only ds2 + (40, {"ds1": 0.0, "ds2": 0.0, "ds3": 1.0}), # Stage 3: only ds3 + ] + + mixture_ds = MixtureDataset(datasets, stages, block_size=10, key=jax.random.PRNGKey(42), randomize_blocks=False) + + # Test first stage (should only get values from ds1) + batch1 = await mixture_ds.get_batch(list(range(10))) + assert all(x in [1, 2, 3, 4, 5] for x in batch1), f"Unexpected values in first stage: {batch1}" + + # Test second stage (should only get values from ds2) + batch2 = await mixture_ds.get_batch(list(range(20, 30))) + assert all(x in [10, 20, 30, 40, 50] for x in batch2), f"Unexpected values in second stage: {batch2}" + + # Test third stage (should only get values from ds3) + batch3 = await mixture_ds.get_batch(list(range(40, 50))) + assert all(x in [100, 200, 300, 400, 500] for x in batch3), f"Unexpected values in third stage: {batch3}" + + +@pytest.mark.asyncio +async def test_mixture_dataset_gradual_transition(): + datasets = create_datasets() + # Define stages with gradual transitions + stages = [ + (0, {"ds1": 0.8, "ds2": 0.2, "ds3": 0.0}), # Mostly ds1 + (20, {"ds1": 0.2, "ds2": 0.6, "ds3": 0.2}), # Mostly ds2 + (40, {"ds1": 0.0, "ds2": 0.2, "ds3": 0.8}), # Mostly ds3 + ] + + mixture_ds = MixtureDataset(datasets, stages, block_size=10, key=jax.random.PRNGKey(42)) + + # Sample a large batch from each stage and verify proportions + def count_sources(batch): + counts = {"ds1": 0, "ds2": 0, "ds3": 0} + for x in batch: + if x < 10: + counts["ds1"] += 1 + elif x < 100: + counts["ds2"] += 1 + else: + counts["ds3"] += 1 + return counts + + # Test first stage + batch1 = await mixture_ds.get_batch(list(range(20))) + counts1 = count_sources(batch1) + assert counts1["ds1"] > counts1["ds2"] > counts1["ds3"], f"Unexpected distribution in first stage: {counts1}" + + # Test second stage + batch2 = await mixture_ds.get_batch(list(range(20, 40))) + counts2 = count_sources(batch2) + assert ( + counts2["ds2"] > counts2["ds1"] and counts2["ds2"] > counts2["ds3"] + ), f"Unexpected distribution in second stage: {counts2}" + + # Test third stage + batch3 = await mixture_ds.get_batch(list(range(40, 60))) + counts3 = count_sources(batch3) + assert counts3["ds3"] > counts3["ds2"] > counts3["ds1"], f"Unexpected distribution in third stage: {counts3}" + + +@pytest.mark.asyncio +async def test_mixture_dataset_invalid_stage_configurations(): + datasets = create_datasets() + + # Test stages that don't start at 0 + with pytest.raises(AssertionError): + MixtureDataset(datasets, [(10, {"ds1": 1.0})], block_size=10, key=jax.random.PRNGKey(42)) + + # Test stages with start indices not multiple of block_size + with pytest.raises(AssertionError): + MixtureDataset(datasets, [(0, {"ds1": 1.0}), (15, {"ds2": 1.0})], block_size=10, key=jax.random.PRNGKey(42)) + + # Test stages in wrong order + with pytest.raises(AssertionError): + MixtureDataset( + datasets, + [(0, {"ds1": 1.0}), (20, {"ds2": 1.0}), (10, {"ds3": 1.0})], + block_size=10, + key=jax.random.PRNGKey(42), + ) + + +@pytest.mark.asyncio +async def test_mixture_dataset_zero_weight_handling(): + datasets = create_datasets() + # Define stages where some datasets have zero weight + stages = [ + (0, {"ds1": 1.0, "ds2": 0.0, "ds3": 0.0}), + (20, {"ds1": 0.0, "ds2": 1.0, "ds3": 0.0}), + ] + + mixture_ds = MixtureDataset(datasets, stages, block_size=10, key=jax.random.PRNGKey(42), randomize_blocks=False) + + # Verify that zero-weight datasets are not sampled + batch1 = await mixture_ds.get_batch(list(range(10))) + assert all(x < 10 for x in batch1), f"Found samples from zero-weight datasets in first stage: {batch1}" + + batch2 = await mixture_ds.get_batch(list(range(20, 30))) + assert all(10 <= x < 100 for x in batch2), f"Found samples from zero-weight datasets in second stage: {batch2}" + + +@pytest.mark.asyncio +async def test_mixture_dataset_block_boundaries(): + datasets = create_datasets() + # Define stages with transition at block boundary + stages = [ + (0, {"ds1": 1.0, "ds2": 0.0, "ds3": 0.0}), + (10, {"ds1": 0.0, "ds2": 1.0, "ds3": 0.0}), + ] + + mixture_ds = MixtureDataset(datasets, stages, block_size=10, key=jax.random.PRNGKey(42), randomize_blocks=False) + + # Test the boundary between stages + batch = await mixture_ds.get_batch(list(range(5, 15))) # Should span both stages + first_half = batch[:5] + second_half = batch[5:] + + assert all(x < 10 for x in first_half), f"Unexpected values at end of first stage: {first_half}" + assert all(10 <= x < 100 for x in second_half), f"Unexpected values at start of second stage: {second_half}"