diff --git a/src/levanter/data/permutation.py b/src/levanter/data/permutation.py index 04eef3bc0..d7c768bc4 100644 --- a/src/levanter/data/permutation.py +++ b/src/levanter/data/permutation.py @@ -42,7 +42,9 @@ async def getitem_async(self, index: int) -> T_co: async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: permutation = await self._get_permutation() - return await self.dataset.get_batch([int(permutation(i)) for i in indices]) # cast to int to be sure it's python int + return await self.dataset.get_batch( + [int(permutation(i)) for i in indices] + ) # cast to int to be sure it's python int async def _get_permutation(self): if self._permutation is None: