|
1 |
| -from typing import List |
| 1 | +import io |
| 2 | +from typing import Annotated, List |
| 3 | +from uuid import uuid4 |
2 | 4 |
|
3 |
| -from fastapi import APIRouter, Depends, HTTPException, UploadFile |
| 5 | +from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile |
| 6 | + |
| 7 | +from quivr_api.celery_config import celery |
4 | 8 | from quivr_api.logger import get_logger
|
5 |
| -from quivr_api.middlewares.auth import AuthBearer, get_current_user |
6 |
| -from quivr_api.modules.assistant.dto.inputs import InputAssistant |
| 9 | +from quivr_api.middlewares.auth.auth_bearer import AuthBearer, get_current_user |
| 10 | +from quivr_api.modules.assistant.controller.assistants_definition import ( |
| 11 | + assistants, |
| 12 | + validate_assistant_input, |
| 13 | +) |
| 14 | +from quivr_api.modules.assistant.dto.inputs import CreateTask, InputAssistant |
7 | 15 | from quivr_api.modules.assistant.dto.outputs import AssistantOutput
|
8 |
| -from quivr_api.modules.assistant.ito.difference import DifferenceAssistant |
9 |
| -from quivr_api.modules.assistant.ito.summary import SummaryAssistant, summary_inputs |
10 |
| -from quivr_api.modules.assistant.service.assistant import Assistant |
| 16 | +from quivr_api.modules.assistant.entity.assistant_entity import ( |
| 17 | + AssistantSettings, |
| 18 | +) |
| 19 | +from quivr_api.modules.assistant.services.tasks_service import TasksService |
| 20 | +from quivr_api.modules.dependencies import get_service |
| 21 | +from quivr_api.modules.upload.service.upload_file import ( |
| 22 | + upload_file_storage, |
| 23 | +) |
11 | 24 | from quivr_api.modules.user.entity.user_identity import UserIdentity
|
12 | 25 |
|
13 |
| -assistant_router = APIRouter() |
14 | 26 | logger = get_logger(__name__)
|
15 | 27 |
|
16 |
| -assistant_service = Assistant() |
| 28 | + |
| 29 | +assistant_router = APIRouter() |
| 30 | + |
| 31 | + |
| 32 | +TasksServiceDep = Annotated[TasksService, Depends(get_service(TasksService))] |
| 33 | +UserIdentityDep = Annotated[UserIdentity, Depends(get_current_user)] |
17 | 34 |
|
18 | 35 |
|
19 | 36 | @assistant_router.get(
|
20 | 37 | "/assistants", dependencies=[Depends(AuthBearer())], tags=["Assistant"]
|
21 | 38 | )
|
22 |
| -async def list_assistants( |
| 39 | +async def get_assistants( |
| 40 | + request: Request, |
23 | 41 | current_user: UserIdentity = Depends(get_current_user),
|
24 | 42 | ) -> List[AssistantOutput]:
|
25 |
| - """ |
26 |
| - Retrieve and list all the knowledge in a brain. |
27 |
| - """ |
| 43 | + logger.info("Getting assistants") |
| 44 | + |
| 45 | + return assistants |
28 | 46 |
|
29 |
| - summary = summary_inputs() |
30 |
| - # difference = difference_inputs() |
31 |
| - # crawler = crawler_inputs() |
32 |
| - return [summary] |
| 47 | + |
| 48 | +@assistant_router.get( |
| 49 | + "/assistants/tasks", dependencies=[Depends(AuthBearer())], tags=["Assistant"] |
| 50 | +) |
| 51 | +async def get_tasks( |
| 52 | + request: Request, |
| 53 | + current_user: UserIdentityDep, |
| 54 | + tasks_service: TasksServiceDep, |
| 55 | +): |
| 56 | + logger.info("Getting tasks") |
| 57 | + return await tasks_service.get_tasks_by_user_id(current_user.id) |
33 | 58 |
|
34 | 59 |
|
35 | 60 | @assistant_router.post(
|
36 |
| - "/assistant/process", |
37 |
| - dependencies=[Depends(AuthBearer())], |
38 |
| - tags=["Assistant"], |
| 61 | + "/assistants/task", dependencies=[Depends(AuthBearer())], tags=["Assistant"] |
39 | 62 | )
|
40 |
| -async def process_assistant( |
| 63 | +async def create_task( |
| 64 | + current_user: UserIdentityDep, |
| 65 | + tasks_service: TasksServiceDep, |
| 66 | + request: Request, |
41 | 67 | input: InputAssistant,
|
42 | 68 | files: List[UploadFile] = None,
|
43 |
| - current_user: UserIdentity = Depends(get_current_user), |
44 | 69 | ):
|
45 |
| - if input.name.lower() == "summary": |
46 |
| - summary_assistant = SummaryAssistant( |
47 |
| - input=input, files=files, current_user=current_user |
48 |
| - ) |
49 |
| - try: |
50 |
| - summary_assistant.check_input() |
51 |
| - return await summary_assistant.process_assistant() |
52 |
| - except ValueError as e: |
53 |
| - raise HTTPException(status_code=400, detail=str(e)) |
54 |
| - elif input.name.lower() == "difference": |
55 |
| - difference_assistant = DifferenceAssistant( |
56 |
| - input=input, files=files, current_user=current_user |
57 |
| - ) |
| 70 | + assistant = next( |
| 71 | + (assistant for assistant in assistants if assistant.id == input.id), None |
| 72 | + ) |
| 73 | + if assistant is None: |
| 74 | + raise HTTPException(status_code=404, detail="Assistant not found") |
| 75 | + |
| 76 | + is_valid, validation_errors = validate_assistant_input(input, assistant) |
| 77 | + if not is_valid: |
| 78 | + for error in validation_errors: |
| 79 | + print(error) |
| 80 | + raise HTTPException(status_code=400, detail=error) |
| 81 | + else: |
| 82 | + print("Assistant input is valid.") |
| 83 | + notification_uuid = uuid4() |
| 84 | + |
| 85 | + # Process files dynamically |
| 86 | + for upload_file in files: |
| 87 | + file_name_path = f"{input.id}/{notification_uuid}/{upload_file.filename}" |
| 88 | + buff_reader = io.BufferedReader(upload_file.file) # type: ignore |
58 | 89 | try:
|
59 |
| - difference_assistant.check_input() |
60 |
| - return await difference_assistant.process_assistant() |
61 |
| - except ValueError as e: |
62 |
| - raise HTTPException(status_code=400, detail=str(e)) |
63 |
| - return {"message": "Assistant not found"} |
| 90 | + await upload_file_storage(buff_reader, file_name_path) |
| 91 | + except Exception as e: |
| 92 | + logger.exception(f"Exception in upload_route {e}") |
| 93 | + raise HTTPException( |
| 94 | + status_code=500, detail=f"Failed to upload file to storage. {e}" |
| 95 | + ) |
| 96 | + |
| 97 | + task = CreateTask( |
| 98 | + assistant_id=input.id, |
| 99 | + pretty_id=str(notification_uuid), |
| 100 | + settings=input.model_dump(mode="json"), |
| 101 | + ) |
| 102 | + |
| 103 | + task_created = await tasks_service.create_task(task, current_user.id) |
| 104 | + |
| 105 | + celery.send_task( |
| 106 | + "process_assistant_task", |
| 107 | + kwargs={ |
| 108 | + "assistant_id": input.id, |
| 109 | + "notification_uuid": notification_uuid, |
| 110 | + "task_id": task_created.id, |
| 111 | + "user_id": str(current_user.id), |
| 112 | + }, |
| 113 | + ) |
| 114 | + return task_created |
| 115 | + |
| 116 | + |
| 117 | +@assistant_router.get( |
| 118 | + "/assistants/task/{task_id}", |
| 119 | + dependencies=[Depends(AuthBearer())], |
| 120 | + tags=["Assistant"], |
| 121 | +) |
| 122 | +async def get_task( |
| 123 | + request: Request, |
| 124 | + task_id: str, |
| 125 | + current_user: UserIdentityDep, |
| 126 | + tasks_service: TasksServiceDep, |
| 127 | +): |
| 128 | + return await tasks_service.get_task_by_id(task_id, current_user.id) # type: ignore |
| 129 | + |
| 130 | + |
| 131 | +@assistant_router.delete( |
| 132 | + "/assistants/task/{task_id}", |
| 133 | + dependencies=[Depends(AuthBearer())], |
| 134 | + tags=["Assistant"], |
| 135 | +) |
| 136 | +async def delete_task( |
| 137 | + request: Request, |
| 138 | + task_id: int, |
| 139 | + current_user: UserIdentityDep, |
| 140 | + tasks_service: TasksServiceDep, |
| 141 | +): |
| 142 | + return await tasks_service.delete_task(task_id, current_user.id) |
| 143 | + |
| 144 | + |
| 145 | +@assistant_router.get( |
| 146 | + "/assistants/task/{task_id}/download", |
| 147 | + dependencies=[Depends(AuthBearer())], |
| 148 | + tags=["Assistant"], |
| 149 | +) |
| 150 | +async def get_download_link_task( |
| 151 | + request: Request, |
| 152 | + task_id: int, |
| 153 | + current_user: UserIdentityDep, |
| 154 | + tasks_service: TasksServiceDep, |
| 155 | +): |
| 156 | + return await tasks_service.get_download_link_task(task_id, current_user.id) |
| 157 | + |
| 158 | + |
| 159 | +@assistant_router.get( |
| 160 | + "/assistants/{assistant_id}/config", |
| 161 | + dependencies=[Depends(AuthBearer())], |
| 162 | + tags=["Assistant"], |
| 163 | + response_model=AssistantSettings, |
| 164 | + summary="Retrieve assistant configuration", |
| 165 | + description="Get the settings and file requirements for the specified assistant.", |
| 166 | +) |
| 167 | +async def get_assistant_config( |
| 168 | + assistant_id: int, |
| 169 | + current_user: UserIdentityDep, |
| 170 | +): |
| 171 | + assistant = next( |
| 172 | + (assistant for assistant in assistants if assistant.id == assistant_id), None |
| 173 | + ) |
| 174 | + if assistant is None: |
| 175 | + raise HTTPException(status_code=404, detail="Assistant not found") |
| 176 | + return assistant.settings |
0 commit comments