Skip to content

[WIP]feat: add multiproc async #224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions examples/map/async_multiproc_map/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
####################################################################################################
# builder: install needed dependencies
####################################################################################################

FROM python:3.10-slim-bullseye AS builder

ENV PYTHONFAULTHANDLER=1 \
PYTHONUNBUFFERED=1 \
PYTHONHASHSEED=random \
PIP_NO_CACHE_DIR=on \
PIP_DISABLE_PIP_VERSION_CHECK=on \
PIP_DEFAULT_TIMEOUT=100 \
POETRY_VERSION=1.2.2 \
POETRY_HOME="/opt/poetry" \
POETRY_VIRTUALENVS_IN_PROJECT=true \
POETRY_NO_INTERACTION=1 \
PYSETUP_PATH="/opt/pysetup"

ENV EXAMPLE_PATH="$PYSETUP_PATH/examples/map/async_multiproc_map"
ENV VENV_PATH="$EXAMPLE_PATH/.venv"
ENV PATH="$POETRY_HOME/bin:$VENV_PATH/bin:$PATH"

RUN apt-get update \
&& apt-get install --no-install-recommends -y \
curl \
wget \
# deps for building python deps
build-essential \
&& apt-get install -y git \
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
\
# install dumb-init
&& wget -O /dumb-init https://github.com/Yelp/dumb-init/releases/download/v1.2.5/dumb-init_1.2.5_x86_64 \
&& chmod +x /dumb-init \
&& curl -sSL https://install.python-poetry.org | python3 -

####################################################################################################
# udf: used for running the udf vertices
####################################################################################################
FROM builder AS udf

WORKDIR $PYSETUP_PATH
COPY ./ ./

WORKDIR $EXAMPLE_PATH
RUN poetry lock
RUN poetry install --no-cache --no-root && \
rm -rf ~/.cache/pypoetry/

RUN chmod +x entry.sh

ENTRYPOINT ["/dumb-init", "--"]
CMD ["sh", "-c", "$EXAMPLE_PATH/entry.sh"]

EXPOSE 5000
22 changes: 22 additions & 0 deletions examples/map/async_multiproc_map/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
TAG ?= v6
PUSH ?= false
IMAGE_REGISTRY = quay.io/skohli/numaflow-python/async-multiproc:${TAG}
DOCKER_FILE_PATH = examples/map/async_multiproc_map/Dockerfile

.PHONY: update
update:
poetry update -vv

.PHONY: image-push
image-push: update
cd ../../../ && docker buildx build \
-f ${DOCKER_FILE_PATH} \
-t ${IMAGE_REGISTRY} \
--platform linux/amd64,linux/arm64 . --push

.PHONY: image
image: update
cd ../../../ && docker build \
-f ${DOCKER_FILE_PATH} \
-t ${IMAGE_REGISTRY} .
@if [ "$(PUSH)" = "true" ]; then docker push ${IMAGE_REGISTRY}; fi
20 changes: 20 additions & 0 deletions examples/map/async_multiproc_map/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Multiprocessing Map

`pynumaflow` supports only asyncio based Reduce UDFs because we found that procedural Python is not able to handle
any substantial traffic.

This features enables the `pynumaflow` developer to utilise multiprocessing capabilities while
writing UDFs using the map function. These are particularly useful for CPU intensive operations,
as it allows for better resource utilisation.

In this mode we would spawn N number (N = Cpu count) of grpc servers in different processes, where each of them are
listening on multiple TCP sockets.

To enable multiprocessing mode start the multiproc server in the UDF using the following command,
providing the optional argument `server_count` to specify the number of
servers to be forked (defaults to `os.cpu_count` if not provided):
```python
if __name__ == "__main__":
grpc_server = MapMultiProcServer(handler, server_count = 3)
grpc_server.start()
```
4 changes: 4 additions & 0 deletions examples/map/async_multiproc_map/entry.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/sh
set -eux

python example.py
40 changes: 40 additions & 0 deletions examples/map/async_multiproc_map/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os

from pynumaflow.mapper import Messages, Message, Datum, Mapper, AsyncMapMultiprocServer
from pynumaflow._constants import _LOGGER


class FlatMap(Mapper):
"""
This class needs to be of type Mapper class to be used
as a handler for the MapServer class.
Example of a mapper that calculates if a number is prime.
"""

