Skip to content

Commit 282fa0e

Browse files
authored
feat(assistants): mock api (QuivrHQ#3195)
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
1 parent 4390d31 commit 282fa0e

File tree

109 files changed

+922
-882
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

109 files changed

+922
-882
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,4 @@ backend/core/examples/chatbot/.chainlit/translations/en-US.json
102102
# Tox
103103
.tox
104104
Pipfile
105+
*.pkl

backend/api/pyproject.toml

-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ dependencies = [
1919
"pydantic-settings>=2.4.0",
2020
"python-dotenv>=1.0.1",
2121
"unidecode>=1.3.8",
22-
"fpdf>=1.7.2",
2322
"colorlog>=6.8.2",
2423
"posthog>=3.5.0",
2524
"pyinstrument>=4.7.2",

backend/api/quivr_api/logger.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from colorlog import (
66
ColoredFormatter,
7-
) # You need to install this package: pip install colorlog
7+
)
88

99

1010
def get_logger(logger_name, log_file="application.log"):

backend/api/quivr_api/middlewares/auth/jwt_token_handler.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from jose import jwt
66
from jose.exceptions import JWTError
7+
78
from quivr_api.modules.user.entity.user_identity import UserIdentity
89

910
SECRET_KEY = os.environ.get("JWT_SECRET_KEY")

backend/api/quivr_api/models/brains_subscription_invitations.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from uuid import UUID
22

33
from pydantic import BaseModel, ConfigDict
4+
45
from quivr_api.logger import get_logger
56

67
logger = get_logger(__name__)

backend/api/quivr_api/modules/analytics/controller/analytics_routes.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from uuid import UUID
22

33
from fastapi import APIRouter, Depends, Query
4+
45
from quivr_api.middlewares.auth.auth_bearer import AuthBearer, get_current_user
56
from quivr_api.modules.analytics.entity.analytics import Range
67
from quivr_api.modules.analytics.service.analytics_service import AnalyticsService
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1+
from datetime import date
12
from enum import IntEnum
23
from typing import List
4+
35
from pydantic import BaseModel
4-
from datetime import date
6+
57

68
class Range(IntEnum):
79
WEEK = 7
810
MONTH = 30
911
QUARTER = 90
1012

13+
1114
class Usage(BaseModel):
1215
date: date
1316
usage_count: int
1417

18+
1519
class BrainsUsages(BaseModel):
16-
usages: List[Usage]
20+
usages: List[Usage]

backend/api/quivr_api/modules/analytics/service/analytics_service.py

-1
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,4 @@ def __init__(self):
1111
self.repository = Analytics()
1212

1313
def get_brains_usages(self, user_id, graph_range, brain_id=None):
14-
1514
return self.repository.get_brains_usages(user_id, graph_range, brain_id)

backend/api/quivr_api/modules/api_key/controller/api_key_routes.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from uuid import uuid4
44

55
from fastapi import APIRouter, Depends
6+
67
from quivr_api.logger import get_logger
78
from quivr_api.middlewares.auth import AuthBearer, get_current_user
89
from quivr_api.modules.api_key.dto.outputs import ApiKeyInfo

backend/api/quivr_api/modules/api_key/service/api_key_service.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from datetime import datetime
22

33
from fastapi import HTTPException
4+
45
from quivr_api.logger import get_logger
56
from quivr_api.modules.api_key.repository.api_key_interface import ApiKeysInterface
67
from quivr_api.modules.api_key.repository.api_keys import ApiKeys
Original file line numberDiff line numberDiff line change
@@ -1 +1,6 @@
1+
# noqa:
12
from .assistant_routes import assistant_router
3+
4+
__all__ = [
5+
"assistant_router",
6+
]
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,176 @@
1-
from typing import List
1+
import io
2+
from typing import Annotated, List
3+
from uuid import uuid4
24

3-
from fastapi import APIRouter, Depends, HTTPException, UploadFile
5+
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
6+
7+
from quivr_api.celery_config import celery
48
from quivr_api.logger import get_logger
5-
from quivr_api.middlewares.auth import AuthBearer, get_current_user
6-
from quivr_api.modules.assistant.dto.inputs import InputAssistant
9+
from quivr_api.middlewares.auth.auth_bearer import AuthBearer, get_current_user
10+
from quivr_api.modules.assistant.controller.assistants_definition import (
11+
assistants,
12+
validate_assistant_input,
13+
)
14+
from quivr_api.modules.assistant.dto.inputs import CreateTask, InputAssistant
715
from quivr_api.modules.assistant.dto.outputs import AssistantOutput
8-
from quivr_api.modules.assistant.ito.difference import DifferenceAssistant
9-
from quivr_api.modules.assistant.ito.summary import SummaryAssistant, summary_inputs
10-
from quivr_api.modules.assistant.service.assistant import Assistant
16+
from quivr_api.modules.assistant.entity.assistant_entity import (
17+
AssistantSettings,
18+
)
19+
from quivr_api.modules.assistant.services.tasks_service import TasksService
20+
from quivr_api.modules.dependencies import get_service
21+
from quivr_api.modules.upload.service.upload_file import (
22+
upload_file_storage,
23+
)
1124
from quivr_api.modules.user.entity.user_identity import UserIdentity
1225

13-
assistant_router = APIRouter()
1426
logger = get_logger(__name__)
1527

16-
assistant_service = Assistant()
28+
29+
assistant_router = APIRouter()
30+
31+
32+
TasksServiceDep = Annotated[TasksService, Depends(get_service(TasksService))]
33+
UserIdentityDep = Annotated[UserIdentity, Depends(get_current_user)]
1734

1835

1936
@assistant_router.get(
2037
"/assistants", dependencies=[Depends(AuthBearer())], tags=["Assistant"]
2138
)
22-
async def list_assistants(
39+
async def get_assistants(
40+
request: Request,
2341
current_user: UserIdentity = Depends(get_current_user),
2442
) -> List[AssistantOutput]:
25-
"""
26-
Retrieve and list all the knowledge in a brain.
27-
"""
43+
logger.info("Getting assistants")
44+
45+
return assistants
2846

29-
summary = summary_inputs()
30-
# difference = difference_inputs()
31-
# crawler = crawler_inputs()
32-
return [summary]
47+
48+
@assistant_router.get(
49+
"/assistants/tasks", dependencies=[Depends(AuthBearer())], tags=["Assistant"]
50+
)
51+
async def get_tasks(
52+
request: Request,
53+
current_user: UserIdentityDep,
54+
tasks_service: TasksServiceDep,
55+
):
56+
logger.info("Getting tasks")
57+
return await tasks_service.get_tasks_by_user_id(current_user.id)
3358

3459

3560
@assistant_router.post(
36-
"/assistant/process",
37-
dependencies=[Depends(AuthBearer())],
38-
tags=["Assistant"],
61+
"/assistants/task", dependencies=[Depends(AuthBearer())], tags=["Assistant"]
3962
)
40-
async def process_assistant(
63+
async def create_task(
64+
current_user: UserIdentityDep,
65+
tasks_service: TasksServiceDep,
66+
request: Request,
4167
input: InputAssistant,
4268
files: List[UploadFile] = None,
43-
current_user: UserIdentity = Depends(get_current_user),
4469
):
45-
if input.name.lower() == "summary":
46-
summary_assistant = SummaryAssistant(
47-
input=input, files=files, current_user=current_user
48-
)
49-
try:
50-
summary_assistant.check_input()
51-
return await summary_assistant.process_assistant()
52-
except ValueError as e:
53-
raise HTTPException(status_code=400, detail=str(e))
54-
elif input.name.lower() == "difference":
55-
difference_assistant = DifferenceAssistant(
56-
input=input, files=files, current_user=current_user
57-
)
70+
assistant = next(
71+
(assistant for assistant in assistants if assistant.id == input.id), None
72+
)
73+
if assistant is None:
74+
raise HTTPException(status_code=404, detail="Assistant not found")
75+
76+
is_valid, validation_errors = validate_assistant_input(input, assistant)
77+
if not is_valid:
78+
for error in validation_errors:
79+
print(error)
80+
raise HTTPException(status_code=400, detail=error)
81+
else:
82+
print("Assistant input is valid.")
83+
notification_uuid = uuid4()
84+
85+
# Process files dynamically
86+
for upload_file in files:
87+
file_name_path = f"{input.id}/{notification_uuid}/{upload_file.filename}"
88+
buff_reader = io.BufferedReader(upload_file.file) # type: ignore
5889
try:
59-
difference_assistant.check_input()
60-
return await difference_assistant.process_assistant()
61-
except ValueError as e:
62-
raise HTTPException(status_code=400, detail=str(e))
63-
return {"message": "Assistant not found"}
90+
await upload_file_storage(buff_reader, file_name_path)
91+
except Exception as e:
92+
logger.exception(f"Exception in upload_route {e}")
93+
raise HTTPException(
94+
status_code=500, detail=f"Failed to upload file to storage. {e}"
95+
)
96+
97+
task = CreateTask(
98+
assistant_id=input.id,
99+
pretty_id=str(notification_uuid),
100+
settings=input.model_dump(mode="json"),
101+
)
102+
103+
task_created = await tasks_service.create_task(task, current_user.id)
104+
105+
celery.send_task(
106+
"process_assistant_task",
107+
kwargs={
108+
"assistant_id": input.id,
109+
"notification_uuid": notification_uuid,
110+
"task_id": task_created.id,
111+
"user_id": str(current_user.id),
112+
},
113+
)
114+
return task_created
115+
116+
117+
@assistant_router.get(
118+
"/assistants/task/{task_id}",
119+
dependencies=[Depends(AuthBearer())],
120+
tags=["Assistant"],
121+
)
122+
async def get_task(
123+
request: Request,
124+
task_id: str,
125+
current_user: UserIdentityDep,
126+
tasks_service: TasksServiceDep,
127+
):
128+
return await tasks_service.get_task_by_id(task_id, current_user.id) # type: ignore
129+
130+
131+
@assistant_router.delete(
132+
"/assistants/task/{task_id}",
133+
dependencies=[Depends(AuthBearer())],
134+
tags=["Assistant"],
135+
)
136+
async def delete_task(
137+
request: Request,
138+
task_id: int,
139+
current_user: UserIdentityDep,
140+
tasks_service: TasksServiceDep,
141+
):
142+
return await tasks_service.delete_task(task_id, current_user.id)
143+
144+
145+
@assistant_router.get(
146+
"/assistants/task/{task_id}/download",
147+
dependencies=[Depends(AuthBearer())],
148+
tags=["Assistant"],
149+
)
150+
async def get_download_link_task(
151+
request: Request,
152+
task_id: int,
153+
current_user: UserIdentityDep,
154+
tasks_service: TasksServiceDep,
155+
):
156+
return await tasks_service.get_download_link_task(task_id, current_user.id)
157+
158+
159+
@assistant_router.get(
160+
"/assistants/{assistant_id}/config",
161+
dependencies=[Depends(AuthBearer())],
162+
tags=["Assistant"],
163+
response_model=AssistantSettings,
164+
summary="Retrieve assistant configuration",
165+
description="Get the settings and file requirements for the specified assistant.",
166+
)
167+
async def get_assistant_config(
168+
assistant_id: int,
169+
current_user: UserIdentityDep,
170+
):
171+
assistant = next(
172+
(assistant for assistant in assistants if assistant.id == assistant_id), None
173+
)
174+
if assistant is None:
175+
raise HTTPException(status_code=404, detail="Assistant not found")
176+
return assistant.settings

0 commit comments

Comments
 (0)