1
1
from typing import List , Optional
2
+ from uuid import UUID
2
3
3
4
import requests
4
5
import structlog
5
- from fastapi import APIRouter , HTTPException , Response
6
+ from fastapi import APIRouter , Depends , HTTPException , Response
6
7
from fastapi .responses import StreamingResponse
7
8
from fastapi .routing import APIRoute
8
- from pydantic import ValidationError
9
+ from pydantic import BaseModel , ValidationError
9
10
10
11
from codegate import __version__
11
12
from codegate .api import v1_models , v1_processing
12
13
from codegate .db .connection import AlreadyExistsError , DbReader
14
+ from codegate .providers import crud as provendcrud
13
15
from codegate .workspaces import crud
14
16
15
17
logger = structlog .get_logger ("codegate" )
16
18
17
19
v1 = APIRouter ()
18
20
wscrud = crud .WorkspaceCrud ()
21
+ pcrud = provendcrud .ProviderCrud ()
19
22
20
23
# This is a singleton object
21
24
dbreader = DbReader ()
@@ -25,38 +28,78 @@ def uniq_name(route: APIRoute):
25
28
return f"v1_{ route .name } "
26
29
27
30
31
+ class FilterByNameParams (BaseModel ):
32
+ name : Optional [str ] = None
33
+
34
+
28
35
@v1 .get ("/provider-endpoints" , tags = ["Providers" ], generate_unique_id_function = uniq_name )
29
- async def list_provider_endpoints (name : Optional [str ] = None ) -> List [v1_models .ProviderEndpoint ]:
36
+ async def list_provider_endpoints (
37
+ filter_query : FilterByNameParams = Depends (),
38
+ ) -> List [v1_models .ProviderEndpoint ]:
30
39
"""List all provider endpoints."""
31
- # NOTE: This is a dummy implementation. In the future, we should have a proper
32
- # implementation that fetches the provider endpoints from the database.
33
- return [
34
- v1_models .ProviderEndpoint (
35
- id = 1 ,
36
- name = "dummy" ,
37
- description = "Dummy provider endpoint" ,
38
- endpoint = "http://example.com" ,
39
- provider_type = v1_models .ProviderType .openai ,
40
- auth_type = v1_models .ProviderAuthType .none ,
41
- )
42
- ]
40
+ if filter_query .name is None :
41
+ try :
42
+ return await pcrud .list_endpoints ()
43
+ except Exception :
44
+ raise HTTPException (status_code = 500 , detail = "Internal server error" )
45
+
46
+ try :
47
+ provend = await pcrud .get_endpoint_by_name (filter_query .name )
48
+ except Exception :
49
+ raise HTTPException (status_code = 500 , detail = "Internal server error" )
50
+
51
+ if provend is None :
52
+ raise HTTPException (status_code = 404 , detail = "Provider endpoint not found" )
53
+ return [provend ]
54
+
55
+
56
+ # This needs to be above /provider-endpoints/{provider_id} to avoid conflict
57
+ @v1 .get (
58
+ "/provider-endpoints/models" ,
59
+ tags = ["Providers" ],
60
+ generate_unique_id_function = uniq_name ,
61
+ )
62
+ async def list_all_models_for_all_providers () -> List [v1_models .ModelByProvider ]:
63
+ """List all models for all providers."""
64
+ try :
65
+ return await pcrud .get_all_models ()
66
+ except Exception :
67
+ raise HTTPException (status_code = 500 , detail = "Internal server error" )
68
+
69
+
70
+ @v1 .get (
71
+ "/provider-endpoints/{provider_id}/models" ,
72
+ tags = ["Providers" ],
73
+ generate_unique_id_function = uniq_name ,
74
+ )
75
+ async def list_models_by_provider (
76
+ provider_id : UUID ,
77
+ ) -> List [v1_models .ModelByProvider ]:
78
+ """List models by provider."""
79
+
80
+ try :
81
+ return await pcrud .models_by_provider (provider_id )
82
+ except provendcrud .ProviderNotFoundError :
83
+ raise HTTPException (status_code = 404 , detail = "Provider not found" )
84
+ except Exception as e :
85
+ raise HTTPException (status_code = 500 , detail = str (e ))
43
86
44
87
45
88
@v1 .get (
46
89
"/provider-endpoints/{provider_id}" , tags = ["Providers" ], generate_unique_id_function = uniq_name
47
90
)
48
- async def get_provider_endpoint (provider_id : int ) -> v1_models .ProviderEndpoint :
91
+ async def get_provider_endpoint (
92
+ provider_id : UUID ,
93
+ ) -> v1_models .ProviderEndpoint :
49
94
"""Get a provider endpoint by ID."""
50
- # NOTE: This is a dummy implementation. In the future, we should have a proper
51
- # implementation that fetches the provider endpoint from the database.
52
- return v1_models .ProviderEndpoint (
53
- id = provider_id ,
54
- name = "dummy" ,
55
- description = "Dummy provider endpoint" ,
56
- endpoint = "http://example.com" ,
57
- provider_type = v1_models .ProviderType .openai ,
58
- auth_type = v1_models .ProviderAuthType .none ,
59
- )
95
+ try :
96
+ provend = await pcrud .get_endpoint_by_id (provider_id )
97
+ except Exception :
98
+ raise HTTPException (status_code = 500 , detail = "Internal server error" )
99
+
100
+ if provend is None :
101
+ raise HTTPException (status_code = 404 , detail = "Provider endpoint not found" )
102
+ return provend
60
103
61
104
62
105
@v1 .post (
@@ -65,59 +108,65 @@ async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint:
65
108
generate_unique_id_function = uniq_name ,
66
109
status_code = 201 ,
67
110
)
68
- async def add_provider_endpoint (request : v1_models .ProviderEndpoint ) -> v1_models .ProviderEndpoint :
111
+ async def add_provider_endpoint (
112
+ request : v1_models .ProviderEndpoint ,
113
+ ) -> v1_models .ProviderEndpoint :
69
114
"""Add a provider endpoint."""
70
- # NOTE: This is a dummy implementation. In the future, we should have a proper
71
- # implementation that adds the provider endpoint to the database.
72
- return request
115
+ try :
116
+ provend = await pcrud .add_endpoint (request )
117
+ except AlreadyExistsError :
118
+ raise HTTPException (status_code = 409 , detail = "Provider endpoint already exists" )
119
+ except ValidationError as e :
120
+ # TODO: This should be more specific
121
+ raise HTTPException (
122
+ status_code = 400 ,
123
+ detail = str (e ),
124
+ )
125
+ except Exception :
126
+ raise HTTPException (status_code = 500 , detail = "Internal server error" )
127
+
128
+ return provend
73
129
74
130
75
131
@v1 .put (
76
132
"/provider-endpoints/{provider_id}" , tags = ["Providers" ], generate_unique_id_function = uniq_name
77
133
)
78
134
async def update_provider_endpoint (
79
- provider_id : int , request : v1_models .ProviderEndpoint
135
+ provider_id : UUID ,
136
+ request : v1_models .ProviderEndpoint ,
80
137
) -> v1_models .ProviderEndpoint :
81
138
"""Update a provider endpoint by ID."""
82
- # NOTE: This is a dummy implementation. In the future, we should have a proper
83
- # implementation that updates the provider endpoint in the database.
84
- return request
139
+ try :
140
+ request .id = provider_id
141
+ provend = await pcrud .update_endpoint (request )
142
+ except ValidationError as e :
143
+ # TODO: This should be more specific
144
+ raise HTTPException (
145
+ status_code = 400 ,
146
+ detail = str (e ),
147
+ )
148
+ except Exception :
149
+ raise HTTPException (status_code = 500 , detail = "Internal server error" )
150
+
151
+ return provend
85
152
86
153
87
154
@v1 .delete (
88
155
"/provider-endpoints/{provider_id}" , tags = ["Providers" ], generate_unique_id_function = uniq_name
89
156
)
90
- async def delete_provider_endpoint (provider_id : int ):
157
+ async def delete_provider_endpoint (
158
+ provider_id : UUID ,
159
+ ):
91
160
"""Delete a provider endpoint by id."""
92
- # NOTE: This is a dummy implementation. In the future, we should have a proper
93
- # implementation that deletes the provider endpoint from the database.
161
+ try :
162
+ await pcrud .delete_endpoint (provider_id )
163
+ except provendcrud .ProviderNotFoundError :
164
+ raise HTTPException (status_code = 404 , detail = "Provider endpoint not found" )
165
+ except Exception :
166
+ raise HTTPException (status_code = 500 , detail = "Internal server error" )
94
167
return Response (status_code = 204 )
95
168
96
169
97
- @v1 .get (
98
- "/provider-endpoints/{provider_name}/models" ,
99
- tags = ["Providers" ],
100
- generate_unique_id_function = uniq_name ,
101
- )
102
- async def list_models_by_provider (provider_name : str ) -> List [v1_models .ModelByProvider ]:
103
- """List models by provider."""
104
- # NOTE: This is a dummy implementation. In the future, we should have a proper
105
- # implementation that fetches the models by provider from the database.
106
- return [v1_models .ModelByProvider (name = "dummy" , provider = "dummy" )]
107
-
108
-
109
- @v1 .get (
110
- "/provider-endpoints/models" ,
111
- tags = ["Providers" ],
112
- generate_unique_id_function = uniq_name ,
113
- )
114
- async def list_all_models_for_all_providers () -> List [v1_models .ModelByProvider ]:
115
- """List all models for all providers."""
116
- # NOTE: This is a dummy implementation. In the future, we should have a proper
117
- # implementation that fetches all the models for all providers from the database.
118
- return [v1_models .ModelByProvider (name = "dummy" , provider = "dummy" )]
119
-
120
-
121
170
@v1 .get ("/workspaces" , tags = ["Workspaces" ], generate_unique_id_function = uniq_name )
122
171
async def list_workspaces () -> v1_models .ListWorkspacesResponse :
123
172
"""List all workspaces."""
@@ -394,7 +443,9 @@ async def delete_workspace_custom_instructions(workspace_name: str):
394
443
tags = ["Workspaces" , "Muxes" ],
395
444
generate_unique_id_function = uniq_name ,
396
445
)
397
- async def get_workspace_muxes (workspace_name : str ) -> List [v1_models .MuxRule ]:
446
+ async def get_workspace_muxes (
447
+ workspace_name : str ,
448
+ ) -> List [v1_models .MuxRule ]:
398
449
"""Get the mux rules of a workspace.
399
450
400
451
The list is ordered in order of priority. That is, the first rule in the list
@@ -422,7 +473,10 @@ async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]:
422
473
generate_unique_id_function = uniq_name ,
423
474
status_code = 204 ,
424
475
)
425
- async def set_workspace_muxes (workspace_name : str , request : List [v1_models .MuxRule ]):
476
+ async def set_workspace_muxes (
477
+ workspace_name : str ,
478
+ request : List [v1_models .MuxRule ],
479
+ ):
426
480
"""Set the mux rules of a workspace."""
427
481
# TODO: This is a dummy implementation. In the future, we should have a proper
428
482
# implementation that sets the mux rules in the database.
0 commit comments