|
19 | 19 | from ray.data._internal.util import make_async_gen
|
20 | 20 |
|
21 | 21 |
|
| 22 | +logger = logging.getLogger(__file__) |
| 23 | + |
| 24 | + |
22 | 25 | def block_generator(num_rows: int, num_blocks: int):
|
23 | 26 | for _ in range(num_blocks):
|
24 | 27 | yield pa.table({"foo": [1] * num_rows})
|
@@ -131,7 +134,39 @@ def gen(base_iterator):
|
131 | 134 | assert e.match("Fail")
|
132 | 135 |
|
133 | 136 |
|
134 |
| -logger = logging.getLogger(__file__) |
| 137 | +@pytest.mark.parametrize("buffer_size", [0, 1, 2]) |
| 138 | +def test_make_async_gen_varying_seq_lengths(buffer_size: int): |
| 139 | + """Tests that iterators of varying lengths are handled appropriately""" |
| 140 | + |
| 141 | + def _gen(base_iterator): |
| 142 | + worker_id = next(base_iterator) |
| 143 | + |
| 144 | + # Make workers produce sequences increasing the same order |
| 145 | + # as worker-ids (so that for left workers sequences run out first) |
| 146 | + target_length = worker_id + 1 |
| 147 | + |
| 148 | + return iter([f"worker_{worker_id}:{i}" for i in range(target_length)]) |
| 149 | + |
| 150 | + num_seqs = 3 |
| 151 | + |
| 152 | + iterator = make_async_gen( |
| 153 | + base_iterator=iter(list(range(num_seqs))), |
| 154 | + fn=_gen, |
| 155 | + # Make sure individual elements are handle by diff workers |
| 156 | + num_workers=num_seqs, |
| 157 | + queue_buffer_size=buffer_size, |
| 158 | + ) |
| 159 | + |
| 160 | + seq = list(iterator) |
| 161 | + |
| 162 | + assert [ |
| 163 | + "worker_0:0", |
| 164 | + "worker_1:0", |
| 165 | + "worker_2:0", |
| 166 | + "worker_1:1", |
| 167 | + "worker_2:1", |
| 168 | + "worker_2:2", |
| 169 | + ] == seq |
135 | 170 |
|
136 | 171 |
|
137 | 172 | @pytest.mark.parametrize("buffer_size", [0, 1, 2])
|
|
0 commit comments