Skip to content

Commit c457a13

Browse files
committed
Update subtensor.py
1 parent 9632ad6 commit c457a13

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

pytensor/link/mlx/dispatch/subtensor.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,23 @@ def normalize_element(element):
3232
normalize_element(element.step),
3333
)
3434
elif isinstance(element, mx.array) and element.ndim == 0:
35-
return int(element.item())
35+
try:
36+
return int(element.item())
37+
except (TypeError, ValueError) as e:
38+
raise TypeError(
39+
"MLX backend does not support symbolic indices. "
40+
"Index values must be concrete (constant) integers, not symbolic variables. "
41+
f"Got: {element}"
42+
) from e
3643
elif isinstance(element, np.integer):
37-
return int(element)
44+
try:
45+
return int(element)
46+
except (TypeError, ValueError) as e:
47+
raise TypeError(
48+
"MLX backend does not support symbolic indices. "
49+
"Index values must be concrete (constant) integers, not symbolic variables. "
50+
f"Got: {element}"
51+
) from e
3852
else:
3953
return element
4054

0 commit comments

Comments
 (0)