Skip to content

Commit

Permalink
fix padding when num deviecs is 0
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Feb 11, 2025
1 parent 47afb95 commit 2451f64
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/levanter/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def _batchify_local_data(self, batch: _Batch[Ex]) -> Ex:
and creates a global array for each leaf of the example.
"""
cache: dict[tuple[int, int], list[Array | hax.NamedArray]] = {}
padded_batch_size = self.dl._round_batch_size(batch.global_size)
padded_batch_size = self.dl.rounded_batch_size_at_step(batch.index)
Batch = hax.Axis(self.dl.batch_axis_name, padded_batch_size)

def get_local_batch(begin: int, end: int) -> list:
Expand Down

0 comments on commit 2451f64

Please sign in to comment.