Skip to content

Commit c31073c

Browse files
authored
Implement SimpleDatabase and remove tinydb dependency (#161)
* Implement SimpleDatabase and remove tiny dependency * add metadata to JSON * add clear existing * tests * Fix types * fix str * remove unused classes * Fix database * Remove maybe_list
1 parent 1d1b380 commit c31073c

9 files changed

+293
-148
lines changed

adaptive_scheduler/_mock_scheduler.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
import os
88
import subprocess
9-
from typing import TYPE_CHECKING
9+
from typing import TYPE_CHECKING, List, Tuple, Union
1010

1111
import structlog
1212
import zmq
@@ -25,6 +25,10 @@
2525

2626
DEFAULT_URL = "tcp://127.0.0.1:60547"
2727

28+
_RequestSubmitType = Tuple[str, str, Union[str, List[str]]]
29+
_RequestCancelType = Tuple[str, str]
30+
_RequestQueueType = Tuple[str]
31+
2832

2933
class MockScheduler:
3034
"""Emulates a HPC-like scheduler.
@@ -156,18 +160,18 @@ async def _command_listener(self) -> Coroutine[None, None, None]:
156160

157161
def _dispatch(
158162
self,
159-
request: tuple[str, ...],
163+
request: _RequestSubmitType | _RequestCancelType | _RequestQueueType,
160164
) -> str | None | dict[str, dict[str, Any]] | Exception:
161165
log.debug("got a request", request=request)
162166
request_type, *request_arg = request
163167
try:
164168
if request_type == "submit":
165169
job_name, fname = request_arg
166170
log.debug("submitting a task", fname=fname, job_name=job_name)
167-
job_id = self.submit(job_name, fname)
171+
job_id = self.submit(job_name, fname) # type: ignore[arg-type]
168172
return job_id
169173
if request_type == "cancel":
170-
job_id = request_arg[0]
174+
job_id = request_arg[0] # type: ignore[assignment]
171175
log.debug("got a cancel request", job_id=job_id)
172176
self.cancel(job_id)
173177
return None

adaptive_scheduler/_server_support/database_manager.py

+149-81
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
"""The DatabaseManager."""
22
from __future__ import annotations
33

4+
import json
45
import pickle
6+
from dataclasses import asdict, dataclass, field
57
from pathlib import Path
6-
from typing import TYPE_CHECKING, Any, List, Union
8+
from typing import TYPE_CHECKING, Any, Callable, List, Union
79

810
import pandas as pd
911
import zmq
1012
import zmq.asyncio
1113
import zmq.ssh
12-
from tinydb import Query, TinyDB
1314

1415
from adaptive_scheduler.utils import (
1516
_deserialize,
@@ -26,8 +27,9 @@
2627

2728
from adaptive_scheduler.scheduler import BaseScheduler
2829

29-
ctx = zmq.asyncio.Context()
3030

31+
ctx = zmq.asyncio.Context()
32+
FnameType = Union[str, Path, List[str], List[Path]]
3133
FnamesTypes = Union[List[str], List[Path], List[List[str]], List[List[Path]]]
3234

3335

@@ -56,6 +58,72 @@ def _ensure_str(
5658
raise ValueError(msg)
5759

5860

61+
@dataclass
62+
class _DBEntry:
63+
fname: str | list[str]
64+
job_id: str | None = None
65+
is_done: bool = False
66+
log_fname: str | None = None
67+
job_name: str | None = None
68+
output_logs: list[str] = field(default_factory=list)
69+
start_time: float | None = None
70+
71+
72+
class SimpleDatabase:
73+
def __init__(self, db_fname: str | Path, *, clear_existing: bool = False) -> None:
74+
self.db_fname = Path(db_fname)
75+
self._data: list[_DBEntry] = []
76+
self._meta: dict[str, Any] = {}
77+
78+
if self.db_fname.exists():
79+
if clear_existing:
80+
self.db_fname.unlink()
81+
else:
82+
with self.db_fname.open() as f:
83+
raw_data = json.load(f)
84+
self._data = [_DBEntry(**entry) for entry in raw_data["data"]]
85+
86+
def all(self) -> list[_DBEntry]: # noqa: A003
87+
return self._data
88+
89+
def insert_multiple(self, entries: list[_DBEntry]) -> None:
90+
self._data.extend(entries)
91+
self._save()
92+
93+
def update(self, update_dict: dict, indices: list[int] | None = None) -> None:
94+
for index, entry in enumerate(self._data):
95+
if indices is None or index in indices:
96+
for key, value in update_dict.items():
97+
assert hasattr(entry, key)
98+
setattr(entry, key, value)
99+
self._save()
100+
101+
def count(self, condition: Callable[[_DBEntry], bool]) -> int:
102+
return sum(1 for entry in self._data if condition(entry))
103+
104+
def get(self, condition: Callable[[_DBEntry], bool]) -> _DBEntry | None:
105+
for entry in self._data:
106+
if condition(entry):
107+
return entry
108+
return None
109+
110+
def get_all(
111+
self,
112+
condition: Callable[[_DBEntry], bool],
113+
) -> list[tuple[int, _DBEntry]]:
114+
return [(i, entry) for i, entry in enumerate(self._data) if condition(entry)]
115+
116+
def contains(self, condition: Callable[[_DBEntry], bool]) -> bool:
117+
return any(condition(entry) for entry in self._data)
118+
119+
def as_dicts(self) -> list[dict[str, Any]]:
120+
return [asdict(entry) for entry in self._data]
121+
122+
def _save(self) -> None:
123+
with self.db_fname.open("w") as f:
124+
json.dump({"data": self.as_dicts(), "meta": self._meta}, f)
125+
126+
59127
class DatabaseManager(BaseManager):
60128
"""Database manager.
61129
@@ -100,20 +168,12 @@ def __init__( # noqa: PLR0913
100168
self.fnames = fnames
101169
self.overwrite_db = overwrite_db
102170

103-
self.defaults: dict[str, Any] = {
104-
"job_id": None,
105-
"is_done": False,
106-
"log_fname": None,
107-
"job_name": None,
108-
"output_logs": [],
109-
"start_time": None,
110-
}
111-
112-
self._last_reply: str | Exception | None = None
171+
self._last_reply: str | list[str] | Exception | None = None
113172
self._last_request: tuple[str, ...] | None = None
114173
self.failed: list[dict[str, Any]] = []
115174
self._pickling_time: float | None = None
116175
self._total_learner_size: int | None = None
176+
self._db: SimpleDatabase | None = None
117177

118178
def _setup(self) -> None:
119179
if self.db_fname.exists() and not self.overwrite_db:
@@ -127,24 +187,21 @@ def _setup(self) -> None:
127187

128188
def update(self, queue: dict[str, dict[str, str]] | None = None) -> None:
129189
"""If the ``job_id`` isn't running anymore, replace it with None."""
190+
assert self._db is not None
130191
if queue is None:
131192
queue = self.scheduler.queue(me_only=True)
132-
133-
with TinyDB(self.db_fname) as db:
134-
failed = [
135-
entry
136-
for entry in db.all()
137-
if (entry["job_id"] is not None) and (entry["job_id"] not in queue)
138-
]
139-
self.failed.extend(failed)
140-
doc_ids = [e.doc_id for e in failed]
141-
db.update({"job_id": None, "job_name": None}, doc_ids=doc_ids)
193+
failed = self._db.get_all(
194+
lambda e: (e.job_id is not None) and (e.job_id not in queue), # type: ignore[operator]
195+
)
196+
self.failed.extend([asdict(entry) for _, entry in failed])
197+
indices = [index for index, _ in failed]
198+
self._db.update({"job_id": None, "job_name": None}, indices)
142199

143200
def n_done(self) -> int:
144201
"""Return the number of jobs that are done."""
145-
entry = Query()
146-
with TinyDB(self.db_fname) as db:
147-
return db.count(entry.is_done == True) # noqa: E712
202+
if self._db is None:
203+
return 0
204+
return self._db.count(lambda e: e.is_done)
148205

149206
def is_done(self) -> bool:
150207
"""Return True if all jobs are done."""
@@ -155,18 +212,18 @@ def create_empty_db(self) -> None:
155212
156213
It keeps track of ``fname -> (job_id, is_done, log_fname, job_name)``.
157214
"""
158-
entries = [
159-
dict(fname=_ensure_str(fname), **self.defaults) for fname in self.fnames
215+
entries: list[_DBEntry] = [
216+
_DBEntry(fname=fname) for fname in _ensure_str(self.fnames)
160217
]
161218
if self.db_fname.exists():
162219
self.db_fname.unlink()
163-
with TinyDB(self.db_fname) as db:
164-
db.insert_multiple(entries)
220+
self._db = SimpleDatabase(self.db_fname)
221+
self._db.insert_multiple(entries)
165222

166223
def as_dicts(self) -> list[dict[str, str]]:
167224
"""Return the database as a list of dictionaries."""
168-
with TinyDB(self.db_fname) as db:
169-
return db.all()
225+
assert self._db is not None
226+
return self._db.as_dicts()
170227

171228
def as_df(self) -> pd.DataFrame:
172229
"""Return the database as a `pandas.DataFrame`."""
@@ -180,75 +237,86 @@ def _output_logs(self, job_id: str, job_name: str) -> list[Path]:
180237
for f in output_fnames
181238
]
182239

183-
def _start_request(self, job_id: str, log_fname: str, job_name: str) -> str | None:
184-
entry = Query()
185-
with TinyDB(self.db_fname) as db:
186-
if db.contains(entry.job_id == job_id):
187-
entry = db.get(entry.job_id == job_id)
188-
fname = entry["fname"] # already running
189-
msg = (
190-
f"The job_id {job_id} already exists in the database and "
191-
f"runs {fname}. You might have forgotten to use the "
192-
"`if __name__ == '__main__': ...` idom in your code. Read the "
193-
"warning in the [mpi4py](https://bit.ly/2HAk0GG) documentation.",
194-
)
195-
raise JobIDExistsInDbError(msg)
196-
entry = db.get(
197-
(entry.job_id == None) & (entry.is_done == False), # noqa: E711,E712
198-
)
199-
log.debug("choose fname", entry=entry)
200-
if entry is None:
201-
return None
202-
db.update(
203-
{
204-
"job_id": job_id,
205-
"log_fname": log_fname,
206-
"job_name": job_name,
207-
"output_logs": _ensure_str(self._output_logs(job_id, job_name)),
208-
"start_time": _now(),
209-
},
210-
doc_ids=[entry.doc_id],
240+
def _start_request(
241+
self,
242+
job_id: str,
243+
log_fname: str,
244+
job_name: str,
245+
) -> str | list[str] | None:
246+
assert self._db is not None
247+
if self._db.contains(lambda e: e.job_id == job_id):
248+
entry = self._db.get(lambda e: e.job_id == job_id)
249+
assert entry is not None
250+
fname = entry.fname # already running
251+
msg = (
252+
f"The job_id {job_id} already exists in the database and "
253+
f"runs {fname}. You might have forgotten to use the "
254+
"`if __name__ == '__main__': ...` idiom in your code. Read the "
255+
"warning in the [mpi4py](https://bit.ly/2HAk0GG) documentation.",
211256
)
212-
return entry["fname"]
257+
raise JobIDExistsInDbError(msg)
258+
entry = self._db.get(
259+
lambda e: e.job_id is None and not e.is_done,
260+
)
261+
log.debug("choose fname", entry=entry)
262+
if entry is None:
263+
return None
264+
index = self._db.all().index(entry)
265+
self._db.update(
266+
{
267+
"job_id": job_id,
268+
"log_fname": log_fname,
269+
"job_name": job_name,
270+
"output_logs": _ensure_str(self._output_logs(job_id, job_name)),
271+
"start_time": _now(),
272+
},
273+
indices=[index],
274+
)
275+
return _ensure_str(entry.fname) # type: ignore[return-value]
213276

214277
def _stop_request(self, fname: str | list[str] | Path | list[Path]) -> None:
215278
fname_str = _ensure_str(fname)
216-
entry = Query()
217-
with TinyDB(self.db_fname) as db:
218-
reset = {"job_id": None, "is_done": True, "job_name": None}
219-
assert (
220-
db.get(entry.fname == fname_str) is not None
221-
) # make sure the entry exists
222-
db.update(reset, entry.fname == fname_str)
279+
reset = {"job_id": None, "is_done": True, "job_name": None}
280+
assert self._db is not None
281+
entry_indices = [
282+
index for index, _ in self._db.get_all(lambda e: e.fname == fname_str)
283+
]
284+
self._db.update(reset, entry_indices)
223285

224286
def _stop_requests(self, fnames: FnamesTypes) -> None:
225287
# Same as `_stop_request` but optimized for processing many `fnames` at once
288+
assert self._db is not None
226289
fnames_str = {str(fname) for fname in _ensure_str(fnames)}
227-
with TinyDB(self.db_fname) as db:
228-
reset = {"job_id": None, "is_done": True, "job_name": None}
229-
doc_ids = [e.doc_id for e in db.all() if str(e["fname"]) in fnames_str]
230-
db.update(reset, doc_ids=doc_ids)
290+
reset = {"job_id": None, "is_done": True, "job_name": None}
291+
entry_indices = [
292+
index for index, _ in self._db.get_all(lambda e: str(e.fname) in fnames_str)
293+
]
294+
self._db.update(reset, entry_indices)
231295

232-
def _dispatch(self, request: tuple[str, ...]) -> str | Exception | None:
296+
def _dispatch(
297+
self,
298+
request: tuple[str, str | list[str]] | tuple[str],
299+
) -> str | list[str] | Exception | None:
233300
request_type, *request_arg = request
234301
log.debug("got a request", request=request)
235302
try:
236303
if request_type == "start":
237304
# workers send us their slurm ID for us to fill in
238305
job_id, log_fname, job_name = request_arg
239-
kwargs = {
240-
"job_id": job_id,
241-
"log_fname": log_fname,
242-
"job_name": job_name,
243-
}
244306
# give the worker a job and send back the fname to the worker
245-
fname = self._start_request(**kwargs)
307+
fname = self._start_request(job_id, log_fname, job_name) # type: ignore[arg-type]
246308
if fname is None:
247309
# This should never happen because the _manage co-routine
248310
# should have stopped the workers before this happens.
249311
msg = "No more learners to run in the database."
250312
raise RuntimeError(msg) # noqa: TRY301
251-
log.debug("choose a fname", fname=fname, **kwargs)
313+
log.debug(
314+
"choose a fname",
315+
fname=fname,
316+
job_id=job_id,
317+
log_fname=log_fname,
318+
job_name=job_name,
319+
)
252320
return fname
253321
if request_type == "stop":
254322
fname = request_arg[0] # workers send us the fname they were given
@@ -291,7 +359,7 @@ async def _manage(self) -> None:
291359
)
292360
else:
293361
assert self._last_request is not None # for mypy
294-
self._last_reply = self._dispatch(self._last_request)
362+
self._last_reply = self._dispatch(self._last_request) # type: ignore[arg-type]
295363
await socket.send_serialized(self._last_reply, _serialize)
296364
if self.is_done():
297365
break

adaptive_scheduler/client_support.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
_serialize,
1818
fname_to_learner,
1919
log_exception,
20-
maybe_lst,
2120
sleep_unless_task_is_done,
2221
)
2322

@@ -106,7 +105,7 @@ def get_learner(
106105
log.info("got fname and loaded learner")
107106

108107
log.info("picked a learner")
109-
return learner, maybe_lst(fname)
108+
return learner, fname
110109

111110

112111
def tell_done(url: str, fname: str | list[str]) -> None:

0 commit comments

Comments
 (0)