Skip to content

Commit

Permalink
Use int64 for prp (#870)
Browse files Browse the repository at this point in the history
fixes #869
  • Loading branch information
dlwh authored Jan 29, 2025
1 parent 0ad8c54 commit 04a81ca
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 32 deletions.
48 changes: 22 additions & 26 deletions src/levanter/data/_prp.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import typing

import jax.lax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np


# TODO: do we make this a pytree
class Permutation:
# Pseudo-Random Permutation Code
"""A stateless pseudo-random permutation.
This class generates a pseudo-random permutation of a given length. The permutation is generated using a PRNG
with a fixed key. The permutation is generated by finding a random `a` and `b` such that `gcd(a, length) != 1` and
with a fixed key. The permutation is generated by finding a random `a` and `b` such that `gcd(a, length) == 1` and
then computing the permutation as `p(x) = (a * x + b) % length`.
This is not a very good PRP, but it is probably good enough for our purposes.
Expand All @@ -21,40 +19,40 @@ class Permutation:

def __init__(self, length, prng_key):
self.length = length
self.prng_key = prng_key
a_key, b_key = jrandom.split(prng_key)
self._a = jrandom.randint(a_key, (), 1, length)
self._b = jrandom.randint(b_key, (), 0, length)
# Convert jax.random.PRNGKey to numpy.random.Generator
self.rng = np.random.Generator(np.random.PCG64(jrandom.randint(prng_key, (), 0, 2**30).item()))
self.a, self.b = self._generate_permutation_params() # Generate a and b in init

cond = lambda a_and_key: jnp.all(jnp.gcd(a_and_key[0], length) != 1)
def _generate_permutation_params(self):
length = self.length
rng = self.rng

def loop_body(a_and_key):
a, key = a_and_key
this_key, key = jrandom.split(key)
a = jrandom.randint(this_key, (), 1, length)
return a, key
if length == 1:
return 1, 0

self._a, key = jax.lax.while_loop(cond, loop_body, (self._a, a_key))
while True:
a = rng.integers(1, length)
if np.gcd(a, length) == 1:
break

self._a = int(self._a)
self._b = int(self._b)
b = rng.integers(0, length) # b can be in [0, length-1]
return a, b

@typing.overload
def __call__(self, indices: int) -> int:
...

@typing.overload
def __call__(self, indices: jnp.ndarray) -> jnp.ndarray:
def __call__(self, indices: np.ndarray) -> np.ndarray:
...

def __call__(self, indices):
a = self.a
b = self.b
length = self.length

was_int = False
if isinstance(indices, jnp.ndarray):
# TODO: use error_if?
# import equinox as eqx
if jnp.any(indices < 0) or jnp.any(indices >= self.length):
raise IndexError(f"index {indices} is out of bounds for length {self.length}")
elif isinstance(indices, np.ndarray):
if isinstance(indices, np.ndarray | jnp.ndarray):
if np.any(indices < 0) or np.any(indices >= self.length):
raise IndexError(f"index {indices} is out of bounds for length {self.length}")
else:
Expand All @@ -64,9 +62,7 @@ def __call__(self, indices):
indices = np.array(indices)
was_int = True

old_settings = np.seterr(over="raise")
out = (self._a * indices + self._b) % self.length
np.seterr(**old_settings)
out = (a * indices + b) % length # Compute permutation on-the-fly

if was_int:
return int(out)
Expand Down
12 changes: 8 additions & 4 deletions src/levanter/data/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ async def getitem_async(self, index: int) -> T_co:

async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:
permutation = await self._get_permutation()
return await self.dataset.get_batch([permutation(i) for i in indices])
return await self.dataset.get_batch(
[int(permutation(i)) for i in indices]
) # cast to int to be sure it's python int

async def _get_permutation(self):
if self._permutation is None:
Expand Down Expand Up @@ -83,10 +85,10 @@ async def gen_era_permutation(era: int) -> Permutation:
# TODO: support epochs
# edge case: final era may be shorter than era_length
current_len = await self.dataset.wait_until_len_at_least((era + 1) * self.era_length)
era_length = min(self.era_length, current_len - era * self.era_length)
era_length_val = min(self.era_length, current_len - era * self.era_length)

mix_key = jax.random.fold_in(key, era)
return Permutation(era_length, mix_key)
return Permutation(era_length_val, mix_key)

self.gen_era_permutation = gen_era_permutation

Expand All @@ -95,7 +97,9 @@ async def _get_index(self, idx: int) -> int:
raise ValueError("Negative indices are not supported")
era = idx // self.era_length
permutation = await self.gen_era_permutation(era)
return permutation(idx - era * self.era_length) + era * self.era_length
out = permutation(idx - era * self.era_length) + era * self.era_length

return out

async def async_len(self) -> int:
return await self.dataset.async_len()
Expand Down
14 changes: 12 additions & 2 deletions tests/test_prp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def test_permutation_creates_valid_instance():
prng_key = jrandom.PRNGKey(0)
permutation = Permutation(length, prng_key)
assert permutation.length == length
assert permutation._a > 0 and permutation._a < length
assert permutation._b >= 0 and permutation._b < length
assert 0 < permutation.a < length
assert 0 <= permutation.b < length


def test_permutation_with_single_index_returns_correct_value():
Expand Down Expand Up @@ -85,3 +85,13 @@ def test_permutation_is_deterministic1():
permutation = Permutation(length, prng_key)
results2 = permutation(indices)
assert jnp.all(results == results2)


def test_permutation_handles_large_length_no_overflow():
large_length = 2**34
prng_key = jrandom.PRNGKey(0)
permutation = Permutation(large_length, prng_key)
index = 2**32 # A large index within the range
result = permutation(index)
assert isinstance(result, int)
assert 0 <= result < large_length

0 comments on commit 04a81ca

Please sign in to comment.