Skip to content

Commit

Permalink
addressing feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
kothasuhas committed Jan 29, 2025
1 parent 1b506ca commit e340079
Showing 1 changed file with 33 additions and 20 deletions.
53 changes: 33 additions & 20 deletions src/levanter/data/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class MixtureDataset(AsyncDataset[T]):
Args:
datasets: A dict of datasets, where the key is the name of the dataset and the value is the dataset itself
weights: Weights for each dataset. This can be provided in a list of stages, where each stage is a tuple of (start_index, weights).
weights: Weights for each dataset. This can be provided in a list of stages, where each stage is a tuple of (start_seq_index, weights). Note that start_seq_index corresponds to the sequence index at which the weights should change, not the training batch index.
stop_strategy: strategy for stopping the iteration, by default RESTART_STRATEGY. (Currently only RESTART_STRATEGY is supported)
- FIRST_STOP_STRATEGY: stop when one dataset has been exhausted
- ALL_STOP_STRATEGY: stop when all datasets have been exhausted
Expand All @@ -60,17 +60,18 @@ def __init__(
weight_stages = weights

# assert that steps are in sorted order and that the start index of each stage is a multiple of block_size
for i, (start_index, _) in enumerate(weight_stages):
for i, (start_seq_index, _) in enumerate(weight_stages):
if i == 0:
assert start_index == 0
assert start_seq_index == 0
else:
assert (
start_index % block_size == 0
), f"start_index for a stage must be a multiple of block_size, got {start_index=} and {block_size=}"
assert start_index > weight_stages[i - 1][0], f"Weights list must be sorted, got {weight_stages}"
assert start_seq_index % block_size == 0, (
f"start_seq_index for a stage must be a multiple of block_size, got {start_seq_index=} and"
f" {block_size=}"
)
assert start_seq_index > weight_stages[i - 1][0], f"Weights list must be sorted, got {weight_stages}"

self.weight_stages = [
(start_index, self._normalize_weights(weights)) for start_index, weights in weight_stages
(start_seq_index, self._normalize_weights(weights)) for start_seq_index, weights in weight_stages
]
self.datasets = {
name: dataset
Expand Down Expand Up @@ -99,24 +100,33 @@ def __init__(

self.stop_strategy = stop_strategy

# Compute counts and unpermuted_ids for each stage
self._counts_per_block_per_stage = []
self._counts_after_stage = []
self._unpermuted_ids_per_stage = []
# Initialize stage-related counts and IDs
(
self._counts_per_block_per_stage,
self._counts_after_stage,
self._unpermuted_ids_per_stage,
) = self._initialize_stage_counts()

def _initialize_stage_counts(self):
counts_per_block_per_stage = []
counts_after_stage = []
unpermuted_ids_per_stage = []

cumulative_counts = np.zeros(len(self.datasets), dtype=np.int32)

for stage_idx, (start_index, stage_weights) in enumerate(self.weight_stages):
counts_this_stage = self._compute_expected_counts_per_block(stage_weights, block_size)
self._counts_per_block_per_stage.append(counts_this_stage)
self._unpermuted_ids_per_stage.append(self._compute_unpermuted_ids(counts_this_stage))
for stage_idx, (start_seq_index, stage_weights) in enumerate(self.weight_stages):
counts_this_stage = self._compute_expected_counts_per_block(stage_weights, self.block_size)
counts_per_block_per_stage.append(counts_this_stage)
unpermuted_ids_per_stage.append(self._compute_unpermuted_ids(counts_this_stage))

if stage_idx < len(self.weight_stages) - 1:
next_start = self.weight_stages[stage_idx + 1][0]
num_blocks_in_stage = (next_start - start_index) // block_size
num_blocks_in_stage = (next_start - start_seq_index) // self.block_size
stage_total_counts = counts_this_stage * num_blocks_in_stage
cumulative_counts += stage_total_counts
self._counts_after_stage.append(cumulative_counts.copy())
counts_after_stage.append(cumulative_counts.copy())

return counts_per_block_per_stage, counts_after_stage, unpermuted_ids_per_stage

def _compute_expected_counts_per_block(self, weights: dict[str, float], block_size: int):
_expected_values_per_block = np.zeros(len(self.datasets), dtype=np.int32)
Expand Down Expand Up @@ -217,8 +227,8 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[T]:
assert len(indices) == len(blocks) == len(block_ids)

for batch_index, (idx, block, block_id) in enumerate(zip(indices, blocks, block_ids)):
index_within_block = idx % self.block_size
id = block[index_within_block]
index_within_block = idx % self.block_size # which element of the block to get
id = block[index_within_block] # for this block, which dataset+base dataset offset
dataset_id, dataset_index = self._index_into_dataset_for_id(id, block_id)
batches_per_dataset[dataset_id].append(dataset_index)
indices_in_final_batch[dataset_id].append(batch_index)
Expand Down Expand Up @@ -259,6 +269,9 @@ async def getitem_async(self, index: int) -> T:
return await dataset.getitem_async(dataset_index)

async def _remap_indices(self, ds, indices_into_ds):
"""
Handles wrap around for datasets that have finite length
"""
if self.stop_strategy == StopStrategy.RESTART_STRATEGY:
if ds.is_finite():
max_elem = max(indices_into_ds)
Expand Down

0 comments on commit e340079

Please sign in to comment.