Skip to content

feat: init claude3.7 extended thinking mode #491

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 132 additions & 16 deletions anthropic-model-provider/anthropic_common/api.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,172 @@
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
import json
import logging
import time
from typing import Any, AsyncIterable

from anthropic import AsyncAnthropic, AsyncAnthropicBedrock, AsyncStream
from anthropic._types import NOT_GIVEN
from anthropic.types import (
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawContentBlockStopEvent,
RawMessageDeltaEvent,
RawMessageStartEvent,
RawMessageStopEvent,
RawMessageStreamEvent, TextDelta, ThinkingDelta,
)
from fastapi.responses import JSONResponse, StreamingResponse
from openai.types.chat import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta

from .helpers import log, map_messages, map_resp, map_tools


async def completions(client: AsyncAnthropic | AsyncAnthropicBedrock, input: dict):
is_thinking = False

model = str(input["model"])
if model.endswith("-thinking"):
is_thinking = True
model = model.removesuffix("-thinking")

# max_tokens defaults:
# - 4096 for regular models, so that it works with even the smallest models
# - 64000 for thinking models - the max for 3.7 Sonnet with extended thinking mode right now
max_tokens = input.get("max_tokens", 4096 if not is_thinking else 64000)
if max_tokens is not None:
max_tokens = int(max_tokens)

thinking_config: Any | NOT_GIVEN = NOT_GIVEN
if is_thinking:
thinking_config = {
"type": "enabled",
"budget_tokens": round(
max_tokens / 2
), # TODO: figure out a good percentage of max_tokens to use for thinking
}

tools = input.get("tools", NOT_GIVEN)
if tools is not NOT_GIVEN:
tools = map_tools(tools)

system, messages = map_messages(input["messages"])

max_tokens = input.get("max_tokens", 1024)
if max_tokens is not None:
max_tokens = int(max_tokens)

temperature = input.get("temperature", NOT_GIVEN)
temperature = input.get("temperature", NOT_GIVEN) if not is_thinking else NOT_GIVEN
if temperature is not NOT_GIVEN:
temperature = float(temperature)

top_k = input.get("top_k", NOT_GIVEN)
top_k = input.get("top_k", NOT_GIVEN) if not is_thinking else NOT_GIVEN
if top_k is not NOT_GIVEN:
top_k = int(top_k)

top_p = input.get("top_p", NOT_GIVEN)
top_p = input.get("top_p", NOT_GIVEN) if not is_thinking else NOT_GIVEN
if top_p is not NOT_GIVEN:
top_p = float(top_p)

stream = input.get("stream", False)

logging.error(f"@@@ thinking_config: {thinking_config}")
try:
response = await client.messages.create(
thinking=thinking_config,
max_tokens=max_tokens,
system=system,
messages=messages,
model=input["model"],
model=model,
temperature=temperature,
tools=tools,
top_k=top_k,
top_p=top_p,
stream=stream,
)
if stream:
async for event in response:
logging.error(f"@@@Anthropic event: {event.model_dump_json()}")
log(f"Anthropic event: {event.model_dump_json()}")
return StreamingResponse(
"data: " + event.model_dump_json() + "\n\n",
media_type="application/x-ndjson",
)
else:
logging.error(f"@@@ Anthropic response: {response.model_dump_json()}")
log(f"Anthropic response: {response.model_dump_json()}")

mapped_response = map_resp(response)

logging.error(f"@@@ Mapped Anthropic response: {mapped_response.model_dump_json()}")
log(f"Mapped Anthropic response: {mapped_response.model_dump_json()}")
return StreamingResponse(
"data: " + mapped_response.model_dump_json() + "\n\n",
media_type="application/x-ndjson",
)

except Exception as e:
logging.error(f"@@@Anthropic API error: {e}")
return JSONResponse(
content={"error": str(e)}, status_code=e.__dict__.get("status_code", 500)
)

log(f"Anthropic response: {response.model_dump_json()}")
def map_event(event: RawMessageStreamEvent, model: str) -> ChatCompletionChunk:
if isinstance(event, RawContentBlockStartEvent):
if event.type == "tool_use":
c = ChatCompletionChunk(
id="0",
choices=[
Choice(
delta=ChoiceDelta(
content=None,
tool_calls=map_tools(event.content_block.tool_calls),
role="assistant",
),
finish_reason=None,
index=0,
)
],
created=int(time.time()),
model=model,
object="chat.completion.chunk",
)
elif isinstance(event, RawContentBlockDeltaEvent):
content = ""
if isinstance(event.delta, TextDelta):
content = event.delta.text
elif isinstance(event.delta, ThinkingDelta):
content = event.delta.thinking
c = ChatCompletionChunk(
id="0",
choices=[
Choice(
delta=ChoiceDelta(
content=content,
tool_calls=None,
role="assistant",
),
finish_reason=None,
index=0,
)
],
created=int(time.time()),
model=model,
object="chat.completion.chunk",
)

if hasattr(event, "tool_calls"):

mapped_response = map_resp(response)
return c
elif isinstance(event, RawContentBlockStopEvent):
pass
elif isinstance(event, RawMessageStartEvent):
pass
elif isinstance(event, RawMessageDeltaEvent):
pass
elif isinstance(event, RawMessageStopEvent):
pass
else:
raise ValueError(f"Unknown event type: {event}")

log(f"Mapped Anthropic response: {mapped_response.model_dump_json()}")

return StreamingResponse(
"data: " + mapped_response.model_dump_json() + "\n\n",
media_type="application/x-ndjson",
)
async def convert_stream(
stream: AsyncStream[RawMessageStreamEvent], model: str
) -> AsyncIterable[str]:
async for chunk in stream:
yield "data: " + map_event(chunk, model).model_dump_json() + "\n\n"
16 changes: 15 additions & 1 deletion anthropic-model-provider/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import anthropic.pagination
from anthropic import AsyncAnthropic
from anthropic.types import ModelInfo
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse

Expand All @@ -15,6 +16,8 @@
app = FastAPI()
uri = "http://127.0.0.1:" + os.environ.get("PORT", "8000")

thinking_models_prefixes = ["claude-3-7-sonnet"]


def log(*args):
if debug:
Expand All @@ -40,12 +43,23 @@ async def list_models() -> JSONResponse:
resp: anthropic.pagination.AsyncPage[
anthropic.types.ModelInfo
] = await client.models.list(limit=20)
thinking_models = []
for model in resp.data:
if any(model.id.startswith(m) for m in thinking_models_prefixes):
thinking_models.append(
ModelInfo(
id=model.id + "-thinking",
display_name=model.display_name + " (Thinking)",
created_at=model.created_at,
type="model",
)
)
return JSONResponse(
content={
"object": "list",
"data": [
set_model_usage(model.model_dump(exclude={"created_at"}))
for model in resp.data
for model in resp.data + thinking_models
],
}
)
Expand Down
2 changes: 1 addition & 1 deletion anthropic-model-provider/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
fastapi
uvicorn[standard]
anthropic==0.43.0
anthropic==0.49.0
openai>=1.54.3
2 changes: 1 addition & 1 deletion anthropic-model-provider/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
name="anthropic_common",
version="0.1",
packages=find_packages(include=["anthropic_common"]),
install_requires=["fastapi", "openai", "anthropic>=0.43.0", "openai>=1.35.7"],
install_requires=["fastapi", "openai", "anthropic>=0.49.0", "openai>=1.35.7"],
)
Loading