diff --git a/projects/fal/src/fal/apps.py b/projects/fal/src/fal/apps.py index 208b7491..30ec503f 100644 --- a/projects/fal/src/fal/apps.py +++ b/projects/fal/src/fal/apps.py @@ -4,15 +4,19 @@ import time from contextlib import contextmanager from dataclasses import dataclass, field -from typing import Any, Iterator +from typing import TYPE_CHECKING, Any, Iterator import httpx from fal import flags from fal.sdk import Credentials, get_default_credentials +if TYPE_CHECKING: + from websockets.sync.connection import Connection + _QUEUE_URL_FORMAT = f"https://queue.{flags.FAL_RUN_HOST}/{{app_id}}" _REALTIME_URL_FORMAT = f"wss://{flags.FAL_RUN_HOST}/{{app_id}}" +_WS_URL_FORMAT = f"wss://ws.{flags.FAL_RUN_HOST}/{{app_id}}" def _backwards_compatible_app_id(app_id: str) -> str: @@ -245,3 +249,127 @@ def _connect(app_id: str, *, path: str = "/realtime") -> Iterator[_RealtimeConne url, additional_headers=creds.to_headers(), open_timeout=90 ) as ws: yield _RealtimeConnection(ws) + + +class _MetaMessageFound(Exception): ... + + +@dataclass +class _WSConnection: + """A WS connection to an HTTP Fal app.""" + + _ws: Connection + _buffer: str | bytes | None = None + + def run(self, arguments: dict[str, Any]) -> dict[str, Any]: + """Run an inference task on the app and return the result.""" + self.send(arguments) + return self.recv() + + def send(self, arguments: dict[str, Any]) -> None: + import json + + payload = json.dumps(arguments) + self._ws.send(payload) + + def _peek(self) -> bytes | str: + if self._buffer is None: + self._buffer = self._ws.recv() + + return self._buffer + + def _consume(self) -> None: + if self._buffer is None: + raise ValueError("No data to consume") + + self._buffer = None + + @contextmanager + def _recv(self) -> Iterator[str | bytes]: + res = self._peek() + + yield res + + # Only consume if it went through the context manager without raising + self._consume() + + def _is_meta(self, res: str | bytes) -> bool: + if not isinstance(res, str): + return False + + try: + json_payload: Any = json.loads(res) + except json.JSONDecodeError: + return False + + if not isinstance(json_payload, dict): + return False + + return "type" in json_payload and "request_id" in json_payload + + def _recv_meta(self, type: str) -> dict[str, Any]: + with self._recv() as res: + if not self._is_meta(res): + raise ValueError(f"Expected a {type} message") + + json_payload: dict = json.loads(res) + if json_payload.get("type") != type: + raise ValueError(f"Expected a {type} message") + + return json_payload + + def _recv_response(self) -> Any: + import msgpack + + body: bytes = b"" + while True: + try: + with self._recv() as res: + if self._is_meta(res): + # Keep the meta message for later + raise _MetaMessageFound() + + if isinstance(res, str): + return res + else: + body += res + except _MetaMessageFound: + break + + if not body: + raise ValueError("Empty response body") + + return msgpack.unpackb(body) + + def recv(self) -> Any: + start = self._recv_meta("start") + request_id = start["request_id"] + + response = self._recv_response() + + end = self._recv_meta("end") + if end["request_id"] != request_id: + raise ValueError("Mismatched request_id in end message") + + return response + + +@contextmanager +def ws(app_id: str, *, path: str = "") -> Iterator[_WSConnection]: + """Connect to a HTTP endpoint but with websocket protocol. This is an internal and + experimental API, use it at your own risk.""" + + from websockets.sync import client + + app_id = _backwards_compatible_app_id(app_id) + url = _WS_URL_FORMAT.format(app_id=app_id) + if path: + _path = path[len("/") :] if path.startswith("/") else path + url += "/" + _path + + creds = get_default_credentials() + + with client.connect( + url, additional_headers=creds.to_headers(), open_timeout=90 + ) as ws: + yield _WSConnection(ws) diff --git a/projects/fal/tests/test_apps.py b/projects/fal/tests/test_apps.py index 75ffcd36..c03a31d1 100644 --- a/projects/fal/tests/test_apps.py +++ b/projects/fal/tests/test_apps.py @@ -395,6 +395,21 @@ def test_app_client(test_app: str, test_nomad_app: str): assert response["result"] == 5 +def test_ws_client(test_app: str): + with apps.ws(test_app) as connection: + for i in range(3): + response = json.loads(connection.run({"lhs": 1, "rhs": i})) + assert response["result"] == 1 + i + + for i in range(3): + connection.send({"lhs": 2, "rhs": i}) + + for i in range(3): + # they should be in order + response = json.loads(connection.recv()) + assert response["result"] == 2 + i + + def test_app_client_old_format(test_app: str): assert test_app.count("/") == 1, "Test app should be in new format" old_format = test_app.replace("/", "-")