Skip to content

Commit 6115632

Browse files
authored
feat: Azure Client Support (#135)
1 parent 3eae31e commit 6115632

File tree

11 files changed

+205
-9
lines changed

11 files changed

+205
-9
lines changed

README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
- [AWS](#AWS)
3737
- [SageMaker](#SageMaker)
3838
- [Bedrock](#Bedrock)
39+
- [Azure](#Azure)
3940

4041
## Examples (tl;dr)
4142

@@ -530,4 +531,33 @@ response = client.completion.create(
530531
)
531532
```
532533

534+
### Azure
535+
536+
If you wish to interact with your Azure endpoint on Azure AI Studio, you can use the `AI21AzureClient`
537+
and `AsyncAI21AzureClient`.
538+
539+
The following models are supported on Azure:
540+
541+
- `jamba-instruct`
542+
543+
```python
544+
from ai21 import AI21AzureClient
545+
from ai21.models.chat import ChatMessage
546+
547+
client = AI21AzureClient(
548+
base_url="https://your-endpoint.inference.ai.azure.com/v1/chat/completions",
549+
api_key="<your api key>",
550+
)
551+
552+
messages = [
553+
ChatMessage(content="You are a helpful assistant", role="system"),
554+
ChatMessage(content="What is the meaning of life?", role="user")
555+
]
556+
557+
response = client.chat.completions.create(
558+
model="jamba-instruct",
559+
messages=[messages],
560+
)
561+
```
562+
533563
Happy prompting! 🚀

ai21/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from typing import Any
22

33
from ai21.ai21_env_config import AI21EnvConfig
4+
from ai21.clients.azure.ai21_azure_client import AI21AzureClient, AsyncAI21AzureClient
45
from ai21.clients.studio.ai21_client import AI21Client
56
from ai21.clients.studio.async_ai21_client import AsyncAI21Client
7+
68
from ai21.errors import (
79
AI21APIError,
810
APITimeoutError,
@@ -65,4 +67,6 @@ def __getattr__(name: str) -> Any:
6567
"AI21SageMakerClient",
6668
"BedrockModelID",
6769
"SageMaker",
70+
"AI21AzureClient",
71+
"AsyncAI21AzureClient",
6872
]

ai21/ai21_http_client/base_ai21_http_client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def __init__(
2727
timeout_sec: Optional[int] = None,
2828
num_retries: Optional[int] = None,
2929
via: Optional[str] = None,
30-
http_client: Optional[HttpClient] = None,
3130
):
3231
self._api_key = api_key
3332

ai21/clients/azure/__init__.py

Whitespace-only changes.
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC
4+
from typing import Optional, Callable, Dict
5+
6+
from ai21.ai21_http_client.ai21_http_client import AI21HTTPClient
7+
from ai21.ai21_http_client.async_ai21_http_client import AsyncAI21HTTPClient
8+
from ai21.clients.studio.resources.studio_chat import StudioChat, AsyncStudioChat
9+
10+
AzureADTokenProvider = Callable[[], str]
11+
12+
13+
class BaseAzureClient(ABC):
14+
_azure_endpoint: str
15+
_api_key: Optional[str]
16+
_azure_ad_token: Optional[str]
17+
_azure_ad_token_provider: Optional[AzureADTokenProvider]
18+
19+
def _prepare_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
20+
azure_ad_token = self._get_azure_ad_token()
21+
22+
if azure_ad_token is not None and "Authorization" not in headers:
23+
return {
24+
"Authorization": f"Bearer {azure_ad_token}",
25+
**headers,
26+
}
27+
28+
if self._api_key is not None:
29+
return {
30+
"api-key": self._api_key,
31+
**headers,
32+
}
33+
34+
return headers
35+
36+
def _get_azure_ad_token(self) -> Optional[str]:
37+
if self._azure_ad_token is not None:
38+
return self._azure_ad_token
39+
40+
if self._azure_ad_token_provider is not None:
41+
return self._azure_ad_token_provider()
42+
43+
return None
44+
45+
46+
class AsyncAI21AzureClient(BaseAzureClient, AsyncAI21HTTPClient):
47+
def __init__(
48+
self,
49+
base_url: str,
50+
api_key: Optional[str] = None,
51+
azure_ad_token: str | None = None,
52+
azure_ad_token_provider: AzureADTokenProvider | None = None,
53+
default_headers: Dict[str, str] | None = None,
54+
timeout_sec: int | None = None,
55+
num_retries: int | None = None,
56+
):
57+
self._api_key = api_key
58+
self._azure_ad_token = azure_ad_token
59+
self._azure_ad_token_provider = azure_ad_token_provider
60+
61+
if self._api_key is None and self._azure_ad_token_provider is None and self._azure_ad_token is None:
62+
raise ValueError("Must provide either api_key or azure_ad_token_provider or azure_ad_token")
63+
64+
headers = self._prepare_headers(headers=default_headers or {})
65+
66+
super().__init__(
67+
api_key=api_key,
68+
base_url=base_url,
69+
headers=headers,
70+
timeout_sec=timeout_sec,
71+
num_retries=num_retries,
72+
)
73+
74+
self.chat = AsyncStudioChat(self)
75+
# Override the chat.create method to match the completions endpoint,
76+
# so it wouldn't get to the old J2 completion endpoint
77+
self.chat.create = self.chat.completions.create
78+
79+
80+
class AI21AzureClient(BaseAzureClient, AI21HTTPClient):
81+
def __init__(
82+
self,
83+
base_url: str,
84+
api_key: Optional[str] = None,
85+
azure_ad_token: str | None = None,
86+
azure_ad_token_provider: AzureADTokenProvider | None = None,
87+
default_headers: Dict[str, str] | None = None,
88+
timeout_sec: int | None = None,
89+
num_retries: int | None = None,
90+
):
91+
self._api_key = api_key
92+
self._azure_ad_token = azure_ad_token
93+
self._azure_ad_token_provider = azure_ad_token_provider
94+
95+
if self._api_key is None and self._azure_ad_token_provider is None and self._azure_ad_token is None:
96+
raise ValueError("Must provide either api_key or azure_ad_token_provider or azure_ad_token")
97+
98+
headers = self._prepare_headers(headers=default_headers or {})
99+
100+
super().__init__(
101+
api_key=api_key,
102+
base_url=base_url,
103+
headers=headers,
104+
timeout_sec=timeout_sec,
105+
num_retries=num_retries,
106+
)
107+
108+
self.chat = StudioChat(self)
109+
# Override the chat.create method to match the completions endpoint,
110+
# so it wouldn't get to the old J2 completion endpoint
111+
self.chat.create = self.chat.completions.create

examples/azure/__init__.py

Whitespace-only changes.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import asyncio
2+
3+
from ai21 import AsyncAI21AzureClient
4+
from ai21.models.chat import ChatMessage
5+
6+
7+
async def chat_completions():
8+
client = AsyncAI21AzureClient(
9+
base_url="<Your endpoint>",
10+
api_key="<your api key>",
11+
)
12+
13+
messages = ChatMessage(content="What is the meaning of life?", role="user")
14+
15+
completion = await client.chat.completions.create(
16+
model="jamba-instruct",
17+
messages=[messages],
18+
)
19+
20+
print(completion.to_json())
21+
22+
23+
asyncio.run(chat_completions())
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from ai21 import AI21AzureClient
2+
3+
from ai21.models.chat import ChatMessage
4+
5+
client = AI21AzureClient(
6+
base_url="<Your endpoint>",
7+
api_key="<your api key>",
8+
)
9+
10+
messages = ChatMessage(content="What is the meaning of life?", role="user")
11+
12+
completion = client.chat.completions.create(
13+
model="jamba-instruct",
14+
messages=[messages],
15+
)
16+
17+
print(completion.to_json())

tests/unittests/clients/azure/__init__.py

Whitespace-only changes.
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import pytest
2+
3+
from ai21 import AI21AzureClient
4+
5+
6+
def test__azure_client__when_init_with_no_auth__should_raise_error():
7+
with pytest.raises(ValueError) as e:
8+
AI21AzureClient(base_url="http://some_endpoint_url")
9+
10+
assert str(e.value) == "Must provide either api_key or azure_ad_token_provider or azure_ad_token"

0 commit comments

Comments
 (0)