Skip to content

Commit 882de79

Browse files
miri-barpazshalev
andauthored
fix: Bug fixes - embed, streaming response, request retry rename (#206)
Co-authored-by: Paz Shalev <[email protected]>
1 parent 307cfc6 commit 882de79

File tree

5 files changed

+52
-8
lines changed

5 files changed

+52
-8
lines changed

ai21/http_client/async_http_client.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
wait=wait_exponential(multiplier=RETRY_BACK_OFF_FACTOR, min=TIME_BETWEEN_RETRIES),
5656
retry=retry_if_result(self._should_retry),
5757
stop=stop_after_attempt(self._num_retries),
58-
)(self._request)
58+
)(self._run_request)
5959
self._streaming_decoder = _SSEDecoder()
6060

6161
async def execute_http_request(
@@ -103,11 +103,16 @@ async def execute_http_request(
103103
logger.error(
104104
f"Calling {method} {self._base_url} failed with a non-200 response code: {response.status_code}"
105105
)
106-
handle_non_success_response(response.status_code, response.text)
106+
107+
if stream:
108+
details = self._extract_streaming_error_details(response)
109+
handle_non_success_response(response.status_code, details)
110+
else:
111+
handle_non_success_response(response.status_code, response.text)
107112

108113
return response
109114

110-
async def _request(self, options: RequestOptions) -> httpx.Response:
115+
async def _run_request(self, options: RequestOptions) -> httpx.Response:
111116
request = self._build_request(options)
112117

113118
_logger.debug(f"Calling {request.method} {request.url} {request.headers}, {options.body}")

ai21/http_client/base_http_client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def execute_http_request(
118118
pass
119119

120120
@abstractmethod
121-
def _request(
121+
def _run_request(
122122
self,
123123
options: RequestOptions,
124124
) -> httpx.Response:
@@ -171,3 +171,9 @@ def _prepare_url(self, options: RequestOptions) -> str:
171171
return f"{options.url}{options.path}"
172172

173173
return options.url
174+
175+
def _extract_streaming_error_details(self, response: httpx.Response) -> str:
176+
try:
177+
return response.read().decode("utf-8")
178+
except Exception:
179+
return "could not extract streaming error details"

ai21/http_client/http_client.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
wait=wait_exponential(multiplier=RETRY_BACK_OFF_FACTOR, min=TIME_BETWEEN_RETRIES),
5555
retry=retry_if_result(self._should_retry),
5656
stop=stop_after_attempt(self._num_retries),
57-
)(self._request)
57+
)(self._run_request)
5858
self._streaming_decoder = _SSEDecoder()
5959

6060
def execute_http_request(
@@ -102,11 +102,16 @@ def execute_http_request(
102102
f"Calling {method} {self._base_url} failed with a non-200 "
103103
f"response code: {response.status_code} headers: {response.headers}"
104104
)
105-
handle_non_success_response(response.status_code, response.text)
105+
106+
if stream:
107+
details = self._extract_streaming_error_details(response)
108+
handle_non_success_response(response.status_code, details)
109+
else:
110+
handle_non_success_response(response.status_code, response.text)
106111

107112
return response
108113

109-
def _request(self, options: RequestOptions) -> httpx.Response:
114+
def _run_request(self, options: RequestOptions) -> httpx.Response:
110115
request = self._build_request(options)
111116

112117
_logger.debug(f"Calling {request.method} {request.url} {request.headers}, {options.body}")

ai21/models/responses/embed_response.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
class EmbedResult(AI21BaseModel):
77
embedding: List[float]
88

9+
def __init__(self, embedding: List[float]):
10+
super().__init__(embedding=embedding)
11+
912

1013
class EmbedResponse(AI21BaseModel):
1114
id: str

tests/unittests/test_http_client.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import httpx
66

7-
from ai21.errors import ServiceUnavailable
7+
from ai21.errors import ServiceUnavailable, Unauthorized
88
from ai21.http_client.base_http_client import RETRY_ERROR_CODES
99
from ai21.http_client.http_client import AI21HTTPClient
1010
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
@@ -42,6 +42,17 @@ def test__execute_http_request__when_retry_error__should_retry_and_stop(mock_htt
4242
assert mock_httpx_client.send.call_count == retries
4343

4444

45+
def test__execute_http_request__when_streaming__should_handle_non_200_response_code(mock_httpx_client: Mock) -> None:
46+
error_details = "test_error"
47+
request = Request(method=_METHOD, url=_URL)
48+
response = httpx.Response(status_code=401, request=request, text=error_details)
49+
mock_httpx_client.send.return_value = response
50+
51+
client = AI21HTTPClient(client=mock_httpx_client, base_url=_URL, api_key=_API_KEY)
52+
with pytest.raises(Unauthorized, match=error_details):
53+
client.execute_http_request(method=_METHOD, stream=True)
54+
55+
4556
@pytest.mark.asyncio
4657
async def test__execute_async_http_request__when_retry_error_code_once__should_retry_and_succeed(
4758
mock_httpx_async_client: Mock,
@@ -74,3 +85,17 @@ async def test__execute_async_http_request__when_retry_error__should_retry_and_s
7485
await client.execute_http_request(method=_METHOD)
7586

7687
assert mock_httpx_async_client.send.call_count == retries
88+
89+
90+
@pytest.mark.asyncio
91+
async def test__execute_async_http_request__when_streaming__should_handle_non_200_response_code(
92+
mock_httpx_async_client: Mock,
93+
) -> None:
94+
error_details = "test_error"
95+
request = Request(method=_METHOD, url=_URL)
96+
response = httpx.Response(status_code=401, request=request, text=error_details)
97+
mock_httpx_async_client.send.return_value = response
98+
99+
client = AsyncAI21HTTPClient(client=mock_httpx_async_client, base_url=_URL, api_key=_API_KEY)
100+
with pytest.raises(Unauthorized, match=error_details):
101+
await client.execute_http_request(method=_METHOD, stream=True)

0 commit comments

Comments
 (0)