Skip to content

Commit 42755ec

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 502f325 commit 42755ec

11 files changed

+70
-60
lines changed

example.ipynb

+31-21
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,25 @@
2323
"outputs": [],
2424
"source": [
2525
"import numpy as np\n",
26+
"\n",
2627
"import adaptive_scheduler\n",
27-
"import random\n",
28+
"\n",
2829
"\n",
2930
"def h(x, width=0.01, offset=0):\n",
3031
" for _ in range(10): # Burn some CPU time just because\n",
3132
" np.linalg.eig(np.random.rand(1000, 1000))\n",
32-
" return x + width ** 2 / (width ** 2 + (x - offset) ** 2)\n",
33+
" return x + width**2 / (width**2 + (x - offset) ** 2)\n",
34+
"\n",
3335
"\n",
3436
"# Define the sequence/samples we want to run\n",
3537
"xs = np.linspace(0, 1, 10_000)\n",
3638
"\n",
3739
"# ⚠️ Here a `learner` is an `adaptive` concept, read it as `jobs`.\n",
3840
"# ⚠️ `fnames` are the result locations\n",
3941
"learners, fnames = adaptive_scheduler.utils.split_sequence_in_sequence_learners(\n",
40-
" h, xs, n_learners=10\n",
42+
" h,\n",
43+
" xs,\n",
44+
" n_learners=10,\n",
4145
")\n",
4246
"\n",
4347
"run_manager = adaptive_scheduler.slurm_run(\n",
@@ -48,7 +52,7 @@
4852
" nodes=1, # number of nodes per `learner`\n",
4953
" cores_per_node=1, # number of cores on 1 node per `learner`\n",
5054
" log_interval=5, # how often to produce a log message\n",
51-
" save_interval=5, # how often to save the results\n",
55+
" save_interval=5, # how often to save the results\n",
5256
")\n",
5357
"run_manager.start()"
5458
]
@@ -85,18 +89,18 @@
8589
"from functools import partial\n",
8690
"\n",
8791
"import adaptive\n",
92+
"\n",
8893
"import adaptive_scheduler\n",
8994
"\n",
9095
"\n",
9196
"def h(x, width=0.01, offset=0):\n",
9297
" import numpy as np\n",
93-
" import random\n",
9498
"\n",
9599
" for _ in range(10): # Burn some CPU time just because\n",
96100
" np.linalg.eig(np.random.rand(1000, 1000))\n",
97101
"\n",
98102
" a = width\n",
99-
" return x + a ** 2 / (a ** 2 + (x - offset) ** 2)\n",
103+
" return x + a**2 / (a**2 + (x - offset) ** 2)\n",
100104
"\n",
101105
"\n",
102106
"offsets = [i / 10 - 0.5 for i in range(5)]\n",
@@ -266,16 +270,16 @@
266270
"outputs": [],
267271
"source": [
268272
"import numpy as np\n",
269-
"\n",
270273
"from adaptive import SequenceLearner\n",
271-
"from adaptive_scheduler.utils import split, combo_to_fname\n",
274+
"\n",
275+
"from adaptive_scheduler.utils import split\n",
272276
"\n",
273277
"\n",
274278
"def g(xyz):\n",
275279
" x, y, z = xyz\n",
276280
" for _ in range(5): # Burn some CPU time just because\n",
277281
" np.linalg.eig(np.random.rand(1000, 1000))\n",
278-
" return x ** 2 + y ** 2 + z ** 2\n",
282+
" return x**2 + y**2 + z**2\n",
279283
"\n",
280284
"\n",
281285
"xs = np.linspace(0, 10, 11)\n",
@@ -302,11 +306,17 @@
302306
"\n",
303307
"\n",
304308
"scheduler = adaptive_scheduler.scheduler.DefaultScheduler(\n",
305-
" cores=10, executor_type=\"ipyparallel\",\n",
309+
" cores=10,\n",
310+
" executor_type=\"ipyparallel\",\n",
306311
") # PBS or SLURM\n",
307312
"\n",
308313
"run_manager2 = adaptive_scheduler.server_support.RunManager(\n",
309-
" scheduler, learners, fnames, goal=goal, log_interval=30, save_interval=30,\n",
314+
" scheduler,\n",
315+
" learners,\n",
316+
" fnames,\n",
317+
" goal=goal,\n",
318+
" log_interval=30,\n",
319+
" save_interval=30,\n",
310320
")\n",
311321
"run_manager2.start()"
312322
]
@@ -343,19 +353,19 @@
343353
"outputs": [],
344354
"source": [
345355
"import numpy as np\n",
346-
"\n",
347356
"from adaptive import SequenceLearner\n",
348-
"from adaptive_scheduler.utils import split, combo2fname\n",
349357
"from adaptive.utils import named_product\n",
350358
"\n",
359+
"from adaptive_scheduler.utils import combo2fname\n",
360+
"\n",
351361
"\n",
352362
"def g(combo):\n",
353363
" x, y, z = combo[\"x\"], combo[\"y\"], combo[\"z\"]\n",
354364
"\n",
355365
" for _ in range(5): # Burn some CPU time just because\n",
356366
" np.linalg.eig(np.random.rand(1000, 1000))\n",
357367
"\n",
358-
" return x ** 2 + y ** 2 + z ** 2\n",
368+
" return x**2 + y**2 + z**2\n",
359369
"\n",
360370
"\n",
361371
"combos = named_product(x=np.linspace(0, 10), y=np.linspace(-1, 1), z=np.linspace(-3, 3))\n",
@@ -364,15 +374,15 @@
364374
"\n",
365375
"# We could run this as 1 job with N nodes, but we can also split it up in multiple jobs.\n",
366376
"# This is desireable when you don't want to run a single job with 300 nodes for example.\n",
367-
"# Note that \n",
377+
"# Note that\n",
368378
"# `adaptive_scheduler.utils.split_sequence_in_sequence_learners(g, combos, 100, \"data\")`\n",
369379
"# does the same!\n",
370380
"\n",
371381
"njobs = 100\n",
372382
"split_combos = list(split(combos, njobs))\n",
373383
"\n",
374384
"print(\n",
375-
" f\"Length of split_combos: {len(split_combos)} and length of split_combos[0]: {len(split_combos[0])}.\"\n",
385+
" f\"Length of split_combos: {len(split_combos)} and length of split_combos[0]: {len(split_combos[0])}.\",\n",
376386
")\n",
377387
"\n",
378388
"learners = [SequenceLearner(g, combos_part) for combos_part in split_combos]\n",
@@ -393,17 +403,16 @@
393403
"outputs": [],
394404
"source": [
395405
"from functools import partial\n",
406+
"\n",
396407
"import adaptive_scheduler\n",
397-
"from adaptive_scheduler.scheduler import DefaultScheduler, PBS, SLURM\n",
408+
"from adaptive_scheduler.scheduler import SLURM, DefaultScheduler\n",
398409
"\n",
399410
"\n",
400411
"def goal(learner):\n",
401412
" return learner.done() # the standard goal for a SequenceLearner\n",
402413
"\n",
403414
"\n",
404-
"extra_scheduler = (\n",
405-
" [\"--exclusive\", \"--time=24:00:00\"] if DefaultScheduler is SLURM else []\n",
406-
")\n",
415+
"extra_scheduler = [\"--exclusive\", \"--time=24:00:00\"] if DefaultScheduler is SLURM else []\n",
407416
"\n",
408417
"scheduler = adaptive_scheduler.scheduler.DefaultScheduler(\n",
409418
" cores=10,\n",
@@ -459,7 +468,8 @@
459468
"source": [
460469
"run_manager3.load_learners() # load the data into the learners\n",
461470
"result = sum(\n",
462-
" [l.result() for l in learners], []\n",
471+
" [l.result() for l in learners],\n",
472+
" [],\n",
463473
") # combine all learner's result into 1 list"
464474
]
465475
}

tests/conftest.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@
2323
import zmq.asyncio
2424

2525

26-
@pytest.fixture()
26+
@pytest.fixture
2727
def mock_scheduler(tmp_path: Path) -> MockScheduler:
2828
"""Fixture for creating a MockScheduler instance."""
2929
return MockScheduler(log_folder=str(tmp_path), cores=8)
3030

3131

32-
@pytest.fixture()
32+
@pytest.fixture
3333
def db_manager(
3434
mock_scheduler: MockScheduler,
3535
learners: list[adaptive.Learner1D]
@@ -99,14 +99,14 @@ def fnames(
9999
raise NotImplementedError(msg)
100100

101101

102-
@pytest.fixture()
102+
@pytest.fixture
103103
def socket(db_manager: DatabaseManager) -> zmq.asyncio.Socket:
104104
"""Fixture for creating a ZMQ socket."""
105105
with get_socket(db_manager) as socket:
106106
yield socket
107107

108108

109-
@pytest.fixture()
109+
@pytest.fixture
110110
def job_manager(
111111
db_manager: DatabaseManager,
112112
mock_scheduler: MockScheduler,
@@ -116,7 +116,7 @@ def job_manager(
116116
return JobManager(job_names, db_manager, mock_scheduler, interval=0.05)
117117

118118

119-
@pytest.fixture()
119+
@pytest.fixture
120120
def _mock_slurm_partitions_output() -> Generator[None, None, None]:
121121
"""Mock `slurm_partitions` function."""
122122
mock_output = "hb120v2-low\nhb60-high\nnc24-low*\nnd40v2-mpi\n"
@@ -125,7 +125,7 @@ def _mock_slurm_partitions_output() -> Generator[None, None, None]:
125125
yield
126126

127127

128-
@pytest.fixture()
128+
@pytest.fixture
129129
def _mock_slurm_partitions() -> Generator[None, None, None]:
130130
"""Mock `slurm_partitions` function."""
131131
with (
@@ -141,7 +141,7 @@ def _mock_slurm_partitions() -> Generator[None, None, None]:
141141
yield
142142

143143

144-
@pytest.fixture()
144+
@pytest.fixture
145145
def _mock_slurm_queue() -> Generator[None, None, None]:
146146
"""Mock `SLURM.queue` function."""
147147
with patch(

tests/test_client_support.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def client(zmq_url: str) -> zmq.Socket:
2727
return client
2828

2929

30-
@pytest.mark.asyncio()
30+
@pytest.mark.asyncio
3131
async def test_get_learner(zmq_url: str) -> None:
3232
"""Test `get_learner` function."""
3333
with tempfile.NamedTemporaryFile() as tmpfile:
@@ -94,7 +94,7 @@ async def test_get_learner(zmq_url: str) -> None:
9494
mock_log.exception.assert_called_with("got an exception")
9595

9696

97-
@pytest.mark.asyncio()
97+
@pytest.mark.asyncio
9898
async def test_tell_done(zmq_url: str) -> None:
9999
"""Test `tell_done` function."""
100100
fname = "test_learner_file.pkl"

tests/test_database_manager.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_simple_database_get_all(tmp_path: Path) -> None:
102102
assert done_entries[1][1].fname == "file3.txt"
103103

104104

105-
@pytest.mark.asyncio()
105+
@pytest.mark.asyncio
106106
async def test_database_manager_start_and_cancel(db_manager: DatabaseManager) -> None:
107107
"""Test starting and canceling the DatabaseManager."""
108108
db_manager.start()
@@ -172,7 +172,7 @@ def test_database_manager_as_dicts(
172172
]
173173

174174

175-
@pytest.mark.asyncio()
175+
@pytest.mark.asyncio
176176
async def test_database_manager_dispatch_start_stop(
177177
db_manager: DatabaseManager,
178178
learners: list[adaptive.Learner1D]
@@ -205,7 +205,7 @@ async def test_database_manager_dispatch_start_stop(
205205
assert entry.is_done is True
206206

207207

208-
@pytest.mark.asyncio()
208+
@pytest.mark.asyncio
209209
async def test_database_manager_start_and_update(
210210
socket: zmq.asyncio.Socket,
211211
db_manager: DatabaseManager,
@@ -259,7 +259,7 @@ async def test_database_manager_start_and_update(
259259
assert entry.job_id is None
260260

261261

262-
@pytest.mark.asyncio()
262+
@pytest.mark.asyncio
263263
async def test_database_manager_start_stop(
264264
socket: zmq.asyncio.Socket,
265265
db_manager: DatabaseManager,
@@ -322,7 +322,7 @@ async def test_database_manager_start_stop(
322322
await send_message(socket, start_message)
323323

324324

325-
@pytest.mark.asyncio()
325+
@pytest.mark.asyncio
326326
async def test_database_manager_stop_request_and_requests(
327327
socket: zmq.asyncio.Socket,
328328
db_manager: DatabaseManager,
@@ -531,7 +531,7 @@ def test_ensure_str_invalid_input(invalid_input: list[str]) -> None:
531531
_ensure_str(invalid_input) # type: ignore[arg-type]
532532

533533

534-
@pytest.mark.asyncio()
534+
@pytest.mark.asyncio
535535
async def test_dependencies(
536536
db_manager: DatabaseManager,
537537
fnames: list[str] | list[Path],

tests/test_job_manager.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
from adaptive_scheduler.server_support import JobManager, MaxRestartsReachedError
1313

1414

15-
@pytest.mark.asyncio()
15+
@pytest.mark.asyncio
1616
async def test_job_manager_init(job_manager: JobManager) -> None:
1717
"""Test the initialization of JobManager."""
1818
job_manager.database_manager.start()
1919
job_manager.start()
2020
assert job_manager.task is not None
2121

2222

23-
@pytest.mark.asyncio()
23+
@pytest.mark.asyncio
2424
async def test_job_manager_queued(job_manager: JobManager) -> None:
2525
"""Test the _queued method of JobManager."""
2626
job_manager.scheduler.start_job("job1")
@@ -30,7 +30,7 @@ async def test_job_manager_queued(job_manager: JobManager) -> None:
3030
assert job_manager._queued(job_manager.scheduler.queue()) == {"job1", "job2"}
3131

3232

33-
@pytest.mark.asyncio()
33+
@pytest.mark.asyncio
3434
async def test_job_manager_manage_max_restarts_reached(job_manager: JobManager) -> None:
3535
"""Test the JobManager when the maximum restarts are reached."""
3636
job_manager.n_started = 105
@@ -48,7 +48,7 @@ async def test_job_manager_manage_max_restarts_reached(job_manager: JobManager)
4848
job_manager.task.result()
4949

5050

51-
@pytest.mark.asyncio()
51+
@pytest.mark.asyncio
5252
async def test_job_manager_manage_start_jobs(job_manager: JobManager) -> None:
5353
"""Test the JobManager when managing the start of jobs."""
5454
job_manager.database_manager.n_done = MagicMock(return_value=0) # type: ignore[method-assign]
@@ -60,7 +60,7 @@ async def test_job_manager_manage_start_jobs(job_manager: JobManager) -> None:
6060
assert set(job_manager.scheduler._started_jobs) == {"job1", "job2"} # type: ignore[attr-defined]
6161

6262

63-
@pytest.mark.asyncio()
63+
@pytest.mark.asyncio
6464
async def test_job_manager_manage_start_max_simultaneous_jobs(
6565
job_manager: JobManager,
6666
) -> None:
@@ -76,7 +76,7 @@ async def test_job_manager_manage_start_max_simultaneous_jobs(
7676
assert len(job_manager.scheduler._started_jobs) == 1 # type: ignore[attr-defined]
7777

7878

79-
@pytest.mark.asyncio()
79+
@pytest.mark.asyncio
8080
async def test_job_manager_manage_cancelled_error(
8181
job_manager: JobManager,
8282
caplog: pytest.LogCaptureFixture,
@@ -100,7 +100,7 @@ async def test_job_manager_manage_cancelled_error(
100100
assert "task was cancelled because of a CancelledError" in caplog.text
101101

102102

103-
@pytest.mark.asyncio()
103+
@pytest.mark.asyncio
104104
async def test_job_manager_manage_n_done_equal_job_names(
105105
job_manager: JobManager,
106106
) -> None:
@@ -116,7 +116,7 @@ async def test_job_manager_manage_n_done_equal_job_names(
116116
assert job_manager.task.result() is None
117117

118118

119-
@pytest.mark.asyncio()
119+
@pytest.mark.asyncio
120120
async def test_job_manager_manage_generic_exception(
121121
job_manager: JobManager,
122122
caplog: pytest.LogCaptureFixture,

0 commit comments

Comments
 (0)