Skip to content

feat: Enable Imported models #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,12 @@ Replace the repo url in the CloudFormation template before you deploy.
Yes, you can run this locally, e.g. run below command under `src` folder:

```bash
cd src/
pip install -r requirements.txt
uvicorn api.app:app --host 0.0.0.0 --port 8000
```

The API base url should look like `http://localhost:8000/api/v1`.
The API base url should look like `http://localhost:8000/api/v1` and the API key should be `bedrock`.

### Any performance sacrifice or latency increase by using the proxy APIs

Expand Down
30 changes: 24 additions & 6 deletions deployment/BedrockProxy.template → deployment/BedrockProxy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@ Parameters:
Type: String
Default: anthropic.claude-3-sonnet-20240229-v1:0
Description: The default model ID, please make sure the model ID is supported in the current region
ImageUri:
Type: String
Default: ""
Description: Specify a custom ECR image, if left blank defaults to 366590864501.dkr.ecr.us-east-1.amazonaws.com/bedrock-proxy-api:latest.
EnableImportedModels:
Type: String
Default: false
AllowedValues:
- true
- false
Description: If enabled, models imported into Bedrock will be available to use.
Conditions:
UseDefaultImage: !Equals [!Ref ImageUri, ""]
Resources:
VPCB9E5F0B4:
Type: AWS::EC2::VPC
Expand Down Expand Up @@ -142,6 +155,7 @@ Resources:
- Action:
- bedrock:ListFoundationModels
- bedrock:ListInferenceProfiles
- bedrock:ListImportedModels
Effect: Allow
Resource: "*"
- Action:
Expand All @@ -151,6 +165,7 @@ Resources:
Resource:
- arn:aws:bedrock:*::foundation-model/*
- arn:aws:bedrock:*:*:inference-profile/*
- arn:aws:bedrock:*:*:imported-model/*
- Action:
- secretsmanager:GetSecretValue
- secretsmanager:DescribeSecret
Expand All @@ -167,14 +182,16 @@ Resources:
Architectures:
- arm64
Code:
ImageUri:
Fn::Join:
ImageUri: !If
- UseDefaultImage
- !Join
- ""
- - 366590864501.dkr.ecr.
- Ref: AWS::Region
- - "366590864501.dkr.ecr."
- !Ref AWS::Region
- "."
- Ref: AWS::URLSuffix
- /bedrock-proxy-api:latest
- !Ref AWS::URLSuffix
- "/bedrock-proxy-api:latest"
- !Ref ImageUri
Description: Bedrock Proxy API Handler
Environment:
Variables:
Expand All @@ -185,6 +202,7 @@ Resources:
Ref: DefaultModelId
DEFAULT_EMBEDDING_MODEL: cohere.embed-multilingual-v3
ENABLE_CROSS_REGION_INFERENCE: "true"
ENABLE_IMPORTED_MODELS: !Ref EnableImportedModels
MemorySize: 1024
PackageType: Image
Role:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@ Parameters:
Type: String
Default: anthropic.claude-3-sonnet-20240229-v1:0
Description: The default model ID, please make sure the model ID is supported in the current region
ImageUri:
Type: String
Default: ""
Description: Specify a custom ECR image, if left blank defaults to 366590864501.dkr.ecr.us-east-1.amazonaws.com/bedrock-proxy-api:latest.
EnableImportedModels:
Type: String
Default: false
AllowedValues:
- true
- false
Description: If enabled, models imported into Bedrock will be available to use.
Conditions:
UseDefaultImage: !Equals [!Ref ImageUri, ""]
Resources:
VPCB9E5F0B4:
Type: AWS::EC2::VPC
Expand Down Expand Up @@ -184,6 +197,7 @@ Resources:
- Action:
- bedrock:ListFoundationModels
- bedrock:ListInferenceProfiles
- bedrock:ListImportedModels
Effect: Allow
Resource: "*"
- Action:
Expand All @@ -193,6 +207,7 @@ Resources:
Resource:
- arn:aws:bedrock:*::foundation-model/*
- arn:aws:bedrock:*:*:inference-profile/*
- arn:aws:bedrock:*:*:imported-model/*
Version: "2012-10-17"
PolicyName: ProxyTaskRoleDefaultPolicy933321B8
Roles:
Expand Down Expand Up @@ -222,15 +237,19 @@ Resources:
Value: cohere.embed-multilingual-v3
- Name: ENABLE_CROSS_REGION_INFERENCE
Value: "true"
- Name: ENABLE_IMPORTED_MODELS
Value: !Ref EnableImportedModels
Essential: true
Image:
Fn::Join:
- ""
- - 366590864501.dkr.ecr.
- Ref: AWS::Region
- "."
- Ref: AWS::URLSuffix
- /bedrock-proxy-api-ecs:latest
Image: !If
- UseDefaultImage
- !Join
- ""
- - "366590864501.dkr.ecr."
- !Ref AWS::Region
- "."
- !Ref AWS::URLSuffix
- "/bedrock-proxy-api:latest"
- !Ref ImageUri
Name: proxy-api
PortMappings:
- ContainerPort: 80
Expand Down
28 changes: 27 additions & 1 deletion src/api/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
Embedding,

)
from api.setting import DEBUG, AWS_REGION, ENABLE_CROSS_REGION_INFERENCE, DEFAULT_MODEL
from api.setting import DEBUG, AWS_REGION, ENABLE_CROSS_REGION_INFERENCE, DEFAULT_MODEL, ENABLE_IMPORTED_MODELS

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,6 +99,18 @@ def list_bedrock_models() -> dict:
byOutputModality='TEXT'
)

# Add imported models to the list if ENABLE_IMPORTED_MODELS is true
if ENABLE_IMPORTED_MODELS:
response_imported = bedrock_client.list_imported_models()
print(response_imported)

# Add imported models to the default model list
for model in response_imported['modelSummaries']:
model_id = model.get('modelName')
model_list[f"custom.{model_id}"] = {
'modalities': ["TEXT"]
}

for model in response['modelSummaries']:
model_id = model.get('modelId', 'N/A')
stream_supported = model.get('responseStreamingSupported', True)
Expand Down Expand Up @@ -170,6 +182,20 @@ def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
if DEBUG:
logger.info("Bedrock request: " + json.dumps(str(args)))

if args["modelId"].startswith("custom."):
# For custom models, get the model ARN by listing models and finding matching name
model_name = args["modelId"].replace("custom.", "")
response = bedrock_client.list_imported_models()
for model in response["modelSummaries"]:
if model["modelName"] == model_name:
args["modelId"] = model["modelArn"]
break
else:
raise HTTPException(
status_code=404,
detail=f"Custom model {model_name} not found"
)

try:
if stream:
response = bedrock_runtime.converse_stream(**args)
Expand Down
1 change: 1 addition & 0 deletions src/api/setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
"DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3"
)
ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"
ENABLE_IMPORTED_MODELS = os.environ.get("ENABLE_IMPORTED_MODELS", "true").lower() != "false"