Skip to content

Commit

Permalink
Merge pull request #179 from MOV-AI/task/BP-1262/codebase-improvements
Browse files Browse the repository at this point in the history
BP-1262: Small codebase improvements
  • Loading branch information
andreparames authored Jan 30, 2025
2 parents fee4ddb + 3b88cb8 commit 1899822
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 41 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# v3.0.2
- [BP-1262]((https://movai.atlassian.net/browse/BP-1262): Small codebase improvements

# v3.0.1
- [BP-1340](https://movai.atlassian.net/browse/BP-1340): Migrate movai-core-shared to py-workflow@v2

Expand Down
15 changes: 8 additions & 7 deletions movai_core_shared/common/time.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from datetime import datetime, timedelta
from typing import Union

from movai_core_shared.exceptions import TimeError

Expand Down Expand Up @@ -30,23 +31,23 @@ def current_timestamp_int() -> int:
return int(datetime.now().timestamp())


def delta_time_int(delta: int) -> int:
def delta_time_int(delta: timedelta) -> int:
"""returns a future time in timestamp format.
Args:
expiration_delta (int): the time delta from now.
expiration_delta (timedelta): the time delta from now.
Returns:
int: an int representing the time delta.
"""
return int((datetime.now() + delta).timestamp())


def delta_time_float(delta: int) -> float:
def delta_time_float(delta: timedelta) -> float:
"""returns a future time in timestamp format.
Args:
expiration_delta (int): the time delta from now.
expiration_delta (timedelta): the time delta from now.
Returns:
float: an float representing the time delta.
Expand All @@ -73,11 +74,11 @@ def validate_timestamp(timestamp: int) -> int:
raise TimeError("The supplied time argument is not in timestamp format!") from exc


def validate_time(value: int) -> str:
def validate_time(value: Union[int, str]) -> int:
"""Validate if value is timestamp or datetime
Args:
value (int): The datetime to validate
value (int|str): The datetime to validate
Raises:
ValueError: In case value isn't a time format.
Expand Down
39 changes: 29 additions & 10 deletions movai_core_shared/core/message_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,25 @@
- Ofer Katz ([email protected]) - 2022
- Erez Zomer ([email protected]) - 2022
"""
from datetime import datetime
import time
from typing import TYPE_CHECKING, Optional, cast

from movai_core_shared.core.zmq.zmq_manager import ZMQManager, ZMQType
from movai_core_shared.core.zmq.zmq_manager import ZMQManager, ZMQType, AsyncZMQClient
from movai_core_shared.envvars import DEVICE_NAME, FLEET_NAME, SERVICE_NAME
from movai_core_shared.exceptions import ArgumentError, MessageFormatError

if TYPE_CHECKING:
from movai_core_shared.core.zmq.zmq_client import ZMQClient


class MessageClient:
"""
This class is the client for message-server.
It wraps the data into the message structure and send it to
the message-server using ZMQClient.
"""
_zmq_client: "ZMQClient"

def __init__(self, server_addr: str, robot_id: str = "") -> None:
"""
Expand All @@ -48,7 +54,6 @@ def __init__(self, server_addr: str, robot_id: str = "") -> None:
"service": SERVICE_NAME,
"id": robot_id,
}
self._zmq_client = None
self._init_zmq_client()

def _init_zmq_client(self) -> None:
Expand All @@ -58,26 +63,28 @@ def _init_zmq_client(self) -> None:
self._zmq_client = ZMQManager.get_client(self._server_addr, ZMQType.CLIENT)

def _build_request(
self, msg_type: str, data: dict, creation_time: str = None, response_required: bool = False
self, msg_type: str, data: dict, creation_time: Optional[datetime] = None, response_required: bool = False
) -> dict:
"""Build a request in the format accepted by the message server.
Args:
msg_type (str): The type of the message (logs, alerts, metrics....)
data (dict): The data to include in the request.
creation_time (str, optional): The time the request was created.
creation_time (str, optional): The time the request was created. Defaults to now.
response_required (bool, optional): Tells the message-server if the client is wainting for response.
Returns:
{dict}: The message request to send the message-server
"""
if creation_time is None:
creation_time = time.time_ns()
creation_time_ns = time.time_ns()
else:
creation_time_ns = creation_time.timestamp() * 1000000000 + creation_time.microsecond * 1000

request = {
"request": {
"req_type": msg_type,
"created": creation_time,
"created": creation_time_ns,
"response_required": response_required,
"req_data": data,
"robot_info": self._robot_info,
Expand Down Expand Up @@ -108,15 +115,19 @@ def _fetch_response(self, msg) -> dict:
return response

def send_request(
self, msg_type: str, data: dict, creation_time: str = None, response_required: bool = False
self,
msg_type: str,
data: dict,
creation_time: Optional[datetime] = None,
response_required: bool = False,
) -> dict:
"""
Wrap the data into a message request and sent it to the robot message server
Args:
msg_type (str): the type of message.
data (dict): The message data to be sent to the robot message server.
creation_time (str): The time where the request is created.
creation_time (datetime, optional): The time where the request is created. Defaults to now.
response_required (bool): whether to wait for response, Default False.
"""
# Add tags to the request data
Expand Down Expand Up @@ -167,14 +178,22 @@ def send_msg(self, data: dict, **kwargs) -> None:


class AsyncMessageClient(MessageClient):
_zmq_client: AsyncZMQClient

def _init_zmq_client(self) -> None:
"""
Initializes the ZMQ attributute.
"""
self._zmq_client = ZMQManager.get_client(self._server_addr, ZMQType.ASYNC_CLIENT)
self._zmq_client = cast(
AsyncZMQClient, ZMQManager.get_client(self._server_addr, ZMQType.ASYNC_CLIENT)
)

async def send_request(
self, msg_type: str, data: dict, creation_time: str = None, response_required: bool = False
self,
msg_type: str,
data: dict,
creation_time: Optional[datetime] = None,
response_required: bool = False,
) -> dict:
"""
Wrap the data into a message request and sent it asynchonously to the robot message server
Expand Down
11 changes: 6 additions & 5 deletions movai_core_shared/core/zmq/zmq_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,18 @@ def handle_socket_errors(self, exc: zmq.error.ZMQError, reset_socket=True) -> No
exc: the exception
"""
if exc.errno == errno.ENOTSOCK:
self._logger.warning(f"ZMQ socket error: {self._addr} got exception: {exc}.")
self._logger.warning("ZMQ socket error: %s got exception: %s.", self._addr, exc)
if reset_socket:
self._logger.warning("Resetting ZMQ {self._addr} with potential data loss.")
self._logger.warning("Resetting ZMQ %s with potential data loss.", self._addr)
self.reset(force=True)
elif exc.errno == errno.EAGAIN:
self._logger.warning(f"ZMQ socket error: {self._addr} got exception: {exc}.")
self._logger.warning("ZMQ socket error: %s got exception: %s.", self._addr, exc)
if reset_socket:
self._logger.warning("Resetting ZMQ {self._addr}.")
self._logger.warning("Resetting ZMQ %s.", self._addr)
self.reset()
else:
self._logger.error(
f"ZMQ socket error: {self._addr} got unhandled ZMQ exception: {exc} "
"ZMQ socket error: %s got unhandled ZMQ exception: %s ", self._addr, exc
)

def send(self, msg: dict, use_lock: bool = False) -> None:
Expand Down Expand Up @@ -156,6 +156,7 @@ def receive(self, use_lock: bool = False) -> dict:
class AsyncZMQClient(ZMQClient):
"""An Async implementation of ZMQ Client"""

_socket: zmq.asyncio.Socket
_context = zmq.asyncio.Context()

def init_lock(self) -> None:
Expand Down
5 changes: 3 additions & 2 deletions movai_core_shared/core/zmq/zmq_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import json
from logging import getLogger
import random
from typing import List

from movai_core_shared.envvars import DEVICE_NAME, SERVICE_NAME
from movai_core_shared.exceptions import MessageError
Expand Down Expand Up @@ -41,11 +42,11 @@ def create_msg(msg: dict):
return None


def extract_reponse(buffer: bytes):
def extract_reponse(buffer: List[bytes]) -> dict:
"""Extracts the response from the buffer.
Args:
buffer (bytes): The memory buffer which contains the response msg.
buffer: List of memory buffers containing the message.
Returns:
(dict): A response from server.
Expand Down
13 changes: 9 additions & 4 deletions movai_core_shared/core/zmq/zmq_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
"""
from enum import Enum
from logging import getLogger
from typing import Dict, Type, TypedDict

from beartype import beartype

from movai_core_shared.core.zmq.zmq_base import ZMQBase
from movai_core_shared.core.zmq.zmq_client import ZMQClient, AsyncZMQClient
from movai_core_shared.core.zmq.zmq_subscriber import ZMQSubscriber, AsyncZMQSubscriber
from movai_core_shared.core.zmq.zmq_publisher import ZMQPublisher, AsyncZMQPublisher
Expand All @@ -31,7 +31,12 @@ class ZMQType(Enum):
ASYNC_SUBSCRIBER = 6


ZMQ_TYPES = {
class ZMQTypeValue(TypedDict):
type: Type[ZMQClient]
identity: str


ZMQ_TYPES: Dict[ZMQType, ZMQTypeValue] = {
ZMQType.CLIENT: {"type": ZMQClient, "identity": "dealer"},
ZMQType.ASYNC_CLIENT: {"type": AsyncZMQClient, "identity": "dealer"},
ZMQType.PUBLISHER: {"type": ZMQPublisher, "identity": "pub"},
Expand All @@ -45,7 +50,7 @@ class ZMQManager:
"""This class will host ZMQ objects by their type and address."""

_logger = getLogger("ZMQManager")
_clients = {
_clients: Dict[ZMQType, Dict[str, ZMQClient]] = {
ZMQType.CLIENT: {},
ZMQType.ASYNC_CLIENT: {},
ZMQType.PUBLISHER: {},
Expand All @@ -63,7 +68,7 @@ def validate_server_addr(cls, server_addr: str):

@classmethod
@beartype
def _get_or_create_zmq_object(cls, server_addr: str, zmq_type: ZMQType) -> ZMQBase:
def _get_or_create_zmq_object(cls, server_addr: str, zmq_type: ZMQType) -> ZMQClient:
if zmq_type not in cls._clients:
raise TypeError(f"{zmq_type} does not exist!")

Expand Down
8 changes: 5 additions & 3 deletions movai_core_shared/core/zmq/zmq_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from typing import List

import zmq
import zmq.asyncio
Expand Down Expand Up @@ -54,11 +55,12 @@ async def spin(self) -> None:
"""accepts new connections requests to zmq."""
try:
self.init_server()
assert self._socket
if self._running:
self._logger.warning("%s is already running", self._name)
self._running = True
except Exception:
self._logger.error("Failed to start %s", self._name)
except Exception as e:
self._logger.error("Failed to start %s: %s", self._name, e)
return

await self.at_startup()
Expand Down Expand Up @@ -120,7 +122,7 @@ def stop(self):
self._running = False

@abstractmethod
async def handle(self, buffer: bytes) -> None:
async def handle(self, buffer: List[bytes]) -> None:
pass

async def at_startup(self):
Expand Down
13 changes: 5 additions & 8 deletions movai_core_shared/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,13 +536,10 @@ async def get_logs(
"count_field": "message",
}

try:
query_response = await message_client.send_request(
LOGS_QUERY_HANDLER_MSG_TYPE, query_data, None, True
)
if "response" in query_response:
response = query_response["response"]
except Exception as error:
raise error
query_response = await message_client.send_request(
LOGS_QUERY_HANDLER_MSG_TYPE, query_data, None, True
)
if "response" in query_response:
response = query_response["response"]

return response if pagination else response.get("data", [])
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "movai-core-shared"
version = "3.0.1.1"
version = "3.0.2.0"
authors = [
{name = "Backend team", email = "[email protected]"},
]
Expand Down Expand Up @@ -36,7 +36,7 @@ exclude = ["movai_core_shared.tests*"]
line-length = 100

[tool.bumpversion]
current_version = "3.0.1.1"
current_version = "3.0.2.0"
parse = "(?P<major>\\d+)\\.(?P<minor>\\d+)\\.(?P<patch>\\d+)?(\\.(?P<build>\\d+))?"
serialize = ["{major}.{minor}.{patch}.{build}"]

Expand Down

0 comments on commit 1899822

Please sign in to comment.