async def handler(self, keys: list[str], datum: Datum) -> Messages:
val = datum.value
_ = datum.event_time
_ = datum.watermark
messages = Messages()
messages.append(Message(val, keys=keys))
_LOGGER.info(f"MY PID {os.getpid()}")
return messages


if __name__ == "__main__":
"""
Example of starting a multiprocessing map vertex.
"""
# To set the env server_count value set the env variable
# NUM_CPU_MULTIPROC="N"
server_count = int(os.getenv("NUM_CPU_MULTIPROC", "2"))
server_type = os.getenv("SERVER_KIND", "tcp")
use_tcp = False
if server_type == "tcp":
use_tcp = True
elif server_type == "uds":
use_tcp = False
_class = FlatMap()
# Server count is the number of server processes to start
grpc_server = AsyncMapMultiprocServer(_class, server_count=server_count, use_tcp=use_tcp)
grpc_server.start()
42 changes: 42 additions & 0 deletions examples/map/async_multiproc_map/pipeline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
apiVersion: numaflow.numaproj.io/v1alpha1
kind: Pipeline
metadata:
name: simple-pipeline
spec:
limits:
readBatchSize: 10
vertices:
- name: in
source:
# A self data generating source
generator:
rpu: 200
duration: 1s
- name: mult
udf:
container:
image: quay.io/skohli/numaflow-python/async-multiproc:v5
# imagePullPolicy: Always
env:
- name: SERVER_KIND
value: "uds"
- name: PYTHONDEBUG
value: "true"
- name: NUM_CPU_MULTIPROC
value: "3" # DO NOT forget the double quotes!!!
containerTemplate:
env:
- name: NUMAFLOW_RUNTIME
value: "rust"
- name: NUMAFLOW_DEBUG
value: "true" # DO NOT forget the double quotes!!!

- name: out
sink:
# A simple log printing sink
log: {}
edges:
- from: in
to: mult
- from: mult
to: out
15 changes: 15 additions & 0 deletions examples/map/async_multiproc_map/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[tool.poetry]
name = "async-multiproc-forward-message"
version = "0.2.4"
description = ""
authors = ["Numaflow developers"]

[tool.poetry.dependencies]
python = ">=3.10,<3.13"
pynumaflow = { path = "../../../"}

[tool.poetry.dev-dependencies]

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
2 changes: 1 addition & 1 deletion pynumaflow/batchmapper/servicer/async_servicer.py
Original file line number Diff line number Diff line change
@@ -98,7 +98,7 @@ async def MapFn(

except BaseException as err:
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING)
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING, False)
return

async def IsReady(
1 change: 1 addition & 0 deletions pynumaflow/info/types.py
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@
# MULTIPROC_KEY is the field used to indicate that Multiproc map mode is enabled
# The value contains the number of servers spawned.
MULTIPROC_KEY = "MULTIPROC"
MULTIPROC_ENDPOINTS = "MULTIPROC_ENDPOINTS"

SI = TypeVar("SI", bound="ServerInfo")

2 changes: 2 additions & 0 deletions pynumaflow/mapper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pynumaflow.mapper.async_multiproc_server import AsyncMapMultiprocServer
from pynumaflow.mapper.async_server import MapAsyncServer
from pynumaflow.mapper.multiproc_server import MapMultiprocServer
from pynumaflow.mapper.sync_server import MapServer
@@ -13,4 +14,5 @@
"MapServer",
"MapAsyncServer",
"MapMultiprocServer",
"AsyncMapMultiprocServer",
]
76 changes: 49 additions & 27 deletions pynumaflow/mapper/_servicer/_async_servicer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import contextlib
from collections.abc import AsyncIterable

