Skip to content

Commit c2e38f7

Browse files
[Cherry-pick] Cherry-picking #50210 to 2.42.1 (#50385)
Cherry-picking #50210 to 2.42.1 Signed-off-by: Alexey Kudinkin <[email protected]>
1 parent 91e8ee0 commit c2e38f7

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

python/ray/data/_internal/util.py

-4
Original file line numberDiff line numberDiff line change
@@ -1088,10 +1088,6 @@ def _run_transforming_worker(worker_id: int):
10881088
non_empty_queues.append(output_queue)
10891089
yield item
10901090

1091-
assert (
1092-
non_empty_queues + empty_queues == remaining_output_queues
1093-
), "Exhausted non-trailing queue!"
1094-
10951091
remaining_output_queues = non_empty_queues
10961092

10971093
finally:

python/ray/data/tests/block_batching/test_util.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from ray.data._internal.util import make_async_gen
2020

2121

22+
logger = logging.getLogger(__file__)
23+
24+
2225
def block_generator(num_rows: int, num_blocks: int):
2326
for _ in range(num_blocks):
2427
yield pa.table({"foo": [1] * num_rows})
@@ -131,7 +134,39 @@ def gen(base_iterator):
131134
assert e.match("Fail")
132135

133136

134-
logger = logging.getLogger(__file__)
137+
@pytest.mark.parametrize("buffer_size", [0, 1, 2])
138+
def test_make_async_gen_varying_seq_lengths(buffer_size: int):
139+
"""Tests that iterators of varying lengths are handled appropriately"""
140+
141+
def _gen(base_iterator):
142+
worker_id = next(base_iterator)
143+
144+
# Make workers produce sequences increasing the same order
145+
# as worker-ids (so that for left workers sequences run out first)
146+
target_length = worker_id + 1
147+
148+
return iter([f"worker_{worker_id}:{i}" for i in range(target_length)])
149+
150+
num_seqs = 3
151+
152+
iterator = make_async_gen(
153+
base_iterator=iter(list(range(num_seqs))),
154+
fn=_gen,
155+
# Make sure individual elements are handle by diff workers
156+
num_workers=num_seqs,
157+
queue_buffer_size=buffer_size,
158+
)
159+
160+
seq = list(iterator)
161+
162+
assert [
163+
"worker_0:0",
164+
"worker_1:0",
165+
"worker_2:0",
166+
"worker_1:1",
167+
"worker_2:1",
168+
"worker_2:2",
169+
] == seq
135170

136171

137172
@pytest.mark.parametrize("buffer_size", [0, 1, 2])

0 commit comments

Comments
 (0)