Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into sftpack
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Feb 5, 2025
2 parents f5cba5d + e49702b commit ecd7f48
Show file tree
Hide file tree
Showing 12 changed files with 450 additions and 97 deletions.
8 changes: 4 additions & 4 deletions docs/Getting-Started-TPU-VM.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ accel_env:
docker_repository: levanter # default
zone: us-west4-a # if not set, will use your default zone
tpu_name: test-spin-up-32
tpu_type: "v5litepod-16"
tpu: test-spin-up-32 # name of the TPU
capacity_type: "preemptible"
subnetwork: "default" # default
EOF
Expand Down Expand Up @@ -319,11 +318,12 @@ This will show you a list of files and directories in the repo, sorted by size,

### Automatic Setup

!!! warning
This approach is deprecated and will be removed in the future. Please use `launch.py` or `launch_on_ray.py` instead.

You can use `infra/spin-up-vm.sh` to create a TPU VM instance. In addition to creating the instance, it will set up
the venv on each worker, and it will clone the repo to `~/levanter/`.

**For Public Users**:

```bash
bash infra/spin-up-vm.sh <name> -z <zone> -t <type> -n <subnetwork> [--preemptible] [--use-alpha]
```
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ dependencies = [
"pyarrow>=11.0.0",
"zstandard>=0.20.0",
"datasets>=3.1.0,<4.0",
"gcsfs>=2024.2,<2025",
"gcsfs>=2024.2,<2026",
"braceexpand>=0.1.7",
"jmp>=0.0.3",
"fsspec[http]>=2024.2,<2025",
"fsspec[http]>=2024.2,<2026",
"tensorstore>=0.1.65",
"pytimeparse>=1.1.8",
"humanfriendly==10.0",
Expand Down
48 changes: 22 additions & 26 deletions src/levanter/data/_prp.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import typing

import jax.lax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np


# TODO: do we make this a pytree
class Permutation:
# Pseudo-Random Permutation Code
"""A stateless pseudo-random permutation.
This class generates a pseudo-random permutation of a given length. The permutation is generated using a PRNG
with a fixed key. The permutation is generated by finding a random `a` and `b` such that `gcd(a, length) != 1` and
with a fixed key. The permutation is generated by finding a random `a` and `b` such that `gcd(a, length) == 1` and
then computing the permutation as `p(x) = (a * x + b) % length`.
This is not a very good PRP, but it is probably good enough for our purposes.
Expand All @@ -21,40 +19,40 @@ class Permutation:

def __init__(self, length, prng_key):
self.length = length
self.prng_key = prng_key
a_key, b_key = jrandom.split(prng_key)
self._a = jrandom.randint(a_key, (), 1, length)
self._b = jrandom.randint(b_key, (), 0, length)
# Convert jax.random.PRNGKey to numpy.random.Generator
self.rng = np.random.Generator(np.random.PCG64(jrandom.randint(prng_key, (), 0, 2**30).item()))
self.a, self.b = self._generate_permutation_params() # Generate a and b in init

cond = lambda a_and_key: jnp.all(jnp.gcd(a_and_key[0], length) != 1)
def _generate_permutation_params(self):
length = self.length
rng = self.rng

def loop_body(a_and_key):
a, key = a_and_key
this_key, key = jrandom.split(key)
a = jrandom.randint(this_key, (), 1, length)
return a, key
if length == 1:
return 1, 0

self._a, key = jax.lax.while_loop(cond, loop_body, (self._a, a_key))
while True:
a = rng.integers(1, length)
if np.gcd(a, length) == 1:
break

self._a = int(self._a)
self._b = int(self._b)
b = rng.integers(0, length) # b can be in [0, length-1]
return a, b

@typing.overload
def __call__(self, indices: int) -> int:
...

@typing.overload
def __call__(self, indices: jnp.ndarray) -> jnp.ndarray:
def __call__(self, indices: np.ndarray) -> np.ndarray:
...

def __call__(self, indices):
a = self.a
b = self.b
length = self.length

was_int = False
if isinstance(indices, jnp.ndarray):
# TODO: use error_if?
# import equinox as eqx
if jnp.any(indices < 0) or jnp.any(indices >= self.length):
raise IndexError(f"index {indices} is out of bounds for length {self.length}")
elif isinstance(indices, np.ndarray):
if isinstance(indices, np.ndarray | jnp.ndarray):
if np.any(indices < 0) or np.any(indices >= self.length):
raise IndexError(f"index {indices} is out of bounds for length {self.length}")
else:
Expand All @@ -64,9 +62,7 @@ def __call__(self, indices):
indices = np.array(indices)
was_int = True

old_settings = np.seterr(over="raise")
out = (self._a * indices + self._b) % self.length
np.seterr(**old_settings)
out = (a * indices + b) % length # Compute permutation on-the-fly

if was_int:
return int(out)
Expand Down
54 changes: 54 additions & 0 deletions src/levanter/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def map(self, fn: MapFunction[U], *extra_args, **extra_kwargs) -> "MappedAsyncDa
def map_batches(self, fn: MapFunction[Sequence[U]], *extra_args, **extra_kwargs) -> "BatchMappedAsyncDataset[U]":
return BatchMappedAsyncDataset(self, fn, *extra_args, **extra_kwargs)

def slice_dataset(self, start_index: Optional[int] = None, end_index: Optional[int] = None):
return SlicedAsyncDataset(self, start_index, end_index)

def shuffle(self, key: PRNGKey):
import levanter.data.permutation as permutation

Expand Down Expand Up @@ -375,6 +378,57 @@ def _call_fn(self, index, item):
return self.fn(item, *self._extra_args, **kwargs)


class SlicedAsyncDataset(AsyncDataset[U]):
def __init__(
self,
dataset: AsyncDataset[U],
start_index: Optional[int] = None,
end_index: Optional[int] = None,
):
super().__init__()
if start_index is None:
start_index = 0
if end_index is not None and start_index > end_index:
raise ValueError("End index must come after start index.")

self.start_index = start_index
self.end_index = end_index
self.dataset = dataset
self._min_known_len = dataset._min_known_len if end_index is None else (end_index - start_index)

async def get_batch(self, indices: Sequence[int]) -> Sequence[U]:
shifted_indices = [(index + self.start_index) for index in indices]
max_index = max(shifted_indices)

if self.end_index is not None and max_index > self.end_index:
raise ValueError("Requested indices beyond the end of the dataset")

return await self.dataset.get_batch(shifted_indices)

async def async_len(self) -> int:
underlying_length = await self.dataset.async_len()
if self.end_index is None:
return underlying_length - self.start_index
else:
return self.end_index - self.start_index

async def final_length_is_known(self) -> bool:
underlying_is_known = await self.dataset.final_length_is_known()
return underlying_is_known and self.end_index is not None

def is_finite(self) -> bool:
return self.dataset.is_finite() and self.end_index is not None

async def current_len(self) -> Optional[int]:
underlying_length = await self.dataset.current_len()
if self.end_index is not None:
return self.end_index - self.start_index
elif underlying_length is not None:
return underlying_length - self.start_index
else:
return underlying_length


class BatchMappedAsyncDataset(AsyncDataset[U]):
"""
A dataset that applies a function to each batch of items in the dataset.
Expand Down
100 changes: 80 additions & 20 deletions src/levanter/data/mixture.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import warnings
from typing import Mapping, Optional, Sequence, TypeVar
from typing import List, Mapping, Optional, Sequence, Tuple, TypeVar

import jax
import numpy as np
Expand Down Expand Up @@ -30,12 +30,12 @@ class MixtureDataset(AsyncDataset[T]):
according to the weights.
Creating a random-access MixtureDataset is challenging because we need to keep track of the current index of each
dataset. So solve this, we instead use "block-deterministic" mixtures, where the number of samples from each dataset
in each block is always identical (and we shuffle the order of the dataset ids in each block).
dataset. To solve this, we instead use "block-deterministic" mixtures, where the number of samples from each dataset
in each block is always identical (and we shuffle the order of the dataset ids in each block). To handle the case where the dataset mixture changes over time, we use a list of stages and precompute statistics to accurately compute the index of each dataset in each block.
Args:
datasets: A dict of datasets, where the key is the name of the dataset and the value is the dataset itself
weights: weights for each dataset
weights: Weights for each dataset. This can be provided in a list of stages, where each stage is a tuple of (start_seq_index, weights). Note that start_seq_index corresponds to the sequence index at which the weights should change, not the training batch index.
stop_strategy: strategy for stopping the iteration, by default RESTART_STRATEGY. (Currently only RESTART_STRATEGY is supported)
- FIRST_STOP_STRATEGY: stop when one dataset has been exhausted
- ALL_STOP_STRATEGY: stop when all datasets have been exhausted
Expand All @@ -46,16 +46,38 @@ class MixtureDataset(AsyncDataset[T]):
def __init__(
self,
datasets: Mapping[str, AsyncDataset[T]],
weights: dict[str, float],
weights: dict[str, float] | List[Tuple[int, dict[str, float]]],
block_size: int,
*,
randomize_blocks: bool = True,
key: PRNGKeyArray | int,
stop_strategy: str = StopStrategy.RESTART_STRATEGY,
):
super().__init__()
self.weights = MixtureDataset._normalize_weights(weights)
self.datasets = {name: dataset for name, dataset in datasets.items() if self.weights.get(name, 0) > 0}
if isinstance(weights, dict):
weight_stages = [(0, weights)]
else:
weight_stages = weights

# assert that steps are in sorted order and that the start index of each stage is a multiple of block_size
for i, (start_seq_index, _) in enumerate(weight_stages):
if i == 0:
assert start_seq_index == 0
else:
assert start_seq_index % block_size == 0, (
f"start_seq_index for a stage must be a multiple of block_size, got {start_seq_index=} and"
f" {block_size=}"
)
assert start_seq_index > weight_stages[i - 1][0], f"Weights list must be sorted, got {weight_stages}"

self.weight_stages = [
(start_seq_index, self._normalize_weights(weights)) for start_seq_index, weights in weight_stages
]
self.datasets = {
name: dataset
for name, dataset in datasets.items()
if any(weights.get(name, 0) > 0 for _, weights in self.weight_stages)
}
self.dataset_index = Index(self.datasets.keys())
self.block_size = block_size
# we pack index and ds id into a single 32 bit, so block size must be at most 2^16
Expand All @@ -78,25 +100,48 @@ def __init__(

self.stop_strategy = stop_strategy

self._counts_per_block = self._compute_expected_counts_per_block(block_size)
# precompute a list of ids for each block
# the ids contain both the dataset index and the index within the dataset
self._unpermuted_ids = self._compute_unpermuted_ids(self._counts_per_block)
# Initialize stage-related counts and IDs
(
self._counts_per_block_per_stage,
self._counts_after_stage,
self._unpermuted_ids_per_stage,
) = self._initialize_stage_counts()

def _initialize_stage_counts(self):
counts_per_block_per_stage = []
counts_after_stage = []
unpermuted_ids_per_stage = []

cumulative_counts = np.zeros(len(self.datasets), dtype=np.int32)

for stage_idx, (start_seq_index, stage_weights) in enumerate(self.weight_stages):
counts_this_stage = self._compute_expected_counts_per_block(stage_weights, self.block_size)
counts_per_block_per_stage.append(counts_this_stage)
unpermuted_ids_per_stage.append(self._compute_unpermuted_ids(counts_this_stage))

if stage_idx < len(self.weight_stages) - 1:
next_start = self.weight_stages[stage_idx + 1][0]
num_blocks_in_stage = (next_start - start_seq_index) // self.block_size
stage_total_counts = counts_this_stage * num_blocks_in_stage
cumulative_counts += stage_total_counts
counts_after_stage.append(cumulative_counts.copy())

def _compute_expected_counts_per_block(self, block_size):
return counts_per_block_per_stage, counts_after_stage, unpermuted_ids_per_stage

def _compute_expected_counts_per_block(self, weights: dict[str, float], block_size: int):
_expected_values_per_block = np.zeros(len(self.datasets), dtype=np.int32)
for i, dsname in enumerate(self.dataset_index):
_expected_values_per_block[i] = self.weights[dsname] * block_size
_expected_values_per_block[i] = weights.get(dsname, 0) * block_size

# handle remainder by adding to the largest dataset
largest_dataset = np.argmax(_expected_values_per_block)
_expected_values_per_block[largest_dataset] += block_size - _expected_values_per_block.sum()

# check if any dataset has 0 samples (and nonzero weight)
for i, dsname in enumerate(self.dataset_index):
if _expected_values_per_block[i] == 0 and self.weights[dsname] > 0:
if _expected_values_per_block[i] == 0 and weights.get(dsname, 0) > 0:
warnings.warn(
f"Dataset {dsname} has 0 samples in the block, but weight of {self.weights[dsname]}."
f"Dataset {dsname} has 0 samples in the block, but weight of {weights[dsname]}."
" Recommend increasing block size."
)

Expand Down Expand Up @@ -143,20 +188,35 @@ async def current_len(self) -> Optional[int]:

raise NotImplementedError("Length is not known for other strategies")

@alru_cache
def _get_stage_for_block(self, block_id: int) -> int:
block_start = block_id * self.block_size
stage_starts = np.array([start for start, _ in self.weight_stages])
return max(0, np.searchsorted(stage_starts, block_start, side="right") - 1)

@alru_cache(maxsize=32)
async def _get_block(self, index: int) -> Optional[np.ndarray]:
stage = self._get_stage_for_block(index)
if not self.randomize_blocks:
return self._unpermuted_ids
return self._unpermuted_ids_per_stage[stage]

return np.array(_compute_block_assignment(self._unpermuted_ids, index, self.key))
return np.array(_compute_block_assignment(self._unpermuted_ids_per_stage[stage], index, self.key))

def _index_into_dataset_for_id(self, id: int, block_id) -> tuple[int, int]:
def _index_into_dataset_for_id(self, id: int, block_id: int) -> tuple[int, int]:
stage = self._get_stage_for_block(block_id)
dataset_id = id >> 16
dataset_index = id & 0xFFFF
return dataset_id, dataset_index + block_id * self._counts_per_block[dataset_id]

# Get the base offset from previous stages
base_offset = self._counts_after_stage[stage - 1][dataset_id] if stage > 0 else 0
# Add offset within current stage
offset_in_stage = (block_id * self.block_size - self.weight_stages[stage][0]) // self.block_size
current_stage_offset = offset_in_stage * self._counts_per_block_per_stage[stage][dataset_id]

return dataset_id, dataset_index + base_offset + current_stage_offset

async def get_batch(self, indices: Sequence[int]) -> Sequence[T]:
block_ids = np.array([idx // self.block_size for idx in indices])

blocks = [self._get_block(block_id) for block_id in block_ids]
blocks = await asyncio.gather(*blocks)

Expand Down
12 changes: 8 additions & 4 deletions src/levanter/data/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ async def getitem_async(self, index: int) -> T_co:

async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:
permutation = await self._get_permutation()
return await self.dataset.get_batch([permutation(i) for i in indices])
return await self.dataset.get_batch(
[int(permutation(i)) for i in indices]
) # cast to int to be sure it's python int

async def _get_permutation(self):
if self._permutation is None:
Expand Down Expand Up @@ -83,10 +85,10 @@ async def gen_era_permutation(era: int) -> Permutation:
# TODO: support epochs
# edge case: final era may be shorter than era_length
current_len = await self.dataset.wait_until_len_at_least((era + 1) * self.era_length)
era_length = min(self.era_length, current_len - era * self.era_length)
era_length_val = min(self.era_length, current_len - era * self.era_length)

mix_key = jax.random.fold_in(key, era)
return Permutation(era_length, mix_key)
return Permutation(era_length_val, mix_key)

self.gen_era_permutation = gen_era_permutation

Expand All @@ -95,7 +97,9 @@ async def _get_index(self, idx: int) -> int:
raise ValueError("Negative indices are not supported")
era = idx // self.era_length
permutation = await self.gen_era_permutation(era)
return permutation(idx - era * self.era_length) + era * self.era_length
out = permutation(idx - era * self.era_length) + era * self.era_length

return out

async def async_len(self) -> int:
return await self.dataset.async_len()
Expand Down
Loading

0 comments on commit ecd7f48

Please sign in to comment.