from google.protobuf import empty_pb2 as _empty_pb2
@@ -18,11 +19,10 @@
Provides the functionality for the required rpc methods.
"""

def __init__(
self,
handler: MapAsyncCallable,
):
def __init__(self, handler: MapAsyncCallable, multiproc: bool = False):
self.background_tasks = set()
# This indicates whether the grpc server attached is multiproc or not
self.multiproc = multiproc
self.__map_handler: MapAsyncCallable = handler

async def MapFn(
@@ -36,6 +36,7 @@
"""
# proto repeated field(keys) is of type google._upb._message.RepeatedScalarContainer
# we need to explicitly convert it to list
producer = None
try:
# The first message to be received should be a valid handshake
req = await request_iterator.__anext__()
@@ -56,44 +57,65 @@
async for msg in consumer:
# If the message is an exception, we raise the exception
if isinstance(msg, BaseException):
await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING)
await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING, self.multiproc)
return
# Send window response back to the client
else:
yield msg
# wait for the producer task to complete
await producer
except GeneratorExit:
_LOGGER.info("Client disconnected, generator closed.")
raise

Check warning on line 69 in pynumaflow/mapper/_servicer/_async_servicer.py

Codecov / codecov/patch

pynumaflow/mapper/_servicer/_async_servicer.py#L68-L69

Added lines #L68 - L69 were not covered by tests
except BaseException as e:
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING, self.multiproc)
return
finally:
if producer and not producer.done():
producer.cancel()
with contextlib.suppress(asyncio.CancelledError):
await producer

Check warning on line 78 in pynumaflow/mapper/_servicer/_async_servicer.py

Codecov / codecov/patch

pynumaflow/mapper/_servicer/_async_servicer.py#L76-L78

Added lines #L76 - L78 were not covered by tests

async def _process_inputs(
self,
request_iterator: AsyncIterable[map_pb2.MapRequest],
result_queue: NonBlockingIterator,
):
"""
Utility function for processing incoming MapRequests
"""
async def _process_inputs(self, request_iterator, result_queue):
try:
# for each incoming request, create a background task to execute the
# UDF code
async for req in request_iterator:
msg_task = asyncio.create_task(self._invoke_map(req, result_queue))
# save a reference to a set to store active tasks
self.background_tasks.add(msg_task)
msg_task.add_done_callback(self.background_tasks.discard)

# wait for all tasks to complete
for task in self.background_tasks:
await task
task = asyncio.create_task(self._invoke_map(req, result_queue))
self.background_tasks.add(task)
task.add_done_callback(self.background_tasks.discard)

# send an EOF to result queue to indicate that all tasks have completed
await asyncio.gather(*self.background_tasks)
except BaseException:
_LOGGER.critical("MapFn Error in _process_inputs", exc_info=True)
finally:
await result_queue.put(STREAM_EOF)

except BaseException:
_LOGGER.critical("MapFn Error, re-raising the error", exc_info=True)
# async def _process_inputs(
# self,
# request_iterator: AsyncIterable[map_pb2.MapRequest],
# result_queue: NonBlockingIterator,
# ):
# """
# Utility function for processing incoming MapRequests
# """
# try:
# # for each incoming request, create a background task to execute the
# # UDF code
# async for req in request_iterator:
# msg_task = asyncio.create_task(self._invoke_map(req, result_queue))
# # save a reference to a set to store active tasks
# self.background_tasks.add(msg_task)
# msg_task.add_done_callback(self.background_tasks.discard)
#
# # wait for all tasks to complete
# for task in self.background_tasks:
# await task
#
# # send an EOF to result queue to indicate that all tasks have completed
# await result_queue.put(STREAM_EOF)
#
# except BaseException:
# _LOGGER.critical("MapFn Error, re-raising the error", exc_info=True)

