Skip to content

Commit eb2e1d4

Browse files
author
Andrew Bernat
committed
Add tests for function list_bedrock_models.
diff --git c/.github/workflows/aws-genai-cicd-suite.yml i/.github/workflows/aws-genai-cicd-suite.yml index b16c41b..656154f 100644 --- c/.github/workflows/aws-genai-cicd-suite.yml +++ i/.github/workflows/aws-genai-cicd-suite.yml @@ -25,16 +25,17 @@ jobs: - name: Checkout code uses: actions/checkout@v3 - - name: Set up Node.js - uses: actions/setup-node@v3 + - name: Set up Python + uses: actions/setup-python@v2 with: - node-version: '20' + python-version: 3.12 # Adjust the Python version as needed - - name: Install dependencies @actions/core and @actions/github - run: | - npm install @actions/core - npm install @actions/github - shell: bash + - name: Install dependencies + run: pip install -r requirements.txt + + - name: Test + run: python -m unittest + working-directory: ./tests # check if required dependencies @actions/core and @actions/github are installed - name: Check if required dependencies are installed diff --git c/src/api/models/bedrock.py i/src/api/models/bedrock.py index be3fab2..39ed9ae 100644 --- c/src/api/models/bedrock.py +++ i/src/api/models/bedrock.py @@ -3,7 +3,7 @@ import json import logging import re import time -from abc import ABC +from abc import ABC, abstractmethod from typing import AsyncIterable, Iterable, Literal import boto3 @@ -73,8 +73,27 @@ SUPPORTED_BEDROCK_EMBEDDING_MODELS = { ENCODER = tiktoken.get_encoding("cl100k_base") +class BedrockClientInterface(ABC): + @AbstractMethod + def list_inference_profiles(self, **kwargs) -> dict: + pass -def list_bedrock_models() -> dict: + @AbstractMethod + def list_foundation_models(self, **kwargs) -> dict: + pass + +class BedrockClient(BedrockClientInterface): + def __init__(self, client): + self.bedrock_client = client + + def list_inference_profiles(self, **kwargs) -> dict: + return self.bedrock_client.list_inference_profiles(**kwargs) + + def list_foundation_models(self, **kwargs) -> dict: + return self.bedrock_client.list_foundation_models(**kwargs) + + +def list_bedrock_models(client : BedrockClientInterface) -> dict: """Automatically getting a list of supported models. Returns a model list combines: @@ -86,11 +105,11 @@ def list_bedrock_models() -> dict: profile_list = [] if ENABLE_CROSS_REGION_INFERENCE: # List system defined inference profile IDs - response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED") + response = client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED") profile_list = [p["inferenceProfileId"] for p in response["inferenceProfileSummaries"]] # List foundation models, only cares about text outputs here. - response = bedrock_client.list_foundation_models(byOutputModality="TEXT") + response = client.list_foundation_models(byOutputModality="TEXT") for model in response["modelSummaries"]: model_id = model.get("modelId", "N/A") @@ -123,14 +142,14 @@ def list_bedrock_models() -> dict: # Initialize the model list. -bedrock_model_list = list_bedrock_models() +bedrock_model_list = list_bedrock_models(BedrockClient(bedrock_client)) class BedrockModel(BaseChatModel): def list_models(self) -> list[str]: """Always refresh the latest model list""" global bedrock_model_list - bedrock_model_list = list_bedrock_models() + bedrock_model_list = list_bedrock_models(BedrockClient(bedrock_client)) return list(bedrock_model_list.keys()) def validate(self, chat_request: ChatRequest): diff --git c/tests/__init__.py i/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git c/tests/list_bedrock_models_test.py i/tests/list_bedrock_models_test.py new file mode 100644 index 0000000..262fe20 --- /dev/null +++ i/tests/list_bedrock_models_test.py @@ -0,0 +1,179 @@ +from typing import Literal + +from src.api.models.bedrock import list_bedrock_models, BedrockClientInterface + +def test_default_model(): + client = FakeBedrockClient( + inference_profile("p1-id", "p1", "SYSTEM_DEFINED"), + inference_profile("p2-id", "p2", "APPLICATION"), + inference_profile("p3-id", "p3", "SYSTEM_DEFINED"), + ) + + models = list_bedrock_models(client) + + assert models == { + "anthropic.claude-3-sonnet-20240229-v1:0": { + "modalities": ["TEXT", "IMAGE"] + } + } + +def test_one_model(): + client = FakeBedrockClient( + model("model-id", "model-name", stream_supported=True, input_modalities=["TEXT", "IMAGE"]) + ) + + models = list_bedrock_models(client) + + assert models == { + "model-id": { + "modalities": ["TEXT", "IMAGE"] + } + } + +def test_two_models(): + client = FakeBedrockClient( + model("model-id-1", "model-name-1", stream_supported=True, input_modalities=["TEXT", "IMAGE"]), + model("model-id-2", "model-name-2", stream_supported=True, input_modalities=["IMAGE"]) + ) + + models = list_bedrock_models(client) + + assert models == { + "model-id-1": { + "modalities": ["TEXT", "IMAGE"] + }, + "model-id-2": { + "modalities": ["IMAGE"] + } + } + +def test_filter_models(): + client = FakeBedrockClient( + model("model-id", "model-name-1", stream_supported=True, input_modalities=["TEXT"], status="LEGACY"), + model("model-id-no-stream", "model-name-2", stream_supported=False, input_modalities=["TEXT", "IMAGE"]), + model("model-id-not-active", "model-name-3", stream_supported=True, status="DISABLED"), + model("model-id-not-text-output", "model-name-4", stream_supported=True, output_modalities=["IMAGE"]) + ) + + models = list_bedrock_models(client) + + assert models == { + "model-id": { + "modalities": ["TEXT"] + } + } + +def test_one_inference_profile(): + client = FakeBedrockClient( + inference_profile("us.model-id", "p1", "SYSTEM_DEFINED"), + model("model-id", "model-name", stream_supported=True, input_modalities=["TEXT"]) + ) + + models = list_bedrock_models(client) + + assert models == { + "model-id": { + "modalities": ["TEXT"] + }, + "us.model-id": { + "modalities": ["TEXT"] + } + } + +def test_default_model_on_throw(): + client = ThrowingBedrockClient() + + models = list_bedrock_models(client) + + assert models == { + "anthropic.claude-3-sonnet-20240229-v1:0": { + "modalities": ["TEXT", "IMAGE"] + } + } + +def inference_profile(profile_id: str, name: str, profile_type: Literal["SYSTEM_DEFINED", "APPLICATION"]): + return { + "inferenceProfileName": name, + "inferenceProfileId": profile_id, + "type": profile_type + } + +def model( + model_id: str, + model_name: str, + input_modalities: list[str] = None, + output_modalities: list[str] = None, + stream_supported: bool = False, + inference_types: list[str] = None, + status: str = "ACTIVE") -> dict: + if input_modalities is None: + input_modalities = ["TEXT"] + if output_modalities is None: + output_modalities = ["TEXT"] + if inference_types is None: + inference_types = ["ON_DEMAND"] + return { + "modelArn": "arn:model:" + model_id, + "modelId": model_id, + "modelName": model_name, + "providerName": "anthropic", + "inputModalities":input_modalities, + "outputModalities": output_modalities, + "responseStreamingSupported": stream_supported, + "customizationsSupported": ["FINE_TUNING"], + "inferenceTypesSupported": inference_types, + "modelLifecycle": { + "status": status + } + } + +def _filter_inference_profiles(inference_profiles: list[dict], profile_type: Literal["SYSTEM_DEFINED", "APPLICATION"], max_results: int = 100): + return [p for p in inference_profiles if p.get("type") == profile_type][:max_results] + +def _filter_models( + models: list[dict], + provider_name: str | None, + customization_type: Literal["FINE_TUNING","CONTINUED_PRE_TRAINING","DISTILLATION"] | None, + output_modality: Literal["TEXT","IMAGE","EMBEDDING"] | None, + inference_type: Literal["ON_DEMAND","PROVISIONED"] | None): + return [m for m in models if + (provider_name is None or m.get("providerName") == provider_name) and + (output_modality is None or output_modality in m.get("outputModalities")) and + (customization_type is None or customization_type in m.get("customizationsSupported")) and + (inference_type is None or inference_type in m.get("inferenceTypesSupported")) + ] + +class ThrowingBedrockClient(BedrockClientInterface): + def list_inference_profiles(self, **kwargs) -> dict: + raise Exception("throwing bedrock client always throws exception") + def list_foundation_models(self, **kwargs) -> dict: + raise Exception("throwing bedrock client always throws exception") + +class FakeBedrockClient(BedrockClientInterface): + def __init__(self, *args): + self.inference_profiles = [p for p in args if p.get("inferenceProfileId", "") != ""] + self.models = [m for m in args if m.get("modelId", "") != ""] + + unexpected = [u for u in args if (u.get("modelId", "") == "" and u.get("inferenceProfileId", "") == "")] + if len(unexpected) > 0: + raise Exception("expected a model or a profile") + + def list_inference_profiles(self, **kwargs) -> dict: + return { + "inferenceProfileSummaries": _filter_inference_profiles( + self.inference_profiles, + profile_type=kwargs["typeEquals"], + max_results=kwargs.get("maxResults", 100) + ) + } + + def list_foundation_models(self, **kwargs) -> dict: + return { + "modelSummaries": _filter_models( + self.models, + provider_name=kwargs.get("byProvider", None), + customization_type=kwargs.get("byCustomizationType", None), + output_modality=kwargs.get("byOutputModality", None), + inference_type=kwargs.get("byInferenceType", None) + ) + } \ No newline at end of file
1 parent 0ead770 commit eb2e1d4

