Skip to content

Commit

Permalink
feat: ws gateway client (#366)
Browse files Browse the repository at this point in the history
  • Loading branch information
chamini2 authored Nov 29, 2024
1 parent bea0df7 commit d9d552c
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 1 deletion.
130 changes: 129 additions & 1 deletion projects/fal/src/fal/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, Iterator
from typing import TYPE_CHECKING, Any, Iterator

import httpx

from fal import flags
from fal.sdk import Credentials, get_default_credentials

if TYPE_CHECKING:
from websockets.sync.connection import Connection

_QUEUE_URL_FORMAT = f"https://queue.{flags.FAL_RUN_HOST}/{{app_id}}"
_REALTIME_URL_FORMAT = f"wss://{flags.FAL_RUN_HOST}/{{app_id}}"
_WS_URL_FORMAT = f"wss://ws.{flags.FAL_RUN_HOST}/{{app_id}}"


def _backwards_compatible_app_id(app_id: str) -> str:
Expand Down Expand Up @@ -245,3 +249,127 @@ def _connect(app_id: str, *, path: str = "/realtime") -> Iterator[_RealtimeConne
url, additional_headers=creds.to_headers(), open_timeout=90
) as ws:
yield _RealtimeConnection(ws)


class _MetaMessageFound(Exception): ...


@dataclass
class _WSConnection:
"""A WS connection to an HTTP Fal app."""

_ws: Connection
_buffer: str | bytes | None = None

def run(self, arguments: dict[str, Any]) -> dict[str, Any]:
"""Run an inference task on the app and return the result."""
self.send(arguments)
return self.recv()

def send(self, arguments: dict[str, Any]) -> None:
import json

payload = json.dumps(arguments)
self._ws.send(payload)

def _peek(self) -> bytes | str:
if self._buffer is None:
self._buffer = self._ws.recv()

return self._buffer

def _consume(self) -> None:
if self._buffer is None:
raise ValueError("No data to consume")

self._buffer = None

@contextmanager
def _recv(self) -> Iterator[str | bytes]:
res = self._peek()

yield res

# Only consume if it went through the context manager without raising
self._consume()

def _is_meta(self, res: str | bytes) -> bool:
if not isinstance(res, str):
return False

try:
json_payload: Any = json.loads(res)
except json.JSONDecodeError:
return False

if not isinstance(json_payload, dict):
return False

return "type" in json_payload and "request_id" in json_payload

def _recv_meta(self, type: str) -> dict[str, Any]:
with self._recv() as res:
if not self._is_meta(res):
raise ValueError(f"Expected a {type} message")

json_payload: dict = json.loads(res)
if json_payload.get("type") != type:
raise ValueError(f"Expected a {type} message")

return json_payload

def _recv_response(self) -> Any:
import msgpack

body: bytes = b""
while True:
try:
with self._recv() as res:
if self._is_meta(res):
# Keep the meta message for later
raise _MetaMessageFound()

if isinstance(res, str):
return res
else:
body += res
except _MetaMessageFound:
break

if not body:
raise ValueError("Empty response body")

return msgpack.unpackb(body)

def recv(self) -> Any:
start = self._recv_meta("start")
request_id = start["request_id"]

response = self._recv_response()

end = self._recv_meta("end")
if end["request_id"] != request_id:
raise ValueError("Mismatched request_id in end message")

return response


@contextmanager
def ws(app_id: str, *, path: str = "") -> Iterator[_WSConnection]:
"""Connect to a HTTP endpoint but with websocket protocol. This is an internal and
experimental API, use it at your own risk."""

from websockets.sync import client

app_id = _backwards_compatible_app_id(app_id)
url = _WS_URL_FORMAT.format(app_id=app_id)
if path:
_path = path[len("/") :] if path.startswith("/") else path
url += "/" + _path

creds = get_default_credentials()

with client.connect(
url, additional_headers=creds.to_headers(), open_timeout=90
) as ws:
yield _WSConnection(ws)
15 changes: 15 additions & 0 deletions projects/fal/tests/test_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,21 @@ def test_app_client(test_app: str, test_nomad_app: str):
assert response["result"] == 5


def test_ws_client(test_app: str):
with apps.ws(test_app) as connection:
for i in range(3):
response = json.loads(connection.run({"lhs": 1, "rhs": i}))
assert response["result"] == 1 + i

for i in range(3):
connection.send({"lhs": 2, "rhs": i})

for i in range(3):
# they should be in order
response = json.loads(connection.recv())
assert response["result"] == 2 + i


def test_app_client_old_format(test_app: str):
assert test_app.count("/") == 1, "Test app should be in new format"
old_format = test_app.replace("/", "-")
Expand Down

0 comments on commit d9d552c

Please sign in to comment.