async def _invoke_map(self, req: map_pb2.MapRequest, result_queue: NonBlockingIterator):
"""
138 changes: 138 additions & 0 deletions pynumaflow/mapper/async_multiproc_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import logging
import multiprocessing
from typing import Optional

import aiorun
import grpc

from pynumaflow._constants import (
MAX_NUM_THREADS,
MAX_MESSAGE_SIZE,
MAP_SERVER_INFO_FILE_PATH,
_PROCESS_COUNT,
NUM_THREADS_DEFAULT,
MULTIPROC_MAP_SOCK_ADDR,
)
from pynumaflow.info.server import get_metadata_env
from pynumaflow.info.types import (
ServerInfo,
MINIMUM_NUMAFLOW_VERSION,
ContainerType,
MAP_MODE_KEY,
MapMode,
METADATA_ENVS,
MULTIPROC_KEY,
MULTIPROC_ENDPOINTS,
Protocol,
)
from pynumaflow.mapper._dtypes import MapAsyncCallable
from pynumaflow.mapper._servicer._async_servicer import AsyncMapServicer
from pynumaflow.proto.mapper import map_pb2_grpc
from pynumaflow.shared.server import start_async_server, NumaflowServer, reserve_port
from pynumaflow.info.server import write as info_server_write

_LOGGER = logging.getLogger(__name__)


class AsyncMapMultiprocServer(NumaflowServer):
"""
A multiprocess asynchronous gRPC server for Numaflow Map UDFs.
Spawns N worker processes, each running an asyncio-based gRPC server.
"""

def __init__(
self,
mapper_instance: MapAsyncCallable,
server_count: int = _PROCESS_COUNT,
sock_path: str = MULTIPROC_MAP_SOCK_ADDR,
max_message_size: int = MAX_MESSAGE_SIZE,
max_threads: int = NUM_THREADS_DEFAULT,
server_info_file: Optional[str] = MAP_SERVER_INFO_FILE_PATH,
use_tcp: bool = False,
):
self.sock_path = f"unix://{sock_path}"
self.max_threads = min(max_threads, MAX_NUM_THREADS)
self.max_message_size = max_message_size
self.server_info_file = server_info_file
self.use_tcp = use_tcp

Check warning on line 57 in pynumaflow/mapper/async_multiproc_server.py

Codecov / codecov/patch

pynumaflow/mapper/async_multiproc_server.py#L53-L57

Added lines #L53 - L57 were not covered by tests

self.mapper_instance = mapper_instance

Check warning on line 59 in pynumaflow/mapper/async_multiproc_server.py

Codecov / codecov/patch

pynumaflow/mapper/async_multiproc_server.py#L59

Added line #L59 was not covered by tests

self._server_options = [

Check warning on line 61 in pynumaflow/mapper/async_multiproc_server.py

Codecov / codecov/patch

pynumaflow/mapper/async_multiproc_server.py#L61

Added line #L61 was not covered by tests
("grpc.max_send_message_length", self.max_message_size),
("grpc.max_receive_message_length", self.max_message_size),
("grpc.so_reuseport", 1),
("grpc.so_reuseaddr", 1),
]

self._process_count = min(server_count, 2 * _PROCESS_COUNT)
self.servicer = AsyncMapServicer(handler=self.mapper_instance, multiproc=True)

Check warning on line 69 in pynumaflow/mapper/async_multiproc_server.py

Codecov / codecov/patch

pynumaflow/mapper/async_multiproc_server.py#L68-L69

Added lines #L68 - L69 were not covered by tests

def start(self):
"""
Starts the multiprocess async gRPC servers.
"""
_LOGGER.info("Starting async multiprocess gRPC server with %d workers", self._process_count)

Check warning on line 75 in pynumaflow/mapper/async_multiproc_server.py

Codecov / codecov/patch

pynumaflow/mapper/async_multiproc_server.py#L75

Added line #L75 was not covered by tests

workers = []
ports = []

Check warning on line 78 in pynumaflow/mapper/async_multiproc_server.py

Codecov / codecov/patch

pynumaflow/mapper/async_multiproc_server.py#L77-L78

Added lines #L77 - L78 were not covered by tests

for idx in range(self._process_count):
if self.use_tcp:
with reserve_port(0) as reserved_port:
bind_address = f"0.0.0.0:{reserved_port}"
ports.append(f"http://{bind_address}")

Check warning on line 84 in pynumaflow/mapper/async_multiproc_server.py

Codecov / codecov/patch

pynumaflow/mapper/async_multiproc_server.py#L82-L84

Added lines #L82 - L84 were not covered by tests
else:
bind_address = f"{self.sock_path}{idx}.sock"
_LOGGER.info("Binding server to: %s", bind_address)

Check warning on line 87 in pynumaflow/mapper/async_multiproc_server.py

Codecov / codecov/patch

pynumaflow/mapper/async_multiproc_server.py#L86-L87

Added lines #L86 - L87 were not covered by tests

worker = multiprocessing.Process(

Check warning on line 89 in pynumaflow/mapper/async_multiproc_server.py

Codecov / codecov/patch

pynumaflow/mapper/async_multiproc_server.py#L89

Added line #L89 was not covered by tests
target=self._run_server_process,
args=(bind_address,),
)
worker.start()
workers.append(worker)

Check warning on line 94 in pynumaflow/mapper/async_multiproc_server.py

Codecov / codecov/patch

pynumaflow/mapper/async_multiproc_server.py#L93-L94

Added lines #L93 - L94 were not covered by tests

# Write server info file
if self.server_info_file:
server_info = ServerInfo.get_default_server_info()
server_info.metadata[MULTIPROC_KEY] = str(self._process_count)
server_info.metadata[MAP_MODE_KEY] = MapMode.UnaryMap

Check warning on line 100 in pynumaflow/mapper/async_multiproc_server.py

Codecov / codecov/patch

pynumaflow/mapper/async_multiproc_server.py#L98-L100

Added lines #L98 - L100 were not covered by tests
if self.use_tcp:
server_info.protocol = Protocol.TCP
server_info.metadata[MULTIPROC_ENDPOINTS] = ",".join(map(str, ports))
info_server_write(server_info=server_info, info_file=self.server_info_file)

Check warning on line 104 in pynumaflow/mapper/async_multiproc_server.py

Codecov / codecov/patch

pynumaflow/mapper/async_multiproc_server.py#L102-L104

Added lines #L102 - L104 were not covered by tests

for worker in workers:
worker.join()

Check warning on line 107 in pynumaflow/mapper/async_multiproc_server.py

Codecov / codecov/patch

pynumaflow/mapper/async_multiproc_server.py#L107

Added line #L107 was not covered by tests

def _run_server_process(self, bind_address):
async def run_server():
server = grpc.aio.server(options=self._server_options)
server.add_insecure_port(bind_address)
map_pb2_grpc.add_MapServicer_to_server(self.servicer, server)

server_info = None
if self.server_info_file:
server_info = ServerInfo.get_default_server_info()
server_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[
ContainerType.Mapper
]
server_info.metadata = get_metadata_env(envs=METADATA_ENVS)
if self.use_tcp:
server_info.protocol = Protocol.TCP
# Add the MULTIPROC metadata using the number of servers to use
server_info.metadata[MULTIPROC_KEY] = str(self._process_count)
# Add the MAP_MODE metadata to the server info for the correct map mode
server_info.metadata[MAP_MODE_KEY] = MapMode.UnaryMap

await start_async_server(
server_async=server,
sock_path=bind_address,
max_threads=self.max_threads,
cleanup_coroutines=list(),
server_info_file=None,
server_info=server_info,
)

aiorun.run(run_server(), use_uvloop=True)
2 changes: 1 addition & 1 deletion pynumaflow/mapper/async_server.py
Original file line number Diff line number Diff line change
@@ -82,7 +82,7 @@ def __init__(
("grpc.max_receive_message_length", self.max_message_size),
]
# Get the servicer instance for the async server
self.servicer = AsyncMapServicer(handler=mapper_instance)
self.servicer = AsyncMapServicer(handler=mapper_instance, multiproc=False)

def start(self) -> None:
"""
2 changes: 1 addition & 1 deletion pynumaflow/mapstreamer/servicer/async_servicer.py
Original file line number Diff line number Diff line change
@@ -59,7 +59,7 @@ async def MapFn(
yield map_pb2.MapResponse(status=map_pb2.TransmissionStatus(eot=True), id=req.id)
except BaseException as err:
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING)
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING, False)
return

