Skip to content

Commit 9632ad6

Browse files
committed
Simplify based on Ricardo input
1 parent 4f7ae9f commit 9632ad6

File tree

1 file changed

+14
-102
lines changed

1 file changed

+14
-102
lines changed

pytensor/link/mlx/dispatch/subtensor.py

Lines changed: 14 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from copy import deepcopy
22

3+
import mlx.core as mx
34
import numpy as np
45

56
from pytensor.link.mlx.dispatch.basic import mlx_funcify
@@ -16,111 +17,34 @@
1617

1718

1819
def normalize_indices_for_mlx(ilist, idx_list):
19-
"""Convert indices to MLX-compatible format.
20-
21-
MLX has strict requirements for indexing:
22-
- Integer indices must be Python int, not np.int64 or other NumPy integer types
23-
- Slice components (start, stop, step) must be Python int or None, not np.int64
24-
- MLX arrays created from scalars need to be converted back to Python int
25-
- Array indices for advanced indexing are handled separately
26-
27-
This function converts all integer-like indices and slice components to Python int
28-
while preserving None values and passing through array indices unchanged.
29-
30-
This conversion is necessary because MLX's C++ indexing implementation
31-
does not recognize NumPy scalar types, raising ValueError when encountered.
32-
Additionally, mlx_typify converts NumPy scalars to MLX arrays, which also
33-
need to be converted back to Python int for use in indexing operations.
34-
Converting to Python int is zero-cost for Python int inputs and minimal
35-
overhead for NumPy scalars and MLX scalar arrays.
36-
37-
Parameters
38-
----------
39-
ilist : tuple
40-
Runtime index values to be passed to indices_from_subtensor
41-
idx_list : tuple
42-
Static index specification from the Op's idx_list attribute
43-
44-
Returns
45-
-------
46-
tuple
47-
Normalized indices compatible with MLX array indexing
48-
49-
Examples
50-
--------
51-
>>> # Single np.int64 index converted to Python int
52-
>>> normalize_indices_for_mlx((np.int64(1),), (True,))
53-
(1,)
54-
55-
>>> # Slice with np.int64 components
56-
>>> indices = indices_from_subtensor(
57-
... (np.int64(0), np.int64(2)), (slice(None, None),)
58-
... )
59-
>>> # After normalization, slice components are Python int
20+
"""Convert numpy integers to Python integers for MLX indexing.
21+
22+
MLX requires index values to be Python int, not np.int64 or other NumPy types.
6023
"""
61-
import mlx.core as mx
6224

6325
def normalize_element(element):
64-
"""Convert a single index element to MLX-compatible format."""
6526
if element is None:
66-
# None is valid in slices (e.g., x[None:5] or x[:None])
6727
return None
6828
elif isinstance(element, slice):
69-
# Recursively normalize slice components
7029
return slice(
7130
normalize_element(element.start),
7231
normalize_element(element.stop),
7332
normalize_element(element.step),
7433
)
75-
elif isinstance(element, mx.array):
76-
# MLX arrays from mlx_typify need special handling
77-
# If it's a 0-d array (scalar), convert to Python int/float
78-
if element.ndim == 0:
79-
# Extract the scalar value
80-
item = element.item()
81-
# Convert to Python int if it's an integer type
82-
if element.dtype in (
83-
mx.int8,
84-
mx.int16,
85-
mx.int32,
86-
mx.int64,
87-
mx.uint8,
88-
mx.uint16,
89-
mx.uint32,
90-
mx.uint64,
91-
):
92-
return int(item)
93-
else:
94-
return float(item)
95-
else:
96-
# Multi-dimensional array for advanced indexing - pass through
97-
return element
98-
elif isinstance(element, (np.integer, np.floating)):
99-
# Convert NumPy scalar to Python int/float
100-
# This handles np.int64, np.int32, np.float64, etc.
101-
return int(element) if isinstance(element, np.integer) else float(element)
102-
elif isinstance(element, (int, float)):
103-
# Python int/float are already compatible
104-
return element
34+
elif isinstance(element, mx.array) and element.ndim == 0:
35+
return int(element.item())
36+
elif isinstance(element, np.integer):
37+
return int(element)
10538
else:
106-
# Pass through other types (arrays for advanced indexing, etc.)
10739
return element
10840

109-
# Get indices from PyTensor's subtensor utility
110-
raw_indices = indices_from_subtensor(ilist, idx_list)
111-
112-
# Normalize each index element
113-
normalized = tuple(normalize_element(idx) for idx in raw_indices)
114-
115-
return normalized
41+
indices = indices_from_subtensor(ilist, idx_list)
42+
return tuple(normalize_element(idx) for idx in indices)
11643

11744

11845
@mlx_funcify.register(Subtensor)
11946
def mlx_funcify_Subtensor(op, node, **kwargs):
120-
"""MLX implementation of Subtensor operation.
121-
122-
Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX.
123-
"""
47+
"""MLX implementation of Subtensor."""
12448
idx_list = getattr(op, "idx_list", None)
12549

12650
def subtensor(x, *ilists):
@@ -137,11 +61,7 @@ def subtensor(x, *ilists):
13761
@mlx_funcify.register(AdvancedSubtensor)
13862
@mlx_funcify.register(AdvancedSubtensor1)
13963
def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
140-
"""MLX implementation of AdvancedSubtensor operation.
141-
142-
Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX,
143-
including handling np.int64 in mixed basic/advanced indexing scenarios.
144-
"""
64+
"""MLX implementation of AdvancedSubtensor."""
14565
idx_list = getattr(op, "idx_list", None)
14666

14767
def advanced_subtensor(x, *ilists):
@@ -158,11 +78,7 @@ def advanced_subtensor(x, *ilists):
15878
@mlx_funcify.register(IncSubtensor)
15979
@mlx_funcify.register(AdvancedIncSubtensor1)
16080
def mlx_funcify_IncSubtensor(op, node, **kwargs):
161-
"""MLX implementation of IncSubtensor operation.
162-
163-
Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX.
164-
Handles both set_instead_of_inc=True (assignment) and False (increment).
165-
"""
81+
"""MLX implementation of IncSubtensor."""
16682
idx_list = getattr(op, "idx_list", None)
16783

16884
if getattr(op, "set_instead_of_inc", False):
@@ -195,11 +111,7 @@ def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
195111

196112
@mlx_funcify.register(AdvancedIncSubtensor)
197113
def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs):
198-
"""MLX implementation of AdvancedIncSubtensor operation.
199-
200-
Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX.
201-
Note: For advanced indexing, ilist contains the actual array indices.
202-
"""
114+
"""MLX implementation of AdvancedIncSubtensor."""
203115
idx_list = getattr(op, "idx_list", None)
204116

205117
if getattr(op, "set_instead_of_inc", False):

0 commit comments

Comments
 (0)