From 12d932172757cf3a59265c55f9fd37be87db851c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 7 Feb 2024 18:19:06 -0700 Subject: [PATCH] Fix selected_indices logic when newaxis comes after the array indices --- ndindex/tests/test_selected_indices.py | 1 + ndindex/tuple.py | 11 ++++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/ndindex/tests/test_selected_indices.py b/ndindex/tests/test_selected_indices.py index af60b9f0..7ebd2ca6 100644 --- a/ndindex/tests/test_selected_indices.py +++ b/ndindex/tests/test_selected_indices.py @@ -8,6 +8,7 @@ from ..integer import Integer from .helpers import ndindices, check_same, short_shapes, prod +@example(([False], None), (1,)) @example((False, slice(0, 10)), (5, 2)) @example((None, True, 0), (5, 2)) @example((slice(0, 10), [0, -1]), (5, 2)) diff --git a/ndindex/tuple.py b/ndindex/tuple.py index 029f70e5..cdf6969e 100644 --- a/ndindex/tuple.py +++ b/ndindex/tuple.py @@ -762,11 +762,11 @@ def _flatten(l): array_indices = [] axis = 0 for i in idx.args: - if i in [None, True]: - continue if i == False: return - if isinstance(i, IntegerArray): + elif i == True: + pass + elif isinstance(i, IntegerArray): array_indices.append(i) else: # Tuples do not support array indices separated by slices, @@ -780,8 +780,9 @@ def _flatten(l): shape, axis=axis)) axis += len(array_indices) array_indices.clear() - iterators.append(i.selected_indices(shape, axis=axis)) - axis += 1 + if i != None: + iterators.append(i.selected_indices(shape, axis=axis)) + axis += 1 if idx.args and isinstance(idx.args[-1], IntegerArray): iterators.append(_zipped_array_indices(array_indices, shape, axis=axis))