diff --git a/src/levanter/data/mixture.py b/src/levanter/data/mixture.py index d73f01081..188e5e426 100644 --- a/src/levanter/data/mixture.py +++ b/src/levanter/data/mixture.py @@ -35,7 +35,7 @@ class MixtureDataset(AsyncDataset[T]): Args: datasets: A dict of datasets, where the key is the name of the dataset and the value is the dataset itself - weights: Weights for each dataset. This can be provided in a list of stages, where each stage is a tuple of (start_index, weights). + weights: Weights for each dataset. This can be provided in a list of stages, where each stage is a tuple of (start_seq_index, weights). Note that start_seq_index corresponds to the sequence index at which the weights should change, not the training batch index. stop_strategy: strategy for stopping the iteration, by default RESTART_STRATEGY. (Currently only RESTART_STRATEGY is supported) - FIRST_STOP_STRATEGY: stop when one dataset has been exhausted - ALL_STOP_STRATEGY: stop when all datasets have been exhausted @@ -60,17 +60,18 @@ def __init__( weight_stages = weights # assert that steps are in sorted order and that the start index of each stage is a multiple of block_size - for i, (start_index, _) in enumerate(weight_stages): + for i, (start_seq_index, _) in enumerate(weight_stages): if i == 0: - assert start_index == 0 + assert start_seq_index == 0 else: - assert ( - start_index % block_size == 0 - ), f"start_index for a stage must be a multiple of block_size, got {start_index=} and {block_size=}" - assert start_index > weight_stages[i - 1][0], f"Weights list must be sorted, got {weight_stages}" + assert start_seq_index % block_size == 0, ( + f"start_seq_index for a stage must be a multiple of block_size, got {start_seq_index=} and" + f" {block_size=}" + ) + assert start_seq_index > weight_stages[i - 1][0], f"Weights list must be sorted, got {weight_stages}" self.weight_stages = [ - (start_index, self._normalize_weights(weights)) for start_index, weights in weight_stages + (start_seq_index, self._normalize_weights(weights)) for start_seq_index, weights in weight_stages ] self.datasets = { name: dataset @@ -99,24 +100,33 @@ def __init__( self.stop_strategy = stop_strategy - # Compute counts and unpermuted_ids for each stage - self._counts_per_block_per_stage = [] - self._counts_after_stage = [] - self._unpermuted_ids_per_stage = [] + # Initialize stage-related counts and IDs + ( + self._counts_per_block_per_stage, + self._counts_after_stage, + self._unpermuted_ids_per_stage, + ) = self._initialize_stage_counts() + + def _initialize_stage_counts(self): + counts_per_block_per_stage = [] + counts_after_stage = [] + unpermuted_ids_per_stage = [] cumulative_counts = np.zeros(len(self.datasets), dtype=np.int32) - for stage_idx, (start_index, stage_weights) in enumerate(self.weight_stages): - counts_this_stage = self._compute_expected_counts_per_block(stage_weights, block_size) - self._counts_per_block_per_stage.append(counts_this_stage) - self._unpermuted_ids_per_stage.append(self._compute_unpermuted_ids(counts_this_stage)) + for stage_idx, (start_seq_index, stage_weights) in enumerate(self.weight_stages): + counts_this_stage = self._compute_expected_counts_per_block(stage_weights, self.block_size) + counts_per_block_per_stage.append(counts_this_stage) + unpermuted_ids_per_stage.append(self._compute_unpermuted_ids(counts_this_stage)) if stage_idx < len(self.weight_stages) - 1: next_start = self.weight_stages[stage_idx + 1][0] - num_blocks_in_stage = (next_start - start_index) // block_size + num_blocks_in_stage = (next_start - start_seq_index) // self.block_size stage_total_counts = counts_this_stage * num_blocks_in_stage cumulative_counts += stage_total_counts - self._counts_after_stage.append(cumulative_counts.copy()) + counts_after_stage.append(cumulative_counts.copy()) + + return counts_per_block_per_stage, counts_after_stage, unpermuted_ids_per_stage def _compute_expected_counts_per_block(self, weights: dict[str, float], block_size: int): _expected_values_per_block = np.zeros(len(self.datasets), dtype=np.int32) @@ -217,8 +227,8 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[T]: assert len(indices) == len(blocks) == len(block_ids) for batch_index, (idx, block, block_id) in enumerate(zip(indices, blocks, block_ids)): - index_within_block = idx % self.block_size - id = block[index_within_block] + index_within_block = idx % self.block_size # which element of the block to get + id = block[index_within_block] # for this block, which dataset+base dataset offset dataset_id, dataset_index = self._index_into_dataset_for_id(id, block_id) batches_per_dataset[dataset_id].append(dataset_index) indices_in_final_batch[dataset_id].append(batch_index) @@ -259,6 +269,9 @@ async def getitem_async(self, index: int) -> T: return await dataset.getitem_async(dataset_index) async def _remap_indices(self, ds, indices_into_ds): + """ + Handles wrap around for datasets that have finite length + """ if self.stop_strategy == StopStrategy.RESTART_STRATEGY: if ds.is_finite(): max_elem = max(indices_into_ds)