-
Notifications
You must be signed in to change notification settings - Fork 78
/
Copy pathprovider.py
164 lines (142 loc) · 5.17 KB
/
provider.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
from typing import Callable, List
import httpx
import structlog
from fastapi import Header, HTTPException, Request
from codegate.clients.clients import ClientType
from codegate.clients.detector import DetectClient
from codegate.pipeline.factory import PipelineFactory
from codegate.profiling import profiled
from codegate.providers.anthropic.completion_handler import AnthropicCompletion
from codegate.providers.base import BaseProvider, ModelFetchError
from codegate.providers.fim_analyzer import FIMAnalyzer
from codegate.types.anthropic import (
ChatCompletionRequest,
single_message,
single_response,
stream_generator,
)
from codegate.types.generators import (
completion_handler_replacement,
)
logger = structlog.get_logger("codegate")
class AnthropicProvider(BaseProvider):
def __init__(
self,
pipeline_factory: PipelineFactory,
):
if self._get_base_url() != "":
self.base_url = self._get_base_url()
else:
self.base_url = "https://api.anthropic.com/v1"
completion_handler = AnthropicCompletion(stream_generator=stream_generator)
super().__init__(
None,
None,
completion_handler,
pipeline_factory,
)
@property
def provider_route_name(self) -> str:
return "anthropic"
def models(self, endpoint: str = None, api_key: str = None) -> List[str]:
headers = {
"Content-Type": "application/json",
"anthropic-version": "2023-06-01",
}
if api_key:
headers["x-api-key"] = api_key
if not endpoint:
endpoint = "https://api.anthropic.com"
resp = httpx.get(
f"{endpoint}/v1/models",
headers=headers,
)
if resp.status_code != 200:
raise ModelFetchError(f"Failed to fetch models from Anthropic API: {resp.text}")
respjson = resp.json()
return [model["id"] for model in respjson.get("data", [])]
@profiled("anthropic")
async def process_request(
self,
data: dict,
api_key: str,
base_url: str,
is_fim_request: bool,
client_type: ClientType,
completion_handler: Callable | None = None,
stream_generator: Callable | None = None,
):
try:
stream = await self.complete(
data,
api_key,
base_url,
is_fim_request,
client_type,
completion_handler=completion_handler,
)
except Exception as e:
# check if we have an status code there
if hasattr(e, "status_code"):
# log the exception
logger.exception("Error in AnthropicProvider completion")
raise HTTPException(status_code=e.status_code, detail=str(e)) # type: ignore
else:
# just continue raising the exception
raise e
return self._completion_handler.create_response(
stream,
client_type,
stream_generator=stream_generator,
)
def _setup_routes(self):
"""
Sets up the /messages route for the provider as expected by the Anthropic
API. Extracts the API key from the "x-api-key" header and passes it to the
completion handler.
There are two routes:
- /messages: This is the route that is used by the Anthropic API with Continue.dev
- /v1/messages: This is the route that is used by the Anthropic API with Cline
"""
@self.router.post(f"/{self.provider_route_name}/messages")
@self.router.post(f"/{self.provider_route_name}/v1/messages")
@DetectClient()
async def create_message(
request: Request,
x_api_key: str = Header(None),
):
if x_api_key == "":
raise HTTPException(status_code=401, detail="No API key provided")
body = await request.body()
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
print(f"{body.decode('utf-8')}")
req = ChatCompletionRequest.model_validate_json(body)
is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, req)
if req.stream:
return await self.process_request(
req,
x_api_key,
self.base_url,
is_fim_request,
request.state.detected_client,
)
else:
return await self.process_request(
req,
x_api_key,
self.base_url,
is_fim_request,
request.state.detected_client,
completion_handler=completion_handler_replacement(single_message),
stream_generator=single_response,
)
async def dumper(stream):
print("==========")
async for event in stream:
res = (
f"event: {event.type}\ndata: {event.json(exclude_defaults=True, exclude_unset=True)}\n"
)
print(res)
yield res
print("==========")