Skip to content

Commit c91a693

Browse files
authored
BUG: Fix size_per_learner (#257)
* BUG: Fix size_per_learner * Add tests
1 parent 7c6b19c commit c91a693

File tree

2 files changed

+151
-20
lines changed

2 files changed

+151
-20
lines changed

adaptive_scheduler/_executor.py

+45-19
Original file line numberDiff line numberDiff line change
@@ -114,16 +114,25 @@ def _get(self) -> Any | None: # noqa: PLR0911
114114
if self.done():
115115
return super().result(timeout=0)
116116

117-
idx_learner, idx_data = self.task_id
118-
learner, fname = self._learner_and_fname
117+
func_id, global_index = self.task_id
118+
try:
119+
learner_idx, local_index = self.executor._task_mapping[(func_id, global_index)]
120+
except KeyError as e:
121+
msg = "Task mapping not found; finalize() must be called first."
122+
raise RuntimeError(msg) from e
123+
# Now retrieve the correct learner and filename:
124+
run_manager = self.executor._run_manager
125+
assert run_manager is not None, "RunManager not initialized"
126+
learner = run_manager.learners[learner_idx]
127+
fname = run_manager.fnames[learner_idx]
119128

120129
if learner.done():
121-
result = learner.data[idx_data]
130+
result = learner.data[local_index]
122131
self.set_result(result)
123132
return result
124133

125134
assert self.executor._run_manager is not None
126-
last_load_time = self.executor._run_manager._last_load_time.get(idx_learner, 0)
135+
last_load_time = self.executor._run_manager._last_load_time.get(learner_idx, 0)
127136
now = time.monotonic()
128137
time_since_last_load = now - last_load_time
129138
if time_since_last_load < self.min_load_interval:
@@ -141,10 +150,10 @@ def _get(self) -> Any | None: # noqa: PLR0911
141150
learner.load(fname)
142151
self._load_time = time.monotonic() - now
143152
self.min_load_interval = max(1.0, 20.0 * self._load_time)
144-
self.executor._run_manager._last_load_time[idx_learner] = now
153+
self.executor._run_manager._last_load_time[learner_idx] = now
145154

146-
if idx_data in learner.data:
147-
result = learner.data[idx_data]
155+
if local_index in learner.data:
156+
result = learner.data[local_index]
148157
self.set_result(result)
149158
return result
150159
return None
@@ -371,6 +380,7 @@ class SlurmExecutor(AdaptiveSchedulerExecutorBase):
371380
_sequences: dict[Callable[..., Any], list[Any]] = field(default_factory=dict)
372381
_sequence_mapping: dict[Callable[..., Any], int] = field(default_factory=dict)
373382
_run_manager: adaptive_scheduler.RunManager | None = None
383+
_task_mapping: dict[tuple[int, int], tuple[int, int]] = field(default_factory=dict)
374384

375385
def __post_init__(self) -> None:
376386
if self.folder is None:
@@ -390,32 +400,48 @@ def submit(self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any) -> SlurmT
390400
task_id = TaskID(self._sequence_mapping[fn], i)
391401
return SlurmTask(self, task_id)
392402

393-
def _to_learners(self) -> tuple[list[SequenceLearner], list[Path]]:
403+
def _to_learners(
404+
self,
405+
) -> tuple[
406+
list[SequenceLearner],
407+
list[Path],
408+
dict[tuple[int, int], tuple[int, int]],
409+
]:
394410
learners = []
395411
fnames = []
396-
for func, args_kwargs_list in self._sequences.items():
397-
# Chunk the sequence if size_per_learner is specified
412+
task_mapping = {}
413+
learner_idx = 0
414+
for func, args_list in self._sequences.items():
415+
func_id = self._sequence_mapping[func]
416+
# Chunk the sequence if size_per_learner is set; otherwise one chunk.
398417
if self.size_per_learner is not None:
399-
chunked_args_kwargs_list = [
400-
args_kwargs_list[i : i + self.size_per_learner]
401-
for i in range(0, len(args_kwargs_list), self.size_per_learner)
418+
chunked_args = [
419+
args_list[i : i + self.size_per_learner]
420+
for i in range(0, len(args_list), self.size_per_learner)
402421
]
403422
else:
404-
chunked_args_kwargs_list = [args_kwargs_list]
423+
chunked_args = [args_list]
424+
425+
global_index = 0 # global index for tasks of this function
426+
for chunk in chunked_args:
427+
# Map each task in the chunk: global index -> (current learner, local index)
428+
for local_index in range(len(chunk)):
429+
task_mapping[(func_id, global_index)] = (learner_idx, local_index)
430+
global_index += 1
405431

406-
for i, chunk in enumerate(chunked_args_kwargs_list):
407432
learner = SequenceLearner(_SerializableFunctionSplatter(func), chunk)
408433
learners.append(learner)
434+
name = func.__name__ if hasattr(func, "__name__") else "func"
409435
assert isinstance(self.folder, Path)
410-
name = func.__name__ if hasattr(func, "__name__") else ""
411-
fnames.append(self.folder / f"{name}-{i}-{uuid.uuid4().hex}.pickle")
412-
return learners, fnames
436+
fnames.append(self.folder / f"{name}-{learner_idx}-{uuid.uuid4().hex}.pickle")
437+
learner_idx += 1
438+
return learners, fnames, task_mapping
413439

414440
def finalize(self, *, start: bool = True) -> adaptive_scheduler.RunManager | None:
415441
if self._run_manager is not None:
416442
msg = "RunManager already initialized. Create a new SlurmExecutor instance."
417443
raise RuntimeError(msg)
418-
learners, fnames = self._to_learners()
444+
learners, fnames, self._task_mapping = self._to_learners()
419445
if not learners:
420446
return None
421447
assert self.folder is not None

tests/test_slurm_executor.py

+106-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def test_cleanup(executor: SlurmExecutor) -> None:
176176
def test_task_get_before_finalize(executor: SlurmExecutor) -> None:
177177
"""Test that _get before finalize returns None."""
178178
task = executor.submit(example_func, 1.0)
179-
with pytest.raises(AssertionError, match="RunManager not initialized"):
179+
with pytest.raises(RuntimeError, match="Task mapping not found; finalize()"):
180180
task._get()
181181

182182

@@ -272,3 +272,108 @@ async def simulate_result() -> None:
272272
asyncio.create_task(simulate_result()) # noqa: RUF006
273273
result = await task
274274
assert result == 42
275+
276+
277+
@pytest.mark.usefixtures("_mock_slurm_partitions")
278+
@pytest.mark.usefixtures("_mock_slurm_queue")
279+
def test_to_learners_mapping_single_function(tmp_path: Path) -> None:
280+
"""Test that _to_learners creates the correct mapping for a single function."""
281+
executor = SlurmExecutor(folder=tmp_path, size_per_learner=2)
282+
# Submit 5 tasks to example_func so that they are split into chunks of 2.
283+
for i in range(5):
284+
executor.submit(example_func, i)
285+
learners, fnames, mapping = executor._to_learners()
286+
287+
# We expect ceil(5/2) = 3 learners.
288+
assert len(learners) == 3
289+
290+
func_id = executor._sequence_mapping[example_func]
291+
expected_mapping = {
292+
(func_id, 0): (0, 0), # first learner, first task
293+
(func_id, 1): (0, 1), # first learner, second task
294+
(func_id, 2): (1, 0), # second learner, first task
295+
(func_id, 3): (1, 1), # second learner, second task
296+
(func_id, 4): (2, 0), # third learner, first task (only one task in this chunk)
297+
}
298+
assert mapping == expected_mapping
299+
300+
301+
@pytest.mark.usefixtures("_mock_slurm_partitions")
302+
@pytest.mark.usefixtures("_mock_slurm_queue")
303+
def test_finalize_mapping_and_learners(tmp_path: Path) -> None:
304+
"""Test that finalize() sets the task mapping correctly and creates the right number of learners."""
305+
executor = SlurmExecutor(folder=tmp_path, size_per_learner=2)
306+
# Submit 3 tasks to example_func.
307+
for i in range(3):
308+
executor.submit(example_func, i)
309+
310+
rm = executor.finalize(start=False)
311+
# For 3 tasks with chunk size 2:
312+
# - The first chunk (learner 0) has tasks 0 and 1.
313+
# - The second chunk (learner 1) has task 2.
314+
func_id = executor._sequence_mapping[example_func]
315+
expected_mapping = {
316+
(func_id, 0): (0, 0),
317+
(func_id, 1): (0, 1),
318+
(func_id, 2): (1, 0),
319+
}
320+
assert executor._task_mapping == expected_mapping
321+
# Also, the run manager should have 2 learners.
322+
assert isinstance(rm, RunManager)
323+
assert len(rm.learners) == 2
324+
325+
326+
@pytest.mark.usefixtures("_mock_slurm_partitions")
327+
@pytest.mark.usefixtures("_mock_slurm_queue")
328+
def test_task_get_with_chunking(tmp_path: Path) -> None:
329+
"""Test that tasks in different learners retrieve the correct result when using size_per_learner."""
330+
executor = SlurmExecutor(folder=tmp_path, size_per_learner=2, save_interval=1)
331+
# Submit three tasks; with size_per_learner=2, this will produce 2 learners.
332+
task1 = executor.submit(example_func, 42)
333+
task2 = executor.submit(example_func, 43)
334+
task3 = executor.submit(example_func, 44)
335+
rm = executor.finalize(start=False)
336+
337+
# For learner 0 (tasks 0 and 1)
338+
assert isinstance(rm, RunManager)
339+
learner0 = rm.learners[0]
340+
fname0 = rm.fnames[0]
341+
learner0.data[0] = 42
342+
learner0.data[1] = 43
343+
learner0.save(fname0)
344+
# For learner 1 (task 2)
345+
learner1 = rm.learners[1]
346+
fname1 = rm.fnames[1]
347+
learner1.data[0] = 44
348+
learner1.save(fname1)
349+
350+
# _get() should now retrieve the correct values based on the mapping.
351+
assert task1._get() == 42
352+
assert task2._get() == 43
353+
assert task3._get() == 44
354+
355+
356+
@pytest.mark.usefixtures("_mock_slurm_partitions")
357+
@pytest.mark.usefixtures("_mock_slurm_queue")
358+
def test_mapping_multiple_functions(tmp_path: Path) -> None:
359+
"""Test that the mapping is correct when tasks are submitted for multiple functions."""
360+
executor = SlurmExecutor(folder=tmp_path, size_per_learner=2)
361+
# Submit two tasks for example_func and two for another_func.
362+
executor.submit(example_func, 10)
363+
executor.submit(example_func, 20)
364+
executor.submit(another_func, 5)
365+
executor.submit(another_func, 6)
366+
367+
# Directly call _to_learners to examine the mapping.
368+
learners, fnames, mapping = executor._to_learners()
369+
370+
expected_mapping = {
371+
# For example_func: two tasks in one learner (since 2 tasks fit in one chunk).
372+
(executor._sequence_mapping[example_func], 0): (0, 0),
373+
(executor._sequence_mapping[example_func], 1): (0, 1),
374+
# For another_func: two tasks in one learner.
375+
(executor._sequence_mapping[another_func], 0): (1, 0),
376+
(executor._sequence_mapping[another_func], 1): (1, 1),
377+
}
378+
assert mapping == expected_mapping
379+
assert len(learners) == 2

0 commit comments

Comments
 (0)