13
13
14
14
app .mount ("/demo" , StaticFiles (directory = "static" , html = True ))
15
15
16
+
16
17
@app .get ("/" )
17
18
async def root ():
18
19
return RedirectResponse (url = '/demo/' )
19
20
21
+
20
22
class Story (BaseModel ):
21
- topic : Optional [str ] = None
23
+ topic : Optional [str ] = None
24
+
22
25
23
26
@app .post ("/api/story" )
24
27
def api_story (story : Story ):
25
28
if story .topic == None or story .topic == "" :
26
- return None
29
+ return None
27
30
28
31
return StreamingResponse (bedrock_stream (story .topic ), media_type = "text/html" )
29
32
30
33
31
34
bedrock = boto3 .client ('bedrock-runtime' )
32
35
36
+
33
37
async def bedrock_stream (topic : str ):
34
38
instruction = f"""
35
39
You are a world class writer. Please write a sweet bedtime story about { topic } .
36
40
"""
37
-
38
41
body = json .dumps ({
39
- 'prompt' : f'Human:{ instruction } \n \n Assistant:' ,
40
- 'max_tokens_to_sample' : 1028 ,
41
- 'temperature' : 1 ,
42
- 'top_k' : 250 ,
43
- 'top_p' : 0.999 ,
44
- 'stop_sequences' : ['\n \n Human:' ]
42
+ "anthropic_version" : "bedrock-2023-05-31" ,
43
+ "max_tokens" : 1024 ,
44
+ "messages" : [
45
+ {
46
+ "role" : "user" ,
47
+ "content" : instruction ,
48
+ }
49
+ ],
45
50
})
51
+
46
52
response = bedrock .invoke_model_with_response_stream (
47
- modelId = 'anthropic.claude-v2 ' ,
53
+ modelId = 'anthropic.claude-3-haiku-20240307-v1:0 ' ,
48
54
body = body
49
55
)
50
56
@@ -53,8 +59,12 @@ async def bedrock_stream(topic: str):
53
59
for event in stream :
54
60
chunk = event .get ('chunk' )
55
61
if chunk :
56
- yield json .loads (chunk .get ('bytes' ).decode ())['completion' ]
62
+ message = json .loads (chunk .get ("bytes" ).decode ())
63
+ if message ['type' ] == "content_block_delta" :
64
+ yield message ['delta' ]['text' ] or ""
65
+ elif message ['type' ] == "message_stop" :
66
+ yield "\n "
57
67
58
68
59
69
if __name__ == "__main__" :
60
- uvicorn .run (app , host = "0.0.0.0" , port = int (os .environ .get ("PORT" , "8080" )))
70
+ uvicorn .run (app , host = "0.0.0.0" , port = int (os .environ .get ("PORT" , "8080" )))
0 commit comments