File tree

4 files changed

+213
-14
lines changed

4 files changed

+213
-14
lines changed

.github/workflows/aws-genai-cicd-suite.yml

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,17 @@ jobs:
2525
- name: Checkout code
2626
uses: actions/checkout@v3
2727

28-
- name: Set up Node.js
29-
uses: actions/setup-node@v3
28+
- name: Set up Python
29+
uses: actions/setup-python@v2
3030
with:
31-
node-version: '20'
31+
python-version: 3.12 # Adjust the Python version as needed
3232

33-
- name: Install dependencies @actions/core and @actions/github
34-
run: |
35-
npm install @actions/core
36-
npm install @actions/github
37-
shell: bash
33+
- name: Install dependencies
34+
run: pip install -r requirements.txt
35+
36+
- name: Test
37+
run: python -m unittest
38+
working-directory: ./tests
3839

3940
# check if required dependencies @actions/core and @actions/github are installed
4041
- name: Check if required dependencies are installed

src/api/models/bedrock.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import re
55
import time
6-
from abc import ABC
6+
from abc import ABC, abstractmethod
77
from typing import AsyncIterable, Iterable, Literal
88

99
import boto3
@@ -73,8 +73,27 @@ def get_inference_region_prefix():
7373

7474
ENCODER = tiktoken.get_encoding("cl100k_base")
7575

