Skip to content

Commit e826bd6

Browse files
committed
Allow a custom_load function
1 parent c31073c commit e826bd6

File tree

3 files changed

+28
-3
lines changed

3 files changed

+28
-3
lines changed

adaptive_scheduler/_server_support/job_manager.py

+11
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
from .common import MaxRestartsReachedError, log
1111

1212
if TYPE_CHECKING:
13+
from typing import Callable
14+
15+
import adaptive
16+
1317
from adaptive_scheduler.scheduler import BaseScheduler
1418
from adaptive_scheduler.utils import (
1519
_DATAFRAME_FORMATS,
@@ -31,6 +35,7 @@ def command_line_options(
3135
save_dataframe: bool = True,
3236
dataframe_format: _DATAFRAME_FORMATS = "parquet",
3337
loky_start_method: LOKY_START_METHODS = "loky",
38+
custom_load: Callable[[adaptive.BaseLearner, str], None] | None = None,
3439
) -> dict[str, Any]:
3540
"""Return the command line options for the job_script."""
3641
if runner_kwargs is None:
@@ -49,6 +54,9 @@ def command_line_options(
4954
"--save-interval": save_interval,
5055
"--serialized-runner-kwargs": base64_runner_kwargs,
5156
}
57+
if custom_load:
58+
base64_custom_load = _serialize_to_b64(custom_load)
59+
opts["--custom-load"] = base64_custom_load
5260
if scheduler.executor_type == "loky":
5361
opts["--loky-start-method"] = loky_start_method
5462
if save_dataframe:
@@ -106,6 +114,7 @@ def __init__(
106114
save_interval: int | float = 300,
107115
runner_kwargs: dict[str, Any] | None = None,
108116
goal: GoalTypes = None,
117+
custom_load: Callable[[adaptive.BaseLearner, str], None] | None = None,
109118
) -> None:
110119
super().__init__()
111120
self.job_names = job_names
@@ -127,6 +136,7 @@ def __init__(
127136
self.save_interval = save_interval
128137
self.runner_kwargs = runner_kwargs
129138
self.goal = goal
139+
self.custom_load = custom_load
130140

131141
@property
132142
def max_job_starts(self) -> int:
@@ -152,6 +162,7 @@ def _setup(self) -> None:
152162
dataframe_format=self.dataframe_format,
153163
goal=self.goal,
154164
loky_start_method=self.loky_start_method,
165+
custom_load=self.custom_load,
155166
)
156167
self.scheduler.write_job_script(name_prefix, options)
157168

adaptive_scheduler/_server_support/launcher.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def _parse_args() -> argparse.Namespace:
7474
parser.add_argument("--job-id", action="store", type=str)
7575
parser.add_argument("--name", action="store", dest="name", type=str, required=True)
7676
parser.add_argument("--url", action="store", type=str, required=True)
77-
7877
parser.add_argument("--save-dataframe", action="store_true", default=False)
7978
parser.add_argument(
8079
"--dataframe-format",
@@ -103,6 +102,7 @@ def _parse_args() -> argparse.Namespace:
103102
default=120,
104103
)
105104
parser.add_argument("--serialized-runner-kwargs", action="store", type=str)
105+
parser.add_argument("--custom-load", action="store", type=str, default=None)
106106
return parser.parse_args()
107107

108108

@@ -121,8 +121,15 @@ def main() -> None:
121121
if args.executor_type == "process-pool":
122122
learner.function = WrappedFunction(learner.function)
123123

124+
if args.custom_load is not None:
125+
custom_load = _deserialize_from_b64(args.custom_load)
126+
124127
with suppress(Exception):
125-
learner.load(fname)
128+
if args.custom_load is not None:
129+
custom_load(learner, fname)
130+
else:
131+
learner.load(fname)
132+
126133
npoints_start = learner.npoints
127134

128135
executor = _get_executor(

adaptive_scheduler/_server_support/run_manager.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ class RunManager(BaseManager):
6464
runner_kwargs : dict, default: None
6565
Extra keyword argument to pass to the `adaptive.Runner`. Note that this dict
6666
will be serialized and pasted in the ``job_script``.
67+
custom_load : callable, default: None
68+
A function that is called to load the learner. It is called as
69+
``custom_load(learner, fname)``. If None, the learner is loaded with
70+
``learner.load(fname)``.
6771
url : str, default: None
6872
The url of the database manager, with the format
6973
``tcp://ip_of_this_machine:allowed_port.``. If None, a correct url will be chosen.
@@ -149,7 +153,7 @@ class RunManager(BaseManager):
149153
150154
"""
151155

152-
def __init__(
156+
def __init__( # noqa: PLR0915
153157
self,
154158
scheduler: BaseScheduler,
155159
learners: list[adaptive.BaseLearner],
@@ -158,6 +162,7 @@ def __init__(
158162
goal: GoalTypes = None,
159163
check_goal_on_start: bool = True,
160164
runner_kwargs: dict | None = None,
165+
custom_load: Callable[[adaptive.BaseLearner, str], None] | None = None,
161166
url: str | None = None,
162167
save_interval: int | float = 300,
163168
log_interval: int | float = 300,
@@ -185,6 +190,7 @@ def __init__(
185190
self.goal = smart_goal(goal, learners)
186191
self.check_goal_on_start = check_goal_on_start
187192
self.runner_kwargs = runner_kwargs
193+
self.custom_load = custom_load
188194
self.save_interval = save_interval
189195
self.log_interval = log_interval
190196
self.job_name = job_name
@@ -264,6 +270,7 @@ def __init__(
264270
save_interval=self.save_interval,
265271
runner_kwargs=self.runner_kwargs,
266272
goal=self.goal,
273+
custom_load=self.custom_load,
267274
**self.job_manager_kwargs,
268275
)
269276
self.kill_manager: KillManager | None

0 commit comments

Comments
 (0)