From c54bf3b5cc19c02c705b46233800ecd351e49387 Mon Sep 17 00:00:00 2001 From: Pedro Ribeiro <47680931+tubarao312@users.noreply.github.com> Date: Sat, 1 Feb 2025 18:55:11 +0000 Subject: [PATCH] Python SDK Rework (#73) * chore: add Pydantic * refactor: Completely rework task models * docs: Improve types documentation * chore: Add aiohttp-retry lib * refactor: Refactor manager config to use aiohttp-retry * refactor: Rework manager client to only have the required functions and use Pydantic * refactor: Rework publisher client * feat: Add from_str method to TaskKind model * refactor: Slightly improve interface for brokers * docs: Fix error in Task docss * fix: Remove useless broker client interface code * refactor: Rework RabbitMQ clients so there exists one for the worker and one for the publisher * feat: Add durable and auto_delete config to BrokerConfig * chore: dependencies organization * fix: Fix broker to find exchanges and queues based on the published task * feat: Set prefetch count to 1 * docs: Simplify docs for worker app config * feat: Add worker kind registration to manager client * feat: Add worker kind broker info model * feat: Register worker kind on worker app init * refactor: Remove irrelevant fields from broker config * refactor: Remove broker instance create function Doesn't make sense to keep it given that we only support RabbitMQ * refactor: Rework broker clients to be more lightweight and only connect to exchanges and queues instead of creating them * docs: Add RabbitMQ to brokerconfig docs * refactor: Remove broker client interface * chore: add all brokers to __init__ export * rework: Refactor worker client to use new broker clients * chore: Fix inconsistent url name * docs: Update examples * refactor: Refactor publisher client to use new broker client * fix: Fix task status name * chore: Add new deps and settings for tests * chore: Update UV lock * chore: Remove outdated tests * chore: Update package info * fix: Fix multiple pydantic models to use proper pydantic inits * chore: Improve inits for exports * chore: Tune modal config to account for tests * test: Refactor worker client tests * test: Refactor publisher client tests * test: Refactor manager client tests * test: Simplify conftest * test: Temporarily disable benchmarks * fix: Improve pydantic model usage for WorkerApplication * feat: Completely remove task kind model * fix: Fix URL for mock manager client * docs: Fix examples to use new interfaces * docs: Inline worker app config example * docs: Tweak examples * feat: Add health check to local docker compose * chore: Move all tests into unit test folder * test: Add e2e test suite * feat: Move all queue setup logic back to client SDK * fix: Remove worker kind broker info model and references, make task result optional * feat: Create an auto-serializable exception. This must be improved in the future. * fix: Fix Rust queues to be compatible with Python ones * fix: Fix wrong default shutdown signal * feat: Make tasks only get ackowledged AFTER being completed --- .../python/examples/example_producer.py | 21 +- client_sdks/python/examples/example_worker.py | 23 +- client_sdks/python/pyproject.toml | 6 + client_sdks/python/src/broker/__init__.py | 33 +- client_sdks/python/src/broker/client.py | 311 ++++++++++++++++++ client_sdks/python/src/broker/config.py | 8 +- client_sdks/python/src/broker/core.py | 22 -- client_sdks/python/src/broker/rabbitmq.py | 57 ---- client_sdks/python/src/manager/client.py | 180 ++++------ client_sdks/python/src/manager/config.py | 27 +- client_sdks/python/src/models/task.py | 161 ++++----- client_sdks/python/src/publisher/__init__.py | 3 + client_sdks/python/src/publisher/client.py | 94 ++++-- .../python/src/tacoq.egg-info/PKG-INFO | 3 + .../python/src/tacoq.egg-info/SOURCES.txt | 4 +- .../python/src/tacoq.egg-info/requires.txt | 3 + client_sdks/python/src/worker/__init__.py | 4 +- client_sdks/python/src/worker/client.py | 262 ++++++++------- client_sdks/python/src/worker/config.py | 16 +- client_sdks/python/tests/__init__.py | 0 .../tests/benchmarks/test_task_creation.py | 4 + client_sdks/python/tests/conftest.py | 35 +- client_sdks/python/tests/e2e/test_full.py | 143 ++++++++ .../python/tests/manager/test_health_check.py | 13 - .../tests/manager/test_task_management.py | 91 ----- .../tests/manager/test_worker_registry.py | 27 -- .../python/tests/unit/manager/test_health.py | 36 ++ .../python/tests/unit/manager/test_tasks.py | 65 ++++ .../unit/publisher/test_publisher_client.py | 105 ++++++ .../tests/unit/worker/test_worker_client.py | 204 ++++++++++++ .../python/tests/worker/test_worker_client.py | 91 ----- client_sdks/python/uv.lock | 93 ++++++ docker-compose.yml | 5 + server/libs/common/src/brokers/mod.rs | 4 +- server/libs/common/src/brokers/rabbit.rs | 22 +- server/services/manager/src/constants.rs | 6 +- server/services/manager/src/main.rs | 20 +- 37 files changed, 1438 insertions(+), 764 deletions(-) create mode 100644 client_sdks/python/src/broker/client.py delete mode 100644 client_sdks/python/src/broker/core.py delete mode 100644 client_sdks/python/src/broker/rabbitmq.py create mode 100644 client_sdks/python/src/publisher/__init__.py create mode 100644 client_sdks/python/tests/__init__.py create mode 100644 client_sdks/python/tests/e2e/test_full.py delete mode 100644 client_sdks/python/tests/manager/test_health_check.py delete mode 100644 client_sdks/python/tests/manager/test_task_management.py delete mode 100644 client_sdks/python/tests/manager/test_worker_registry.py create mode 100644 client_sdks/python/tests/unit/manager/test_health.py create mode 100644 client_sdks/python/tests/unit/manager/test_tasks.py create mode 100644 client_sdks/python/tests/unit/publisher/test_publisher_client.py create mode 100644 client_sdks/python/tests/unit/worker/test_worker_client.py delete mode 100644 client_sdks/python/tests/worker/test_worker_client.py diff --git a/client_sdks/python/examples/example_producer.py b/client_sdks/python/examples/example_producer.py index 94287de..b4fa0ab 100644 --- a/client_sdks/python/examples/example_producer.py +++ b/client_sdks/python/examples/example_producer.py @@ -1,5 +1,6 @@ import asyncio +from broker.config import BrokerConfig from manager.config import ManagerConfig from publisher.client import PublisherClient @@ -9,21 +10,35 @@ # Setup the manager location configuration manager_config = ManagerConfig(url="http://localhost:3000") +# Setup the broker configuration +broker_config = BrokerConfig(url="amqp://user:password@localhost:5672") + # Both the publisher and the worker need to know about the task kinds and # should have unified names for them. +WORKER_KIND_NAME = "worker_kind" TASK_1_NAME = "task_1" TASK_2_NAME = "task_2" # APPLICATION CONFIGURATION ___________________________________________________ # 1. Create a producer application -worker_application = PublisherClient(manager_config) +worker_application = PublisherClient( + manager_config=manager_config, broker_config=broker_config +) # 2. Start the application async def main(): - task1 = await worker_application.publish_task(TASK_1_NAME, {"data": "task_1_data"}) - task2 = await worker_application.publish_task(TASK_2_NAME, {"data": "task_2_data"}) + task1 = await worker_application.publish_task( + TASK_1_NAME, + WORKER_KIND_NAME, + {"data": "task_1_data"}, + ) + task2 = await worker_application.publish_task( + TASK_2_NAME, + WORKER_KIND_NAME, + {"data": "task_2_data"}, + ) print(f"Task 1: {task1}") print(f"Task 2: {task2}") diff --git a/client_sdks/python/examples/example_worker.py b/client_sdks/python/examples/example_worker.py index 4661143..475c0c0 100644 --- a/client_sdks/python/examples/example_worker.py +++ b/client_sdks/python/examples/example_worker.py @@ -14,23 +14,24 @@ # Both the publisher and the worker need to know about the task kinds and # should have unified names for them. +WORKER_KIND_NAME = "worker_kind" TASK_1_NAME = "task_1" TASK_2_NAME = "task_2" # APPLICATION CONFIGURATION ___________________________________________________ -# 2. Configure the worker -worker_config = WorkerApplicationConfig( - name="test_worker", - manager_config=manager_config, - broker_config=broker_config, +# 1. Create a worker application +worker_application = WorkerApplication( + config=WorkerApplicationConfig( + name="test_worker", + manager_config=manager_config, + broker_config=broker_config, + kind=WORKER_KIND_NAME, + ), ) -# 3. Create a worker application -worker_application = WorkerApplication(worker_config) - -# 4. Create tasks and register them with the worker application +# 2. Create tasks and register them with the worker application @worker_application.task(TASK_1_NAME) async def task_1(input_data: dict[Any, Any]) -> dict[Any, Any]: await asyncio.sleep(1) @@ -38,10 +39,12 @@ async def task_1(input_data: dict[Any, Any]) -> dict[Any, Any]: @worker_application.task(TASK_2_NAME) -async def task_2(input_data: dict[Any, Any]) -> dict[Any, Any]: +async def task_2(_: dict[Any, Any]) -> dict[Any, Any]: raise Exception("This is a test exception") +# 3. Run the worker application + if __name__ == "__main__": # Application can be run either as a standalone script or via the CLI. asyncio.run(worker_application.entrypoint()) diff --git a/client_sdks/python/pyproject.toml b/client_sdks/python/pyproject.toml index e0165d9..d441302 100644 --- a/client_sdks/python/pyproject.toml +++ b/client_sdks/python/pyproject.toml @@ -16,9 +16,12 @@ classifiers = [ requires-python = ">=3.12" dependencies = [ "aio-pika>=9.5.3", + "aiohttp-retry>=2.9.1", "aiohttp>=3.11.8", "aioredis>=2.0.1", + "aioresponses>=0.7.8", "click>=8.1.7", + "pydantic>=2.10.5", "uuid>=1.30", "watchfiles>=1.0.3", ] @@ -28,6 +31,9 @@ asyncio_mode = "auto" pythonpath = ["."] markers = [ "bench: benchmark a function", + "target: current target test. Used locally.", + "unit: unit tests", + "service: tests that interact with external services", ] [build-system] diff --git a/client_sdks/python/src/broker/__init__.py b/client_sdks/python/src/broker/__init__.py index 0488954..aa507ce 100644 --- a/client_sdks/python/src/broker/__init__.py +++ b/client_sdks/python/src/broker/__init__.py @@ -1,28 +1,9 @@ +from broker.client import PublisherBrokerClient, WorkerBrokerClient, BaseBrokerClient from broker.config import BrokerConfig -from broker.core import BrokerClient -# TODO: Import only rabbit if tacoq[amqp] is installed -from broker.rabbitmq import RabbitMQBroker - - -def create_broker_instance( - config: BrokerConfig, exchange_name: str, worker_id: str -) -> BrokerClient: - """Create appropriate broker client based on configuration. - - ### Parameters - - `config`: Configuration for the broker connection - - `exchange_name`: Name of the exchange to use - - `worker_id`: Unique identifier for this worker instance - - ### Returns - - `BrokerClient`: Configured broker client instance - - ### Raises - - `ValueError`: If broker URL scheme is not supported - """ - - if config.url.startswith("amqp"): - return RabbitMQBroker(config, exchange_name, worker_id) - else: - raise ValueError(f"Unsupported broker URL: {config.url}") +__all__ = [ + "PublisherBrokerClient", + "WorkerBrokerClient", + "BaseBrokerClient", + "BrokerConfig", +] diff --git a/client_sdks/python/src/broker/client.py b/client_sdks/python/src/broker/client.py new file mode 100644 index 0000000..087a752 --- /dev/null +++ b/client_sdks/python/src/broker/client.py @@ -0,0 +1,311 @@ +import asyncio +import json +from typing import AsyncGenerator, Optional +from broker.config import BrokerConfig +from aio_pika import Message, connect_robust +from models.task import Task +from pydantic import BaseModel +from logging import warn + +from aio_pika.abc import ( + AbstractChannel, + AbstractQueue, + AbstractRobustConnection, + AbstractExchange, +) + +# ========================================= +# Constants - Exchange and queue names +# ========================================= + +# Publisher + +TASK_ASSIGNMENT_EXCHANGE = "task_assignment_exchange" +""" Exchange for task assignments. Used by the publisher to +send tasks to the manager and workers. """ + +TASK_ASSIGNMENT_QUEUE = "task_assignment_queue" +""" Queue for task assignments. Used by the publisher to send +tasks to the manager and workers. """ + +MANAGER_ROUTING_KEY = "tasks.#" # Matches all tasks +""" Routing key for the manager queue. Receives all tasks to +save them. """ + +WORKER_ROUTING_KEY_PREFIX = "tasks.{worker_kind}" # Will be combined with worker_kind +""" Routing key for worker queues. Only workers of a +specific kind will receive these tasks. """ + + +def get_worker_routing_key(worker_kind: str) -> str: + """Get the routing key for a worker kind. + + ### Args: + - `worker_kind`: The kind of worker to get the routing key for. + """ + return WORKER_ROUTING_KEY_PREFIX.format(worker_kind=worker_kind) + + +# Worker + +TASK_RESULT_EXCHANGE = "task_result_exchange" +""" Exchange for task results. Used by all workers to +publish their results. """ + +TASK_RESULT_QUEUE = "task_results" +""" Queue for task results. Used by all workers to publish their results. """ + + +# ========================================= +# Errors +# ========================================= + + +class NoChannelError(Exception): + """Raised when a RabbitMQ client is not connected to the broker while + trying to perform an operation that requires a channel.""" + + pass + + +class NotConnectedError(Exception): + """Raised when a RabbitMQ client is not connected to the broker while + trying to perform an operation that requires a connection.""" + + pass + + +class QueueNotDeclaredError(Exception): + """Raised when a RabbitMQ client tries to use a queue that has not been + declared.""" + + pass + + +class ExchangeNotDeclaredError(Exception): + """Raised when a RabbitMQ client tries to use an exchange that has not been + declared.""" + + pass + + +## ========================================= +## Base Client +## ========================================= + + +class BaseBrokerClient(BaseModel): + """RabbitMQ implementation of the broker interface.""" + + config: BrokerConfig + """ Configuration for the broker. """ + + _connection: Optional[AbstractRobustConnection] = None + """ The connection to the RabbitMQ server. """ + + _channel: Optional[AbstractChannel] = None + """ The channel to the RabbitMQ server. """ + + async def connect(self) -> None: + """Establish connection to RabbitMQ server and setup channel. + + ### Raises + - `ConnectionError`: If connection to RabbitMQ fails + """ + + self._connection = await connect_robust(self.config.url) + self._channel = await self._connection.channel() + + async def disconnect(self) -> None: + """Close RabbitMQ connection. + + ### Raises + - `RabbitMQNotConnectedError`: If connection is not established + """ + + if self._connection is None: + raise NotConnectedError( + "Tried to disconnect from RabbitMQ, but connection was not established." + ) + + # Remove the exchanges + await self._connection.close() + + +## ========================================= +## Publisher Client +## ========================================= + + +class PublisherBrokerClient(BaseBrokerClient): + """RabbitMQ client for publishing tasks to workers. + Uses a fanout exchange to send tasks to both the manager queue + and the appropriate worker kind queue.""" + + _task_exchange: Optional[AbstractExchange] = None + """ The exchange for task assignments. """ + + _binded_worker_queues: set[str] = set() + """ The queues that have been binded to the exchange. We keep track of them + so we don't have to bind them again every time we submit a new task. """ + + async def connect(self) -> None: + await super().connect() + + if self._channel is None: + raise NoChannelError( + "Tried to connect to RabbitMQ, but channel was not established." + ) + + # Declare a topic exchange instead of the default direct + self._task_exchange = await self._channel.declare_exchange( + TASK_ASSIGNMENT_EXCHANGE, type="topic", durable=True + ) + + # Declare both queues without binding to them + manager_queue = await self._channel.declare_queue( + TASK_ASSIGNMENT_QUEUE, durable=True + ) + await manager_queue.bind( + TASK_ASSIGNMENT_EXCHANGE, routing_key=MANAGER_ROUTING_KEY + ) + + async def _declare_worker_queue(self, worker_kind: str) -> None: + """Declare a worker queue and bind it to the exchange.""" + + if worker_kind in self._binded_worker_queues: + return + + if self._channel is None: + raise NoChannelError( + "Tried to declare worker queue, but channel was not established." + ) + + worker_queue = await self._channel.declare_queue(worker_kind, durable=True) + await worker_queue.bind( + TASK_ASSIGNMENT_EXCHANGE, routing_key=get_worker_routing_key(worker_kind) + ) + self._binded_worker_queues.add(worker_kind) + + async def publish_task(self, task: Task) -> None: + """Publish a task to both manager and worker queues via exchange and routing mechanisms. + + ### Arguments + - `routing_key`: The routing key for the task. This is based on the worker kind. The publisher + client will know the routing key based on the requests it has made to the manager, who creates + the queues and binds them to the exchange. + + ### Raises + - `RuntimeError`: If the exchange was not declared. + """ + + if self._channel is None: + await self.connect() + + await self._declare_worker_queue(task.worker_kind) + + if self._task_exchange is None: + raise RuntimeError("Tried to publish task, but exchange was not declared.") + + message = Message(body=task.model_dump_json().encode()) + + await self._task_exchange.publish( + message, routing_key=get_worker_routing_key(task.worker_kind) + ) + + +## ========================================= +## Worker Client +## ========================================= + + +class WorkerBrokerClient(BaseBrokerClient): + """RabbitMQ client for workers to consume tasks and publish results. + Each worker kind has its own queue for task assignments, but all workers + share a single queue for publishing results.""" + + worker_kind: str + """ The name of the worker kind. """ + + _task_assignment_queue: Optional[AbstractQueue] = None + """ Queue for task assignments. """ + + _result_exchange: Optional[AbstractExchange] = None + """ Exchange for publishing results (shared by all workers). """ + + async def connect(self) -> None: + await super().connect() + + if self._channel is None: + raise NoChannelError( + "Tried to connect to RabbitMQ, but channel was not established." + ) + + # ========================================= + # Setup task assignment queue for this worker kind + # ========================================= + + # Set prefetch to one to enable fair dispatching + await self._channel.set_qos(prefetch_count=1) + + while True: + try: + self._task_assignment_queue = await self._channel.declare_queue( + self.worker_kind, + passive=True, # We only want to connect to the queue if it already exists. + ) + break + except Exception as e: + warn( + f"Failed to passively declare task assignment queue: {e}. Retrying in 1 second...\nThis might mean the queue doesn't exist because: 1. no tasks for this worker kind have been published yet, or 2. there is a mismatch between the worker kind named on the publisher vs the worker kind named on the worker." + ) + await asyncio.sleep(1) + + # ========================================= + # Setup result publishing infrastructure + # ========================================= + + # Exchange + self._result_exchange = await self._channel.declare_exchange( + TASK_RESULT_EXCHANGE, + durable=True, + ) + + # Set up the result queue and bind it to the exchange + _result_queue = await self._channel.declare_queue( + TASK_RESULT_QUEUE, durable=True + ) + await _result_queue.bind(TASK_RESULT_EXCHANGE) + + async def listen(self) -> AsyncGenerator[Task, None]: + """Listen for tasks assigned to this worker's kind. + Messages are only acknowledged after the task has been processed.""" + + if self._task_assignment_queue is None: + raise QueueNotDeclaredError( + "Tried to listen for tasks, but queue was not declared." + ) + + async with self._task_assignment_queue.iterator() as queue_iter: + async for message in queue_iter: + task = Task(**json.loads(message.body.decode())) + yield task + await message.ack() + + async def publish_task_result(self, task: Task) -> None: + """Publish a task result to the shared results queue.""" + + # Check if the task has a result attached + if task.result is None: + raise ValueError( + "Tried to publish task result, but task has no result attached. How did it get to this point?" + ) + + if self._result_exchange is None: + raise ExchangeNotDeclaredError( + "Tried to publish task result, but exchange was not declared." + ) + + message = Message(body=task.model_dump_json().encode()) + + await self._result_exchange.publish(message, routing_key=TASK_RESULT_QUEUE) diff --git a/client_sdks/python/src/broker/config.py b/client_sdks/python/src/broker/config.py index c2ef7e4..1368408 100644 --- a/client_sdks/python/src/broker/config.py +++ b/client_sdks/python/src/broker/config.py @@ -1,8 +1,8 @@ -from dataclasses import dataclass +from pydantic import BaseModel -@dataclass -class BrokerConfig: - """Configuration for a broker.""" +class BrokerConfig(BaseModel): + """Configuration for a RabbitMQ broker.""" url: str + """ The URL of the broker. """ diff --git a/client_sdks/python/src/broker/core.py b/client_sdks/python/src/broker/core.py deleted file mode 100644 index 65c40cf..0000000 --- a/client_sdks/python/src/broker/core.py +++ /dev/null @@ -1,22 +0,0 @@ -from abc import ABC, abstractmethod -from typing import AsyncGenerator -from uuid import UUID - -from models.task import TaskInput - - -class BrokerClient(ABC): - @abstractmethod - async def connect(self) -> None: - """Connect to the broker.""" - pass - - @abstractmethod - async def disconnect(self) -> None: - """Disconnect from the broker.""" - pass - - @abstractmethod - def listen(self) -> AsyncGenerator[tuple[TaskInput, UUID, str], None]: - """Listen to the worker queue.""" - pass diff --git a/client_sdks/python/src/broker/rabbitmq.py b/client_sdks/python/src/broker/rabbitmq.py deleted file mode 100644 index cd00dc2..0000000 --- a/client_sdks/python/src/broker/rabbitmq.py +++ /dev/null @@ -1,57 +0,0 @@ -import json -from typing import AsyncGenerator, Tuple -from broker.config import BrokerConfig -from broker.core import BrokerClient -from aio_pika import connect_robust - - -class RabbitMQBroker(BrokerClient): - """RabbitMQ implementation of the broker interface. - - Attributes: - config (BrokerConfig): Configuration for connecting to RabbitMQ - exchange_name (str): Name of the primary exchange - worker_id (str): Unique identifier for this worker instance - connection: Active connection to RabbitMQ server - channel: Active channel for communication - exchange: Declared exchange for message routing - """ - - def __init__(self, config: BrokerConfig, exchange_name: str, worker_id: str): - self.config = config - self.exchange_name = exchange_name - self.worker_id = worker_id # Add worker_id to identify this worker - - async def connect(self) -> None: - """Establish connection to RabbitMQ server and setup channel. - - ### Raises - - `ConnectionError`: If connection to RabbitMQ fails - """ - self.connection = await connect_robust(self.config.url) - self.channel = await self.connection.channel() - - async def disconnect(self) -> None: - """Close RabbitMQ connection.""" - # Remove the exchanges - await self.connection.close() - - async def listen(self) -> AsyncGenerator[Tuple[str, str, str], None]: - """Listen for tasks of a specific type. - - ### Yields - - `str`: Decoded message body containing task data - - ### Raises - - `ConnectionError`: If broker connection is lost - """ - # The queue should have been created sucessfully on the gateway side - # The queue name should be the id of the worker - queue_instance = await self.channel.declare_queue(self.worker_id, durable=False) - - async for message in queue_instance.iterator(): - async with message.process(): - task_kind = message.headers.get("task_kind") - yield json.loads(message.body.decode()), message.message_id, task_kind - - await queue_instance.delete() diff --git a/client_sdks/python/src/manager/client.py b/client_sdks/python/src/manager/client.py index d6c5393..253f401 100644 --- a/client_sdks/python/src/manager/client.py +++ b/client_sdks/python/src/manager/client.py @@ -1,168 +1,112 @@ from enum import Enum from typing import Optional -from dataclasses import dataclass - from uuid import UUID -import aiohttp as aio + +from pydantic import BaseModel +from aiohttp import ClientSession, ClientConnectorError +from aiohttp_retry import RetryClient, RetryOptionsBase from manager.config import ManagerConfig -from models.task import ( - TaskStatus, - TaskInput, - TaskOutput, - TaskInstance, -) +from models.task import Task -WORKER_PATH = "/workers" -""" Base path for worker registration and unregistration endpoints.""" +# ========================================= +# Constants +# ========================================= TASK_PATH = "/tasks" """ Base path for task CRUD operations.""" +HEALTH_PATH = "/health" +""" Base path for health checking.""" -class ManagerStates(str, Enum): - """Possible states of the manager. +# ========================================= +# Manager States +# ========================================= - ### States - - `HEALTHY`: The manager is healthy. - - `UNHEALTHY`: The manager is unhealthy. - - `NOT_REACHABLE`: The manager is not reachable. - - `UNKNOWN`: The manager is in an unknown state. - """ + +class ManagerStates(str, Enum): + """Possible states of the manager. Used for health checking during tests.""" HEALTHY = "healthy" + """ The manager is healthy. """ + UNHEALTHY = "unhealthy" + """ The manager is unhealthy. """ + NOT_REACHABLE = "not_reachable" + """ The manager is not reachable. """ + UNKNOWN = "unknown" + """ The manager is in an unknown state. Schrödinger's Manager?""" -@dataclass -class ManagerClient: - """Abstracts the manager API for worker registration and unregistration.""" +class ManagerClient(BaseModel): + """Abstracts the manager API.""" config: ManagerConfig + """Configuration for the manager client.""" - # Check whether the manager is healthy + # ================================ + # Health Checking + # ================================ - async def check_health(self) -> ManagerStates: + async def check_health( + self, override_retry_options: Optional[RetryOptionsBase] = None + ) -> ManagerStates: """Check whether the manager is healthy. This is currently used before tests are run to notify the user if the manager is not healthy or even running at all. + ### Args + - `override_retry_options`: Retry options to override the default ones + ### Returns - `ManagerStates`: Whether the manager is healthy. """ + try: - async with aio.ClientSession(timeout=self.config.timeout) as session: - async with session.get(f"{self.config.url}/health") as resp: + async with ClientSession() as session: + retry_client = RetryClient( + session, + retry_options=override_retry_options or self.config.retry_options, + ) + async with retry_client.get(f"{self.config.url}{HEALTH_PATH}") as resp: match resp.status: case 200: return ManagerStates.HEALTHY - case 503: - return ManagerStates.UNHEALTHY case _: - return ManagerStates.UNHEALTHY - except aio.ClientConnectorError: + return ManagerStates.UNKNOWN + except ClientConnectorError: return ManagerStates.NOT_REACHABLE + # ================================ # Task Get/Set Operations + # ================================ - async def get_task(self, task_id: UUID) -> TaskInstance: + async def get_task( + self, task_id: UUID, override_retry_options: Optional[RetryOptionsBase] = None + ) -> Optional[Task]: """Get a task by its UUID. - ### Parameters + ### Args - `task_id`: UUID of the task to retrieve + - `override_retry_options`: Retry options to override the default ones ### Returns - - `TaskInstance`: The task details - """ - async with aio.ClientSession(timeout=self.config.timeout) as session: - async with session.get(f"{self.config.url}{TASK_PATH}/{task_id}") as resp: - resp.raise_for_status() - data = await resp.json() - return TaskInstance.from_dict(data) - - async def publish_task( - self, task_kind_name: str, input_data: Optional[TaskInput] = None - ) -> TaskInstance: - """Create a new task. - - ### Parameters - - `task_kind_name`: Name of the task kind - - `input_data`: Optional input data for the task - - ### Returns - - `TaskInstance`: The created task details - """ - async with aio.ClientSession(timeout=self.config.timeout) as session: - async with session.post( - f"{self.config.url}{TASK_PATH}", - json={"task_kind_name": task_kind_name, "input_data": input_data}, - ) as resp: - resp.raise_for_status() - data = await resp.json() - return TaskInstance.from_dict(data) - - async def update_task_status(self, task_id: UUID, status: TaskStatus) -> None: - """Update the status of a task. - - ### Parameters - - `task_id`: UUID of the task to update - - `status`: New status for the task - """ - async with aio.ClientSession(timeout=self.config.timeout) as session: - async with session.put( - f"{self.config.url}{TASK_PATH}/{task_id}/status", json=status.value - ) as resp: - resp.raise_for_status() - - async def update_task_result( - self, task_id: UUID, data: TaskOutput, is_error: bool = False - ) -> None: - """Submit results or error for a task. - - ### Parameters - - `task_id`: UUID of the task to update - - `data`: Result data or error message - - `is_error`: Whether this is an error result + - `Task`: The task details """ - async with aio.ClientSession(timeout=self.config.timeout) as session: - async with session.put( - f"{self.config.url}{TASK_PATH}/{task_id}/result", - json={"data": data, "is_error": is_error}, - ) as resp: - resp.raise_for_status() - # Worker registration and unregistration + async with ClientSession() as session: + retry_client = RetryClient( + session, + retry_options=override_retry_options or self.config.retry_options, + ) - async def register_worker(self, name: str, task_kinds: list[str]) -> UUID: - """Register a new worker with the manager service. Called on worker startup. - - ### Parameters - - `name`: The name of the worker. - - `task_kinds`: The task kinds that the worker can handle. - - ### Returns - - `UUID`: The ID of the registered worker. - """ - - async with aio.ClientSession(timeout=self.config.timeout) as session: - async with session.post( - f"{self.config.url}{WORKER_PATH}", - json={"name": name, "task_kinds": task_kinds}, + async with retry_client.get( + f"{self.config.url}{TASK_PATH}/{task_id}" ) as resp: + if resp.status == 404: + return None resp.raise_for_status() data = await resp.json() - return UUID(data["id"]) - - async def unregister_worker(self, worker_id: UUID) -> None: - """Unregister an existing worker. Called on graceful worker shutdown. - - ### Parameters - - `worker_id`: The ID of the worker to unregister. - """ - async with aio.ClientSession(timeout=self.config.timeout) as session: - async with session.delete( - f"{self.config.url}{WORKER_PATH}/{worker_id}" - ) as resp: - resp.raise_for_status() + return Task(**data) diff --git a/client_sdks/python/src/manager/config.py b/client_sdks/python/src/manager/config.py index 4b6486c..14e5081 100644 --- a/client_sdks/python/src/manager/config.py +++ b/client_sdks/python/src/manager/config.py @@ -1,15 +1,24 @@ -from dataclasses import dataclass +from aiohttp_retry import ExponentialRetry, RetryOptionsBase -import aiohttp as aio +from pydantic import BaseModel -@dataclass -class ManagerConfig: - """Configuration for the manager. +class ManagerConfig(BaseModel): + """Configuration for communicating with the manager.""" - ### Attributes - - `url`: The URL of the manager. - """ + model_config = {"arbitrary_types_allowed": True} url: str - timeout: aio.ClientTimeout = aio.ClientTimeout(total=10) + """ The base URL of the manager (with no paths). """ + retry_options: RetryOptionsBase = ExponentialRetry( + attempts=3, + start_timeout=0.2, + max_timeout=10, + factor=2.0, + statuses={500, 502, 503, 504}, + ) + """ The retry options for the publisher's HTTP requests to the manager. + This can be overriden on a per-request basis. + + Based on [aiohttp_retry](https://github.com/inyutin/aiohttp_retry). + """ diff --git a/client_sdks/python/src/models/task.py b/client_sdks/python/src/models/task.py index 2c46e3c..73abf51 100644 --- a/client_sdks/python/src/models/task.py +++ b/client_sdks/python/src/models/task.py @@ -1,137 +1,86 @@ from typing import Any, Optional from uuid import UUID -from dataclasses import dataclass from enum import Enum from datetime import datetime +import uuid +from pydantic import BaseModel, Field -TaskInput = dict[str, Any] # Maps to Option -TaskOutput = dict[str, Any] # Maps to Option +TaskInput = Any +""" Task input data defined by the user - they can use whatever format they want, but +they must handle the serialization and deserialization of the data themselves. """ + +TaskOutput = Any +""" Task output data defined by the user - they can use whatever format they want, but +they must handle the serialization and deserialization of the data themselves. """ class TaskStatus(str, Enum): - """The status of a task. - - ### Possible Status: - - `Pending`: Task is created but not yet assigned - - `Queued`: Task has been assigned to a worker and sent to a queue - - `Running`: Worker has started processing - - `Completed`: Task completed successfully - - `Failed`: Task failed to complete - - `Cancelled`: Task was cancelled before completion - - `Accepted`: Worker acknowledged receipt - - `Paused`: Temporarily suspended - - `Retrying`: Failed but attempting again - - `Timeout`: Exceeded time limit - - `Rejected`: Worker refused task - - `Blocked`: Waiting on dependencies - """ + """The status of a task""" PENDING = "pending" - ACCEPTED = "accepted" - QUEUED = "queued" - RUNNING = "running" - PAUSED = "paused" - RETRYING = "retrying" - COMPLETED = "completed" - FAILED = "failed" - CANCELLED = "cancelled" - TIMEOUT = "timeout" - REJECTED = "rejected" - BLOCKED = "blocked" + """ The task is created but not yet assigned """ + PROCESSING = "processing" + """ The task is being processed by a worker """ -@dataclass -class TaskResult: - """Task results contain the output or error data from a completed task. + COMPLETED = "completed" + """ The task has completed successfully """ - ### Parameters - - `data`: The data of the task. - - `is_error`: Whether the task failed. - - `created_at`: The time the task was created. - ### Methods - - `from_dict`: Creates a TaskResult from a dictionary. - """ +class TaskResult(BaseModel): + """Task results contain the output or error data from a completed task.""" - data: Optional[TaskOutput] - is_error: bool - created_at: datetime + data: TaskOutput + """ The data of the task's result. """ - @staticmethod - def from_dict(data: dict[str, Any]) -> "TaskResult": - """Creates a TaskResult from a dictionary.""" + is_error: bool + """ Whether the task failed. """ - is_error: bool = data["output_data"] is None - output_data = data["output_data"] if not is_error else data["error_data"] + started_at: datetime + """ The time the task was started. """ - return TaskResult( - data=output_data, - is_error=is_error, - created_at=datetime.fromisoformat(data["created_at"]), - ) + completed_at: datetime + """ The time the task was completed. """ -@dataclass -class TaskInstance: +class Task(BaseModel): """Tasks are sent to workers to be executed with a specific payload. Workers are eligible for receiving certain tasks depending on their - list of capabilities. - - ### Parameters - - `id`: The ID of the task. - - `task_kind`: The kind/class of the task. - - `input_data`: The data of the task. - - `status`: The status of the task. - - `created_at`: The time the task was created. - - `assigned_to`: The ID of the worker that is assigned to the task. - - `result`: The result of the task. - - ### Properties - - `has_finished`: Whether the task has finished. - - `has_completed`: Whether the task has completed. - - `has_failed`: Whether the task has failed. - - ### Methods - - `from_dict`: Creates a TaskInstance from a dictionary. + kind. """ - id: UUID + id: UUID = Field(default_factory=uuid.uuid4) + """The unique ID of the task. Generated by the client so that it can be communicated to the + manager and the workers directly.""" + task_kind: str - input_data: Optional[TaskInput] - status: TaskStatus - created_at: datetime - assigned_to: Optional[UUID] - result: Optional[TaskResult] + """ The kind of the task - dictates the queue that it will be placed in (via the worker kind) + and dictates how the worker will interpret the task. """ + + worker_kind: str + """ The kind of worker that will execute the task. """ + + input_data: TaskInput = Field(default=None) + """ The input data of the task. """ + + status: TaskStatus = Field(default=TaskStatus.PENDING) + """ The current status of the task at the time of retrieval.""" + + created_at: datetime = Field(default_factory=datetime.now) + """ The time the task was created. """ + + result: Optional[TaskResult] = Field(default=None) + """ The result of the task. """ + + priority: int = Field(default=0) + """ The priority of the task. """ + + is_error: bool = Field(default=False) + """ Whether the task failed. """ @property def has_finished(self) -> bool: - return self.status in [ - TaskStatus.COMPLETED, - TaskStatus.FAILED, - TaskStatus.CANCELLED, - TaskStatus.TIMEOUT, - TaskStatus.REJECTED, - ] + """Whether the task has finished.""" - @property - def has_completed(self) -> bool: return self.status == TaskStatus.COMPLETED - - @property - def has_failed(self) -> bool: - return self.status == TaskStatus.FAILED - - @staticmethod - def from_dict(data: dict[str, Any]) -> "TaskInstance": - """Creates a TaskInstance from a dictionary.""" - - return TaskInstance( - id=UUID(data["id"]), - task_kind=data["task_kind"]["name"], - input_data=data["input_data"], - status=TaskStatus(data["status"].lower()), - created_at=datetime.fromisoformat(data["created_at"]), - assigned_to=UUID(data["assigned_to"]) if data["assigned_to"] else None, - result=TaskResult.from_dict(data["result"]) if data["result"] else None, - ) diff --git a/client_sdks/python/src/publisher/__init__.py b/client_sdks/python/src/publisher/__init__.py new file mode 100644 index 0000000..46e2606 --- /dev/null +++ b/client_sdks/python/src/publisher/__init__.py @@ -0,0 +1,3 @@ +from .client import PublisherClient + +__all__ = ["PublisherClient"] diff --git a/client_sdks/python/src/publisher/client.py b/client_sdks/python/src/publisher/client.py index a690ce3..3cc4260 100644 --- a/client_sdks/python/src/publisher/client.py +++ b/client_sdks/python/src/publisher/client.py @@ -1,61 +1,85 @@ -import asyncio -from uuid import UUID +from typing import Optional -from dataclasses import dataclass +from uuid import UUID, uuid4 +from aiohttp_retry import RetryOptionsBase +from broker import PublisherBrokerClient, BrokerConfig from manager import ManagerClient, ManagerConfig -from models.task import TaskInput, TaskInstance +from models.task import TaskInput, Task +from pydantic import BaseModel -@dataclass -class PublisherClient: - """A client for publishing tasks to the manager. +class PublisherClient(BaseModel): + """A client for publishing and retrieving tasks.""" - ### Attributes - - `manager_config`: The configuration for the manager. - - `_manager_client`: The manager client. - - ### Methods - - `publish_task`: Publish a task to the manager. - - `get_task`: Get the status of a task by its UUID. - """ + # Broker + broker_config: BrokerConfig + _broker_client: Optional[PublisherBrokerClient] = None + # Manager manager_config: ManagerConfig - _manager_client: ManagerClient - - def __init__(self, manager_config: ManagerConfig): - self._manager_config = manager_config - self._manager_client = ManagerClient(manager_config) - - async def publish_task(self, task_kind: str, input_data: TaskInput) -> TaskInstance: + _manager_client: ManagerClient = None # type: ignore + + def model_post_init(self, _) -> None: + self._manager_client = ManagerClient(config=self.manager_config) + + def _connect_to_broker(self): + self._broker_client = PublisherBrokerClient(config=self.broker_config) + + async def publish_task( + self, + task_kind: str, + worker_kind: str, + input_data: Optional[TaskInput] = None, + task_id: Optional[UUID] = None, + priority: int = 0, + ) -> Task: """Publish a task to the manager. ### Arguments - - `task_kind`: The kind of the task. + - `task_kind`: The kind of the task. Can either be in the format of `worker_kind:task_name` string or a `TaskKind` object. - `input_data`: The data to publish. ### Returns - `TaskInstance`: The task instance. """ - return await self._manager_client.publish_task(task_kind, input_data) + # Connect to the broker if that hasn't yet been done + if not self._broker_client: + self._connect_to_broker() + if not self._broker_client: + raise ConnectionError("Failed to connect to the broker") + + # Create a task with base values + task = Task( + id=task_id or uuid4(), + task_kind=task_kind, + worker_kind=worker_kind, + input_data=input_data, + priority=priority, + ) + + # Publish the task to the manager + await self._broker_client.publish_task( + task, + ) + + # Return the task + return task - async def get_task(self, task_id: UUID, long_poll: bool = False) -> TaskInstance: + async def get_task( + self, task_id: UUID, override_retry_options: Optional[RetryOptionsBase] = None + ) -> Optional[Task]: """Get the status of a task by its UUID. ### Arguments - `task_id`: The UUID of the task. - - `long_poll`: Whether to long poll for the task to finish. + - `override_retry_options`: The retry options to use if you want to override the default ones. ### Returns - - `TaskInstance`: The task instance. + - `Task`: The task. """ - task = await self._manager_client.get_task(task_id) - - if long_poll: - while not task.has_finished: - await asyncio.sleep(1) - task = await self._manager_client.get_task(task_id) - - return task + return await self._manager_client.get_task( + task_id, override_retry_options=override_retry_options + ) diff --git a/client_sdks/python/src/tacoq.egg-info/PKG-INFO b/client_sdks/python/src/tacoq.egg-info/PKG-INFO index ac302ce..7d1cff4 100644 --- a/client_sdks/python/src/tacoq.egg-info/PKG-INFO +++ b/client_sdks/python/src/tacoq.egg-info/PKG-INFO @@ -9,9 +9,12 @@ Classifier: Operating System :: OS Independent Requires-Python: >=3.12 Description-Content-Type: text/markdown Requires-Dist: aio-pika>=9.5.3 +Requires-Dist: aiohttp-retry>=2.9.1 Requires-Dist: aiohttp>=3.11.8 Requires-Dist: aioredis>=2.0.1 +Requires-Dist: aioresponses>=0.7.8 Requires-Dist: click>=8.1.7 +Requires-Dist: pydantic>=2.10.5 Requires-Dist: uuid>=1.30 Requires-Dist: watchfiles>=1.0.3 diff --git a/client_sdks/python/src/tacoq.egg-info/SOURCES.txt b/client_sdks/python/src/tacoq.egg-info/SOURCES.txt index 821889f..8c3132a 100644 --- a/client_sdks/python/src/tacoq.egg-info/SOURCES.txt +++ b/client_sdks/python/src/tacoq.egg-info/SOURCES.txt @@ -2,9 +2,8 @@ README.md pyproject.toml src/__init__.py src/broker/__init__.py +src/broker/client.py src/broker/config.py -src/broker/core.py -src/broker/rabbitmq.py src/cli/__init__.py src/cli/cli.py src/cli/importer.py @@ -15,6 +14,7 @@ src/manager/__init__.py src/manager/client.py src/manager/config.py src/models/task.py +src/publisher/__init__.py src/publisher/client.py src/tacoq.egg-info/PKG-INFO src/tacoq.egg-info/SOURCES.txt diff --git a/client_sdks/python/src/tacoq.egg-info/requires.txt b/client_sdks/python/src/tacoq.egg-info/requires.txt index 0450359..8506b7c 100644 --- a/client_sdks/python/src/tacoq.egg-info/requires.txt +++ b/client_sdks/python/src/tacoq.egg-info/requires.txt @@ -1,6 +1,9 @@ aio-pika>=9.5.3 +aiohttp-retry>=2.9.1 aiohttp>=3.11.8 aioredis>=2.0.1 +aioresponses>=0.7.8 click>=8.1.7 +pydantic>=2.10.5 uuid>=1.30 watchfiles>=1.0.3 diff --git a/client_sdks/python/src/worker/__init__.py b/client_sdks/python/src/worker/__init__.py index 5bf8869..8f39cb9 100644 --- a/client_sdks/python/src/worker/__init__.py +++ b/client_sdks/python/src/worker/__init__.py @@ -1,4 +1,4 @@ -from worker.client import WorkerApplication +from worker.client import WorkerApplication, TaskNotRegisteredError from worker.config import WorkerApplicationConfig -__all__ = ["WorkerApplication", "WorkerApplicationConfig"] +__all__ = ["WorkerApplication", "WorkerApplicationConfig", "TaskNotRegisteredError"] diff --git a/client_sdks/python/src/worker/client.py b/client_sdks/python/src/worker/client.py index a9bf158..09e4808 100644 --- a/client_sdks/python/src/worker/client.py +++ b/client_sdks/python/src/worker/client.py @@ -1,47 +1,75 @@ -import asyncio -from dataclasses import dataclass +from datetime import datetime from typing import Callable, Awaitable, Optional, Dict -from uuid import UUID -from broker import create_broker_instance, BrokerClient +import asyncio +from broker import WorkerBrokerClient from manager import ManagerClient -from models.task import TaskInput, TaskOutput +from models.task import Task, TaskInput, TaskOutput, TaskResult, TaskStatus +from pydantic import BaseModel from worker.config import WorkerApplicationConfig -@dataclass -class WorkerApplication: - """A worker application that processes tasks from a task queue. - - ### Attributes - - `_config`: Configuration for the worker application - - `_tasks`: Mapping of task kinds to their handler functions - - `_broker_client`: Client for communicating with the message broker - - `_manager_client`: Client for communicating with the task manager - - `_id`: Unique identifier assigned by the manager - - ### Methods - - `register_task`: Register a task handler function for a specific task kind - - `task`: Decorator for registering task handler functions - - `_register_worker`: Register this worker with the manager and initialize broker connection - - `_unregister_worker`: Unregister from the manager and clean up broker connection - - `_execute_task`: Execute a task and update its status in the manager - - `_listen`: Listen for tasks of a specific kind from the broker - - `entrypoint`: Start the worker application - """ - - _config: WorkerApplicationConfig - _registered_tasks: Dict[str, Callable[[TaskInput], Awaitable[TaskOutput]]] - _broker_client: Optional[BrokerClient] - _manager_client: ManagerClient - - def __init__(self, config: WorkerApplicationConfig): - self._config = config - self._id = None - - self._manager_client = ManagerClient(config.manager_config) +# ========================================= +# Errors +# ========================================= + + +class TaskNotRegisteredError(Exception): + """Exception raised when a task is not registered.""" + + def __init__( + self, + task_kind: str, + registered_tasks: Dict[str, Callable[[TaskInput], Awaitable[TaskOutput]]], + ): + self.message = f"Task {task_kind} not registered for this worker. Available tasks: {registered_tasks.keys()}" + super().__init__(self.message) + + +# ========================================= +# Worker Application +# ========================================= + + +class SerializableException(BaseModel): + """A serializable exception.""" + + type: str + """ The type of the exception. `RuntimeError` evaluates to `"RuntimeError"`.""" + + message: str + """ The message of the exception. """ + + +class WorkerApplication(BaseModel): + """A worker application that processes tasks from a task queue.""" + + config: WorkerApplicationConfig + """ The configuration for this worker application. """ + + _manager_client: Optional[ManagerClient] = None + """ The manager client that this worker application uses to interface with the manager service. """ + + _registered_tasks: Dict[str, Callable[[TaskInput], Awaitable[TaskOutput]]] = {} + """ All the tasks that this worker application can handle. """ + + _broker_client: Optional[WorkerBrokerClient] = None + """ The broker client that this worker application uses. """ + + _queue_name: Optional[str] = None + """ The queue name that this worker application listens to. """ + + _shutdown: bool = False + """ Whether the worker application is shutting down. """ + + def model_post_init(self, _) -> None: self._registered_tasks = {} + self._manager_client = ManagerClient(config=self.config.manager_config) + + # ================================ + # Task Registration & Execution + # ================================ def register_task( self, kind: str, task: Callable[[TaskInput], Awaitable[TaskOutput]] @@ -77,117 +105,125 @@ def decorator( return decorator - async def _register_worker(self): - """Register this worker with the manager and initialize broker connection. + async def _execute_task(self, task: Task): + """Execute a task and update its status in the manager. + + ### Parameters + - `task`: Task to execute ### Raises - - `ConnectionError`: If connection to manager or broker fails + - `ValueError`: If task kind is not registered """ - worker = await self._manager_client.register_worker( - self._config.name, list(self._registered_tasks.keys()) - ) - self._id = worker - # For this ideally we would get the broker information from the manager - self._broker_client = create_broker_instance( - self._config.broker_config, self._config.name, str(self._id) - ) - await self._broker_client.connect() + # Check if broker client is initialized + if self._broker_client is None: + raise RuntimeError("Broker client not initialized") - async def _unregister_worker(self): - """Unregister from the manager and clean up broker connection. + # Find task handler + task_func = self._registered_tasks.get(task.task_kind) + if task_func is None: + raise TaskNotRegisteredError(task.task_kind, self._registered_tasks) - ### Raises - - `ValueError`: If worker is not registered - """ - if self._id is None: - raise ValueError("Worker is not registered.") + # Compute task result + result: Optional[TaskOutput | Exception] = None + is_error: bool = False - try: - if self._broker_client: - await self._broker_client.disconnect() - except Exception as e: - raise ValueError(f"Error during broker disconnect: {e}") + # Start timer + started_at = datetime.now() + # Execute task try: - await self._manager_client.unregister_worker(self._id) + result = await task_func(task.input_data) except Exception as e: - raise ValueError(f"Error during worker unregister: {e}") + result = SerializableException( + type=e.__class__.__name__, + message=e.__str__(), + ) + is_error = True - # Clear local variables state - # Important for hot reloading code - self.cleanup() + # Stop timer + completed_at = datetime.now() - async def _execute_task(self, kind: str, input_data: TaskInput, task_id: UUID): - """Execute a task and update its status in the manager. + # Update task + task.result = TaskResult( + data=result, + is_error=is_error, + started_at=started_at, + completed_at=completed_at, + ) + task.status = TaskStatus.COMPLETED - ### Parameters - - `kind`: Type of task to execute - - `input_data`: Input data for the task - - `task_id`: Unique identifier for the task + # Submit task result via broker + await self._broker_client.publish_task_result( + task=task, + ) + + # ================================ + # Worker Lifecycle + # ================================ + + async def _init_broker_client(self): + """Initialize the broker client for this worker. ### Raises - - `ValueError`: If task kind is not registered + - `RuntimeError`: If manager client is not initialized """ - task_func = self._registered_tasks.get(kind) - if task_func is None: - raise ValueError(f"Task {kind} not registered.") + if self._manager_client is None: + raise RuntimeError("Manager client not initialized") - try: - # Check what to do with the task result - result = await task_func(input_data) - await self._manager_client.update_task_result( - task_id, result, is_error=False - ) - except Exception as e: - # Log the exception (could improve error handling) - await self._manager_client.update_task_result( - task_id, {"error": str(e)}, is_error=True - ) + # Init the broker client using the queue name of the worker kind + self._broker_client = WorkerBrokerClient( + config=self.config.broker_config, + worker_kind=self.config.kind, + ) + await self._broker_client.connect() - async def _listen(self): - """Listen for tasks of a specific kind from the broker. + async def entrypoint(self): + """Start the worker application. - ### Raises - - `RuntimeError`: If broker client is not initialized + This method registers the worker kind with the manager, + starts listening for tasks, and handles graceful shutdown. """ - if not self._broker_client: - raise RuntimeError("Broker client is not initialized.") + await self._init_broker_client() - input_data: TaskInput - task_id: UUID - task_kind: str + print("Worker application initialized!") + # Begin loop try: - async for ( - input_data, - task_id, - task_kind, - ) in self._broker_client.listen(): - await self._execute_task(task_kind, input_data, task_id) + await self._listen() except asyncio.CancelledError: - await self._broker_client.disconnect() + pass + finally: + await self._cleanup() - async def entrypoint(self): - """Start the worker application. + async def _listen(self): + """Listen for tasks of a specific kind from the broker. - This method registers the worker, starts listening for tasks, - and handles graceful shutdown. + ### Raises + - `RuntimeError`: If broker client is not initialized """ - await self._register_worker() + + if self._broker_client is None: + raise RuntimeError("Broker client not initialized") try: - await self._listen() + async for task in self._broker_client.listen(): + if self._shutdown: + break + await self._execute_task(task) except asyncio.CancelledError: pass finally: - await self._unregister_worker() + await self._cleanup() + + def shutdown(self): + """Shutdown the worker application.""" + self._shutdown = True - def cleanup(self): + async def _cleanup(self): """Cleanup the worker application. This method is called when the worker is shutting down. Used for cleaning internal state. """ - self._broker_client = None - self._registered_tasks = {} - self._id = None + if self._broker_client is not None: + await self._broker_client.disconnect() diff --git a/client_sdks/python/src/worker/config.py b/client_sdks/python/src/worker/config.py index 06cce91..62664b2 100644 --- a/client_sdks/python/src/worker/config.py +++ b/client_sdks/python/src/worker/config.py @@ -8,14 +8,16 @@ class WorkerApplicationConfig: """Configuration for a worker application. This is passed in on initialization of the `WorkerApplication` class, and can come from a config - file or other sources. - - ### Attributes - - `name`: The name of the worker. - - `broker_config`: Configuration for the broker. - - `manager_config`: Configuration for the manager. - """ + file or other sources.""" name: str + """ The name of the worker. This is used to identify the worker in the manager.""" + + kind: str + """ The kind of worker. This dictates which tasks get routed to this worker.""" + broker_config: BrokerConfig + """ Configuration for the broker. """ + manager_config: ManagerConfig + """ Configuration for the manager. """ diff --git a/client_sdks/python/tests/__init__.py b/client_sdks/python/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client_sdks/python/tests/benchmarks/test_task_creation.py b/client_sdks/python/tests/benchmarks/test_task_creation.py index 1e01248..70eb890 100644 --- a/client_sdks/python/tests/benchmarks/test_task_creation.py +++ b/client_sdks/python/tests/benchmarks/test_task_creation.py @@ -22,6 +22,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: @async_to_sync +@pytest.mark.skip(reason="Not implemented") async def create_n_tasks_sequential( manager_client: ManagerClient, test_task_kind: str, n: int ): @@ -32,6 +33,7 @@ async def create_n_tasks_sequential( @async_to_sync +@pytest.mark.skip(reason="Not implemented") async def create_n_tasks_concurrent( manager_client: ManagerClient, test_task_kind: str, n: int ): @@ -56,6 +58,7 @@ def n_tasks(request: pytest.FixtureRequest) -> int: return request.param +@pytest.mark.skip(reason="Not implemented") @pytest.mark.bench def test_task_creation_benchmark_sync( benchmark: BenchmarkFixture, manager_client: ManagerClient, n_tasks: int @@ -85,6 +88,7 @@ def test_task_creation_benchmark_sync( asyncio.run(manager_client.unregister_worker(worker_id)) +@pytest.mark.skip(reason="Not implemented") @pytest.mark.bench def test_task_creation_benchmark_concurrent( benchmark: BenchmarkFixture, manager_client: ManagerClient, n_tasks: int diff --git a/client_sdks/python/tests/conftest.py b/client_sdks/python/tests/conftest.py index 9b013aa..f301e91 100644 --- a/client_sdks/python/tests/conftest.py +++ b/client_sdks/python/tests/conftest.py @@ -1,4 +1,4 @@ -from src.manager import ManagerClient, ManagerConfig, ManagerStates +from src.manager import ManagerClient, ManagerConfig from src.worker import WorkerApplication, WorkerApplicationConfig from src.broker import BrokerConfig import pytest @@ -10,11 +10,17 @@ "BROKER_TEST_URL", "amqp://user:password@localhost:5672/" ) +WORKER_KIND_NAME = "test_worker_kind" WORKER_NAME = "test_worker" pytest_plugins = ["pytest_asyncio"] +## ============================== +## Manager Fixtures +## ============================== + + @pytest.fixture async def manager_config() -> ManagerConfig: """Fixture that provides a configured ManagerConfig instance.""" @@ -22,19 +28,18 @@ async def manager_config() -> ManagerConfig: @pytest.fixture -async def manager_client(manager_config: ManagerConfig) -> ManagerClient: - """Fixture that provides a configured ManagerClient instance. Also checks - whether the manager client is running and healthy before starting the tests.""" +def manager_client(manager_config: ManagerConfig) -> ManagerClient: + return ManagerClient(config=manager_config) - client = ManagerClient(config=manager_config) - # Check if the manager is healthy - client_health = await client.check_health() +@pytest.fixture +def mock_manager_client() -> ManagerClient: + return ManagerClient(config=ManagerConfig(url="http://test")) - if client_health != ManagerStates.HEALTHY: - raise RuntimeError(f"Manager is not healthy. Current state: {client_health}") - return client +## ============================== +## Broker Fixtures +## ============================== @pytest.fixture @@ -43,13 +48,21 @@ async def broker_config() -> BrokerConfig: return BrokerConfig(url=BROKER_TEST_URL) +## ============================== +## Worker Fixtures +## ============================== + + @pytest.fixture async def worker_config( manager_config: ManagerConfig, broker_config: BrokerConfig ) -> WorkerApplicationConfig: """Fixture that provides a configured WorkerConfig instance.""" return WorkerApplicationConfig( - name=WORKER_NAME, manager_config=manager_config, broker_config=broker_config + kind=WORKER_KIND_NAME, + name=WORKER_NAME, + manager_config=manager_config, + broker_config=broker_config, ) diff --git a/client_sdks/python/tests/e2e/test_full.py b/client_sdks/python/tests/e2e/test_full.py new file mode 100644 index 0000000..d5da640 --- /dev/null +++ b/client_sdks/python/tests/e2e/test_full.py @@ -0,0 +1,143 @@ +from asyncio import sleep, get_event_loop +import pytest +from uuid import uuid4 + +from models.task import TaskInput, TaskOutput, TaskStatus +from publisher import PublisherClient +from worker import WorkerApplication +from worker.config import WorkerApplicationConfig +from broker.config import BrokerConfig +from manager.config import ManagerConfig + +WORKER_NAME = "test_worker" +WORKER_KIND = "test_worker_kind" + +DELAYED_TASK = "delayed_task" +FAILING_TASK = "failing_task" + + +class WorkerContext: + """Context manager for running a worker in a separate process.""" + + _worker_app: WorkerApplication = None # type: ignore + + def __init__(self): + async def delayed_task(input_data: TaskInput) -> TaskOutput: + await sleep(2) + return {"message": "Task completed", "input": input_data} + + async def failing_task(_: TaskInput) -> TaskOutput: + raise ValueError("Task failed intentionally") + + # Create and configure worker + self._worker_app = WorkerApplication( + config=WorkerApplicationConfig( + name=WORKER_NAME, + kind=WORKER_KIND, + manager_config=ManagerConfig(url="http://localhost:3000"), + broker_config=BrokerConfig(url="amqp://user:password@localhost:5672"), + ) + ) + + # Register appropriate task handler + self._worker_app.register_task(DELAYED_TASK, delayed_task) + self._worker_app.register_task(FAILING_TASK, failing_task) + + def __enter__(self): + # Run worker in background + loop = get_event_loop() + loop.create_task(self._worker_app.entrypoint()) + return self._worker_app + + def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore + self._worker_app.shutdown() + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_delayed_task_e2e(): + """Test a task that takes 2 seconds to complete. + This test verifies the full lifecycle of a task: + 1. Task submission + 2. Immediate task status check (should be pending) + 3. Wait for completion + 4. Final task status check (should be completed) + """ + publisher = PublisherClient( + manager_config=ManagerConfig(url="http://localhost:3000"), + broker_config=BrokerConfig(url="amqp://user:password@localhost:5672"), + ) + + # Start worker in background + with WorkerContext(): + task = await publisher.publish_task( + task_kind=DELAYED_TASK, + worker_kind=WORKER_KIND, + input_data={"test": "data"}, + ) + + await sleep(1) + + # Check immediate status + task_status = await publisher.get_task(task.id) + assert task_status is not None, "Task status is None" + assert task_status.status == TaskStatus.PENDING + assert task_status.result is None + + # Wait and check final status + await sleep(3) # Wait for task completion + buffer + task_status = await publisher.get_task(task.id) + assert task_status is not None, "Task status is None" + assert task_status.status == TaskStatus.COMPLETED + assert task_status.is_error is False + assert task_status.result is not None + assert task_status.result.data["message"] == "Task completed" + assert task_status.result.data["input"] == {"test": "data"} + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_error_task_e2e(): + """Test a task that fails immediately. + This test verifies error handling: + 1. Task submission + 2. Task execution (fails) + 3. Task status check (should be failed) + """ + publisher = PublisherClient( + manager_config=ManagerConfig(url="http://localhost:3000"), + broker_config=BrokerConfig(url="amqp://user:password@localhost:5672"), + ) + + # Start worker in background + with WorkerContext(): + # Submit task + task = await publisher.publish_task( + task_kind=FAILING_TASK, + worker_kind=WORKER_KIND, + input_data={}, + ) + + # Wait a bit for task to be processed + await sleep(0.5) + + # Check status + task_status = await publisher.get_task(task.id) + assert task_status is not None, "Task status is None" + assert task_status.status == TaskStatus.COMPLETED + assert task_status.is_error is True + assert task_status.result is not None + assert "Task failed intentionally" in str(task_status.result.data["error"]) + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_task_not_found(): + """Test that requesting a non-existent task returns None""" + publisher = PublisherClient( + manager_config=ManagerConfig(url="http://localhost:3000"), + broker_config=BrokerConfig(url="amqp://user:password@localhost:5672"), + ) + + task_status = await publisher.get_task(uuid4()) + assert task_status is None diff --git a/client_sdks/python/tests/manager/test_health_check.py b/client_sdks/python/tests/manager/test_health_check.py deleted file mode 100644 index e65040d..0000000 --- a/client_sdks/python/tests/manager/test_health_check.py +++ /dev/null @@ -1,13 +0,0 @@ -from src.manager import ManagerClient, ManagerStates -import pytest - - -@pytest.mark.asyncio -async def test_health_check_client(manager_client: ManagerClient): - """Tests whether the manager client can check the health of the manager at all.""" - - health_state = await manager_client.check_health() - - assert health_state == ManagerStates.HEALTHY, ( - f"Manager is not healthy. Current state: {health_state}" - ) diff --git a/client_sdks/python/tests/manager/test_task_management.py b/client_sdks/python/tests/manager/test_task_management.py deleted file mode 100644 index da6d94b..0000000 --- a/client_sdks/python/tests/manager/test_task_management.py +++ /dev/null @@ -1,91 +0,0 @@ -from src.manager import ManagerClient -from src.models.task import TaskStatus -import pytest -from uuid import UUID, uuid4 - - -@pytest.mark.asyncio -async def test_task_lifecycle(manager_client: ManagerClient): - """Tests full task lifecycle - publish, get, update status, update result.""" - - # NOTE - We use a random UUID for the task kind to avoid conflicts in parallel tests - TEST_TASK_KIND = str(uuid4()) - TEST_WORKER_NAME = str(uuid4()) - - # Register worker - worker_id = await manager_client.register_worker(TEST_WORKER_NAME, [TEST_TASK_KIND]) - - try: - # Publish task - input_data = {"test": "data"} - task = await manager_client.publish_task(TEST_TASK_KIND, input_data) - - assert task.task_kind == TEST_TASK_KIND - assert task.input_data == input_data - - # Get task - retrieved = await manager_client.get_task(task.id) - assert retrieved.id == task.id - assert retrieved.task_kind == TEST_TASK_KIND - - # Update status - await manager_client.update_task_status(task.id, TaskStatus.RUNNING) - updated = await manager_client.get_task(task.id) - assert updated.status == TaskStatus.RUNNING - - # Update result - result_data = {"result": "success"} - await manager_client.update_task_result(task.id, result_data) - final = await manager_client.get_task(task.id) - assert final.result is not None - assert final.result.data == result_data - assert not final.result.is_error - - finally: - await manager_client.unregister_worker(worker_id) - - -@pytest.mark.asyncio -async def test_task_error_result(manager_client: ManagerClient): - """Tests submitting error results for tasks.""" - - # NOTE - We use a random UUID for the task kind to avoid conflicts in parallel tests - TEST_TASK_KIND = str(uuid4()) - TEST_WORKER_NAME = str(uuid4()) - - worker_id = await manager_client.register_worker(TEST_WORKER_NAME, [TEST_TASK_KIND]) - - try: - task = await manager_client.publish_task(TEST_TASK_KIND) - error_data = {"error": "test failure"} - await manager_client.update_task_result(task.id, error_data, is_error=True) - - result = await manager_client.get_task(task.id) - assert result.result is not None - assert result.result.is_error - assert result.result.data == error_data - - finally: - await manager_client.unregister_worker(worker_id) - - -@pytest.mark.asyncio -async def test_get_nonexistent_task(manager_client: ManagerClient): - """Tests getting a task that doesn't exist.""" - - fake_id = UUID("00000000-0000-0000-0000-000000000000") - with pytest.raises(Exception): - await manager_client.get_task(fake_id) - - -@pytest.mark.asyncio -async def test_update_nonexistent_task(manager_client: ManagerClient): - """Tests updating status/result of nonexistent task.""" - - fake_id = UUID("00000000-0000-0000-0000-000000000000") - - with pytest.raises(Exception): - await manager_client.update_task_status(fake_id, TaskStatus.COMPLETED) - - with pytest.raises(Exception): - await manager_client.update_task_result(fake_id, {"data": "test"}) diff --git a/client_sdks/python/tests/manager/test_worker_registry.py b/client_sdks/python/tests/manager/test_worker_registry.py deleted file mode 100644 index 220eb73..0000000 --- a/client_sdks/python/tests/manager/test_worker_registry.py +++ /dev/null @@ -1,27 +0,0 @@ -from src.manager import ManagerClient -import pytest -from uuid import UUID - - -@pytest.mark.asyncio -async def test_worker_lifecycle(manager_client: ManagerClient): - """Tests worker registration and unregistration.""" - - # Register worker - worker_name = "test_worker" - task_kinds = ["test_task", "another_task"] - worker_id = await manager_client.register_worker(worker_name, task_kinds) - - assert isinstance(worker_id, UUID) - - # Unregister worker - await manager_client.unregister_worker(worker_id) - - -@pytest.mark.asyncio -async def test_unregister_nonexistent_worker(manager_client: ManagerClient): - """Tests unregistering a worker that doesn't exist.""" - - nonexistent_id = UUID("00000000-0000-0000-0000-000000000000") - with pytest.raises(Exception): - await manager_client.unregister_worker(nonexistent_id) diff --git a/client_sdks/python/tests/unit/manager/test_health.py b/client_sdks/python/tests/unit/manager/test_health.py new file mode 100644 index 0000000..e1315ab --- /dev/null +++ b/client_sdks/python/tests/unit/manager/test_health.py @@ -0,0 +1,36 @@ +import pytest +from aiohttp import ClientConnectorError +from aiohttp.client_reqrep import ConnectionKey +from aioresponses import aioresponses + +from manager.client import ManagerClient, ManagerStates + + +@pytest.mark.asyncio +async def test_health_check_healthy(mock_manager_client: ManagerClient): + with aioresponses() as m: + m.get("http://test/health", status=200) # type: ignore + state = await mock_manager_client.check_health() + assert state == ManagerStates.HEALTHY + + +@pytest.mark.asyncio +async def test_health_check_unknown(mock_manager_client: ManagerClient): + with aioresponses() as m: + # Mock multiple attempts since RetryClient is used for 500 errors + m.get("http://test/health", status=500, body=b"{}", repeat=True) # type: ignore + state = await mock_manager_client.check_health() + assert state == ManagerStates.UNKNOWN + + +@pytest.mark.asyncio +async def test_health_check_not_reachable(mock_manager_client: ManagerClient): + with aioresponses() as m: + m.get( # type: ignore + "http://test/health", + exception=ClientConnectorError( + ConnectionKey("test", 80, False, None, None, None, None), OSError() + ), + ) + state = await mock_manager_client.check_health() + assert state == ManagerStates.NOT_REACHABLE diff --git a/client_sdks/python/tests/unit/manager/test_tasks.py b/client_sdks/python/tests/unit/manager/test_tasks.py new file mode 100644 index 0000000..083bed9 --- /dev/null +++ b/client_sdks/python/tests/unit/manager/test_tasks.py @@ -0,0 +1,65 @@ +import pytest +from uuid import UUID +from aiohttp import ClientResponseError +from aioresponses import aioresponses + +from manager.client import ManagerClient +from models.task import Task, TaskStatus + + +@pytest.mark.asyncio +async def test_get_task_success(mock_manager_client: ManagerClient): + task_id = UUID("00000000-0000-0000-0000-000000000000") + task_data = { + "id": str(task_id), + "task_kind": "test_kind", + "worker_kind": "test_worker_kind", + "created_at": "2024-01-01T00:00:00Z", + "input_data": {"foo": "bar"}, + "status": TaskStatus.PENDING.value, # Use enum value for serialization + "priority": 5, + "result": None, + } + + with aioresponses() as m: + m.get( # type: ignore + f"http://test/tasks/{task_id}", + payload=task_data, + status=200, + ) + task = await mock_manager_client.get_task(task_id) + assert isinstance(task, Task) + assert task.id == task_id + assert task.task_kind == "test_kind" + assert task.status == TaskStatus.PENDING + + +@pytest.mark.asyncio +async def test_get_task_not_found(mock_manager_client: ManagerClient): + task_id = UUID("00000000-0000-0000-0000-000000000000") + + with aioresponses() as m: + m.get( # type: ignore + f"http://test/tasks/{task_id}", + status=404, + body=b"Task not found", + repeat=True, + ) + response = await mock_manager_client.get_task(task_id) + assert response is None + + +@pytest.mark.asyncio +async def test_get_task_server_error(mock_manager_client: ManagerClient): + task_id = UUID("00000000-0000-0000-0000-000000000000") + + with aioresponses() as m: + m.get( # type: ignore + f"http://test/tasks/{task_id}", + status=500, + body=b"Internal server error", + repeat=True, + ) + with pytest.raises(ClientResponseError) as exc_info: + await mock_manager_client.get_task(task_id) + assert exc_info.value.status == 500 diff --git a/client_sdks/python/tests/unit/publisher/test_publisher_client.py b/client_sdks/python/tests/unit/publisher/test_publisher_client.py new file mode 100644 index 0000000..3626d0d --- /dev/null +++ b/client_sdks/python/tests/unit/publisher/test_publisher_client.py @@ -0,0 +1,105 @@ +# pyright: reportPrivateUsage=false + +from unittest import mock +from broker.client import PublisherBrokerClient +from manager.client import ManagerClient +import pytest +from uuid import uuid4 + +from broker.config import BrokerConfig +from manager.config import ManagerConfig +from models.task import Task, TaskStatus +from publisher import PublisherClient + + +# ========================================= +# Fixtures +# ========================================= + + +@pytest.fixture +def publisher_client(): + """Creates a publisher client with mocked dependencies.""" + + client = PublisherClient( + manager_config=ManagerConfig(url="http://localhost:8080"), + broker_config=BrokerConfig(url="http://localhost:5672"), + ) + return client + + +# ========================================= +# Task Publishing Tests +# ========================================= + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_publish_task_success(publisher_client: PublisherClient): + """Test publishing a task successfully.""" + task_kind = "test_task" + worker_kind = "test_kind" + input_data = {"test": "data"} + priority = 5 + id = uuid4() + + publisher_client._broker_client = mock.create_autospec( + PublisherBrokerClient, instance=True + ) + + task = await publisher_client.publish_task( + task_kind=task_kind, + worker_kind=worker_kind, + input_data=input_data, + priority=priority, + task_id=id, + ) + + # Verify task properties + assert task.task_kind == task_kind + assert task.worker_kind == worker_kind + assert task.input_data == input_data + assert task.priority == priority + assert task.id == id + + # Verify broker client calls + publisher_client._broker_client.publish_task.assert_called_once_with( # type: ignore + task, + ) + + +# ========================================= +# Task Retrieval Tests +# ========================================= + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_get_task_success( + publisher_client: PublisherClient, +): + """Test retrieving a task successfully. Here we mock the manager client because the actual + task retrieval behaviour is already tested in the manager client tests.""" + + task_id = uuid4() + expected_task = Task( + id=task_id, + task_kind="test_task", + worker_kind="test_kind", + input_data={"test": "data"}, + priority=0, + status=TaskStatus.PENDING, + result=None, + ) + + publisher_client._manager_client = mock.create_autospec( + ManagerClient, instance=True + ) + publisher_client._manager_client.get_task.return_value = expected_task # type: ignore + + task = await publisher_client.get_task(task_id) + assert task == expected_task + publisher_client._manager_client.get_task.assert_called_once_with( # type: ignore + task_id, + override_retry_options=None, + ) diff --git a/client_sdks/python/tests/unit/worker/test_worker_client.py b/client_sdks/python/tests/unit/worker/test_worker_client.py new file mode 100644 index 0000000..7b5dace --- /dev/null +++ b/client_sdks/python/tests/unit/worker/test_worker_client.py @@ -0,0 +1,204 @@ +# pyright: reportPrivateUsage=false + +import asyncio +import datetime +from typing import AsyncGenerator +from unittest import mock +from uuid import uuid4 + +import pytest +from broker import WorkerBrokerClient +from broker.config import BrokerConfig +from manager.config import ManagerConfig +from models.task import Task, TaskInput, TaskOutput, TaskStatus +from worker import TaskNotRegisteredError, WorkerApplication +from worker.config import WorkerApplicationConfig + +# ========================================= +# Fixtures +# ========================================= + + +@pytest.fixture +def worker_app(): + """Creates a worker app with mocked dependencies.""" + + config = WorkerApplicationConfig( + name="test_worker", + kind="test_kind", + manager_config=ManagerConfig( + url="http://localhost:8080", + ), + broker_config=BrokerConfig( + url="http://localhost:5672", + ), + ) + return WorkerApplication(config=config) + + +@pytest.fixture +def sample_task(): + """Creates a sample task for testing.""" + return Task( + id=uuid4(), + task_kind="test_task", + worker_kind="test_kind", + input_data={"value": 5}, + priority=0, + created_at=datetime.datetime.now(), + status=TaskStatus.PENDING, + result=None, + ) + + +# ========================================= +# Task Registration Tests +# ========================================= + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_register_single_task(worker_app: WorkerApplication): + """Test registering a single task handler.""" + + async def task_handler(input_data: TaskInput) -> TaskOutput: + return {"result": input_data["value"] * 2} + + worker_app.register_task("test_task", task_handler) + assert "test_task" in worker_app._registered_tasks + assert worker_app._registered_tasks["test_task"] == task_handler + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_register_multiple_tasks(worker_app: WorkerApplication): + """Test registering multiple task handlers.""" + + async def task1(input_data: TaskInput) -> TaskOutput: + return {"result": input_data["value"] * 2} + + async def task2(input_data: TaskInput) -> TaskOutput: + return {"result": input_data["value"] + 1} + + worker_app.register_task("task1", task1) + worker_app.register_task("task2", task2) + + assert "task1" in worker_app._registered_tasks + assert "task2" in worker_app._registered_tasks + assert worker_app._registered_tasks["task1"] == task1 + assert worker_app._registered_tasks["task2"] == task2 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_task_decorator_registration(worker_app: WorkerApplication): + """Test registering tasks using the decorator.""" + + @worker_app.task("decorated_task") + async def task_handler(input_data: TaskInput) -> TaskOutput: + return {"result": input_data["value"] * 2} + + assert "decorated_task" in worker_app._registered_tasks + assert worker_app._registered_tasks["decorated_task"] == task_handler + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_reregister_task(worker_app: WorkerApplication): + """Test re-registering a task (should overwrite).""" + + async def task1(input_data: TaskInput) -> TaskOutput: + return {"result": 1} + + async def task2(input_data: TaskInput) -> TaskOutput: + return {"result": 2} + + worker_app.register_task("same_kind", task1) + worker_app.register_task("same_kind", task2) + + assert worker_app._registered_tasks["same_kind"] == task2 + + +# ========================================= +# Task Execution Tests +# ========================================= + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_execute_registered_task( + worker_app: WorkerApplication, sample_task: Task +): + """Test executing a registered task successfully.""" + executed = False + worker_app._broker_client = mock.create_autospec(WorkerBrokerClient, instance=True) + + async def task_handler(input_data: TaskInput) -> TaskOutput: + nonlocal executed + executed = True + assert input_data == sample_task.input_data + return {"result": input_data["value"] * 2} + + worker_app.register_task(sample_task.task_kind, task_handler) + await worker_app._execute_task(sample_task) + assert executed + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_execute_unregistered_task( + worker_app: WorkerApplication, sample_task: Task +): + """Test executing an unregistered task.""" + worker_app._broker_client = mock.create_autospec(WorkerBrokerClient, instance=True) + with pytest.raises(TaskNotRegisteredError) as exc_info: + await worker_app._execute_task(sample_task) + assert sample_task.task_kind in str(exc_info.value) + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_execute_task_with_error( + worker_app: WorkerApplication, sample_task: Task +): + """Test executing a task that raises an exception.""" + worker_app._broker_client = mock.create_autospec(WorkerBrokerClient, instance=True) + + async def failing_task(_: TaskInput) -> TaskOutput: + raise ValueError("Task failed") + + worker_app.register_task(sample_task.task_kind, failing_task) + await worker_app._execute_task(sample_task) + # TODO: Add assertions for error handling once implemented + + +# ========================================= +# Worker Lifecycle Tests +# ========================================= + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_worker_startup(worker_app: WorkerApplication): + """Test worker startup sequence and the graceful shutdown.""" + + # We create a full mock that gracefully shuts down as soon as it is initialized + worker_app._broker_client = mock.create_autospec(WorkerBrokerClient, instance=True) + if worker_app._broker_client: + + async def mock_listen() -> AsyncGenerator[Task, None]: + raise asyncio.CancelledError() + yield None + + worker_app._broker_client.listen = mock_listen # type: ignore + + await worker_app._listen() + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_broker_not_initialized(worker_app: WorkerApplication): + """Test error when broker client is not initialized.""" + worker_app._broker_client = None + with pytest.raises(RuntimeError, match="Broker client not initialized"): + await worker_app._listen() diff --git a/client_sdks/python/tests/worker/test_worker_client.py b/client_sdks/python/tests/worker/test_worker_client.py deleted file mode 100644 index a820567..0000000 --- a/client_sdks/python/tests/worker/test_worker_client.py +++ /dev/null @@ -1,91 +0,0 @@ -# pyright: reportPrivateUsage=false, reportOptionalMemberAccess=false - -from src.worker import WorkerApplication -from src.manager import ManagerClient -import pytest -from uuid import uuid4 -from typing import Any - - -# Test Task Definitions. One will fail and one will complete successfully. -async def failing_task(input_data: Any): - raise Exception("Task failed") - - -async def successful_task(input_data: Any) -> Any: - return input_data - - -@pytest.mark.asyncio -async def test_worker_startup_and_task_success( - worker_application: WorkerApplication, manager_client: ManagerClient -): - """Tests that a worker can start and successfully process a task.""" - TEST_TASK_KIND = str(uuid4()) - - # Start worker - worker_application.register_task(TEST_TASK_KIND, successful_task) - await worker_application._register_worker() - - try: - # Create and fetch a task - input_data = {"test": "data"} - await manager_client.publish_task(TEST_TASK_KIND, input_data) - - async for ( - data, - task_id, - task_kind, - ) in worker_application._broker_client.listen(): - assert input_data == data - - # This should execute the task with the given function - await worker_application._execute_task(task_kind, data, task_id) - - # # Process task successfully - task = await manager_client.get_task(task_id) - assert task.has_completed - assert task.result is not None - assert task.result.data == input_data - - break - - finally: - await worker_application._unregister_worker() - - -@pytest.mark.asyncio -async def test_worker_task_failure_handling( - worker_application: WorkerApplication, manager_client: ManagerClient -): - """Tests that a worker can properly handle and report task failures.""" - TEST_TASK_KIND = str(uuid4()) - - # Start worker - worker_application.register_task(TEST_TASK_KIND, failing_task) - await worker_application._register_worker() - - try: - # Create and fetch a task - input_data = {"test": "data"} - await manager_client.publish_task(TEST_TASK_KIND, input_data) - - async for ( - data, - task_id, - task_kind, - ) in worker_application._broker_client.listen(): - assert input_data == data - - # This should execute the task with the given function - await worker_application._execute_task(task_kind, data, task_id) - - # Check that the task failed - task = await manager_client.get_task(task_id) - assert task.has_failed - assert "Task failed" in str(task.result.data) - - break - - finally: - await worker_application._unregister_worker() diff --git a/client_sdks/python/uv.lock b/client_sdks/python/uv.lock index b30daa4..1aa9617 100644 --- a/client_sdks/python/uv.lock +++ b/client_sdks/python/uv.lock @@ -71,6 +71,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/da/e7/45d57621d9caba3c7d2687618c0e12025e477bd035834cf9ec3334e82810/aiohttp-3.11.8-cp313-cp313-win_amd64.whl", hash = "sha256:481075a1949de79a8a6841e0086f2f5f464785c592cf527ed0db2c0cbd0e1ba2", size = 435403 }, ] +[[package]] +name = "aiohttp-retry" +version = "2.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9d/61/ebda4d8e3d8cfa1fd3db0fb428db2dd7461d5742cea35178277ad180b033/aiohttp_retry-2.9.1.tar.gz", hash = "sha256:8eb75e904ed4ee5c2ec242fefe85bf04240f685391c4879d8f541d6028ff01f1", size = 13608 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1a/99/84ba7273339d0f3dfa57901b846489d2e5c2cd731470167757f1935fffbd/aiohttp_retry-2.9.1-py3-none-any.whl", hash = "sha256:66d2759d1921838256a05a3f80ad7e724936f083e35be5abb5e16eed6be6dc54", size = 9981 }, +] + [[package]] name = "aioredis" version = "2.0.1" @@ -84,6 +96,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9b/a9/0da089c3ae7a31cbcd2dcf0214f6f571e1295d292b6139e2bac68ec081d0/aioredis-2.0.1-py3-none-any.whl", hash = "sha256:9ac0d0b3b485d293b8ca1987e6de8658d7dafcca1cddfcd1d506cae8cdebfdd6", size = 71243 }, ] +[[package]] +name = "aioresponses" +version = "0.7.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/de/03/532bbc645bdebcf3b6af3b25d46655259d66ce69abba7720b71ebfabbade/aioresponses-0.7.8.tar.gz", hash = "sha256:b861cdfe5dc58f3b8afac7b0a6973d5d7b2cb608dd0f6253d16b8ee8eaf6df11", size = 40253 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/12/b7/584157e43c98aa89810bc2f7099e7e01c728ecf905a66cf705106009228f/aioresponses-0.7.8-py2.py3-none-any.whl", hash = "sha256:b73bd4400d978855e55004b23a3a84cb0f018183bcf066a85ad392800b5b9a94", size = 12518 }, +] + [[package]] name = "aiormq" version = "6.8.1" @@ -109,6 +134,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/ac/a7305707cb852b7e16ff80eaf5692309bde30e2b1100a1fcacdc8f731d97/aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17", size = 7617 }, ] +[[package]] +name = "annotated-types" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 }, +] + [[package]] name = "anyio" version = "4.7.0" @@ -353,6 +387,59 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335 }, ] +[[package]] +name = "pydantic" +version = "2.10.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-types" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6a/c7/ca334c2ef6f2e046b1144fe4bb2a5da8a4c574e7f2ebf7e16b34a6a2fa92/pydantic-2.10.5.tar.gz", hash = "sha256:278b38dbbaec562011d659ee05f63346951b3a248a6f3642e1bc68894ea2b4ff", size = 761287 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/58/26/82663c79010b28eddf29dcdd0ea723439535fa917fce5905885c0e9ba562/pydantic-2.10.5-py3-none-any.whl", hash = "sha256:4dd4e322dbe55472cb7ca7e73f4b63574eecccf2835ffa2af9021ce113c83c53", size = 431426 }, +] + +[[package]] +name = "pydantic-core" +version = "2.27.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fc/01/f3e5ac5e7c25833db5eb555f7b7ab24cd6f8c322d3a3ad2d67a952dc0abc/pydantic_core-2.27.2.tar.gz", hash = "sha256:eb026e5a4c1fee05726072337ff51d1efb6f59090b7da90d30ea58625b1ffb39", size = 413443 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d6/74/51c8a5482ca447871c93e142d9d4a92ead74de6c8dc5e66733e22c9bba89/pydantic_core-2.27.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:9e0c8cfefa0ef83b4da9588448b6d8d2a2bf1a53c3f1ae5fca39eb3061e2f0b0", size = 1893127 }, + { url = "https://files.pythonhosted.org/packages/d3/f3/c97e80721735868313c58b89d2de85fa80fe8dfeeed84dc51598b92a135e/pydantic_core-2.27.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:83097677b8e3bd7eaa6775720ec8e0405f1575015a463285a92bfdfe254529ef", size = 1811340 }, + { url = "https://files.pythonhosted.org/packages/9e/91/840ec1375e686dbae1bd80a9e46c26a1e0083e1186abc610efa3d9a36180/pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:172fce187655fece0c90d90a678424b013f8fbb0ca8b036ac266749c09438cb7", size = 1822900 }, + { url = "https://files.pythonhosted.org/packages/f6/31/4240bc96025035500c18adc149aa6ffdf1a0062a4b525c932065ceb4d868/pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:519f29f5213271eeeeb3093f662ba2fd512b91c5f188f3bb7b27bc5973816934", size = 1869177 }, + { url = "https://files.pythonhosted.org/packages/fa/20/02fbaadb7808be578317015c462655c317a77a7c8f0ef274bc016a784c54/pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:05e3a55d124407fffba0dd6b0c0cd056d10e983ceb4e5dbd10dda135c31071d6", size = 2038046 }, + { url = "https://files.pythonhosted.org/packages/06/86/7f306b904e6c9eccf0668248b3f272090e49c275bc488a7b88b0823444a4/pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c3ed807c7b91de05e63930188f19e921d1fe90de6b4f5cd43ee7fcc3525cb8c", size = 2685386 }, + { url = "https://files.pythonhosted.org/packages/8d/f0/49129b27c43396581a635d8710dae54a791b17dfc50c70164866bbf865e3/pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fb4aadc0b9a0c063206846d603b92030eb6f03069151a625667f982887153e2", size = 1997060 }, + { url = "https://files.pythonhosted.org/packages/0d/0f/943b4af7cd416c477fd40b187036c4f89b416a33d3cc0ab7b82708a667aa/pydantic_core-2.27.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:28ccb213807e037460326424ceb8b5245acb88f32f3d2777427476e1b32c48c4", size = 2004870 }, + { url = "https://files.pythonhosted.org/packages/35/40/aea70b5b1a63911c53a4c8117c0a828d6790483f858041f47bab0b779f44/pydantic_core-2.27.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:de3cd1899e2c279b140adde9357c4495ed9d47131b4a4eaff9052f23398076b3", size = 1999822 }, + { url = "https://files.pythonhosted.org/packages/f2/b3/807b94fd337d58effc5498fd1a7a4d9d59af4133e83e32ae39a96fddec9d/pydantic_core-2.27.2-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:220f892729375e2d736b97d0e51466252ad84c51857d4d15f5e9692f9ef12be4", size = 2130364 }, + { url = "https://files.pythonhosted.org/packages/fc/df/791c827cd4ee6efd59248dca9369fb35e80a9484462c33c6649a8d02b565/pydantic_core-2.27.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a0fcd29cd6b4e74fe8ddd2c90330fd8edf2e30cb52acda47f06dd615ae72da57", size = 2158303 }, + { url = "https://files.pythonhosted.org/packages/9b/67/4e197c300976af185b7cef4c02203e175fb127e414125916bf1128b639a9/pydantic_core-2.27.2-cp312-cp312-win32.whl", hash = "sha256:1e2cb691ed9834cd6a8be61228471d0a503731abfb42f82458ff27be7b2186fc", size = 1834064 }, + { url = "https://files.pythonhosted.org/packages/1f/ea/cd7209a889163b8dcca139fe32b9687dd05249161a3edda62860430457a5/pydantic_core-2.27.2-cp312-cp312-win_amd64.whl", hash = "sha256:cc3f1a99a4f4f9dd1de4fe0312c114e740b5ddead65bb4102884b384c15d8bc9", size = 1989046 }, + { url = "https://files.pythonhosted.org/packages/bc/49/c54baab2f4658c26ac633d798dab66b4c3a9bbf47cff5284e9c182f4137a/pydantic_core-2.27.2-cp312-cp312-win_arm64.whl", hash = "sha256:3911ac9284cd8a1792d3cb26a2da18f3ca26c6908cc434a18f730dc0db7bfa3b", size = 1885092 }, + { url = "https://files.pythonhosted.org/packages/41/b1/9bc383f48f8002f99104e3acff6cba1231b29ef76cfa45d1506a5cad1f84/pydantic_core-2.27.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:7d14bd329640e63852364c306f4d23eb744e0f8193148d4044dd3dacdaacbd8b", size = 1892709 }, + { url = "https://files.pythonhosted.org/packages/10/6c/e62b8657b834f3eb2961b49ec8e301eb99946245e70bf42c8817350cbefc/pydantic_core-2.27.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:82f91663004eb8ed30ff478d77c4d1179b3563df6cdb15c0817cd1cdaf34d154", size = 1811273 }, + { url = "https://files.pythonhosted.org/packages/ba/15/52cfe49c8c986e081b863b102d6b859d9defc63446b642ccbbb3742bf371/pydantic_core-2.27.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71b24c7d61131bb83df10cc7e687433609963a944ccf45190cfc21e0887b08c9", size = 1823027 }, + { url = "https://files.pythonhosted.org/packages/b1/1c/b6f402cfc18ec0024120602bdbcebc7bdd5b856528c013bd4d13865ca473/pydantic_core-2.27.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fa8e459d4954f608fa26116118bb67f56b93b209c39b008277ace29937453dc9", size = 1868888 }, + { url = "https://files.pythonhosted.org/packages/bd/7b/8cb75b66ac37bc2975a3b7de99f3c6f355fcc4d89820b61dffa8f1e81677/pydantic_core-2.27.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce8918cbebc8da707ba805b7fd0b382816858728ae7fe19a942080c24e5b7cd1", size = 2037738 }, + { url = "https://files.pythonhosted.org/packages/c8/f1/786d8fe78970a06f61df22cba58e365ce304bf9b9f46cc71c8c424e0c334/pydantic_core-2.27.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eda3f5c2a021bbc5d976107bb302e0131351c2ba54343f8a496dc8783d3d3a6a", size = 2685138 }, + { url = "https://files.pythonhosted.org/packages/a6/74/d12b2cd841d8724dc8ffb13fc5cef86566a53ed358103150209ecd5d1999/pydantic_core-2.27.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd8086fa684c4775c27f03f062cbb9eaa6e17f064307e86b21b9e0abc9c0f02e", size = 1997025 }, + { url = "https://files.pythonhosted.org/packages/a0/6e/940bcd631bc4d9a06c9539b51f070b66e8f370ed0933f392db6ff350d873/pydantic_core-2.27.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8d9b3388db186ba0c099a6d20f0604a44eabdeef1777ddd94786cdae158729e4", size = 2004633 }, + { url = "https://files.pythonhosted.org/packages/50/cc/a46b34f1708d82498c227d5d80ce615b2dd502ddcfd8376fc14a36655af1/pydantic_core-2.27.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7a66efda2387de898c8f38c0cf7f14fca0b51a8ef0b24bfea5849f1b3c95af27", size = 1999404 }, + { url = "https://files.pythonhosted.org/packages/ca/2d/c365cfa930ed23bc58c41463bae347d1005537dc8db79e998af8ba28d35e/pydantic_core-2.27.2-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:18a101c168e4e092ab40dbc2503bdc0f62010e95d292b27827871dc85450d7ee", size = 2130130 }, + { url = "https://files.pythonhosted.org/packages/f4/d7/eb64d015c350b7cdb371145b54d96c919d4db516817f31cd1c650cae3b21/pydantic_core-2.27.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ba5dd002f88b78a4215ed2f8ddbdf85e8513382820ba15ad5ad8955ce0ca19a1", size = 2157946 }, + { url = "https://files.pythonhosted.org/packages/a4/99/bddde3ddde76c03b65dfd5a66ab436c4e58ffc42927d4ff1198ffbf96f5f/pydantic_core-2.27.2-cp313-cp313-win32.whl", hash = "sha256:1ebaf1d0481914d004a573394f4be3a7616334be70261007e47c2a6fe7e50130", size = 1834387 }, + { url = "https://files.pythonhosted.org/packages/71/47/82b5e846e01b26ac6f1893d3c5f9f3a2eb6ba79be26eef0b759b4fe72946/pydantic_core-2.27.2-cp313-cp313-win_amd64.whl", hash = "sha256:953101387ecf2f5652883208769a79e48db18c6df442568a0b5ccd8c2723abee", size = 1990453 }, + { url = "https://files.pythonhosted.org/packages/51/b2/b2b50d5ecf21acf870190ae5d093602d95f66c9c31f9d5de6062eb329ad1/pydantic_core-2.27.2-cp313-cp313-win_arm64.whl", hash = "sha256:ac4dbfd1691affb8f48c2c13241a2e3b60ff23247cbcf981759c768b6633cf8b", size = 1885186 }, +] + [[package]] name = "pytest" version = "8.3.4" @@ -422,8 +509,11 @@ source = { editable = "." } dependencies = [ { name = "aio-pika" }, { name = "aiohttp" }, + { name = "aiohttp-retry" }, { name = "aioredis" }, + { name = "aioresponses" }, { name = "click" }, + { name = "pydantic" }, { name = "uuid" }, { name = "watchfiles" }, ] @@ -440,8 +530,11 @@ dev = [ requires-dist = [ { name = "aio-pika", specifier = ">=9.5.3" }, { name = "aiohttp", specifier = ">=3.11.8" }, + { name = "aiohttp-retry", specifier = ">=2.9.1" }, { name = "aioredis", specifier = ">=2.0.1" }, + { name = "aioresponses", specifier = ">=0.7.8" }, { name = "click", specifier = ">=8.1.7" }, + { name = "pydantic", specifier = ">=2.10.5" }, { name = "uuid", specifier = ">=1.30" }, { name = "watchfiles", specifier = ">=1.0.3" }, ] diff --git a/docker-compose.yml b/docker-compose.yml index fdf1071..a377e21 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -16,6 +16,11 @@ services: condition: service_healthy postgres: condition: service_healthy + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:3000/health"] + interval: 5s + timeout: 5s + retries: 5 env_file: - server/services/manager/.env.docker diff --git a/server/libs/common/src/brokers/mod.rs b/server/libs/common/src/brokers/mod.rs index a50a546..dbe443f 100644 --- a/server/libs/common/src/brokers/mod.rs +++ b/server/libs/common/src/brokers/mod.rs @@ -38,13 +38,13 @@ where pub async fn setup_consumer_broker( url_str: &str, queue: &str, - is_running: Arc, + shutdown: Arc, ) -> Result>, Box> where T: Debug + Send + Sync + serde::de::DeserializeOwned + 'static, { match url_str.split_once("://") { - Some(("amqp", _)) => Ok(setup_rabbit_consumer::(url_str, queue, is_running).await?), + Some(("amqp", _)) => Ok(setup_rabbit_consumer::(url_str, queue, shutdown).await?), _ => Err("Unsupported broker".into()), } } diff --git a/server/libs/common/src/brokers/rabbit.rs b/server/libs/common/src/brokers/rabbit.rs index da5d77f..e7beb2b 100644 --- a/server/libs/common/src/brokers/rabbit.rs +++ b/server/libs/common/src/brokers/rabbit.rs @@ -14,6 +14,7 @@ use std::{ Arc, }, }; +use tracing::warn; #[derive(Clone, Debug)] pub struct RabbitMQConsumer @@ -39,7 +40,14 @@ where let channel = connection.create_channel().await?; channel - .queue_declare(queue, QueueDeclareOptions::default(), FieldTable::default()) + .queue_declare( + queue, + QueueDeclareOptions { + durable: true, + ..QueueDeclareOptions::default() + }, + FieldTable::default(), + ) .await?; Ok(Self { @@ -72,6 +80,7 @@ where while let Some(delivery) = consumer.next().await { if self.shutdown.load(Ordering::SeqCst) { + warn!("Shutting down consumer due to shutdown signal"); break; } @@ -111,8 +120,11 @@ where channel .exchange_declare( exchange, - ExchangeKind::Direct, - ExchangeDeclareOptions::default(), + ExchangeKind::Topic, + ExchangeDeclareOptions { + durable: true, + ..ExchangeDeclareOptions::default() + }, FieldTable::default(), ) .await?; @@ -162,12 +174,12 @@ where pub async fn setup_rabbit_consumer( url_string: &str, queue: &str, - is_running: Arc, + shutdown: Arc, ) -> Result>, Box> where T: Debug, { Ok(Arc::new( - RabbitMQConsumer::::new(url_string, queue, is_running).await?, + RabbitMQConsumer::::new(url_string, queue, shutdown).await?, )) } diff --git a/server/services/manager/src/constants.rs b/server/services/manager/src/constants.rs index cc1f056..b0563ab 100644 --- a/server/services/manager/src/constants.rs +++ b/server/services/manager/src/constants.rs @@ -1,5 +1,5 @@ // This is the file for all the project constants -pub const TASK_INPUT_QUEUE: &str = "task_input"; -pub const TASK_RESULT_QUEUE: &str = "task_result"; +pub const TASK_INPUT_QUEUE: &str = "task_assignment_queue"; +pub const TASK_RESULT_QUEUE: &str = "task_results"; -pub const TASK_OUTPUT_EXCHANGE: &str = "task_input"; +pub const TASK_OUTPUT_EXCHANGE: &str = "task_assignment_exchange"; diff --git a/server/services/manager/src/main.rs b/server/services/manager/src/main.rs index cc13d3c..6f36d00 100644 --- a/server/services/manager/src/main.rs +++ b/server/services/manager/src/main.rs @@ -10,7 +10,7 @@ use common::brokers::{setup_consumer_broker, setup_publisher_broker}; use server::Server; use std::sync::{atomic::AtomicBool, Arc}; use tokio::sync::oneshot; -use tracing::{info, info_span}; +use tracing::{info, info_span, warn}; use axum::Router; use axum_tracing_opentelemetry::middleware::{OtelAxumLayer, OtelInResponseLayer}; @@ -96,7 +96,7 @@ async fn initialize_system( ), Box, > { - let is_running = Arc::new(AtomicBool::new(true)); + let shutdown = Arc::new(AtomicBool::new(false)); let db_pools = setup_db_pools(config).await; info!("Database connection pools created"); @@ -107,12 +107,12 @@ async fn initialize_system( .expect("Failed to setup publisher broker"); let task_result_consumer = - setup_consumer_broker::(&config.broker_addr, TASK_RESULT_QUEUE, is_running.clone()) + setup_consumer_broker::(&config.broker_addr, TASK_RESULT_QUEUE, shutdown.clone()) .await .expect("Failed to setup task result consumer"); let new_task_consumer = - setup_consumer_broker::(&config.broker_addr, TASK_INPUT_QUEUE, is_running.clone()) + setup_consumer_broker::(&config.broker_addr, TASK_INPUT_QUEUE, shutdown.clone()) .await .expect("Failed to setup task instance consumer"); info!("Brokers initialized"); @@ -188,9 +188,15 @@ async fn main() { // Wait for shutdown tokio::select! { - _ = input_handle => {}, - _ = result_handle => {}, - _ = server_handle => {}, + _ = input_handle => { + warn!("Task input controller shutdown"); + }, + _ = result_handle => { + warn!("Task result controller shutdown"); + }, + _ = server_handle => { + warn!("Server shutdown"); + }, } info!("Cleanup complete");