11from copy import deepcopy
22
3+ import mlx .core as mx
34import numpy as np
45
56from pytensor .link .mlx .dispatch .basic import mlx_funcify
1617
1718
1819def normalize_indices_for_mlx (ilist , idx_list ):
19- """Convert indices to MLX-compatible format.
20-
21- MLX has strict requirements for indexing:
22- - Integer indices must be Python int, not np.int64 or other NumPy integer types
23- - Slice components (start, stop, step) must be Python int or None, not np.int64
24- - MLX arrays created from scalars need to be converted back to Python int
25- - Array indices for advanced indexing are handled separately
26-
27- This function converts all integer-like indices and slice components to Python int
28- while preserving None values and passing through array indices unchanged.
29-
30- This conversion is necessary because MLX's C++ indexing implementation
31- does not recognize NumPy scalar types, raising ValueError when encountered.
32- Additionally, mlx_typify converts NumPy scalars to MLX arrays, which also
33- need to be converted back to Python int for use in indexing operations.
34- Converting to Python int is zero-cost for Python int inputs and minimal
35- overhead for NumPy scalars and MLX scalar arrays.
36-
37- Parameters
38- ----------
39- ilist : tuple
40- Runtime index values to be passed to indices_from_subtensor
41- idx_list : tuple
42- Static index specification from the Op's idx_list attribute
43-
44- Returns
45- -------
46- tuple
47- Normalized indices compatible with MLX array indexing
48-
49- Examples
50- --------
51- >>> # Single np.int64 index converted to Python int
52- >>> normalize_indices_for_mlx((np.int64(1),), (True,))
53- (1,)
54-
55- >>> # Slice with np.int64 components
56- >>> indices = indices_from_subtensor(
57- ... (np.int64(0), np.int64(2)), (slice(None, None),)
58- ... )
59- >>> # After normalization, slice components are Python int
20+ """Convert numpy integers to Python integers for MLX indexing.
21+
22+ MLX requires index values to be Python int, not np.int64 or other NumPy types.
6023 """
61- import mlx .core as mx
6224
6325 def normalize_element (element ):
64- """Convert a single index element to MLX-compatible format."""
6526 if element is None :
66- # None is valid in slices (e.g., x[None:5] or x[:None])
6727 return None
6828 elif isinstance (element , slice ):
69- # Recursively normalize slice components
7029 return slice (
7130 normalize_element (element .start ),
7231 normalize_element (element .stop ),
7332 normalize_element (element .step ),
7433 )
75- elif isinstance (element , mx .array ):
76- # MLX arrays from mlx_typify need special handling
77- # If it's a 0-d array (scalar), convert to Python int/float
78- if element .ndim == 0 :
79- # Extract the scalar value
80- item = element .item ()
81- # Convert to Python int if it's an integer type
82- if element .dtype in (
83- mx .int8 ,
84- mx .int16 ,
85- mx .int32 ,
86- mx .int64 ,
87- mx .uint8 ,
88- mx .uint16 ,
89- mx .uint32 ,
90- mx .uint64 ,
91- ):
92- return int (item )
93- else :
94- return float (item )
95- else :
96- # Multi-dimensional array for advanced indexing - pass through
97- return element
98- elif isinstance (element , (np .integer , np .floating )):
99- # Convert NumPy scalar to Python int/float
100- # This handles np.int64, np.int32, np.float64, etc.
101- return int (element ) if isinstance (element , np .integer ) else float (element )
102- elif isinstance (element , (int , float )):
103- # Python int/float are already compatible
104- return element
34+ elif isinstance (element , mx .array ) and element .ndim == 0 :
35+ return int (element .item ())
36+ elif isinstance (element , np .integer ):
37+ return int (element )
10538 else :
106- # Pass through other types (arrays for advanced indexing, etc.)
10739 return element
10840
109- # Get indices from PyTensor's subtensor utility
110- raw_indices = indices_from_subtensor (ilist , idx_list )
111-
112- # Normalize each index element
113- normalized = tuple (normalize_element (idx ) for idx in raw_indices )
114-
115- return normalized
41+ indices = indices_from_subtensor (ilist , idx_list )
42+ return tuple (normalize_element (idx ) for idx in indices )
11643
11744
11845@mlx_funcify .register (Subtensor )
11946def mlx_funcify_Subtensor (op , node , ** kwargs ):
120- """MLX implementation of Subtensor operation.
121-
122- Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX.
123- """
47+ """MLX implementation of Subtensor."""
12448 idx_list = getattr (op , "idx_list" , None )
12549
12650 def subtensor (x , * ilists ):
@@ -137,11 +61,7 @@ def subtensor(x, *ilists):
13761@mlx_funcify .register (AdvancedSubtensor )
13862@mlx_funcify .register (AdvancedSubtensor1 )
13963def mlx_funcify_AdvancedSubtensor (op , node , ** kwargs ):
140- """MLX implementation of AdvancedSubtensor operation.
141-
142- Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX,
143- including handling np.int64 in mixed basic/advanced indexing scenarios.
144- """
64+ """MLX implementation of AdvancedSubtensor."""
14565 idx_list = getattr (op , "idx_list" , None )
14666
14767 def advanced_subtensor (x , * ilists ):
@@ -158,11 +78,7 @@ def advanced_subtensor(x, *ilists):
15878@mlx_funcify .register (IncSubtensor )
15979@mlx_funcify .register (AdvancedIncSubtensor1 )
16080def mlx_funcify_IncSubtensor (op , node , ** kwargs ):
161- """MLX implementation of IncSubtensor operation.
162-
163- Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX.
164- Handles both set_instead_of_inc=True (assignment) and False (increment).
165- """
81+ """MLX implementation of IncSubtensor."""
16682 idx_list = getattr (op , "idx_list" , None )
16783
16884 if getattr (op , "set_instead_of_inc" , False ):
@@ -195,11 +111,7 @@ def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
195111
196112@mlx_funcify .register (AdvancedIncSubtensor )
197113def mlx_funcify_AdvancedIncSubtensor (op , node , ** kwargs ):
198- """MLX implementation of AdvancedIncSubtensor operation.
199-
200- Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX.
201- Note: For advanced indexing, ilist contains the actual array indices.
202- """
114+ """MLX implementation of AdvancedIncSubtensor."""
203115 idx_list = getattr (op , "idx_list" , None )
204116
205117 if getattr (op , "set_instead_of_inc" , False ):
0 commit comments