Skip to content

Add tests for function list_bedrock_models. #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions .github/workflows/aws-genai-cicd-suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,17 @@ jobs:
- name: Checkout code
uses: actions/checkout@v3

- name: Set up Node.js
uses: actions/setup-node@v3
- name: Set up Python
uses: actions/setup-python@v2
with:
node-version: '20'
python-version: 3.12 # Adjust the Python version as needed

- name: Install dependencies @actions/core and @actions/github
run: |
npm install @actions/core
npm install @actions/github
shell: bash
- name: Install dependencies
run: pip install -r requirements.txt

- name: Test
run: python -m unittest
working-directory: ./tests

# check if required dependencies @actions/core and @actions/github are installed
- name: Check if required dependencies are installed
Expand Down
31 changes: 25 additions & 6 deletions src/api/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import re
import time
from abc import ABC
from abc import ABC, abstractmethod
from typing import AsyncIterable, Iterable, Literal

import boto3
Expand Down Expand Up @@ -73,8 +73,27 @@ def get_inference_region_prefix():

ENCODER = tiktoken.get_encoding("cl100k_base")

class BedrockClientInterface(ABC):
@abstractmethod
def list_inference_profiles(self, **kwargs) -> dict:
pass

def list_bedrock_models() -> dict:
@abstractmethod
def list_foundation_models(self, **kwargs) -> dict:
pass

class BedrockClient(BedrockClientInterface):
def __init__(self, client):
self.bedrock_client = client

def list_inference_profiles(self, **kwargs) -> dict:
return self.bedrock_client.list_inference_profiles(**kwargs)

def list_foundation_models(self, **kwargs) -> dict:
return self.bedrock_client.list_foundation_models(**kwargs)