76+
class BedrockClientInterface(ABC):
77+
@abstractmethod
78+
def list_inference_profiles(self, **kwargs) -> dict:
79+
pass
7680

77-
def list_bedrock_models() -> dict:
81+
@abstractmethod
82+
def list_foundation_models(self, **kwargs) -> dict:
83+
pass
84+
85+
class BedrockClient(BedrockClientInterface):
86+
def __init__(self, client):
87+
self.bedrock_client = client
88+
89+
def list_inference_profiles(self, **kwargs) -> dict:
90+
return self.bedrock_client.list_inference_profiles(**kwargs)
91+
92+
def list_foundation_models(self, **kwargs) -> dict:
93+
return self.bedrock_client.list_foundation_models(**kwargs)
94+
95+
96+
def list_bedrock_models(client : BedrockClientInterface) -> dict:
7897
"""Automatically getting a list of supported models.
7998
8099
Returns a model list combines:
@@ -86,11 +105,11 @@ def list_bedrock_models() -> dict:
86105
profile_list = []
87106
if ENABLE_CROSS_REGION_INFERENCE:
88107
# List system defined inference profile IDs
89-
response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED")
108+
response = client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED")
90109
profile_list = [p["inferenceProfileId"] for p in response["inferenceProfileSummaries"]]
91110

92111
# List foundation models, only cares about text outputs here.
93-
response = bedrock_client.list_foundation_models(byOutputModality="TEXT")
112+
response = client.list_foundation_models(byOutputModality="TEXT")
94113

95114
for model in response["modelSummaries"]:
96115
model_id = model.get("modelId", "N/A")
@@ -123,14 +142,14 @@ def list_bedrock_models() -> dict:
123142

124143

125144
# Initialize the model list.
126-
bedrock_model_list = list_bedrock_models()
145+
bedrock_model_list = list_bedrock_models(BedrockClient(bedrock_client))
127146

128147

129148
class BedrockModel(BaseChatModel):
130149
def list_models(self) -> list[str]:
131150
"""Always refresh the latest model list"""
132151
global bedrock_model_list
133-
bedrock_model_list = list_bedrock_models()
152+
bedrock_model_list = list_bedrock_models(BedrockClient(bedrock_client))
134153
return list(bedrock_model_list.keys())
135154

136155
def validate(self, chat_request: ChatRequest):

tests/__init__.py

Whitespace-only changes.

tests/list_bedrock_models_test.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
from typing import Literal
2+
3+
from src.api.models.bedrock import list_bedrock_models, BedrockClientInterface
4+
5+
def test_default_model():
6+
client = FakeBedrockClient(
7+
inference_profile("p1-id", "p1", "SYSTEM_DEFINED"),
8+
inference_profile("p2-id", "p2", "APPLICATION"),
9+
inference_profile("p3-id", "p3", "SYSTEM_DEFINED"),
10+
)
11+
12+
models = list_bedrock_models(client)
13+
14+
assert models == {
15+
"anthropic.claude-3-sonnet-20240229-v1:0": {
16+
"modalities": ["TEXT", "IMAGE"]
17+
}
18+
}
19+
20+
def test_one_model():
21+
client = FakeBedrockClient(
22+
model("model-id", "model-name", stream_supported=True, input_modalities=["TEXT", "IMAGE"])
23+
)
24+
25+
models = list_bedrock_models(client)
26+
27+
assert models == {
28+
"model-id": {
29+
"modalities": ["TEXT", "IMAGE"]
30+
}
31+
}
32+
33+
def test_two_models():
34+
client = FakeBedrockClient(
35+
model("model-id-1", "model-name-1", stream_supported=True, input_modalities=["TEXT", "IMAGE"]),
36+
model("model-id-2", "model-name-2", stream_supported=True, input_modalities=["IMAGE"])
37+
)
38+
39+
models = list_bedrock_models(client)
40+
41+
assert models == {
42+
"model-id-1": {
43+
"modalities": ["TEXT", "IMAGE"]
44+
},
45+
"model-id-2": {
46+
"modalities": ["IMAGE"]
47+
}
48+
}
49+
50+
def test_filter_models():
51+
client = FakeBedrockClient(
52+
model("model-id", "model-name-1", stream_supported=True, input_modalities=["TEXT"], status="LEGACY"),
53+
model("model-id-no-stream", "model-name-2", stream_supported=False, input_modalities=["TEXT", "IMAGE"]),
54+
model("model-id-not-active", "model-name-3", stream_supported=True, status="DISABLED"),
55+
model("model-id-not-text-output", "model-name-4", stream_supported=True, output_modalities=["IMAGE"])
56+
)
57+
58+
models = list_bedrock_models(client)
59+
60+
assert models == {
61+
"model-id": {
62+
"modalities": ["TEXT"]
63+
}
64+
}
65+
66+
def test_one_inference_profile():
67+
client = FakeBedrockClient(
68+
inference_profile("us.model-id", "p1", "SYSTEM_DEFINED"),
69+
model("model-id", "model-name", stream_supported=True, input_modalities=["TEXT"])
70+
)
71+
72+
models = list_bedrock_models(client)
73+
74+
assert models == {
75+
"model-id": {
76+
"modalities": ["TEXT"]
77+
},
78+
"us.model-id": {
79+
"modalities": ["TEXT"]
80+
}
81+
}
82+
83+
def test_default_model_on_throw():
84+
client = ThrowingBedrockClient()
85+
86+
models = list_bedrock_models(client)
87+
88+
assert models == {
89+
"anthropic.claude-3-sonnet-20240229-v1:0": {
90+
"modalities": ["TEXT", "IMAGE"]
91+
}
92+
}
93+
94+
def inference_profile(profile_id: str, name: str, profile_type: Literal["SYSTEM_DEFINED", "APPLICATION"]):
95+
return {
96+
"inferenceProfileName": name,
97+
"inferenceProfileId": profile_id,
98+
"type": profile_type
99+
}
100+
101+
def model(
102+
model_id: str,
103+
model_name: str,
104+
input_modalities: list[str] = None,
105+
output_modalities: list[str] = None,
106+
stream_supported: bool = False,
107+
inference_types: list[str] = None,
108+
status: str = "ACTIVE") -> dict:
109+
if input_modalities is None:
110+
input_modalities = ["TEXT"]
111+
if output_modalities is None:
112+
output_modalities = ["TEXT"]
113+
if inference_types is None:
114+
inference_types = ["ON_DEMAND"]
115+
return {
116+
"modelArn": "arn:model:" + model_id,
117+
"modelId": model_id,
118+
"modelName": model_name,
119+
"providerName": "anthropic",
120+
"inputModalities":input_modalities,
121+
"outputModalities": output_modalities,
122+
"responseStreamingSupported": stream_supported,
123+
"customizationsSupported": ["FINE_TUNING"],
124+
"inferenceTypesSupported": inference_types,
125+
"modelLifecycle": {
126+
"status": status
127+
}
128+
}
129+
130+
def _filter_inference_profiles(inference_profiles: list[dict], profile_type: Literal["SYSTEM_DEFINED", "APPLICATION"], max_results: int = 100):
131+
return [p for p in inference_profiles if p.get("type") == profile_type][:max_results]
132+
133+
def _filter_models(
134+
models: list[dict],
135+
provider_name: str | None,
136+
customization_type: Literal["FINE_TUNING","CONTINUED_PRE_TRAINING","DISTILLATION"] | None,
137+
output_modality: Literal["TEXT","IMAGE","EMBEDDING"] | None,
138+
inference_type: Literal["ON_DEMAND","PROVISIONED"] | None):
139+
return [m for m in models if
140+
(provider_name is None or m.get("providerName") == provider_name) and
141+
(output_modality is None or output_modality in m.get("outputModalities")) and
142+
(customization_type is None or customization_type in m.get("customizationsSupported")) and
143+
(inference_type is None or inference_type in m.get("inferenceTypesSupported"))
144+
]
145+
146+
class ThrowingBedrockClient(BedrockClientInterface):
147+
def list_inference_profiles(self, **kwargs) -> dict:
148+
raise Exception("throwing bedrock client always throws exception")
149+
def list_foundation_models(self, **kwargs) -> dict:
150+
raise Exception("throwing bedrock client always throws exception")
151+
152+
class FakeBedrockClient(BedrockClientInterface):
153+
def __init__(self, *args):
154+
self.inference_profiles = [p for p in args if p.get("inferenceProfileId", "") != ""]
155+
self.models = [m for m in args if m.get("modelId", "") != ""]
156+
157+
unexpected = [u for u in args if (u.get("modelId", "") == "" and u.get("inferenceProfileId", "") == "")]
158+
if len(unexpected) > 0:
159+
raise Exception("expected a model or a profile")
160+
161+
def list_inference_profiles(self, **kwargs) -> dict:
162+
return {
163+
"inferenceProfileSummaries": _filter_inference_profiles(
164+
self.inference_profiles,
165+
profile_type=kwargs["typeEquals"],
166+
max_results=kwargs.get("maxResults", 100)
167+
)
168+
}
169+
170+
def list_foundation_models(self, **kwargs) -> dict:
171+
return {
172+
"modelSummaries": _filter_models(
173+
self.models,
174+
provider_name=kwargs.get("byProvider", None),
175+
customization_type=kwargs.get("byCustomizationType", None),
176+
output_modality=kwargs.get("byOutputModality", None),
177+
inference_type=kwargs.get("byInferenceType", None)
178+
)
179+
}

0 commit comments

Comments
 (0)