Skip to content

Commit d1933e4

Browse files
authored
feat: add support for studio async client (#129)
* feat: async support - stream, http, ai21 http * fix: commit changes * feat: studio resource, chat, chat completions, answer * feat: beta, dataset, completion, custom model * feat: embed, gec, improvements * feat: paraphrase, segmentation, summarize, by segment * feat: library * feat: client * refactor: sync and async http, ai21 http, ai21 client, resources * test: update imports, create tests for async * fix: base client * fix: add pytest marker asyncio * fix: add pytest asyncio to poetry * fix: add delete to lib files, add examples, test examples * fix: tests * fix: fix stream, add stream tests, add readme * fix: fix import on sm stub * feat: async support - stream, http, ai21 http * fix: commit changes * feat: studio resource, chat, chat completions, answer * feat: beta, dataset, completion, custom model * feat: embed, gec, improvements * feat: paraphrase, segmentation, summarize, by segment * feat: library * feat: client * refactor: sync and async http, ai21 http, ai21 client, resources * test: update imports, create tests for async * fix: base client * fix: add pytest marker asyncio * fix: add pytest asyncio to poetry * fix: add delete to lib files, add examples, test examples * fix: tests * fix: fix stream, add stream tests, add readme * fix: fix import on sm stub * fix: fix async http client, fix tests * fix: remove commented out code * fix: CR comments * fix: fix failing test * fix: fix failing test * fix: fix failing test * fix: fix library test * fix: cr comments
1 parent 8b91187 commit d1933e4

File tree

93 files changed

+3126
-380
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

93 files changed

+3126
-380
lines changed

README.md

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,10 @@ client = AI21Client(
167167
api_key='my_api_key',
168168
)
169169

170+
system = "You're a support engineer in a SaaS company"
170171
messages = [
171-
# Could be a dict or a ChatMessage object
172-
ChatMessage(content="Hello, this is a readme", role="user"),
173-
ChatMessage(content="You are correct, how can I help you?", role="assistant"),
172+
ChatMessage(content=system, role="system"),
173+
ChatMessage(content="Hello, I need help with a signup process.", role="user"),
174174
]
175175

176176
chat_completions = client.chat.completions.create(
@@ -179,6 +179,42 @@ chat_completions = client.chat.completions.create(
179179
)
180180
```
181181

182+
### Async Usage
183+
184+
You can use the `AsyncAI21Client` to make asynchronous requests.
185+
There is no difference between the sync and the async client in terms of usage.
186+
187+
```python
188+
import asyncio
189+
190+
from ai21 import AsyncAI21Client
191+
from ai21.models.chat import ChatMessage
192+
193+
system = "You're a support engineer in a SaaS company"
194+
messages = [
195+
ChatMessage(content=system, role="system"),
196+
ChatMessage(content="Hello, I need help with a signup process.", role="user"),
197+
]
198+
199+
client = AsyncAI21Client(
200+
# defaults to os.enviorn.get('AI21_API_KEY')
201+
api_key='my_api_key',
202+
)
203+
204+
205+
async def main():
206+
response = await client.chat.completions.create(
207+
messages=messages,
208+
model="jamba-instruct-preview",
209+
)
210+
211+
print(response)
212+
213+
214+
asyncio.run(main())
215+
216+
```
217+
182218
A more detailed example can be found [here](examples/studio/chat/chat_completions.py).
183219

184220
## Older Models Support Usage
@@ -260,6 +296,33 @@ for chunk in response:
260296

261297
```
262298

299+
### Async Streaming
300+
301+
```python
302+
import asyncio
303+
304+
from ai21 import AsyncAI21Client
305+
from ai21.models.chat import ChatMessage
306+
307+
messages = [ChatMessage(content="What is the meaning of life?", role="user")]
308+
309+
client = AsyncAI21Client()
310+
311+
312+
async def main():
313+
response = await client.chat.completions.create(
314+
messages=messages,
315+
model="jamba-instruct-preview",
316+
stream=True,
317+
)
318+
async for chunk in response:
319+
print(chunk.choices[0].delta.content, end="")
320+
321+
322+
asyncio.run(main())
323+
324+
```
325+
263326
---
264327

265328
## More Models

ai21/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ai21.ai21_env_config import AI21EnvConfig
44
from ai21.clients.studio.ai21_client import AI21Client
5+
from ai21.clients.studio.async_ai21_client import AsyncAI21Client
56
from ai21.errors import (
67
AI21APIError,
78
APITimeoutError,
@@ -53,6 +54,7 @@ def __getattr__(name: str) -> Any:
5354
__all__ = [
5455
"AI21EnvConfig",
5556
"AI21Client",
57+
"AsyncAI21Client",
5658
"AI21APIError",
5759
"APITimeoutError",
5860
"AI21Error",

ai21/ai21_http_client/__init__.py

Whitespace-only changes.
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from typing import Optional, Dict, Any, BinaryIO
2+
3+
import httpx
4+
5+
from ai21.http_client.http_client import HttpClient
6+
from ai21.ai21_http_client.base_ai21_http_client import BaseAI21HTTPClient
7+
8+
9+
class AI21HTTPClient(BaseAI21HTTPClient):
10+
def __init__(
11+
self,
12+
*,
13+
api_key: Optional[str] = None,
14+
requires_api_key: bool = True,
15+
base_url: Optional[str] = None,
16+
api_version: Optional[str] = None,
17+
headers: Optional[Dict[str, Any]] = None,
18+
timeout_sec: Optional[int] = None,
19+
num_retries: Optional[int] = None,
20+
via: Optional[str] = None,
21+
http_client: Optional[HttpClient] = None,
22+
):
23+
super().__init__(
24+
api_key=api_key,
25+
requires_api_key=requires_api_key,
26+
base_url=base_url,
27+
api_version=api_version,
28+
headers=headers,
29+
timeout_sec=timeout_sec,
30+
num_retries=num_retries,
31+
via=via,
32+
)
33+
34+
headers = self._build_headers(passed_headers=headers)
35+
self._http_client = self._init_http_client(http_client=http_client, headers=headers)
36+
37+
def _init_http_client(self, http_client: Optional[HttpClient], headers: Dict[str, Any]) -> HttpClient:
38+
if http_client is None:
39+
return HttpClient(
40+
timeout_sec=self._timeout_sec,
41+
num_retries=self._num_retries,
42+
headers=headers,
43+
)
44+
45+
http_client.add_headers(headers)
46+
47+
return http_client
48+
49+
def execute_http_request(
50+
self,
51+
method: str,
52+
path: str,
53+
params: Optional[Dict] = None,
54+
body: Optional[Dict] = None,
55+
stream: bool = False,
56+
files: Optional[Dict[str, BinaryIO]] = None,
57+
) -> httpx.Response:
58+
return self._http_client.execute_http_request(
59+
method=method,
60+
url=f"{self._base_url}{path}",
61+
params=params or {},
62+
files=files,
63+
stream=stream,
64+
body=body or {},
65+
)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from typing import Optional, Dict, Any, BinaryIO
2+
3+
import httpx
4+
5+
from ai21.http_client.async_http_client import AsyncHttpClient
6+
from ai21.ai21_http_client.base_ai21_http_client import BaseAI21HTTPClient
7+
8+
9+
class AsyncAI21HTTPClient(BaseAI21HTTPClient):
10+
def __init__(
11+
self,
12+
*,
13+
api_key: Optional[str] = None,
14+
requires_api_key: bool = True,
15+
base_url: Optional[str] = None,
16+
api_version: Optional[str] = None,
17+
headers: Optional[Dict[str, Any]] = None,
18+
timeout_sec: Optional[int] = None,
19+
num_retries: Optional[int] = None,
20+
via: Optional[str] = None,
21+
http_client: Optional[AsyncHttpClient] = None,
22+
):
23+
super().__init__(
24+
api_key=api_key,
25+
requires_api_key=requires_api_key,
26+
base_url=base_url,
27+
api_version=api_version,
28+
headers=headers,
29+
timeout_sec=timeout_sec,
30+
num_retries=num_retries,
31+
via=via,
32+
)
33+
34+
headers = self._build_headers(passed_headers=headers)
35+
self._http_client = self._init_http_client(http_client=http_client, headers=headers)
36+
37+
def _init_http_client(self, http_client: Optional[AsyncHttpClient], headers: Dict[str, Any]) -> AsyncHttpClient:
38+
if http_client is None:
39+
return AsyncHttpClient(
40+
timeout_sec=self._timeout_sec,
41+
num_retries=self._num_retries,
42+
headers=headers,
43+
)
44+
45+
http_client.add_headers(headers)
46+
47+
return http_client
48+
49+
async def execute_http_request(
50+
self,
51+
method: str,
52+
path: str,
53+
params: Optional[Dict] = None,
54+
body: Optional[Dict] = None,
55+
stream: bool = False,
56+
files: Optional[Dict[str, BinaryIO]] = None,
57+
) -> httpx.Response:
58+
return await self._http_client.execute_http_request(
59+
method=method,
60+
url=f"{self._base_url}{path}",
61+
params=params or {},
62+
files=files,
63+
stream=stream,
64+
body=body or {},
65+
)

ai21/ai21_http_client.py renamed to ai21/ai21_http_client/base_ai21_http_client.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
import platform
2-
from typing import Optional, Dict, Any, BinaryIO
2+
from typing import Optional, Dict, Any, BinaryIO, TypeVar, Union
3+
from abc import ABC, abstractmethod
34

45
import httpx
56

67
from ai21.errors import MissingApiKeyError
7-
from ai21.http_client import HttpClient
8+
from ai21.http_client.http_client import HttpClient
9+
from ai21.http_client.async_http_client import AsyncHttpClient
810
from ai21.version import VERSION
911

1012

11-
class AI21HTTPClient:
13+
_HttpClientT = TypeVar("_HttpClientT", bound=Union[HttpClient, AsyncHttpClient])
14+
15+
16+
class BaseAI21HTTPClient(ABC):
17+
_http_client: Optional[_HttpClientT] = None
18+
1219
def __init__(
1320
self,
1421
*,
@@ -34,9 +41,6 @@ def __init__(
3441
self._num_retries = num_retries
3542
self._via = via
3643

37-
headers = self._build_headers(passed_headers=headers)
38-
self._http_client = self._init_http_client(http_client=http_client, headers=headers)
39-
4044
def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, Any]:
4145
headers = {
4246
"Content-Type": "application/json",
@@ -51,18 +55,6 @@ def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str,
5155

5256
return headers
5357

54-
def _init_http_client(self, http_client: Optional[HttpClient], headers: Dict[str, Any]) -> HttpClient:
55-
if http_client is None:
56-
return HttpClient(
57-
timeout_sec=self._timeout_sec,
58-
num_retries=self._num_retries,
59-
headers=headers,
60-
)
61-
62-
http_client.add_headers(headers)
63-
64-
return http_client
65-
6658
def _build_user_agent(self) -> str:
6759
user_agent = (
6860
f"AI21 studio SDK {VERSION} Python {platform.python_version()} Operating System {platform.platform()}"
@@ -73,6 +65,11 @@ def _build_user_agent(self) -> str:
7365

7466
return user_agent
7567

68+
@abstractmethod
69+
def _init_http_client(self, http_client: Optional[HttpClient], headers: Dict[str, Any]) -> _HttpClientT:
70+
pass
71+
72+
@abstractmethod
7673
def execute_http_request(
7774
self,
7875
method: str,
@@ -82,11 +79,4 @@ def execute_http_request(
8279
stream: bool = False,
8380
files: Optional[Dict[str, BinaryIO]] = None,
8481
) -> httpx.Response:
85-
return self._http_client.execute_http_request(
86-
method=method,
87-
url=f"{self._base_url}{path}",
88-
params=params or {},
89-
files=files,
90-
stream=stream,
91-
body=body or {},
92-
)
82+
pass

ai21/clients/bedrock/bedrock_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from ai21.logger import logger
99
from ai21.errors import AccessDenied, NotFound, APITimeoutError
10-
from ai21.http_client import handle_non_success_response
10+
from ai21.http_client.base_http_client import handle_non_success_response
1111

1212
_ERROR_MSG_TEMPLATE = (
1313
r"Received client error \((.*?)\) from primary with message \"(.*?)\". "

ai21/clients/common/answer_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from abc import ABC
1+
from abc import ABC, abstractmethod
22
from typing import Any, Dict
33

44
from ai21.models import AnswerResponse
@@ -7,6 +7,7 @@
77
class Answer(ABC):
88
_module_name = "answer"
99

10+
@abstractmethod
1011
def create(
1112
self,
1213
context: str,

0 commit comments

Comments
 (0)