diff --git a/test/test_gmm.py b/test/test_gmm.py index cd247d6250aa..bbabcf3dd9ef 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -6,7 +6,7 @@ import torch import torch_xla import torch_xla.core.xla_model as xm -from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata +from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram from torch_xla import runtime as xr from torch_xla._internal import tpu @@ -183,6 +183,62 @@ def test_make_group_metadata(self): torch.all(torch.from_numpy(np.array(jax_meta[i])) == torch_meta[i])) self.assertEqual(jax_num_tiles, torch_meta[-1].item()) + def test_histogram(self): + test_grids = [ + { + 'input': [1, 4, 4, 1, 2, 3], + 'min': 1, + 'max': 4, + }, + { + 'input': [1, 4, 4, 1, 2, 3], + 'min': 2, + 'max': 3, + }, + { + 'input': [1, 4, 4, 1, 2, 3], + 'min': 0, + 'max': 5, + }, + { + 'input': [], + 'min': 0, + 'max': 5, + }, + ] + + for test_grid in test_grids: + torch_chart = torch.histc( + torch.tensor(test_grid['input'], dtype=torch.float), + bins=test_grid['max'] - test_grid['min'] + 1, + min=test_grid['min'], + max=test_grid['max'], + ) + + chart, _ = _histogram( + torch.tensor(test_grid['input'], dtype=torch.int32).to("xla"), + min=test_grid['min'], + max=test_grid['max'], + ) + + self.assertTrue(torch.all(torch_chart == chart.cpu())) + + def test_histogram_raise(self): + with self.assertRaisesRegex(AssertionError, + "input must be of torch.int32 dtype."): + _histogram( + torch.tensor([1, 4, 4, 1, 2, 3], dtype=torch.float), + min=4, + max=5, + ) + + with self.assertRaisesRegex(AssertionError, "min must be less than max."): + _histogram( + torch.tensor([1, 4, 4, 1, 2, 3], dtype=torch.int32), + min=4, + max=3, + ) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 9ef61d20fab6..930c512aeaff 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -496,6 +496,22 @@ def _calculate_num_tiles(x: int, tx: int) -> int: return tiles +def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor: + """ + Compute the histogram of a int32 tensor. The bin edges are defined by the min and max values, with step = 1. + """ + assert input.dtype == torch.int32, "input must be of torch.int32 dtype." + assert min < max, "min must be less than max." + + def searchsorted(sorted_sequence: torch.Tensor, + values_to_search: torch.Tensor) -> torch.Tensor: + return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1) + + bin_edges = torch.linspace( + min, max, max - min + 1, dtype=input.dtype).to(input.device) + return searchsorted(bin_edges, input), bin_edges + + # This can only be ran in cpu now as repeat_interleave is not lowered to xla. def _make_group_metadata( *,