1
+ from typing import Literal
2
+
3
+ from src .api .models .bedrock import list_bedrock_models , BedrockClientInterface
4
+
5
+ def test_default_model ():
6
+ client = FakeBedrockClient (
7
+ inference_profile ("p1-id" , "p1" , "SYSTEM_DEFINED" ),
8
+ inference_profile ("p2-id" , "p2" , "APPLICATION" ),
9
+ inference_profile ("p3-id" , "p3" , "SYSTEM_DEFINED" ),
10
+ )
11
+
12
+ models = list_bedrock_models (client )
13
+
14
+ assert models == {
15
+ "anthropic.claude-3-sonnet-20240229-v1:0" : {
16
+ "modalities" : ["TEXT" , "IMAGE" ]
17
+ }
18
+ }
19
+
20
+ def test_one_model ():
21
+ client = FakeBedrockClient (
22
+ model ("model-id" , "model-name" , stream_supported = True , input_modalities = ["TEXT" , "IMAGE" ])
23
+ )
24
+
25
+ models = list_bedrock_models (client )
26
+
27
+ assert models == {
28
+ "model-id" : {
29
+ "modalities" : ["TEXT" , "IMAGE" ]
30
+ }
31
+ }
32
+
33
+ def test_two_models ():
34
+ client = FakeBedrockClient (
35
+ model ("model-id-1" , "model-name-1" , stream_supported = True , input_modalities = ["TEXT" , "IMAGE" ]),
36
+ model ("model-id-2" , "model-name-2" , stream_supported = True , input_modalities = ["IMAGE" ])
37
+ )
38
+
39
+ models = list_bedrock_models (client )
40
+
41
+ assert models == {
42
+ "model-id-1" : {
43
+ "modalities" : ["TEXT" , "IMAGE" ]
44
+ },
45
+ "model-id-2" : {
46
+ "modalities" : ["IMAGE" ]
47
+ }
48
+ }
49
+
50
+ def test_filter_models ():
51
+ client = FakeBedrockClient (
52
+ model ("model-id" , "model-name-1" , stream_supported = True , input_modalities = ["TEXT" ], status = "LEGACY" ),
53
+ model ("model-id-no-stream" , "model-name-2" , stream_supported = False , input_modalities = ["TEXT" , "IMAGE" ]),
54
+ model ("model-id-not-active" , "model-name-3" , stream_supported = True , status = "DISABLED" ),
55
+ model ("model-id-not-text-output" , "model-name-4" , stream_supported = True , output_modalities = ["IMAGE" ])
56
+ )
57
+
58
+ models = list_bedrock_models (client )
59
+
60
+ assert models == {
61
+ "model-id" : {
62
+ "modalities" : ["TEXT" ]
63
+ }
64
+ }
65
+
66
+ def test_one_inference_profile ():
67
+ client = FakeBedrockClient (
68
+ inference_profile ("us.model-id" , "p1" , "SYSTEM_DEFINED" ),
69
+ model ("model-id" , "model-name" , stream_supported = True , input_modalities = ["TEXT" ])
70
+ )
71
+
72
+ models = list_bedrock_models (client )
73
+
74
+ assert models == {
75
+ "model-id" : {
76
+ "modalities" : ["TEXT" ]
77
+ },
78
+ "us.model-id" : {
79
+ "modalities" : ["TEXT" ]
80
+ }
81
+ }
82
+
83
+ def test_default_model_on_throw ():
84
+ client = ThrowingBedrockClient ()
85
+
86
+ models = list_bedrock_models (client )
87
+
88
+ assert models == {
89
+ "anthropic.claude-3-sonnet-20240229-v1:0" : {
90
+ "modalities" : ["TEXT" , "IMAGE" ]
91
+ }
92
+ }
93
+
94
+ def inference_profile (profile_id : str , name : str , profile_type : Literal ["SYSTEM_DEFINED" , "APPLICATION" ]):
95
+ return {
96
+ "inferenceProfileName" : name ,
97
+ "inferenceProfileId" : profile_id ,
98
+ "type" : profile_type
99
+ }
100
+
101
+ def model (
102
+ model_id : str ,
103
+ model_name : str ,
104
+ input_modalities : list [str ] = None ,
105
+ output_modalities : list [str ] = None ,
106
+ stream_supported : bool = False ,
107
+ inference_types : list [str ] = None ,
108
+ status : str = "ACTIVE" ) -> dict :
109
+ if input_modalities is None :
110
+ input_modalities = ["TEXT" ]
111
+ if output_modalities is None :
112
+ output_modalities = ["TEXT" ]
113
+ if inference_types is None :
114
+ inference_types = ["ON_DEMAND" ]
115
+ return {
116
+ "modelArn" : "arn:model:" + model_id ,
117
+ "modelId" : model_id ,
118
+ "modelName" : model_name ,
119
+ "providerName" : "anthropic" ,
120
+ "inputModalities" :input_modalities ,
121
+ "outputModalities" : output_modalities ,
122
+ "responseStreamingSupported" : stream_supported ,
123
+ "customizationsSupported" : ["FINE_TUNING" ],
124
+ "inferenceTypesSupported" : inference_types ,
125
+ "modelLifecycle" : {
126
+ "status" : status
127
+ }
128
+ }
129
+
130
+ def _filter_inference_profiles (inference_profiles : list [dict ], profile_type : Literal ["SYSTEM_DEFINED" , "APPLICATION" ], max_results : int = 100 ):
131
+ return [p for p in inference_profiles if p .get ("type" ) == profile_type ][:max_results ]
132
+
133
+ def _filter_models (
134
+ models : list [dict ],
135
+ provider_name : str | None ,
136
+ customization_type : Literal ["FINE_TUNING" ,"CONTINUED_PRE_TRAINING" ,"DISTILLATION" ] | None ,
137
+ output_modality : Literal ["TEXT" ,"IMAGE" ,"EMBEDDING" ] | None ,
138
+ inference_type : Literal ["ON_DEMAND" ,"PROVISIONED" ] | None ):
139
+ return [m for m in models if
140
+ (provider_name is None or m .get ("providerName" ) == provider_name ) and
141
+ (output_modality is None or output_modality in m .get ("outputModalities" )) and
142
+ (customization_type is None or customization_type in m .get ("customizationsSupported" )) and
143
+ (inference_type is None or inference_type in m .get ("inferenceTypesSupported" ))
144
+ ]
145
+
146
+ class ThrowingBedrockClient (BedrockClientInterface ):
147
+ def list_inference_profiles (self , ** kwargs ) -> dict :
148
+ raise Exception ("throwing bedrock client always throws exception" )
149
+ def list_foundation_models (self , ** kwargs ) -> dict :
150
+ raise Exception ("throwing bedrock client always throws exception" )
151
+
152
+ class FakeBedrockClient (BedrockClientInterface ):
153
+ def __init__ (self , * args ):
154
+ self .inference_profiles = [p for p in args if p .get ("inferenceProfileId" , "" ) != "" ]
155
+ self .models = [m for m in args if m .get ("modelId" , "" ) != "" ]
156
+
157
+ unexpected = [u for u in args if (u .get ("modelId" , "" ) == "" and u .get ("inferenceProfileId" , "" ) == "" )]
158
+ if len (unexpected ) > 0 :
159
+ raise Exception ("expected a model or a profile" )
160
+
161
+ def list_inference_profiles (self , ** kwargs ) -> dict :
162
+ return {
163
+ "inferenceProfileSummaries" : _filter_inference_profiles (
164
+ self .inference_profiles ,
165
+ profile_type = kwargs ["typeEquals" ],
166
+ max_results = kwargs .get ("maxResults" , 100 )
167
+ )
168
+ }
169
+
170
+ def list_foundation_models (self , ** kwargs ) -> dict :
171
+ return {
172
+ "modelSummaries" : _filter_models (
173
+ self .models ,
174
+ provider_name = kwargs .get ("byProvider" , None ),
175
+ customization_type = kwargs .get ("byCustomizationType" , None ),
176
+ output_modality = kwargs .get ("byOutputModality" , None ),
177
+ inference_type = kwargs .get ("byInferenceType" , None )
178
+ )
179
+ }
0 commit comments