1
1
"""The DatabaseManager."""
2
2
from __future__ import annotations
3
3
4
+ import json
4
5
import pickle
6
+ from dataclasses import asdict , dataclass , field
5
7
from pathlib import Path
6
- from typing import TYPE_CHECKING , Any , List , Union
8
+ from typing import TYPE_CHECKING , Any , Callable , List , Union
7
9
8
10
import pandas as pd
9
11
import zmq
10
12
import zmq .asyncio
11
13
import zmq .ssh
12
- from tinydb import Query , TinyDB
13
14
14
15
from adaptive_scheduler .utils import (
15
16
_deserialize ,
26
27
27
28
from adaptive_scheduler .scheduler import BaseScheduler
28
29
29
- ctx = zmq .asyncio .Context ()
30
30
31
+ ctx = zmq .asyncio .Context ()
32
+ FnameType = Union [str , Path , List [str ], List [Path ]]
31
33
FnamesTypes = Union [List [str ], List [Path ], List [List [str ]], List [List [Path ]]]
32
34
33
35
@@ -56,6 +58,72 @@ def _ensure_str(
56
58
raise ValueError (msg )
57
59
58
60
61
+ @dataclass
62
+ class _DBEntry :
63
+ fname : str | list [str ]
64
+ job_id : str | None = None
65
+ is_done : bool = False
66
+ log_fname : str | None = None
67
+ job_name : str | None = None
68
+ output_logs : list [str ] = field (default_factory = list )
69
+ start_time : float | None = None
70
+
71
+
72
+ class SimpleDatabase :
73
+ def __init__ (self , db_fname : str | Path , * , clear_existing : bool = False ) -> None :
74
+ self .db_fname = Path (db_fname )
75
+ self ._data : list [_DBEntry ] = []
76
+ self ._meta : dict [str , Any ] = {}
77
+
78
+ if self .db_fname .exists ():
79
+ if clear_existing :
80
+ self .db_fname .unlink ()
81
+ else :
82
+ with self .db_fname .open () as f :
83
+ raw_data = json .load (f )
84
+ self ._data = [_DBEntry (** entry ) for entry in raw_data ["data" ]]
85
+
86
+ def all (self ) -> list [_DBEntry ]: # noqa: A003
87
+ return self ._data
88
+
89
+ def insert_multiple (self , entries : list [_DBEntry ]) -> None :
90
+ self ._data .extend (entries )
91
+ self ._save ()
92
+
93
+ def update (self , update_dict : dict , indices : list [int ] | None = None ) -> None :
94
+ for index , entry in enumerate (self ._data ):
95
+ if indices is None or index in indices :
96
+ for key , value in update_dict .items ():
97
+ assert hasattr (entry , key )
98
+ setattr (entry , key , value )
99
+ self ._save ()
100
+
101
+ def count (self , condition : Callable [[_DBEntry ], bool ]) -> int :
102
+ return sum (1 for entry in self ._data if condition (entry ))
103
+
104
+ def get (self , condition : Callable [[_DBEntry ], bool ]) -> _DBEntry | None :
105
+ for entry in self ._data :
106
+ if condition (entry ):
107
+ return entry
108
+ return None
109
+
110
+ def get_all (
111
+ self ,
112
+ condition : Callable [[_DBEntry ], bool ],
113
+ ) -> list [tuple [int , _DBEntry ]]:
114
+ return [(i , entry ) for i , entry in enumerate (self ._data ) if condition (entry )]
115
+
116
+ def contains (self , condition : Callable [[_DBEntry ], bool ]) -> bool :
117
+ return any (condition (entry ) for entry in self ._data )
118
+
119
+ def as_dicts (self ) -> list [dict [str , Any ]]:
120
+ return [asdict (entry ) for entry in self ._data ]
121
+
122
+ def _save (self ) -> None :
123
+ with self .db_fname .open ("w" ) as f :
124
+ json .dump ({"data" : self .as_dicts (), "meta" : self ._meta }, f )
125
+
126
+
59
127
class DatabaseManager (BaseManager ):
60
128
"""Database manager.
61
129
@@ -100,20 +168,12 @@ def __init__( # noqa: PLR0913
100
168
self .fnames = fnames
101
169
self .overwrite_db = overwrite_db
102
170
103
- self .defaults : dict [str , Any ] = {
104
- "job_id" : None ,
105
- "is_done" : False ,
106
- "log_fname" : None ,
107
- "job_name" : None ,
108
- "output_logs" : [],
109
- "start_time" : None ,
110
- }
111
-
112
- self ._last_reply : str | Exception | None = None
171
+ self ._last_reply : str | list [str ] | Exception | None = None
113
172
self ._last_request : tuple [str , ...] | None = None
114
173
self .failed : list [dict [str , Any ]] = []
115
174
self ._pickling_time : float | None = None
116
175
self ._total_learner_size : int | None = None
176
+ self ._db : SimpleDatabase | None = None
117
177
118
178
def _setup (self ) -> None :
119
179
if self .db_fname .exists () and not self .overwrite_db :
@@ -127,24 +187,21 @@ def _setup(self) -> None:
127
187
128
188
def update (self , queue : dict [str , dict [str , str ]] | None = None ) -> None :
129
189
"""If the ``job_id`` isn't running anymore, replace it with None."""
190
+ assert self ._db is not None
130
191
if queue is None :
131
192
queue = self .scheduler .queue (me_only = True )
132
-
133
- with TinyDB (self .db_fname ) as db :
134
- failed = [
135
- entry
136
- for entry in db .all ()
137
- if (entry ["job_id" ] is not None ) and (entry ["job_id" ] not in queue )
138
- ]
139
- self .failed .extend (failed )
140
- doc_ids = [e .doc_id for e in failed ]
141
- db .update ({"job_id" : None , "job_name" : None }, doc_ids = doc_ids )
193
+ failed = self ._db .get_all (
194
+ lambda e : (e .job_id is not None ) and (e .job_id not in queue ), # type: ignore[operator]
195
+ )
196
+ self .failed .extend ([asdict (entry ) for _ , entry in failed ])
197
+ indices = [index for index , _ in failed ]
198
+ self ._db .update ({"job_id" : None , "job_name" : None }, indices )
142
199
143
200
def n_done (self ) -> int :
144
201
"""Return the number of jobs that are done."""
145
- entry = Query ()
146
- with TinyDB ( self . db_fname ) as db :
147
- return db . count (entry . is_done == True ) # noqa: E712
202
+ if self . _db is None :
203
+ return 0
204
+ return self . _db . count (lambda e : e . is_done )
148
205
149
206
def is_done (self ) -> bool :
150
207
"""Return True if all jobs are done."""
@@ -155,18 +212,18 @@ def create_empty_db(self) -> None:
155
212
156
213
It keeps track of ``fname -> (job_id, is_done, log_fname, job_name)``.
157
214
"""
158
- entries = [
159
- dict (fname = _ensure_str ( fname ), ** self . defaults ) for fname in self .fnames
215
+ entries : list [ _DBEntry ] = [
216
+ _DBEntry (fname = fname ) for fname in _ensure_str ( self .fnames )
160
217
]
161
218
if self .db_fname .exists ():
162
219
self .db_fname .unlink ()
163
- with TinyDB (self .db_fname ) as db :
164
- db .insert_multiple (entries )
220
+ self . _db = SimpleDatabase (self .db_fname )
221
+ self . _db .insert_multiple (entries )
165
222
166
223
def as_dicts (self ) -> list [dict [str , str ]]:
167
224
"""Return the database as a list of dictionaries."""
168
- with TinyDB ( self .db_fname ) as db :
169
- return db . all ()
225
+ assert self ._db is not None
226
+ return self . _db . as_dicts ()
170
227
171
228
def as_df (self ) -> pd .DataFrame :
172
229
"""Return the database as a `pandas.DataFrame`."""
@@ -180,75 +237,86 @@ def _output_logs(self, job_id: str, job_name: str) -> list[Path]:
180
237
for f in output_fnames
181
238
]
182
239
183
- def _start_request (self , job_id : str , log_fname : str , job_name : str ) -> str | None :
184
- entry = Query ()
185
- with TinyDB (self .db_fname ) as db :
186
- if db .contains (entry .job_id == job_id ):
187
- entry = db .get (entry .job_id == job_id )
188
- fname = entry ["fname" ] # already running
189
- msg = (
190
- f"The job_id { job_id } already exists in the database and "
191
- f"runs { fname } . You might have forgotten to use the "
192
- "`if __name__ == '__main__': ...` idom in your code. Read the "
193
- "warning in the [mpi4py](https://bit.ly/2HAk0GG) documentation." ,
194
- )
195
- raise JobIDExistsInDbError (msg )
196
- entry = db .get (
197
- (entry .job_id == None ) & (entry .is_done == False ), # noqa: E711,E712
198
- )
199
- log .debug ("choose fname" , entry = entry )
200
- if entry is None :
201
- return None
202
- db .update (
203
- {
204
- "job_id" : job_id ,
205
- "log_fname" : log_fname ,
206
- "job_name" : job_name ,
207
- "output_logs" : _ensure_str (self ._output_logs (job_id , job_name )),
208
- "start_time" : _now (),
209
- },
210
- doc_ids = [entry .doc_id ],
240
+ def _start_request (
241
+ self ,
242
+ job_id : str ,
243
+ log_fname : str ,
244
+ job_name : str ,
245
+ ) -> str | list [str ] | None :
246
+ assert self ._db is not None
247
+ if self ._db .contains (lambda e : e .job_id == job_id ):
248
+ entry = self ._db .get (lambda e : e .job_id == job_id )
249
+ assert entry is not None
250
+ fname = entry .fname # already running
251
+ msg = (
252
+ f"The job_id { job_id } already exists in the database and "
253
+ f"runs { fname } . You might have forgotten to use the "
254
+ "`if __name__ == '__main__': ...` idiom in your code. Read the "
255
+ "warning in the [mpi4py](https://bit.ly/2HAk0GG) documentation." ,
211
256
)
212
- return entry ["fname" ]
257
+ raise JobIDExistsInDbError (msg )
258
+ entry = self ._db .get (
259
+ lambda e : e .job_id is None and not e .is_done ,
260
+ )
261
+ log .debug ("choose fname" , entry = entry )
262
+ if entry is None :
263
+ return None
264
+ index = self ._db .all ().index (entry )
265
+ self ._db .update (
266
+ {
267
+ "job_id" : job_id ,
268
+ "log_fname" : log_fname ,
269
+ "job_name" : job_name ,
270
+ "output_logs" : _ensure_str (self ._output_logs (job_id , job_name )),
271
+ "start_time" : _now (),
272
+ },
273
+ indices = [index ],
274
+ )
275
+ return _ensure_str (entry .fname ) # type: ignore[return-value]
213
276
214
277
def _stop_request (self , fname : str | list [str ] | Path | list [Path ]) -> None :
215
278
fname_str = _ensure_str (fname )
216
- entry = Query ()
217
- with TinyDB (self .db_fname ) as db :
218
- reset = {"job_id" : None , "is_done" : True , "job_name" : None }
219
- assert (
220
- db .get (entry .fname == fname_str ) is not None
221
- ) # make sure the entry exists
222
- db .update (reset , entry .fname == fname_str )
279
+ reset = {"job_id" : None , "is_done" : True , "job_name" : None }
280
+ assert self ._db is not None
281
+ entry_indices = [
282
+ index for index , _ in self ._db .get_all (lambda e : e .fname == fname_str )
283
+ ]
284
+ self ._db .update (reset , entry_indices )
223
285
224
286
def _stop_requests (self , fnames : FnamesTypes ) -> None :
225
287
# Same as `_stop_request` but optimized for processing many `fnames` at once
288
+ assert self ._db is not None
226
289
fnames_str = {str (fname ) for fname in _ensure_str (fnames )}
227
- with TinyDB (self .db_fname ) as db :
228
- reset = {"job_id" : None , "is_done" : True , "job_name" : None }
229
- doc_ids = [e .doc_id for e in db .all () if str (e ["fname" ]) in fnames_str ]
230
- db .update (reset , doc_ids = doc_ids )
290
+ reset = {"job_id" : None , "is_done" : True , "job_name" : None }
291
+ entry_indices = [
292
+ index for index , _ in self ._db .get_all (lambda e : str (e .fname ) in fnames_str )
293
+ ]
294
+ self ._db .update (reset , entry_indices )
231
295
232
- def _dispatch (self , request : tuple [str , ...]) -> str | Exception | None :
296
+ def _dispatch (
297
+ self ,
298
+ request : tuple [str , str | list [str ]] | tuple [str ],
299
+ ) -> str | list [str ] | Exception | None :
233
300
request_type , * request_arg = request
234
301
log .debug ("got a request" , request = request )
235
302
try :
236
303
if request_type == "start" :
237
304
# workers send us their slurm ID for us to fill in
238
305
job_id , log_fname , job_name = request_arg
239
- kwargs = {
240
- "job_id" : job_id ,
241
- "log_fname" : log_fname ,
242
- "job_name" : job_name ,
243
- }
244
306
# give the worker a job and send back the fname to the worker
245
- fname = self ._start_request (** kwargs )
307
+ fname = self ._start_request (job_id , log_fname , job_name ) # type: ignore[arg-type]
246
308
if fname is None :
247
309
# This should never happen because the _manage co-routine
248
310
# should have stopped the workers before this happens.
249
311
msg = "No more learners to run in the database."
250
312
raise RuntimeError (msg ) # noqa: TRY301
251
- log .debug ("choose a fname" , fname = fname , ** kwargs )
313
+ log .debug (
314
+ "choose a fname" ,
315
+ fname = fname ,
316
+ job_id = job_id ,
317
+ log_fname = log_fname ,
318
+ job_name = job_name ,
319
+ )
252
320
return fname
253
321
if request_type == "stop" :
254
322
fname = request_arg [0 ] # workers send us the fname they were given
@@ -291,7 +359,7 @@ async def _manage(self) -> None:
291
359
)
292
360
else :
293
361
assert self ._last_request is not None # for mypy
294
- self ._last_reply = self ._dispatch (self ._last_request )
362
+ self ._last_reply = self ._dispatch (self ._last_request ) # type: ignore[arg-type]
295
363
await socket .send_serialized (self ._last_reply , _serialize )
296
364
if self .is_done ():
297
365
break
0 commit comments