Skip to content

Commit 304d243

Browse files
authored
Kick off provider endpoint CRUD structure and registration (#790)
This structure will handle all the database operations and turn that into the right models. Note that for provider endpoints we already have a way of setting these via configuration, so this is taken into account to output some sample objects that users can leverage. Each provider will need to implement a `models` function which allows us to auto-discover models for a provider. Signed-off-by: Juan Antonio Osorio <[email protected]>
1 parent 62b5b63 commit 304d243

File tree

16 files changed

+606
-69
lines changed

16 files changed

+606
-69
lines changed

src/codegate/api/v1.py

+117-63
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
from typing import List, Optional
2+
from uuid import UUID
23

34
import requests
45
import structlog
5-
from fastapi import APIRouter, HTTPException, Response
6+
from fastapi import APIRouter, Depends, HTTPException, Response
67
from fastapi.responses import StreamingResponse
78
from fastapi.routing import APIRoute
8-
from pydantic import ValidationError
9+
from pydantic import BaseModel, ValidationError
910

1011
from codegate import __version__
1112
from codegate.api import v1_models, v1_processing
1213
from codegate.db.connection import AlreadyExistsError, DbReader
14+
from codegate.providers import crud as provendcrud
1315
from codegate.workspaces import crud
1416

1517
logger = structlog.get_logger("codegate")
1618

1719
v1 = APIRouter()
1820
wscrud = crud.WorkspaceCrud()
21+
pcrud = provendcrud.ProviderCrud()
1922

2023
# This is a singleton object
2124
dbreader = DbReader()
@@ -25,38 +28,78 @@ def uniq_name(route: APIRoute):
2528
return f"v1_{route.name}"
2629

2730

31+
class FilterByNameParams(BaseModel):
32+
name: Optional[str] = None
33+
34+
2835
@v1.get("/provider-endpoints", tags=["Providers"], generate_unique_id_function=uniq_name)
29-
async def list_provider_endpoints(name: Optional[str] = None) -> List[v1_models.ProviderEndpoint]:
36+
async def list_provider_endpoints(
37+
filter_query: FilterByNameParams = Depends(),
38+
) -> List[v1_models.ProviderEndpoint]:
3039
"""List all provider endpoints."""
31-
# NOTE: This is a dummy implementation. In the future, we should have a proper
32-
# implementation that fetches the provider endpoints from the database.
33-
return [
34-
v1_models.ProviderEndpoint(
35-
id=1,
36-
name="dummy",
37-
description="Dummy provider endpoint",
38-
endpoint="http://example.com",
39-
provider_type=v1_models.ProviderType.openai,
40-
auth_type=v1_models.ProviderAuthType.none,
41-
)
42-
]
40+
if filter_query.name is None:
41+
try:
42+
return await pcrud.list_endpoints()
43+
except Exception:
44+
raise HTTPException(status_code=500, detail="Internal server error")
45+
46+
try:
47+
provend = await pcrud.get_endpoint_by_name(filter_query.name)
48+
except Exception:
49+
raise HTTPException(status_code=500, detail="Internal server error")
50+
51+
if provend is None:
52+
raise HTTPException(status_code=404, detail="Provider endpoint not found")
53+
return [provend]
54+
55+
56+
# This needs to be above /provider-endpoints/{provider_id} to avoid conflict
57+
@v1.get(
58+
"/provider-endpoints/models",
59+
tags=["Providers"],
60+
generate_unique_id_function=uniq_name,
61+
)
62+
async def list_all_models_for_all_providers() -> List[v1_models.ModelByProvider]:
63+
"""List all models for all providers."""
64+
try:
65+
return await pcrud.get_all_models()
66+
except Exception:
67+
raise HTTPException(status_code=500, detail="Internal server error")
68+
69+
70+
@v1.get(
71+
"/provider-endpoints/{provider_id}/models",
72+
tags=["Providers"],
73+
generate_unique_id_function=uniq_name,
74+
)
75+
async def list_models_by_provider(
76+
provider_id: UUID,
77+
) -> List[v1_models.ModelByProvider]:
78+
"""List models by provider."""
79+
80+
try:
81+
return await pcrud.models_by_provider(provider_id)
82+
except provendcrud.ProviderNotFoundError:
83+
raise HTTPException(status_code=404, detail="Provider not found")
84+
except Exception as e:
85+
raise HTTPException(status_code=500, detail=str(e))
4386

4487

4588
@v1.get(
4689
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
4790
)
48-
async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint:
91+
async def get_provider_endpoint(
92+
provider_id: UUID,
93+
) -> v1_models.ProviderEndpoint:
4994
"""Get a provider endpoint by ID."""
50-
# NOTE: This is a dummy implementation. In the future, we should have a proper
51-
# implementation that fetches the provider endpoint from the database.
52-
return v1_models.ProviderEndpoint(
53-
id=provider_id,
54-
name="dummy",
55-
description="Dummy provider endpoint",
56-
endpoint="http://example.com",
57-
provider_type=v1_models.ProviderType.openai,
58-
auth_type=v1_models.ProviderAuthType.none,
59-
)
95+
try:
96+
provend = await pcrud.get_endpoint_by_id(provider_id)
97+
except Exception:
98+
raise HTTPException(status_code=500, detail="Internal server error")
99+
100+
if provend is None:
101+
raise HTTPException(status_code=404, detail="Provider endpoint not found")
102+
return provend
60103

61104

62105
@v1.post(
@@ -65,59 +108,65 @@ async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint:
65108
generate_unique_id_function=uniq_name,
66109
status_code=201,
67110
)
68-
async def add_provider_endpoint(request: v1_models.ProviderEndpoint) -> v1_models.ProviderEndpoint:
111+
async def add_provider_endpoint(
112+
request: v1_models.ProviderEndpoint,
113+
) -> v1_models.ProviderEndpoint:
69114
"""Add a provider endpoint."""
70-
# NOTE: This is a dummy implementation. In the future, we should have a proper
71-
# implementation that adds the provider endpoint to the database.
72-
return request
115+
try:
116+
provend = await pcrud.add_endpoint(request)
117+
except AlreadyExistsError:
118+
raise HTTPException(status_code=409, detail="Provider endpoint already exists")
119+
except ValidationError as e:
120+
# TODO: This should be more specific
121+
raise HTTPException(
122+
status_code=400,
123+
detail=str(e),
124+
)
125+
except Exception:
126+
raise HTTPException(status_code=500, detail="Internal server error")
127+
128+
return provend
73129

74130

75131
@v1.put(
76132
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
77133
)
78134
async def update_provider_endpoint(
79-
provider_id: int, request: v1_models.ProviderEndpoint
135+
provider_id: UUID,
136+
request: v1_models.ProviderEndpoint,
80137
) -> v1_models.ProviderEndpoint:
81138
"""Update a provider endpoint by ID."""
82-
# NOTE: This is a dummy implementation. In the future, we should have a proper
83-
# implementation that updates the provider endpoint in the database.
84-
return request
139+
try:
140+
request.id = provider_id
141+
provend = await pcrud.update_endpoint(request)
142+
except ValidationError as e:
143+
# TODO: This should be more specific
144+
raise HTTPException(
145+
status_code=400,
146+
detail=str(e),
147+
)
148+
except Exception:
149+
raise HTTPException(status_code=500, detail="Internal server error")
150+
151+
return provend
85152

86153

87154
@v1.delete(
88155
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
89156
)
90-
async def delete_provider_endpoint(provider_id: int):
157+
async def delete_provider_endpoint(
158+
provider_id: UUID,
159+
):
91160
"""Delete a provider endpoint by id."""
92-
# NOTE: This is a dummy implementation. In the future, we should have a proper
93-
# implementation that deletes the provider endpoint from the database.
161+
try:
162+
await pcrud.delete_endpoint(provider_id)
163+
except provendcrud.ProviderNotFoundError:
164+
raise HTTPException(status_code=404, detail="Provider endpoint not found")
165+
except Exception:
166+
raise HTTPException(status_code=500, detail="Internal server error")
94167
return Response(status_code=204)
95168

96169

97-
@v1.get(
98-
"/provider-endpoints/{provider_name}/models",
99-
tags=["Providers"],
100-
generate_unique_id_function=uniq_name,
101-
)
102-
async def list_models_by_provider(provider_name: str) -> List[v1_models.ModelByProvider]:
103-
"""List models by provider."""
104-
# NOTE: This is a dummy implementation. In the future, we should have a proper
105-
# implementation that fetches the models by provider from the database.
106-
return [v1_models.ModelByProvider(name="dummy", provider="dummy")]
107-
108-
109-
@v1.get(
110-
"/provider-endpoints/models",
111-
tags=["Providers"],
112-
generate_unique_id_function=uniq_name,
113-
)
114-
async def list_all_models_for_all_providers() -> List[v1_models.ModelByProvider]:
115-
"""List all models for all providers."""
116-
# NOTE: This is a dummy implementation. In the future, we should have a proper
117-
# implementation that fetches all the models for all providers from the database.
118-
return [v1_models.ModelByProvider(name="dummy", provider="dummy")]
119-
120-
121170
@v1.get("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name)
122171
async def list_workspaces() -> v1_models.ListWorkspacesResponse:
123172
"""List all workspaces."""
@@ -394,7 +443,9 @@ async def delete_workspace_custom_instructions(workspace_name: str):
394443
tags=["Workspaces", "Muxes"],
395444
generate_unique_id_function=uniq_name,
396445
)
397-
async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]:
446+
async def get_workspace_muxes(
447+
workspace_name: str,
448+
) -> List[v1_models.MuxRule]:
398449
"""Get the mux rules of a workspace.
399450
400451
The list is ordered in order of priority. That is, the first rule in the list
@@ -422,7 +473,10 @@ async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]:
422473
generate_unique_id_function=uniq_name,
423474
status_code=204,
424475
)
425-
async def set_workspace_muxes(workspace_name: str, request: List[v1_models.MuxRule]):
476+
async def set_workspace_muxes(
477+
workspace_name: str,
478+
request: List[v1_models.MuxRule],
479+
):
426480
"""Set the mux rules of a workspace."""
427481
# TODO: This is a dummy implementation. In the future, we should have a proper
428482
# implementation that sets the mux rules in the database.

src/codegate/api/v1_models.py

+33-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from codegate.db import models as db_models
88
from codegate.pipeline.base import CodeSnippet
9+
from codegate.providers.base import BaseProvider
10+
from codegate.providers.registry import ProviderRegistry
911

1012

1113
class Workspace(pydantic.BaseModel):
@@ -122,6 +124,8 @@ class ProviderType(str, Enum):
122124
openai = "openai"
123125
anthropic = "anthropic"
124126
vllm = "vllm"
127+
ollama = "ollama"
128+
lm_studio = "lm_studio"
125129

126130

127131
class TokenUsageByModel(pydantic.BaseModel):
@@ -191,13 +195,38 @@ class ProviderEndpoint(pydantic.BaseModel):
191195
so we can use this for muxing messages.
192196
"""
193197