async def __invoke_map_stream(self, keys: list[str], req: Datum):
4 changes: 2 additions & 2 deletions pynumaflow/reducer/servicer/async_servicer.py
Original file line number Diff line number Diff line change
@@ -105,7 +105,7 @@ async def ReduceFn(
_LOGGER.critical("Reduce Error", exc_info=True)
# Send a context abort signal for the rpc, this is required for numa container to get
# the correct grpc error
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING, False)

# send EOF to all the tasks once the request iterator is exhausted
# This will signal the tasks to stop reading the data on their
@@ -136,7 +136,7 @@ async def ReduceFn(
_LOGGER.critical("Reduce Error", exc_info=True)
# Send a context abort signal for the rpc, this is required for numa container to get
# the correct grpc error
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING, False)

async def IsReady(
self, request: _empty_pb2.Empty, context: NumaflowServicerContext
6 changes: 3 additions & 3 deletions pynumaflow/reducestreamer/servicer/async_servicer.py
Original file line number Diff line number Diff line change
@@ -95,20 +95,20 @@
async for msg in consumer:
# If the message is an exception, we raise the exception
if isinstance(msg, BaseException):
await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING)
await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING, False)
return
# Send window EOF response or Window result response
# back to the client
else:
yield msg
except BaseException as e:
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING, False)

