Skip to content

Commit 9ec9af8

Browse files
committed
Implement correct iteration through disjoint enumerated set for infinite set
1 parent dc99dc8 commit 9ec9af8

File tree

1 file changed

+108
-14
lines changed

1 file changed

+108
-14
lines changed

src/sage/sets/disjoint_union_enumerated_sets.py

Lines changed: 108 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -392,32 +392,126 @@ def __iter__(self):
392392
"""
393393
TESTS::
394394
395+
sage: from itertools import islice
395396
sage: U4 = DisjointUnionEnumeratedSets(
396397
....: Family(NonNegativeIntegers(), Permutations))
397-
sage: it = iter(U4)
398-
sage: [next(it), next(it), next(it), next(it), next(it), next(it)]
398+
sage: list(islice(iter(U4), 6))
399399
[[], [1], [1, 2], [2, 1], [1, 2, 3], [1, 3, 2]]
400400
401401
sage: # needs sage.combinat
402402
sage: U4 = DisjointUnionEnumeratedSets(
403403
....: Family(NonNegativeIntegers(), Permutations),
404404
....: keepkey=True, facade=False)
405-
sage: it = iter(U4)
406-
sage: [next(it), next(it), next(it), next(it), next(it), next(it)]
407-
[(0, []), (1, [1]), (2, [1, 2]), (2, [2, 1]), (3, [1, 2, 3]), (3, [1, 3, 2])]
408-
sage: el = next(it); el.parent() == U4
409-
True
410-
sage: el.value == (3, Permutation([2,1,3]))
405+
sage: l = list(islice(iter(U4), 7)); l
406+
[(0, []), (1, [1]), (2, [1, 2]), (2, [2, 1]), (3, [1, 2, 3]), (3, [1, 3, 2]), (3, [2, 1, 3])]
407+
sage: l[-1].parent() is U4
411408
True
409+
410+
Check when both the set of keys and each element set is finite::
411+
412+
sage: list(DisjointUnionEnumeratedSets(
413+
....: Family({1: FiniteEnumeratedSet([1,2,3]),
414+
....: 2: FiniteEnumeratedSet([4,5,6])})))
415+
[1, 2, 3, 4, 5, 6]
416+
417+
Check when the set of keys is finite but each element set is infinite::
418+
419+
sage: list(islice(DisjointUnionEnumeratedSets(
420+
....: Family({1: NonNegativeIntegers(),
421+
....: 2: NonNegativeIntegers()}), keepkey=True), 0, 10))
422+
[(1, 0), (1, 1), (2, 0), (1, 2), (2, 1), (1, 3), (2, 2), (1, 4), (2, 3), (1, 5)]
423+
424+
Check when the set of keys is infinite but each element set is finite::
425+
426+
sage: list(islice(DisjointUnionEnumeratedSets(
427+
....: Family(NonNegativeIntegers(), lambda x: FiniteEnumeratedSet(range(x))),
428+
....: keepkey=True), 0, 20))
429+
[(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2), (4, 0), (4, 1), (4, 2), (4, 3),
430+
(5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (6, 0), (6, 1), (6, 2), (6, 3), (6, 4)]
431+
432+
Check when some element sets are empty (note that if there are infinitely many sets
433+
but only finitely many elements in total, the iteration will hang)::
434+
435+
sage: list(DisjointUnionEnumeratedSets(
436+
....: Family({1: FiniteEnumeratedSet([]),
437+
....: 2: FiniteEnumeratedSet([]),
438+
....: 3: FiniteEnumeratedSet([]),
439+
....: 4: FiniteEnumeratedSet([]),
440+
....: 5: FiniteEnumeratedSet([1,2,3]),
441+
....: 6: FiniteEnumeratedSet([4,5,6])})))
442+
[1, 2, 3, 4, 5, 6]
443+
444+
Check when there's one infinite set and infinitely many finite sets::
445+
446+
sage: list(islice(DisjointUnionEnumeratedSets(
447+
....: Family(NonNegativeIntegers(), lambda x: FiniteEnumeratedSet([]) if x else NonNegativeIntegers())),
448+
....: 0, 10))
449+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
450+
451+
The following cannot be determined to be finite, but the first elements can still be retrieved::
452+
453+
sage: U = DisjointUnionEnumeratedSets(
454+
....: Family(NonNegativeIntegers(), lambda x: FiniteEnumeratedSet([] if x >= 2 else [1, 2])),
455+
....: keepkey=True)
456+
sage: list(U) # not tested
457+
sage: list(islice(iter(U), 5)) # not tested, hangs
458+
sage: list(islice(iter(U), 4))
459+
[(0, 1), (0, 2), (1, 1), (1, 2)]
412460
"""
413-
for k in self._family.keys():
414-
for el in self._family[k]:
461+
def wrap_element(el, k):
462+
nonlocal self
463+
if self._keepkey:
464+
el = (k, el)
465+
if self._facade:
466+
return el
467+
else:
468+
return self.element_class(self, el) # Bypass correctness tests
469+
470+
keys_iter = iter(self._family.keys())
471+
if self._keepkey:
472+
seen_keys = []
473+
el_iters = []
474+
while keys_iter is not None or el_iters:
475+
if keys_iter is not None:
476+
try:
477+
k = next(keys_iter)
478+
except StopIteration:
479+
keys_iter = None
480+
if keys_iter is not None:
481+
el_set = self._family[k]
482+
if el_set.is_finite():
483+
for el in el_set:
484+
yield wrap_element(el, k)
485+
else:
486+
el_iters.append(iter(el_set))
487+
if self._keepkey:
488+
seen_keys.append(k)
489+
any_stopped = False
490+
for i, obj in enumerate(zip(seen_keys, el_iters) if self._keepkey else el_iters):
491+
if self._keepkey:
492+
k, el_iter = obj
493+
else:
494+
k = None
495+
el_iter = obj
496+
try:
497+
el = next(el_iter)
498+
except StopIteration:
499+
el_iters[i] = None
500+
any_stopped = True
501+
continue
502+
yield wrap_element(el, k)
503+
if any_stopped:
415504
if self._keepkey:
416-
el = (k, el)
417-
if self._facade:
418-
yield el
505+
filtered = [*zip(
506+
*[(k, el_iter) for k, el_iter in zip(seen_keys, el_iters) if el_iter is not None])]
507+
if filtered:
508+
seen_keys = list(filtered[0])
509+
el_iters = list(filtered[1])
510+
else:
511+
seen_keys = []
512+
el_iters = []
419513
else:
420-
yield self.element_class(self, el) # Bypass correctness tests
514+
el_iters = [el_iter for el_iter in el_iters if el_iter is not None]
421515

422516
def an_element(self):
423517
"""

0 commit comments

Comments
 (0)