194-
id: int
198+
# This will be set on creation
199+
id: Optional[str] = ""
195200
name: str
196201
description: str = ""
197202
provider_type: ProviderType
198203
endpoint: str
199204
auth_type: ProviderAuthType
200205

206+
@staticmethod
207+
def from_db_model(db_model: db_models.ProviderEndpoint) -> "ProviderEndpoint":
208+
return ProviderEndpoint(
209+
id=db_model.id,
210+
name=db_model.name,
211+
description=db_model.description,
212+
provider_type=db_model.provider_type,
213+
endpoint=db_model.endpoint,
214+
auth_type=db_model.auth_type,
215+
)
216+
217+
def to_db_model(self) -> db_models.ProviderEndpoint:
218+
return db_models.ProviderEndpoint(
219+
id=self.id,
220+
name=self.name,
221+
description=self.description,
222+
provider_type=self.provider_type,
223+
endpoint=self.endpoint,
224+
auth_type=self.auth_type,
225+
)
226+
227+
def get_from_registry(self, registry: ProviderRegistry) -> Optional[BaseProvider]:
228+
return registry.get_provider(self.provider_type)
229+
201230

202231
class ModelByProvider(pydantic.BaseModel):
203232
"""
@@ -207,10 +236,11 @@ class ModelByProvider(pydantic.BaseModel):
207236
"""
208237

209238
name: str
210-
provider: str
239+
provider_id: str
240+
provider_name: str
211241

212242
def __str__(self):
213-
return f"{self.provider}/{self.name}"
243+
return f"{self.provider_name} / {self.name}"
214244

215245

216246
class MuxMatcherType(str, Enum):

src/codegate/cli.py

+4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from codegate.db.connection import init_db_sync, init_session_if_not_exists
1818
from codegate.pipeline.factory import PipelineFactory
1919
from codegate.pipeline.secrets.manager import SecretsManager
20+
from codegate.providers import crud as provendcrud
2021
from codegate.providers.copilot.provider import CopilotProvider
2122
from codegate.server import init_app
2223
from codegate.storage.utils import restore_storage_backup
@@ -338,6 +339,9 @@ def serve( # noqa: C901
338339
loop = asyncio.new_event_loop()
339340
asyncio.set_event_loop(loop)
340341

342+
registry = app.provider_registry
343+
loop.run_until_complete(provendcrud.initialize_provider_endpoints(registry))
344+
341345
# Run the server
342346
try:
343347
loop.run_until_complete(run_servers(cfg, app))

0 commit comments

Comments
 (0)