Skip to content

Commit 769f646

Browse files
committed
add more type annotations
1 parent 8f28615 commit 769f646

File tree

3 files changed

+64
-58
lines changed

3 files changed

+64
-58
lines changed

adaptive_scheduler/client_support.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,20 @@
22
import datetime
33
import socket
44
from contextlib import suppress
5+
from typing import Any, Dict, List, Union, Tuple
56

67
import psutil
78
import structlog
89
import zmq
910

11+
from adaptive import AsyncRunner, BaseLearner
1012
from adaptive_scheduler._scheduler import get_job_id
1113

1214
ctx = zmq.Context()
1315
log = structlog.get_logger("adaptive_scheduler.client")
1416

1517

16-
def get_learner(url, learners, fnames):
18+
def get_learner(url: str, learners: List[BaseLearner], fnames: List[str]) -> None:
1719
"""Get a learner from the database running at `url`.
1820
1921
Parameters
@@ -50,7 +52,7 @@ def get_learner(url, learners, fnames):
5052
fname = reply
5153
log.info(f"got fname")
5254

53-
def maybe_lst(fname):
55+
def maybe_lst(fname: Union[Tuple[str], str]):
5456
if isinstance(fname, tuple):
5557
# TinyDB converts tuples to lists
5658
fname = list(fname)
@@ -67,7 +69,7 @@ def maybe_lst(fname):
6769
return learner, fname
6870

6971

70-
def tell_done(url, fname):
72+
def tell_done(url: str, fname: str) -> None:
7173
"""Tell the database that the learner has reached it's goal.
7274
7375
Parameters
@@ -86,15 +88,15 @@ def tell_done(url, fname):
8688
socket.recv_pyobj() # Needed because of socket type
8789

8890

89-
def _get_npoints(learner):
91+
def _get_npoints(learner: BaseLearner) -> int:
9092
with suppress(AttributeError):
9193
return learner.npoints
9294
with suppress(AttributeError):
9395
# If the Learner is a BalancingLearner
9496
return sum(l.npoints for l in learner.learners)
9597

9698

97-
def _get_log_entry(runner, npoints_start):
99+
def _get_log_entry(runner: AsyncRunner, npoints_start: int) -> Dict[str, Any]:
98100
learner = runner.learner
99101
info = {}
100102
Δt = datetime.timedelta(seconds=runner.elapsed_time())
@@ -118,7 +120,7 @@ def _get_log_entry(runner, npoints_start):
118120
return info
119121

120122

121-
def log_info(runner, interval=300):
123+
def log_info(runner: AsyncRunner, interval=300) -> asyncio.Task:
122124
"""Log info in the job's logfile, similar to `runner.live_info`.
123125
124126
Parameters

adaptive_scheduler/server_support.py

+53-49
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import time
1111
import warnings
1212
from contextlib import suppress
13-
from typing import Any, Callable, Coroutine, Dict, List, Optional, Union
13+
from typing import Any, Callable, Coroutine, Dict, List, Optional, Union, Tuple
1414

1515
import adaptive
1616
import dill
@@ -40,7 +40,7 @@ class MaxRestartsReached(Exception):
4040
your Python code which results jobs being started indefinitely."""
4141

4242

43-
def _dispatch(request, db_fname):
43+
def _dispatch(request: Tuple[str, str], db_fname: str):
4444
request_type, request_arg = request
4545
log.debug("got a request", request=request)
4646
try:
@@ -53,7 +53,7 @@ def _dispatch(request, db_fname):
5353
elif request_type == "stop":
5454
fname = request_arg # workers send us the fname they were given
5555
log.debug("got a stop request", fname=fname)
56-
return _done_with_learner(db_fname, fname) # reset the job_id to None
56+
_done_with_learner(db_fname, fname) # reset the job_id to None
5757
except Exception as e:
5858
return e
5959

@@ -94,7 +94,7 @@ async def manage_database(url: str, db_fname: str) -> Coroutine:
9494
)
9595

9696

97-
def start_database_manager(url: str, db_fname: str):
97+
def start_database_manager(url: str, db_fname: str) -> asyncio.Task:
9898
ioloop = asyncio.get_event_loop()
9999
coro = manage_database(url, db_fname)
100100
return ioloop.create_task(coro)
@@ -148,17 +148,19 @@ def start_database_manager(url: str, db_fname: str):
148148

149149