def list_bedrock_models(client : BedrockClientInterface) -> dict:
"""Automatically getting a list of supported models.

Returns a model list combines:
Expand All @@ -86,11 +105,11 @@ def list_bedrock_models() -> dict:
profile_list = []
if ENABLE_CROSS_REGION_INFERENCE:
# List system defined inference profile IDs
response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED")
response = client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED")
profile_list = [p["inferenceProfileId"] for p in response["inferenceProfileSummaries"]]

# List foundation models, only cares about text outputs here.
response = bedrock_client.list_foundation_models(byOutputModality="TEXT")
response = client.list_foundation_models(byOutputModality="TEXT")

for model in response["modelSummaries"]:
model_id = model.get("modelId", "N/A")
Expand Down Expand Up @@ -123,14 +142,14 @@ def list_bedrock_models() -> dict:


# Initialize the model list.
bedrock_model_list = list_bedrock_models()
bedrock_model_list = list_bedrock_models(BedrockClient(bedrock_client))


class BedrockModel(BaseChatModel):
def list_models(self) -> list[str]:
"""Always refresh the latest model list"""
global bedrock_model_list
bedrock_model_list = list_bedrock_models()
bedrock_model_list = list_bedrock_models(BedrockClient(bedrock_client))
return list(bedrock_model_list.keys())

def validate(self, chat_request: ChatRequest):
Expand Down
Empty file added tests/__init__.py
Empty file.
179 changes: 179 additions & 0 deletions tests/list_bedrock_models_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from typing import Literal

from src.api.models.bedrock import list_bedrock_models, BedrockClientInterface

def test_default_model():
client = FakeBedrockClient(
inference_profile("p1-id", "p1", "SYSTEM_DEFINED"),
inference_profile("p2-id", "p2", "APPLICATION"),
inference_profile("p3-id", "p3", "SYSTEM_DEFINED"),
)

models = list_bedrock_models(client)

assert models == {
"anthropic.claude-3-sonnet-20240229-v1:0": {
"modalities": ["TEXT", "IMAGE"]
}
}

def test_one_model():
client = FakeBedrockClient(
model("model-id", "model-name", stream_supported=True, input_modalities=["TEXT", "IMAGE"])
)

models = list_bedrock_models(client)

assert models == {
"model-id": {
"modalities": ["TEXT", "IMAGE"]
}
}

def test_two_models():
client = FakeBedrockClient(
model("model-id-1", "model-name-1", stream_supported=True, input_modalities=["TEXT", "IMAGE"]),
model("model-id-2", "model-name-2", stream_supported=True, input_modalities=["IMAGE"])
)

models = list_bedrock_models(client)

assert models == {
"model-id-1": {
"modalities": ["TEXT", "IMAGE"]
},
"model-id-2": {
"modalities": ["IMAGE"]
}
}

def test_filter_models():
client = FakeBedrockClient(
model("model-id", "model-name-1", stream_supported=True, input_modalities=["TEXT"], status="LEGACY"),
model("model-id-no-stream", "model-name-2", stream_supported=False, input_modalities=["TEXT", "IMAGE"]),
model("model-id-not-active", "model-name-3", stream_supported=True, status="DISABLED"),
model("model-id-not-text-output", "model-name-4", stream_supported=True, output_modalities=["IMAGE"])
)

models = list_bedrock_models(client)

assert models == {
"model-id": {
"modalities": ["TEXT"]
}
}

def test_one_inference_profile():
client = FakeBedrockClient(
inference_profile("us.model-id", "p1", "SYSTEM_DEFINED"),
model("model-id", "model-name", stream_supported=True, input_modalities=["TEXT"])
)

models = list_bedrock_models(client)

assert models == {
"model-id": {
"modalities": ["TEXT"]
},
"us.model-id": {
"modalities": ["TEXT"]
}
}

def test_default_model_on_throw():
client = ThrowingBedrockClient()

models = list_bedrock_models(client)

assert models == {
"anthropic.claude-3-sonnet-20240229-v1:0": {
"modalities": ["TEXT", "IMAGE"]
}
}

def inference_profile(profile_id: str, name: str, profile_type: Literal["SYSTEM_DEFINED", "APPLICATION"]):
return {
"inferenceProfileName": name,
"inferenceProfileId": profile_id,
"type": profile_type
}

def model(
model_id: str,
model_name: str,
input_modalities: list[str] = None,
output_modalities: list[str] = None,
stream_supported: bool = False,
inference_types: list[str] = None,
status: str = "ACTIVE") -> dict:
if input_modalities is None:
input_modalities = ["TEXT"]
if output_modalities is None:
output_modalities = ["TEXT"]
if inference_types is None:
inference_types = ["ON_DEMAND"]
return {
"modelArn": "arn:model:" + model_id,
"modelId": model_id,
"modelName": model_name,
"providerName": "anthropic",
"inputModalities":input_modalities,
"outputModalities": output_modalities,
"responseStreamingSupported": stream_supported,
"customizationsSupported": ["FINE_TUNING"],
"inferenceTypesSupported": inference_types,
"modelLifecycle": {
"status": status
}
}

def _filter_inference_profiles(inference_profiles: list[dict], profile_type: Literal["SYSTEM_DEFINED", "APPLICATION"], max_results: int = 100):
return [p for p in inference_profiles if p.get("type") == profile_type][:max_results]

def _filter_models(
models: list[dict],
provider_name: str | None,
customization_type: Literal["FINE_TUNING","CONTINUED_PRE_TRAINING","DISTILLATION"] | None,
output_modality: Literal["TEXT","IMAGE","EMBEDDING"] | None,
inference_type: Literal["ON_DEMAND","PROVISIONED"] | None):
return [m for m in models if
(provider_name is None or m.get("providerName") == provider_name) and
(output_modality is None or output_modality in m.get("outputModalities")) and
(customization_type is None or customization_type in m.get("customizationsSupported")) and
(inference_type is None or inference_type in m.get("inferenceTypesSupported"))
]

class ThrowingBedrockClient(BedrockClientInterface):
def list_inference_profiles(self, **kwargs) -> dict:
raise Exception("throwing bedrock client always throws exception")
def list_foundation_models(self, **kwargs) -> dict:
raise Exception("throwing bedrock client always throws exception")

class FakeBedrockClient(BedrockClientInterface):
def __init__(self, *args):
self.inference_profiles = [p for p in args if p.get("inferenceProfileId", "") != ""]
self.models = [m for m in args if m.get("modelId", "") != ""]

unexpected = [u for u in args if (u.get("modelId", "") == "" and u.get("inferenceProfileId", "") == "")]
if len(unexpected) > 0:
raise Exception("expected a model or a profile")

def list_inference_profiles(self, **kwargs) -> dict:
return {
"inferenceProfileSummaries": _filter_inference_profiles(
self.inference_profiles,
profile_type=kwargs["typeEquals"],
max_results=kwargs.get("maxResults", 100)
)
}

def list_foundation_models(self, **kwargs) -> dict:
return {
"modelSummaries": _filter_models(
self.models,
provider_name=kwargs.get("byProvider", None),
customization_type=kwargs.get("byCustomizationType", None),
output_modality=kwargs.get("byOutputModality", None),
inference_type=kwargs.get("byInferenceType", None)
)
}