Check warning on line 105 in pynumaflow/reducestreamer/servicer/async_servicer.py

Codecov / codecov/patch

pynumaflow/reducestreamer/servicer/async_servicer.py#L105

Added line #L105 was not covered by tests
return
# Wait for the process_input_stream task to finish for a clean exit
try:
await producer
except BaseException as e:
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING, False)

Check warning on line 111 in pynumaflow/reducestreamer/servicer/async_servicer.py

Codecov / codecov/patch

pynumaflow/reducestreamer/servicer/async_servicer.py#L111

Added line #L111 was not covered by tests
return

async def IsReady(
17 changes: 9 additions & 8 deletions pynumaflow/shared/server.py
Original file line number Diff line number Diff line change
@@ -180,7 +180,7 @@ async def start_async_server(
sock_path: str,
max_threads: int,
cleanup_coroutines: list,
server_info_file: str,
server_info_file: Optional[str] = None,
server_info: Optional[ServerInfo] = None,
):
"""
@@ -190,11 +190,9 @@ async def start_async_server(
"""
await server_async.start()

if server_info is None:
# Create the server info file if not provided
server_info = ServerInfo.get_default_server_info()
# Add the server information to the server info file
info_server_write(server_info=server_info, info_file=server_info_file)
if server_info_file:
info_server_write(server_info=server_info, info_file=server_info_file)

# Log the server start
_LOGGER.info(
@@ -217,7 +215,7 @@ async def server_graceful_shutdown():


@contextlib.contextmanager
def _reserve_port(port_num: int) -> Iterator[int]:
def reserve_port(port_num: int) -> Iterator[int]:
"""Find and reserve a port for all subprocesses to use."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@@ -312,7 +310,10 @@ def get_exception_traceback_str(exc) -> str:


async def handle_async_error(
context: NumaflowServicerContext, exception: BaseException, exception_type: str
context: NumaflowServicerContext,
exception: BaseException,
exception_type: str,
parent: bool = False,
):
"""
Handle exceptions for async servers by updating the context and exiting.
@@ -322,4 +323,4 @@ async def handle_async_error(
await asyncio.gather(
context.abort(grpc.StatusCode.INTERNAL, details=err_msg), return_exceptions=True
)
exit_on_error(err=err_msg, parent=False, context=context, update_context=False)
exit_on_error(err=err_msg, parent=parent, context=context, update_context=False)
2 changes: 1 addition & 1 deletion pynumaflow/sinker/servicer/async_servicer.py
Original file line number Diff line number Diff line change
@@ -85,7 +85,7 @@ async def SinkFn(
# if there is an exception, we will mark all the responses as a failure
err_msg = f"UDSinkError: {repr(err)}"
_LOGGER.critical(err_msg, exc_info=True)
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING)
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING, False)
return

async def __invoke_sink(
10 changes: 5 additions & 5 deletions pynumaflow/sourcer/servicer/async_servicer.py
Original file line number Diff line number Diff line change
@@ -108,7 +108,7 @@ async def ReadFn(

async for resp in riter:
if isinstance(resp, BaseException):
await handle_async_error(context, resp)
await handle_async_error(context, resp, ERR_UDF_EXCEPTION_STRING, False)
return

yield _create_read_response(resp)
@@ -119,7 +119,7 @@ async def ReadFn(
yield _create_eot_response()
except BaseException as err:
_LOGGER.critical("User-Defined Source ReadFn error", exc_info=True)
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING)
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING, False)

async def __invoke_read(self, req, niter):
"""Invoke the read handler and manage the iterator."""
@@ -165,7 +165,7 @@ async def AckFn(
yield _create_ack_response()
except BaseException as err:
_LOGGER.critical("User-Defined Source AckFn error", exc_info=True)
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING)
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING, False)

async def IsReady(
self, request: _empty_pb2.Empty, context: NumaflowServicerContext
@@ -187,7 +187,7 @@ async def PendingFn(
count = await self.__source_pending_handler()
except BaseException as err:
_LOGGER.critical("PendingFn Error", exc_info=True)
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING)
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING, False)
return
resp = source_pb2.PendingResponse.Result(count=count.count)
return source_pb2.PendingResponse(result=resp)
@@ -202,7 +202,7 @@ async def PartitionsFn(
partitions = await self.__source_partitions_handler()
except BaseException as err:
_LOGGER.critical("PartitionsFn Error", exc_info=True)
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING)
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING, False)
return
resp = source_pb2.PartitionsResponse.Result(partitions=partitions.partitions)
return source_pb2.PartitionsResponse(result=resp)
85 changes: 85 additions & 0 deletions tests/map/test_async_multiproc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import uuid
from pynumaflow.mapper import Datum, Messages, Message

sock_prefix = f"/tmp/test_async_multiproc_map_{uuid.uuid4().hex}_"


async def async_handler(keys, datum: Datum) -> Messages:
msg = (
f"payload:{datum.value.decode()} event_time:{datum.event_time} watermark:{datum.watermark}"
)
return Messages(Message(value=msg.encode(), keys=keys))


#
# class TestAsyncMapMultiprocServer(unittest.TestCase):
# def setUp(self):
# self.base_sock_path = sock_prefix
# self.server = AsyncMapMultiprocServer(
# mapper_instance=async_handler,
# server_count=2,
# sock_path=self.base_sock_path,
# use_tcp=False,
# server_info_file=None,
# )
# self.process = Process(target=self.server.start)
# self.process.start()
#
# # Wait for both servers to bind
# self.socket_paths = [f"{self.base_sock_path}{i}.sock" for i in range(2)]
# for path in self.socket_paths:
# for _ in range(10):
# if os.path.exists(path):
# break
# time.sleep(0.5)
#
# def tearDown(self):
# self.process.terminate()
# self.process.join()
# for path in self.socket_paths:
# try:
# os.remove(path)
# except FileNotFoundError:
# pass
#
# def test_map_fn(self):
# bind_address = f"unix://{self.socket_paths[0]}"
# request = get_test_datums()
# with grpc.insecure_channel(bind_address) as channel:
# stub = map_pb2_grpc.MapStub(channel)
# responses_iter = stub.MapFn(request_iterator=request_generator(request))
# responses = []
# # capture the output from the ReadFn generator and assert.
# for r in responses_iter:
# responses.append(r)
#
# # 1 handshake + 3 data responses
# self.assertEqual(4, len(responses))
#
# self.assertTrue(responses[0].handshake.sot)
#
# idx = 1
# while idx < len(responses):
# _id = "test-id-" + str(idx)
# self.assertEqual(_id, responses[idx].id)
# self.assertEqual(
# bytes(
# "payload:test_mock_message "
# "event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00",
# encoding="utf-8",
# ),
# responses[idx].results[0].value,
# )
# self.assertEqual(1, len(responses[idx].results))
# idx += 1
#
# def test_server_start(self):
# for path in self.socket_paths:
# self.assertTrue(
# os.path.exists(path), f"Server socket {path} was not created successfully"
# )

#
# if __name__ == "__main__":
# unittest.main()
#
7 changes: 1 addition & 6 deletions tests/source/test_async_source_err.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,6 @@
from grpc.aio._server import Server

from pynumaflow import setup_logging
from pynumaflow._constants import ERR_UDF_EXCEPTION_STRING
from pynumaflow.sourcer import SourceAsyncServer
from pynumaflow.proto.sourcer import source_pb2_grpc
from google.protobuf import empty_pb2 as _empty_pb2
@@ -93,11 +92,7 @@ def test_read_error(self) -> None:
for _ in generator_response:
pass
except BaseException as e:
self.assertTrue(
f"{ERR_UDF_EXCEPTION_STRING}: TypeError("
'"handle_async_error() missing 1 required positional argument: '
"'exception_type'\")" in e.__str__()
)
self.assertTrue("Got a runtime error from read handler" in e.__str__())
return
except grpc.RpcError as e:
grpc_exception = e