|
1 |
| -from anthropic import AsyncAnthropic, AsyncAnthropicBedrock |
| 1 | +import logging |
| 2 | +from typing import Any, AsyncIterable |
| 3 | + |
| 4 | +from anthropic import AsyncAnthropic, AsyncAnthropicBedrock, AsyncStream |
2 | 5 | from anthropic._types import NOT_GIVEN
|
| 6 | +from anthropic.types import ( |
| 7 | + RawContentBlockDeltaEvent, |
| 8 | + RawContentBlockStartEvent, |
| 9 | + RawContentBlockStopEvent, |
| 10 | + RawMessageDeltaEvent, |
| 11 | + RawMessageStartEvent, |
| 12 | + RawMessageStopEvent, |
| 13 | + RawMessageStreamEvent, |
| 14 | +) |
3 | 15 | from fastapi.responses import JSONResponse, StreamingResponse
|
| 16 | +from pydantic import BaseModel |
4 | 17 |
|
5 | 18 | from .helpers import log, map_messages, map_resp, map_tools
|
6 | 19 |
|
7 | 20 |
|
8 | 21 | async def completions(client: AsyncAnthropic | AsyncAnthropicBedrock, input: dict):
|
| 22 | + is_thinking = False |
| 23 | + |
| 24 | + model = str(input["model"]) |
| 25 | + if model.endswith("-thinking"): |
| 26 | + is_thinking = True |
| 27 | + model = model.removesuffix("-thinking") |
| 28 | + |
| 29 | + # max_tokens defaults: |
| 30 | + # - 4096 for regular models, so that it works with even the smallest models |
| 31 | + # - 64000 for thinking models - the max for 3.7 Sonnet with extended thinking mode right now |
| 32 | + max_tokens = input.get("max_tokens", 4096 if not is_thinking else 64000) |
| 33 | + if max_tokens is not None: |
| 34 | + max_tokens = int(max_tokens) |
| 35 | + |
| 36 | + thinking_config: Any | NOT_GIVEN = NOT_GIVEN |
| 37 | + if is_thinking: |
| 38 | + thinking_config = { |
| 39 | + "type": "enabled", |
| 40 | + "budget_tokens": round( |
| 41 | + max_tokens / 2 |
| 42 | + ), # TODO: figure out a good percentage of max_tokens to use for thinking |
| 43 | + } |
| 44 | + |
9 | 45 | tools = input.get("tools", NOT_GIVEN)
|
10 | 46 | if tools is not NOT_GIVEN:
|
11 | 47 | tools = map_tools(tools)
|
12 | 48 |
|
13 | 49 | system, messages = map_messages(input["messages"])
|
14 | 50 |
|
15 |
| - max_tokens = input.get("max_tokens", 1024) |
16 |
| - if max_tokens is not None: |
17 |
| - max_tokens = int(max_tokens) |
18 |
| - |
19 |
| - temperature = input.get("temperature", NOT_GIVEN) |
| 51 | + temperature = input.get("temperature", NOT_GIVEN) if not is_thinking else NOT_GIVEN |
20 | 52 | if temperature is not NOT_GIVEN:
|
21 | 53 | temperature = float(temperature)
|
22 | 54 |
|
23 |
| - top_k = input.get("top_k", NOT_GIVEN) |
| 55 | + top_k = input.get("top_k", NOT_GIVEN) if not is_thinking else NOT_GIVEN |
24 | 56 | if top_k is not NOT_GIVEN:
|
25 | 57 | top_k = int(top_k)
|
26 | 58 |
|
27 |
| - top_p = input.get("top_p", NOT_GIVEN) |
| 59 | + top_p = input.get("top_p", NOT_GIVEN) if not is_thinking else NOT_GIVEN |
28 | 60 | if top_p is not NOT_GIVEN:
|
29 | 61 | top_p = float(top_p)
|
30 | 62 |
|
| 63 | + stream = input.get("stream", False) |
| 64 | + |
| 65 | + logging.error(f"@@@ thinking_config: {thinking_config}") |
31 | 66 | try:
|
32 | 67 | response = await client.messages.create(
|
| 68 | + thinking=thinking_config, |
33 | 69 | max_tokens=max_tokens,
|
34 | 70 | system=system,
|
35 | 71 | messages=messages,
|
36 |
| - model=input["model"], |
| 72 | + model=model, |
37 | 73 | temperature=temperature,
|
38 | 74 | tools=tools,
|
39 | 75 | top_k=top_k,
|
40 | 76 | top_p=top_p,
|
| 77 | + stream=stream, |
41 | 78 | )
|
| 79 | + if stream: |
| 80 | + async for event in response: |
| 81 | + logging.error(f"@@@Anthropic event: {event.model_dump_json()}") |
| 82 | + log(f"Anthropic event: {event.model_dump_json()}") |
| 83 | + return StreamingResponse( |
| 84 | + "data: " + event.model_dump_json() + "\n\n", |
| 85 | + media_type="application/x-ndjson", |
| 86 | + ) |
| 87 | + else: |
| 88 | + logging.error(f"@@@ Anthropic response: {response.model_dump_json()}") |
| 89 | + log(f"Anthropic response: {response.model_dump_json()}") |
| 90 | + |
| 91 | + mapped_response = map_resp(response) |
| 92 | + |
| 93 | + logging.error( |
| 94 | + f"@@@ Mapped Anthropic response: {mapped_response.model_dump_json()}" |
| 95 | + ) |
| 96 | + log(f"Mapped Anthropic response: {mapped_response.model_dump_json()}") |
| 97 | + return StreamingResponse( |
| 98 | + "data: " + mapped_response.model_dump_json() + "\n\n", |
| 99 | + media_type="application/x-ndjson", |
| 100 | + ) |
| 101 | + |
42 | 102 | except Exception as e:
|
| 103 | + logging.error(f"@@@Anthropic API error: {e}") |
43 | 104 | return JSONResponse(
|
44 | 105 | content={"error": str(e)}, status_code=e.__dict__.get("status_code", 500)
|
45 | 106 | )
|
46 | 107 |
|
47 |
| - log(f"Anthropic response: {response.model_dump_json()}") |
48 | 108 |
|
49 |
| - mapped_response = map_resp(response) |
| 109 | +def map_event(event: RawMessageStreamEvent) -> BaseModel: |
| 110 | + if isinstance(event, RawContentBlockStartEvent): |
| 111 | + return event |
| 112 | + elif isinstance(event, RawContentBlockDeltaEvent): |
| 113 | + return event |
| 114 | + elif isinstance(event, RawContentBlockStopEvent): |
| 115 | + return event |
| 116 | + elif isinstance(event, RawMessageStartEvent): |
| 117 | + return event |
| 118 | + elif isinstance(event, RawMessageDeltaEvent): |
| 119 | + return event |
| 120 | + elif isinstance(event, RawMessageStopEvent): |
| 121 | + return event |
| 122 | + else: |
| 123 | + raise ValueError(f"Unknown event type: {event}") |
50 | 124 |
|
51 |
| - log(f"Mapped Anthropic response: {mapped_response.model_dump_json()}") |
52 | 125 |
|
53 |
| - return StreamingResponse( |
54 |
| - "data: " + mapped_response.model_dump_json() + "\n\n", |
55 |
| - media_type="application/x-ndjson", |
56 |
| - ) |
| 126 | +async def convert_stream( |
| 127 | + stream: AsyncStream[RawMessageStreamEvent], |
| 128 | +) -> AsyncIterable[str]: |
| 129 | + async for chunk in stream: |
| 130 | + yield "data: " + map_event(chunk).model_dump_json() + "\n\n" |
0 commit comments