Skip to content

Commit 564df3e

Browse files
committed
feat(fal): better endpoint error
1 parent 4aba448 commit 564df3e

File tree

2 files changed

+27
-14
lines changed

2 files changed

+27
-14
lines changed

Diff for: projects/fal/src/fal/app.py

+20-7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import time
1010
import typing
1111
from contextlib import asynccontextmanager, contextmanager
12+
from dataclasses import dataclass
1213
from typing import Any, Callable, ClassVar, Literal, TypeVar
1314

1415
import httpx
@@ -17,7 +18,7 @@
1718
import fal.api
1819
from fal._serialization import include_modules_from
1920
from fal.api import RouteSignature
20-
from fal.exceptions import RequestCancelledException
21+
from fal.exceptions import FalServerlessException, RequestCancelledException
2122
from fal.logging import get_logger
2223
from fal.toolkit.file import get_lifecycle_preference
2324
from fal.toolkit.file.providers.fal import GLOBAL_LIFECYCLE_PREFERENCE
@@ -76,6 +77,12 @@ def initialize_and_serve():
7677
return fn
7778

7879

80+
@dataclass
81+
class AppClientError(FalServerlessException):
82+
message: str
83+
status_code: int
84+
85+
7986
class EndpointClient:
8087
def __init__(self, url, endpoint, signature, timeout: int | None = None):
8188
self.url = url
@@ -88,17 +95,19 @@ def __init__(self, url, endpoint, signature, timeout: int | None = None):
8895

8996
def __call__(self, data):
9097
with httpx.Client() as client:
98+
url = self.url + self.signature.path
9199
resp = client.post(
92100
self.url + self.signature.path,
93101
json=data.dict() if hasattr(data, "dict") else dict(data),
94102
timeout=self.timeout,
95103
)
96-
try:
97-
resp.raise_for_status()
98-
except httpx.HTTPStatusError:
104+
if not resp.is_success:
99105
# allow logs to be printed before raising the exception
100106
time.sleep(1)
101-
raise
107+
raise AppClientError(
108+
f"Failed to POST {url}: {resp.status_code} {resp.text}",
109+
status_code=resp.status_code,
110+
)
102111
resp_dict = resp.json()
103112

104113
if not self.return_type:
@@ -151,12 +160,16 @@ def _print_logs():
151160
with httpx.Client() as client:
152161
retries = 100
153162
for _ in range(retries):
154-
resp = client.get(info.url + "/health", timeout=60)
163+
url = info.url + "/health"
164+
resp = client.get(url, timeout=60)
155165

156166
if resp.is_success:
157167
break
158168
elif resp.status_code not in (500, 404):
159-
resp.raise_for_status()
169+
raise AppClientError(
170+
f"Failed to GET {url}: {resp.status_code} {resp.text}",
171+
status_code=resp.status_code,
172+
)
160173
time.sleep(0.1)
161174

162175
client = cls(app_cls, info.url)

Diff for: projects/fal/tests/test_apps.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import httpx
1313
import pytest
1414
from fal import apps
15-
from fal.app import AppClient
15+
from fal.app import AppClient, AppClientError
1616
from fal.cli.deploy import _get_user
1717
from fal.container import ContainerImage
1818
from fal.exceptions import AppException, FieldException, RequestCancelledException
@@ -692,7 +692,7 @@ def test_workflows(test_app: str):
692692
def test_traceback_logs(test_exception_app: AppClient):
693693
date = datetime.utcnow().isoformat()
694694

695-
with pytest.raises(HTTPStatusError):
695+
with pytest.raises(AppClientError):
696696
test_exception_app.fail({})
697697

698698
with httpx.Client(
@@ -714,17 +714,17 @@ def test_traceback_logs(test_exception_app: AppClient):
714714

715715

716716
def test_app_exceptions(test_exception_app: AppClient):
717-
with pytest.raises(HTTPStatusError) as app_exc:
717+
with pytest.raises(AppClientError) as app_exc:
718718
test_exception_app.app_exception({})
719719

720-
assert app_exc.value.response.status_code == 401
720+
assert app_exc.status_code == 401
721721

722-
with pytest.raises(HTTPStatusError) as field_exc:
722+
with pytest.raises(AppClientError) as field_exc:
723723
test_exception_app.field_exception({"lhs": 1, "rhs": "2"})
724724

725-
assert field_exc.value.response.status_code == 422
725+
assert field_exc.status_code == 422
726726

727-
with pytest.raises(HTTPStatusError) as cuda_exc:
727+
with pytest.raises(AppClientError) as cuda_exc:
728728
test_exception_app.cuda_exception({})
729729

730730
assert cuda_exc.value.response.status_code == _CUDA_OOM_STATUS_CODE

0 commit comments

Comments
 (0)