Skip to content

Commit eee793b

Browse files
committed
Allow a custom_load function
1 parent c31073c commit eee793b

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

adaptive_scheduler/_server_support/job_manager.py

+8
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
from .common import MaxRestartsReachedError, log
1111

1212
if TYPE_CHECKING:
13+
from typing import Callable
14+
15+
import adaptive
16+
1317
from adaptive_scheduler.scheduler import BaseScheduler
1418
from adaptive_scheduler.utils import (
1519
_DATAFRAME_FORMATS,
@@ -31,6 +35,7 @@ def command_line_options(
3135
save_dataframe: bool = True,
3236
dataframe_format: _DATAFRAME_FORMATS = "parquet",
3337
loky_start_method: LOKY_START_METHODS = "loky",
38+
custom_load: Callable[[adaptive.BaseLearner, str], None] | None = None,
3439
) -> dict[str, Any]:
3540
"""Return the command line options for the job_script."""
3641
if runner_kwargs is None:
@@ -49,6 +54,9 @@ def command_line_options(
4954
"--save-interval": save_interval,
5055
"--serialized-runner-kwargs": base64_runner_kwargs,
5156
}
57+
if custom_load:
58+
base64_custom_load = _serialize_to_b64(custom_load)
59+
opts["--custom-load"] = base64_custom_load
5260
if scheduler.executor_type == "loky":
5361
opts["--loky-start-method"] = loky_start_method
5462
if save_dataframe:

adaptive_scheduler/_server_support/launcher.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def _parse_args() -> argparse.Namespace:
7474
parser.add_argument("--job-id", action="store", type=str)
7575
parser.add_argument("--name", action="store", dest="name", type=str, required=True)
7676
parser.add_argument("--url", action="store", type=str, required=True)
77-
7877
parser.add_argument("--save-dataframe", action="store_true", default=False)
7978
parser.add_argument(
8079
"--dataframe-format",
@@ -103,6 +102,7 @@ def _parse_args() -> argparse.Namespace:
103102
default=120,
104103
)
105104
parser.add_argument("--serialized-runner-kwargs", action="store", type=str)
105+
parser.add_argument("--custom-load", action="store", type=str, default=None)
106106
return parser.parse_args()
107107

108108

@@ -121,8 +121,15 @@ def main() -> None:
121121
if args.executor_type == "process-pool":
122122
learner.function = WrappedFunction(learner.function)
123123

124+
if args.custom_load is not None:
125+
custom_load = _deserialize_from_b64(args.custom_load)
126+
124127
with suppress(Exception):
125-
learner.load(fname)
128+
if args.custom_load is not None:
129+
custom_load(learner, fname)
130+
else:
131+
learner.load(fname)
132+
126133
npoints_start = learner.npoints
127134

128135
executor = _get_executor(

0 commit comments

Comments
 (0)