Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize the communication loop #44

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions adaptive_scheduler/client_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import datetime
import logging
import socket
import time
from contextlib import suppress
from typing import Any, Dict, List, Tuple, Union

Expand All @@ -10,15 +11,10 @@
import zmq
from adaptive import AsyncRunner, BaseLearner

from adaptive_scheduler.utils import (
_deserialize,
_get_npoints,
_serialize,
log_exception,
maybe_lst,
)
from adaptive_scheduler.utils import _get_npoints, log_exception, maybe_lst

ctx = zmq.Context()
ctx.linger = 0
logger = logging.getLogger("adaptive_scheduler.client")
logger.setLevel(logging.INFO)
log = structlog.wrap_logger(
Expand Down Expand Up @@ -67,13 +63,13 @@ def get_learner(
"trying to get learner", job_id=job_id, log_fname=log_fname, job_name=job_name
)
with ctx.socket(zmq.REQ) as socket:
socket.setsockopt(zmq.LINGER, 0)
socket.connect(url)
socket.send_serialized(("start", job_id, log_fname, job_name), _serialize)
log.info(f"sent start signal, going to wait 60s for a reply.")
socket.setsockopt(zmq.RCVTIMEO, 60_000) # timeout after 60s
reply = socket.recv_serialized(_deserialize)
log.info("got reply", reply=str(reply))
t_start = time.time()
socket.send_pyobj(("start", job_id, log_fname, job_name))
log.info(f"sent start signal, going to wait 180s for a reply.")
socket.setsockopt(zmq.RCVTIMEO, 180_000) # timeout after 180s
reply = socket.recv_pyobj()
log.info("got reply", reply=str(reply), t_total=time.time() - t_start)
if reply is None:
msg = f"No learners to be run."
exception = RuntimeError(msg)
Expand Down Expand Up @@ -104,10 +100,15 @@ def tell_done(url: str, fname: str) -> None:
log.info("goal reached! 🎉🎊🥳")
with ctx.socket(zmq.REQ) as socket:
socket.connect(url)
socket.send_serialized(("stop", fname), _serialize)
socket.setsockopt(zmq.RCVTIMEO, 10_000) # timeout after 10s
log.info("sent stop signal, going to wait 10s for a reply", fname=fname)
socket.recv_serialized(_deserialize) # Needed because of socket type
t_start = time.time()
socket.send_pyobj(("stop", fname))
socket.setsockopt(zmq.RCVTIMEO, 180_000) # timeout after 19s
log.info(
"sent stop signal, going to wait 180s for a reply",
fname=fname,
t_total=time.time() - t_start,
)
socket.recv_pyobj() # Needed because of socket type


def _get_log_entry(runner: AsyncRunner, npoints_start: int) -> Dict[str, Any]:
Expand Down
2 changes: 0 additions & 2 deletions adaptive_scheduler/run_script.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ import cloudpickle
from adaptive_scheduler import client_support
{%- if executor_type == "mpi4py" %}
import cloudpickle
from mpi4py import MPI
from mpi4py.futures import MPIPoolExecutor
MPI.pickle.__init__(cloudpickle.dumps, cloudpickle.loads)
{% elif executor_type == "ipyparallel" %}
from adaptive_scheduler.utils import connect_to_ipyparallel
{% elif executor_type == "dask-mpi" %}
Expand Down
60 changes: 52 additions & 8 deletions adaptive_scheduler/server_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import glob
import json
import logging
import multiprocessing
import os
import shutil
import socket
Expand All @@ -26,22 +27,37 @@

from adaptive_scheduler.scheduler import BaseScheduler
from adaptive_scheduler.utils import (
_deserialize,
_progress,
_remove_or_move_files,
_serialize,
hash_anything,
load_parallel,
maybe_lst,
)
from adaptive_scheduler.widgets import log_explorer

ctx = zmq.asyncio.Context()
ctx.linger = 0

logger = logging.getLogger("adaptive_scheduler.server")
logger.setLevel(logging.INFO)
log = structlog.wrap_logger(logger)


async def _run_proxy(socket_from, socket_to):
poller = zmq.asyncio.Poller()
poller.register(socket_from, zmq.POLLIN)
poller.register(socket_to, zmq.POLLIN)
while True:
events = await poller.poll()
events = dict(events)
if socket_from in events:
msg = await socket_from.recv_multipart()
await socket_to.send_multipart(msg)
elif socket_to in events:
msg = await socket_to.recv_multipart()
await socket_from.send_multipart(msg)


class MaxRestartsReached(Exception):
"""Jobs can fail instantly because of a error in
your Python code which results jobs being started indefinitely."""
Expand Down Expand Up @@ -116,6 +132,7 @@ def __init__(
):
super().__init__()
self.url = url
self._url_worker = f"inproc://workers-{hash_anything(time.time())}"
self.scheduler = scheduler
self.db_fname = db_fname
self.learners = learners
Expand All @@ -129,6 +146,7 @@ def __init__(
self._last_reply: Union[str, Exception, None] = None
self._last_request: Optional[Tuple[str, ...]] = None
self.failed: List[Dict[str, Any]] = []
self._comm_times: List[Tuple[float, float, float]] = []

def _setup(self) -> None:
if os.path.exists(self.db_fname) and not self.overwrite_db:
Expand Down Expand Up @@ -252,24 +270,47 @@ def _dispatch(self, request: Tuple[str, ...]) -> Union[str, Exception, None]:
except Exception as e:
return e

async def _manage(self) -> None:
"""Database manager co-routine.
async def _manage_worker(self) -> None:
"""Database worker manager co-routine.

Returns
-------
coroutine
"""
log.debug("started database")
socket = ctx.socket(zmq.REP)
socket.bind(self.url)
socket.connect(self._url_worker)
try:
while True:
self._last_request = await socket.recv_serialized(_deserialize)
self._last_request = await socket.recv_pyobj()
t_0 = time.time()
self._last_reply = self._dispatch(self._last_request)
await socket.send_serialized(self._last_reply, _serialize)
t_1 = time.time()
await socket.send_pyobj(self._last_reply)
t_2 = time.time()
self._comm_times.append((t_0, t_1, t_2))
finally:
socket.close()

async def _manage(self) -> None:
"""Database manager co-routine.

Runs multiple instances of ``_manage_worker`` to not be limitted by
slow ``send_serialized`` calls.

Returns
-------
coroutine
"""
clients = ctx.socket(zmq.ROUTER)
clients.bind(self.url)
workers = ctx.socket(zmq.DEALER)
workers.bind(self._url_worker)
proxy_coro = _run_proxy(clients, workers)
ncores = multiprocessing.cpu_count()
worker_coros = [self._manage_worker() for i in range(ncores - 1)]
await asyncio.wait((proxy_coro, *worker_coros))


class JobManager(_BaseManager):
"""Job manager.
Expand Down Expand Up @@ -909,8 +950,8 @@ async def _manage(self) -> None:

def cancel(self) -> None:
"""Cancel the manager tasks and the jobs in the queue."""
self.database_manager.cancel()
self.job_manager.cancel()
self.database_manager.cancel()
self.kill_manager.cancel()
self.scheduler.cancel(self.job_names)
self.task.cancel()
Expand Down Expand Up @@ -1091,6 +1132,8 @@ def _info_html(self) -> str:
n_running = sum(job["state"] in ("RUNNING", "R") for job in jobs)
n_pending = sum(job["state"] in ("PENDING", "Q", "CONFIGURING") for job in jobs)
n_done = sum(job["is_done"] for job in self.database_manager.as_dicts())
comm_time = sum([t2 - t0 for t0, t1, t2 in self.database_manager._comm_times])
comm_time = datetime.timedelta(seconds=comm_time)

status = self.status()
color = {
Expand All @@ -1114,6 +1157,7 @@ def _table_row(i, key, value):
("# pending jobs", f'<font color="orange">{n_pending}</font>'),
("# finished jobs", f'<font color="green">{n_done}</font>'),
("elapsed time", datetime.timedelta(seconds=self.elapsed_time())),
("communication time", comm_time),
]

with suppress(Exception):
Expand Down
9 changes: 0 additions & 9 deletions adaptive_scheduler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import adaptive
import cloudpickle
import numpy as np
import toolz
from adaptive.notebook_integration import in_ipynb
Expand Down Expand Up @@ -521,11 +520,3 @@ def maybe_lst(fname: Union[List[str], str]):
# TinyDB converts tuples to lists
fname = list(fname)
return fname


def _serialize(msg):
return [cloudpickle.dumps(msg)]


def _deserialize(frames):
return cloudpickle.loads(frames[0])