Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use int64 for prp #870

Merged
merged 10 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 22 additions & 26 deletions src/levanter/data/_prp.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import typing

import jax.lax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np


# TODO: do we make this a pytree
class Permutation:
# Pseudo-Random Permutation Code
"""A stateless pseudo-random permutation.

This class generates a pseudo-random permutation of a given length. The permutation is generated using a PRNG
with a fixed key. The permutation is generated by finding a random `a` and `b` such that `gcd(a, length) != 1` and
with a fixed key. The permutation is generated by finding a random `a` and `b` such that `gcd(a, length) == 1` and
then computing the permutation as `p(x) = (a * x + b) % length`.

This is not a very good PRP, but it is probably good enough for our purposes.
Expand All @@ -21,40 +19,40 @@ class Permutation:

def __init__(self, length, prng_key):
self.length = length
self.prng_key = prng_key
a_key, b_key = jrandom.split(prng_key)
self._a = jrandom.randint(a_key, (), 1, length)
self._b = jrandom.randint(b_key, (), 0, length)
# Convert jax.random.PRNGKey to numpy.random.Generator
self.rng = np.random.Generator(np.random.PCG64(jrandom.randint(prng_key, (), 0, 2**30).item()))
self.a, self.b = self._generate_permutation_params() # Generate a and b in init

cond = lambda a_and_key: jnp.all(jnp.gcd(a_and_key[0], length) != 1)
def _generate_permutation_params(self):
length = self.length
rng = self.rng

def loop_body(a_and_key):
a, key = a_and_key
this_key, key = jrandom.split(key)
a = jrandom.randint(this_key, (), 1, length)
return a, key
if length == 1:
return 1, 0

self._a, key = jax.lax.while_loop(cond, loop_body, (self._a, a_key))
while True:
a = rng.integers(1, length)
if np.gcd(a, length) == 1:
break

self._a = int(self._a)
self._b = int(self._b)
b = rng.integers(0, length) # b can be in [0, length-1]
return a, b

@typing.overload
def __call__(self, indices: int) -> int:
...

@typing.overload
def __call__(self, indices: jnp.ndarray) -> jnp.ndarray:
def __call__(self, indices: np.ndarray) -> np.ndarray:
...

def __call__(self, indices):
a = self.a
b = self.b
length = self.length

was_int = False
if isinstance(indices, jnp.ndarray):
# TODO: use error_if?
# import equinox as eqx
if jnp.any(indices < 0) or jnp.any(indices >= self.length):
raise IndexError(f"index {indices} is out of bounds for length {self.length}")
elif isinstance(indices, np.ndarray):
if isinstance(indices, np.ndarray | jnp.ndarray):
if np.any(indices < 0) or np.any(indices >= self.length):
raise IndexError(f"index {indices} is out of bounds for length {self.length}")
else:
Expand All @@ -64,9 +62,7 @@ def __call__(self, indices):
indices = np.array(indices)
was_int = True

old_settings = np.seterr(over="raise")
out = (self._a * indices + self._b) % self.length
np.seterr(**old_settings)
out = (a * indices + b) % length # Compute permutation on-the-fly

if was_int:
return int(out)
Expand Down
12 changes: 8 additions & 4 deletions src/levanter/data/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ async def getitem_async(self, index: int) -> T_co:

async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:
permutation = await self._get_permutation()
return await self.dataset.get_batch([permutation(i) for i in indices])
return await self.dataset.get_batch(
[int(permutation(i)) for i in indices]
) # cast to int to be sure it's python int

async def _get_permutation(self):
if self._permutation is None:
Expand Down Expand Up @@ -83,10 +85,10 @@ async def gen_era_permutation(era: int) -> Permutation:
# TODO: support epochs
# edge case: final era may be shorter than era_length
current_len = await self.dataset.wait_until_len_at_least((era + 1) * self.era_length)
era_length = min(self.era_length, current_len - era * self.era_length)
era_length_val = min(self.era_length, current_len - era * self.era_length)

mix_key = jax.random.fold_in(key, era)
return Permutation(era_length, mix_key)
return Permutation(era_length_val, mix_key)

self.gen_era_permutation = gen_era_permutation

Expand All @@ -95,7 +97,9 @@ async def _get_index(self, idx: int) -> int:
raise ValueError("Negative indices are not supported")
era = idx // self.era_length
permutation = await self.gen_era_permutation(era)
return permutation(idx - era * self.era_length) + era * self.era_length
out = permutation(idx - era * self.era_length) + era * self.era_length

return out

async def async_len(self) -> int:
return await self.dataset.async_len()
Expand Down
14 changes: 12 additions & 2 deletions tests/test_prp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def test_permutation_creates_valid_instance():
prng_key = jrandom.PRNGKey(0)
permutation = Permutation(length, prng_key)
assert permutation.length == length
assert permutation._a > 0 and permutation._a < length
assert permutation._b >= 0 and permutation._b < length
assert 0 < permutation.a < length
assert 0 <= permutation.b < length


def test_permutation_with_single_index_returns_correct_value():
Expand Down Expand Up @@ -85,3 +85,13 @@ def test_permutation_is_deterministic1():
permutation = Permutation(length, prng_key)
results2 = permutation(indices)
assert jnp.all(results == results2)


def test_permutation_handles_large_length_no_overflow():
large_length = 2**34
prng_key = jrandom.PRNGKey(0)
permutation = Permutation(large_length, prng_key)
index = 2**32 # A large index within the range
result = permutation(index)
assert isinstance(result, int)
assert 0 <= result < large_length
Loading