12
12
import tiktoken
13
13
from botocore .config import Config
14
14
from fastapi import HTTPException
15
+ from starlette .concurrency import run_in_threadpool
15
16
16
17
from api .models .base import BaseChatModel , BaseEmbeddingsModel
17
18
from api .schema import (
@@ -145,7 +146,7 @@ def validate(self, chat_request: ChatRequest):
145
146
detail = error ,
146
147
)
147
148
148
- def _invoke_bedrock (self , chat_request : ChatRequest , stream = False ):
149
+ async def _invoke_bedrock (self , chat_request : ChatRequest , stream = False ):
149
150
"""Common logic for invoke bedrock models"""
150
151
if DEBUG :
151
152
logger .info ("Raw request: " + chat_request .model_dump_json ())
@@ -157,9 +158,11 @@ def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
157
158
158
159
try :
159
160
if stream :
160
- response = bedrock_runtime .converse_stream (** args )
161
+ # Run the blocking boto3 call in a thread pool
162
+ response = await run_in_threadpool (bedrock_runtime .converse_stream , ** args )
161
163
else :
162
- response = bedrock_runtime .converse (** args )
164
+ # Run the blocking boto3 call in a thread pool
165
+ response = await run_in_threadpool (bedrock_runtime .converse , ** args )
163
166
except bedrock_runtime .exceptions .ValidationException as e :
164
167
logger .error ("Validation Error: " + str (e ))
165
168
raise HTTPException (status_code = 400 , detail = str (e ))
@@ -171,11 +174,11 @@ def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
171
174
raise HTTPException (status_code = 500 , detail = str (e ))
172
175
return response
173
176
174
- def chat (self , chat_request : ChatRequest ) -> ChatResponse :
177
+ async def chat (self , chat_request : ChatRequest ) -> ChatResponse :
175
178
"""Default implementation for Chat API."""
176
179
177
180
message_id = self .generate_message_id ()
178
- response = self ._invoke_bedrock (chat_request )
181
+ response = await self ._invoke_bedrock (chat_request )
179
182
180
183
output_message = response ["output" ]["message" ]
181
184
input_tokens = response ["usage" ]["inputTokens" ]
@@ -194,9 +197,9 @@ def chat(self, chat_request: ChatRequest) -> ChatResponse:
194
197
logger .info ("Proxy response :" + chat_response .model_dump_json ())
195
198
return chat_response
196
199
197
- def chat_stream (self , chat_request : ChatRequest ) -> AsyncIterable [bytes ]:
200
+ async def chat_stream (self , chat_request : ChatRequest ) -> AsyncIterable [bytes ]:
198
201
"""Default implementation for Chat Stream API"""
199
- response = self ._invoke_bedrock (chat_request , stream = True )
202
+ response = await self ._invoke_bedrock (chat_request , stream = True )
200
203
message_id = self .generate_message_id ()
201
204
stream = response .get ("stream" )
202
205
for chunk in stream :
0 commit comments