Skip to content

Commit 9452257

Browse files
committed
Fix expand_dims type hint
1 parent 42a7adb commit 9452257

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pytensor/tensor/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4369,7 +4369,7 @@ def atleast_Nd(
43694369
atleast_3d = partial(atleast_Nd, n=3)
43704370

43714371

4372-
def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable:
4372+
def expand_dims(a: "TensorLike", axis: Sequence[int] | int) -> TensorVariable:
43734373
"""Expand the shape of an array.
43744374
43754375
Insert a new axis that will appear at the `axis` position in the expanded

0 commit comments

Comments
 (0)