Skip to content

Commit c4cd39e

Browse files
authored
refactor: refactor the FastAPI response streaming with Claude3 (#416)
1 parent 1056168 commit c4cd39e

File tree

1 file changed

+22
-12
lines changed
  • examples/fastapi-response-streaming/app

1 file changed

+22
-12
lines changed

examples/fastapi-response-streaming/app/main.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,38 +13,44 @@
1313

1414
app.mount("/demo", StaticFiles(directory="static", html=True))
1515

16+
1617
@app.get("/")
1718
async def root():
1819
return RedirectResponse(url='/demo/')
1920

21+
2022
class Story(BaseModel):
21-
topic: Optional[str] = None
23+
topic: Optional[str] = None
24+
2225

2326
@app.post("/api/story")
2427
def api_story(story: Story):
2528
if story.topic == None or story.topic == "":
26-
return None
29+
return None
2730

2831
return StreamingResponse(bedrock_stream(story.topic), media_type="text/html")
2932

3033

3134
bedrock = boto3.client('bedrock-runtime')
3235

36+
3337
async def bedrock_stream(topic: str):
3438
instruction = f"""
3539
You are a world class writer. Please write a sweet bedtime story about {topic}.
3640
"""
37-
3841
body = json.dumps({
39-
'prompt': f'Human:{instruction}\n\nAssistant:',
40-
'max_tokens_to_sample': 1028,
41-
'temperature': 1,
42-
'top_k': 250,
43-
'top_p': 0.999,
44-
'stop_sequences': ['\n\nHuman:']
42+
"anthropic_version": "bedrock-2023-05-31",
43+
"max_tokens": 1024,
44+
"messages": [
45+
{
46+
"role": "user",
47+
"content": instruction,
48+
}
49+
],
4550
})
51+
4652
response = bedrock.invoke_model_with_response_stream(
47-
modelId='anthropic.claude-v2',
53+
modelId='anthropic.claude-3-haiku-20240307-v1:0',
4854
body=body
4955
)
5056

@@ -53,8 +59,12 @@ async def bedrock_stream(topic: str):
5359
for event in stream:
5460
chunk = event.get('chunk')
5561
if chunk:
56-
yield json.loads(chunk.get('bytes').decode())['completion']
62+
message = json.loads(chunk.get("bytes").decode())
63+
if message['type'] == "content_block_delta":
64+
yield message['delta']['text'] or ""
65+
elif message['type'] == "message_stop":
66+
yield "\n"
5767

5868

5969
if __name__ == "__main__":
60-
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "8080")))
70+
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "8080")))

0 commit comments

Comments
 (0)