150150
async def manage_jobs(
151-
job_names,
152-
db_fname,
151+
job_names: List[str],
152+
db_fname: str,
153153
ioloop,
154154
cores=8,
155-
job_script_function=make_job_script,
156-
run_script="run_learner.py",
157-
python_executable=None,
158-
interval=30,
155+
job_script_function: Callable[
156+
[str, int, str, Optional[str]], str
157+
] = make_job_script,
158+
run_script: str = "run_learner.py",
159+
python_executable: Optional[str] = None,
160+
interval: int = 30,
159161
*,
160-
max_simultaneous_jobs=5000,
161-
max_fails_per_job=100,
162+
max_simultaneous_jobs: int = 5000,
163+
max_fails_per_job: int = 100,
162164
) -> Coroutine:
163165
n_started = 0
164166
max_job_starts = max_fails_per_job * len(job_names)
@@ -230,16 +232,18 @@ async def manage_jobs(
230232

231233

232234
def start_job_manager(
233-
job_names,
234-
db_fname,
235-
cores=8,
236-
job_script_function=make_job_script,
237-
run_script="run_learner.py",
238-
python_executable=None,
239-
interval=30,
235+
job_names: List[str],
236+
db_fname: str,
237+
cores: int = 8,
238+
job_script_function: Callable[
239+
[str, int, str, Optional[str]], str
240+
] = make_job_script,
241+
run_script: str = "run_learner.py",
242+
python_executable: Optional[str] = None,
243+
interval: int = 30,
240244
*,
241-
max_simultaneous_jobs=5000,
242-
max_fails_per_job=40,
245+
max_simultaneous_jobs: int = 5000,
246+
max_fails_per_job: int = 40,
243247
) -> asyncio.Task:
244248
ioloop = asyncio.get_event_loop()
245249
coro = manage_jobs(
@@ -275,7 +279,7 @@ def _start_job(name, cores, job_script_function, run_script, python_executable):
275279
time.sleep(0.5)
276280

277281

278-
def get_allowed_url():
282+
def get_allowed_url() -> str:
279283
"""Get an allowed url for the database manager.
280284
281285
Returns
@@ -289,7 +293,7 @@ def get_allowed_url():
289293
return f"tcp://{ip}:{port}"
290294

291295

292-
def create_empty_db(db_fname: str, fnames: List[str]):
296+
def create_empty_db(db_fname: str, fnames: List[str]) -> None:
293297
"""Create an empty database that keeps track of fname -> (job_id, is_done).
294298
295299
Parameters
@@ -312,14 +316,14 @@ def get_database(db_fname: str) -> List[Dict[str, Any]]:
312316
return db.all()
313317

314318

315-
def _update_db(db_fname: str, running: Dict[str, dict]):
319+
def _update_db(db_fname: str, running: Dict[str, dict]) -> None:
316320
"""If the job_id isn't running anymore, replace it with None."""
317321
with TinyDB(db_fname) as db:
318322
doc_ids = [entry.doc_id for entry in db.all() if entry["job_id"] not in running]
319323
db.update({"job_id": None}, doc_ids=doc_ids)
320324

321325

322-
def _choose_fname(db_fname: str, job_id: str):
326+
def _choose_fname(db_fname: str, job_id: str) -> str:
323327
Entry = Query()
324328
with TinyDB(db_fname) as db:
325329
if db.contains(Entry.job_id == job_id):
@@ -339,13 +343,13 @@ def _choose_fname(db_fname: str, job_id: str):
339343
return entry["fname"]
340344

341345

342-
def _done_with_learner(db_fname: str, fname: str):
346+
def _done_with_learner(db_fname: str, fname: str) -> None:
343347
Entry = Query()
344348
with TinyDB(db_fname) as db:
345349
db.update({"job_id": None, "is_done": True}, Entry.fname == fname)
346350

347351

348-
def _get_n_jobs_done(db_fname: str):
352+
def _get_n_jobs_done(db_fname: str) -> int:
349353
Entry = Query()
350354
with TinyDB(db_fname) as db:
351355
return db.count(Entry.is_done == True) # noqa: E711
@@ -440,14 +444,14 @@ def start_kill_manager(
440444

441445

442446
def _make_default_run_script(
443-
url,
444-
learners_file,
445-
save_interval,
446-
log_interval,
447-
goal=None,
448-
runner_kwargs=None,
449-
run_script_fname="run_learner.py",
450-
executor_type="mpi4py",
447+
url: str,
448+
learners_file: str,
449+
save_interval: int,
450+
log_interval: int,
451+
goal: Optional[Callable[[adaptive.BaseLearner], bool]] = None,
452+
runner_kwargs: Optional[Dict[str, Any]] = None,
453+
run_script_fname: str = "run_learner.py",
454+
executor_type: str = "mpi4py",
451455
):
452456
default_runner_kwargs = dict(shutdown_executor=True)
453457
runner_kwargs = dict(default_runner_kwargs, goal=goal, **(runner_kwargs or {}))
@@ -668,8 +672,8 @@ def __init__(
668672
log_file_folder: str = "",
669673
db_fname: str = "running.json",
670674
overwrite_db: bool = True,
671-
start_job_manager_kwargs: Optional[dict] = None,
672-
start_kill_manager_kwargs: Optional[dict] = None,
675+
start_job_manager_kwargs: Optional[Dict[str, Any]] = None,
676+
start_kill_manager_kwargs: Optional[Dict[str, Any]] = None,
673677
):
674678
# Set from arguments
675679
self.run_script = run_script
@@ -814,16 +818,16 @@ def _start_kill_manager(self) -> None:
814818
**self.start_kill_manager_kwargs,
815819
)
816820

817-
def cancel(self):
821+
def cancel(self) -> None:
818822
"""Cancel the manager tasks and the jobs in the queue."""
819823
if self.job_task is not None:
820824
self.job_task.cancel()
821825
self.database_task.cancel()
822826
if self.kill_task is not None:
823827
self.kill_task.cancel()
824-
return cancel(self.job_names)
828+
cancel(self.job_names)
825829

826-
def cleanup(self):
830+
def cleanup(self) -> None:
827831
"""Cleanup the log and batch files.
828832
829833
If the `RunManager` is not running, the ``run_script.py`` file
@@ -838,9 +842,9 @@ def cleanup(self):
838842
running_job_ids = set(queue().keys())
839843
if self.executor_type == "ipyparallel":
840844
_delete_old_ipython_profiles(running_job_ids)
841-
return cleanup_files(self.job_names, log_file_folder=self.log_file_folder)
845+
cleanup_files(self.job_names, log_file_folder=self.log_file_folder)
842846

843-
def parse_log_files(self, only_last=True):
847+
def parse_log_files(self, only_last: bool = True):
844848
"""Parse the log-files and convert it to a `~pandas.core.frame.DataFrame`.
845849
846850
Parameters
@@ -859,7 +863,7 @@ def parse_log_files(self, only_last=True):
859863
self.job_names, only_last, self.db_fname, self.log_file_folder
860864
)
861865

862-
def task_status(self):
866+
def task_status(self) -> None:
863867
r"""Print the stack of the `asyncio.Task`\s."""
864868
if self.job_task is not None:
865869
self.job_task.print_stack()
@@ -872,13 +876,13 @@ def get_database(self) -> List[Dict[str, Any]]:
872876
"""Get the database as a list of dicts."""
873877
return get_database(self.db_fname)
874878

875-
def load_learners(self):
879+
def load_learners(self) -> None:
876880
"""Load the learners in parallel using `adaptive_scheduler.utils.load_parallel`."""
877881
from adaptive_scheduler.utils import load_parallel
878882

879883
load_parallel(self.learners_module.learners, self.learners_module.fnames)
880884

881-
def elapsed_time(self):
885+
def elapsed_time(self) -> float:
882886
"""Total time elapsed since the RunManager was started."""
883887
if not self.is_started:
884888
return 0
@@ -893,7 +897,7 @@ def elapsed_time(self):
893897
end_time = time.time()
894898
return end_time - self.start_time
895899

896-
def status(self):
900+
def status(self) -> str:
897901
"""Return the current status of the RunManager."""
898902
if not self.is_started:
899903
return "not yet started"
@@ -912,7 +916,7 @@ def status(self):
912916
self.end_time = time.time()
913917
return status
914918

915-
def info(self):
919+
def info(self) -> None:
916920
"""Display information about the `RunManager`.
917921
918922
Returns an interactive ipywidget that can be
@@ -958,7 +962,7 @@ def cleanup(_):
958962
)
959963
)
960964

961-
def _info_html(self):
965+
def _info_html(self) -> str:
962966
jobs = [job for job in queue().values() if job["name"] in self.job_names]
963967
n_running = sum(job["state"] in ("RUNNING", "R") for job in jobs)
964968
n_pending = sum(job["state"] in ("PENDING", "Q") for job in jobs)
@@ -1002,5 +1006,5 @@ def _info_html(self):
10021006
</dl>
10031007
"""
10041008

1005-
def _repr_html_(self):
1009+
def _repr_html_(self) -> None:
10061010
return self.info()

adaptive_scheduler/utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -454,15 +454,15 @@ def logs_with_string_or_condition(
454454
return dict(has_string)
455455

456456

457-
def _print_same_line(msg, new_line_end=False):
457+
def _print_same_line(msg: str, new_line_end: bool = False):
458458
msg = msg.strip()
459459
global MAX_LINE_LENGTH
460460
MAX_LINE_LENGTH = max(len(msg), MAX_LINE_LENGTH)
461461
empty_space = max(MAX_LINE_LENGTH - len(msg), 0) * " "
462462
print(msg + empty_space, end="\r" if not new_line_end else "\n")
463463

464464

465-
def _wait_for_successful_ipyparallel_client_start(client, n, timeout):
465+
def _wait_for_successful_ipyparallel_client_start(client, n: int, timeout: int):
466466
from ipyparallel.error import NoEnginesRegistered
467467

468468
n_engines_old = 0
@@ -518,7 +518,7 @@ def connect_to_ipyparallel(
518518
return client
519519

520520

521-
def _get_default_args(func):
521+
def _get_default_args(func: Callable) -> Dict[str, str]:
522522
signature = inspect.signature(func)
523523
return {
524524
k: v.default

0 commit comments

Comments
 (0)