|
44 | 44 | router = APIRouter() |
45 | 45 |
|
46 | 46 |
|
| 47 | +def get_user_id(request: Request, entity_id: str = None) -> str: |
| 48 | + """Extract user ID from request or entity_id.""" |
| 49 | + if not hasattr(request.state, "user"): |
| 50 | + return entity_id if entity_id else "public" |
| 51 | + else: |
| 52 | + return entity_id if entity_id else request.state.user.get("id") |
| 53 | + |
| 54 | + |
| 55 | +async def save_upload_file_async(file: UploadFile, temp_file_path: str) -> None: |
| 56 | + """Save uploaded file asynchronously.""" |
| 57 | + try: |
| 58 | + async with aiofiles.open(temp_file_path, "wb") as temp_file: |
| 59 | + chunk_size = 64 * 1024 # 64 KB |
| 60 | + while content := await file.read(chunk_size): |
| 61 | + await temp_file.write(content) |
| 62 | + except Exception as e: |
| 63 | + logger.error( |
| 64 | + "Failed to save uploaded file | Path: %s | Error: %s | Traceback: %s", |
| 65 | + temp_file_path, |
| 66 | + str(e), |
| 67 | + traceback.format_exc(), |
| 68 | + ) |
| 69 | + raise HTTPException( |
| 70 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 71 | + detail=f"Failed to save the uploaded file. Error: {str(e)}", |
| 72 | + ) |
| 73 | + |
| 74 | + |
| 75 | +def save_upload_file_sync(file: UploadFile, temp_file_path: str) -> None: |
| 76 | + """Save uploaded file synchronously.""" |
| 77 | + try: |
| 78 | + with open(temp_file_path, "wb") as temp_file: |
| 79 | + copyfileobj(file.file, temp_file) |
| 80 | + except Exception as e: |
| 81 | + logger.error( |
| 82 | + "Failed to save uploaded file | Path: %s | Error: %s | Traceback: %s", |
| 83 | + temp_file_path, |
| 84 | + str(e), |
| 85 | + traceback.format_exc(), |
| 86 | + ) |
| 87 | + raise HTTPException( |
| 88 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 89 | + detail=f"Failed to save the uploaded file. Error: {str(e)}", |
| 90 | + ) |
| 91 | + |
| 92 | + |
| 93 | +async def load_file_content( |
| 94 | + filename: str, content_type: str, file_path: str, executor |
| 95 | +) -> tuple: |
| 96 | + """Load file content using appropriate loader.""" |
| 97 | + loader, known_type, file_ext = get_loader(filename, content_type, file_path) |
| 98 | + data = await run_in_executor(executor, loader.load) |
| 99 | + |
| 100 | + # Clean up temporary UTF-8 file if it was created for encoding conversion |
| 101 | + cleanup_temp_encoding_file(loader) |
| 102 | + |
| 103 | + return data, known_type, file_ext |
| 104 | + |
| 105 | + |
| 106 | +def extract_text_from_documents(documents: List[Document], file_ext: str) -> str: |
| 107 | + """Extract text content from loaded documents.""" |
| 108 | + text_content = "" |
| 109 | + if documents: |
| 110 | + for doc in documents: |
| 111 | + if hasattr(doc, "page_content"): |
| 112 | + # Clean text if it's a PDF |
| 113 | + if file_ext == "pdf": |
| 114 | + text_content += clean_text(doc.page_content) + "\n" |
| 115 | + else: |
| 116 | + text_content += doc.page_content + "\n" |
| 117 | + |
| 118 | + # Remove trailing newline |
| 119 | + return text_content.rstrip("\n") |
| 120 | + |
| 121 | + |
| 122 | +async def cleanup_temp_file_async(file_path: str) -> None: |
| 123 | + """Clean up temporary file asynchronously.""" |
| 124 | + try: |
| 125 | + await aiofiles.os.remove(file_path) |
| 126 | + except Exception as e: |
| 127 | + logger.error( |
| 128 | + "Failed to remove temporary file | Path: %s | Error: %s | Traceback: %s", |
| 129 | + file_path, |
| 130 | + str(e), |
| 131 | + traceback.format_exc(), |
| 132 | + ) |
| 133 | + |
| 134 | + |
47 | 135 | @router.get("/ids") |
48 | 136 | async def get_all_ids(request: Request): |
49 | 137 | try: |
@@ -251,7 +339,12 @@ async def query_embeddings_by_file_id( |
251 | 339 |
|
252 | 340 |
|
253 | 341 | def generate_digest(page_content: str): |
254 | | - hash_obj = hashlib.md5(page_content.encode()) |
| 342 | + try: |
| 343 | + hash_obj = hashlib.md5(page_content.encode("utf-8")) |
| 344 | + except UnicodeEncodeError: |
| 345 | + hash_obj = hashlib.md5( |
| 346 | + page_content.encode("utf-8", "ignore").decode("utf-8").encode("utf-8") |
| 347 | + ) |
255 | 348 | return hash_obj.hexdigest() |
256 | 349 |
|
257 | 350 |
|
@@ -383,40 +476,21 @@ async def embed_file( |
383 | 476 | response_status = True |
384 | 477 | response_message = "File processed successfully." |
385 | 478 | known_type = None |
386 | | - if not hasattr(request.state, "user"): |
387 | | - user_id = entity_id if entity_id else "public" |
388 | | - else: |
389 | | - user_id = entity_id if entity_id else request.state.user.get("id") |
390 | 479 |
|
| 480 | + user_id = get_user_id(request, entity_id) |
391 | 481 | temp_base_path = os.path.join(RAG_UPLOAD_DIR, user_id) |
392 | 482 | os.makedirs(temp_base_path, exist_ok=True) |
393 | 483 | temp_file_path = os.path.join(RAG_UPLOAD_DIR, user_id, file.filename) |
394 | 484 |
|
395 | | - try: |
396 | | - async with aiofiles.open(temp_file_path, "wb") as temp_file: |
397 | | - chunk_size = 64 * 1024 # 64 KB |
398 | | - while content := await file.read(chunk_size): |
399 | | - await temp_file.write(content) |
400 | | - except Exception as e: |
401 | | - logger.error( |
402 | | - "Failed to save uploaded file | Path: %s | Error: %s | Traceback: %s", |
403 | | - temp_file_path, |
404 | | - str(e), |
405 | | - traceback.format_exc(), |
406 | | - ) |
407 | | - raise HTTPException( |
408 | | - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
409 | | - detail=f"Failed to save the uploaded file. Error: {str(e)}", |
410 | | - ) |
| 485 | + await save_upload_file_async(file, temp_file_path) |
411 | 486 |
|
412 | 487 | try: |
413 | | - loader, known_type, file_ext = get_loader( |
414 | | - file.filename, file.content_type, temp_file_path |
| 488 | + data, known_type, file_ext = await load_file_content( |
| 489 | + file.filename, |
| 490 | + file.content_type, |
| 491 | + temp_file_path, |
| 492 | + request.app.state.thread_pool, |
415 | 493 | ) |
416 | | - data = await run_in_executor(request.app.state.thread_pool, loader.load) |
417 | | - |
418 | | - # Clean up temporary UTF-8 file if it was created for encoding conversion |
419 | | - cleanup_temp_encoding_file(loader) |
420 | 494 |
|
421 | 495 | result = await store_data_in_vector_db( |
422 | 496 | data=data, |
@@ -465,15 +539,7 @@ async def embed_file( |
465 | 539 | detail=f"Error during file processing: {str(e)}", |
466 | 540 | ) |
467 | 541 | finally: |
468 | | - try: |
469 | | - await aiofiles.os.remove(temp_file_path) |
470 | | - except Exception as e: |
471 | | - logger.error( |
472 | | - "Failed to remove temporary file | Path: %s | Error: %s | Traceback: %s", |
473 | | - temp_file_path, |
474 | | - str(e), |
475 | | - traceback.format_exc(), |
476 | | - ) |
| 542 | + await cleanup_temp_file_async(temp_file_path) |
477 | 543 |
|
478 | 544 | return { |
479 | 545 | "status": response_status, |
@@ -539,32 +605,19 @@ async def embed_file_upload( |
539 | 605 | uploaded_file: UploadFile = File(...), |
540 | 606 | entity_id: str = Form(None), |
541 | 607 | ): |
| 608 | + user_id = get_user_id(request, entity_id) |
542 | 609 | temp_file_path = os.path.join(RAG_UPLOAD_DIR, uploaded_file.filename) |
543 | 610 |
|
544 | | - if not hasattr(request.state, "user"): |
545 | | - user_id = entity_id if entity_id else "public" |
546 | | - else: |
547 | | - user_id = entity_id if entity_id else request.state.user.get("id") |
548 | | - |
549 | | - try: |
550 | | - with open(temp_file_path, "wb") as temp_file: |
551 | | - copyfileobj(uploaded_file.file, temp_file) |
552 | | - except Exception as e: |
553 | | - raise HTTPException( |
554 | | - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
555 | | - detail=f"Failed to save the uploaded file. Error: {str(e)}", |
556 | | - ) |
| 611 | + save_upload_file_sync(uploaded_file, temp_file_path) |
557 | 612 |
|
558 | 613 | try: |
559 | | - loader, known_type, file_ext = get_loader( |
560 | | - uploaded_file.filename, uploaded_file.content_type, temp_file_path |
| 614 | + data, known_type, file_ext = await load_file_content( |
| 615 | + uploaded_file.filename, |
| 616 | + uploaded_file.content_type, |
| 617 | + temp_file_path, |
| 618 | + request.app.state.thread_pool, |
561 | 619 | ) |
562 | 620 |
|
563 | | - data = await run_in_executor(request.app.state.thread_pool, loader.load) |
564 | | - |
565 | | - # Clean up temporary UTF-8 file if it was created for encoding conversion |
566 | | - cleanup_temp_encoding_file(loader) |
567 | | - |
568 | 621 | result = await store_data_in_vector_db( |
569 | 622 | data, |
570 | 623 | file_id, |
@@ -651,7 +704,6 @@ async def query_embeddings_by_file_ids(request: Request, body: QueryMultipleBody |
651 | 704 | ) |
652 | 705 | raise HTTPException(status_code=500, detail=str(e)) |
653 | 706 |
|
654 | | - |
655 | 707 | @router.post("/rerank") |
656 | 708 | async def rerank_documents_by_query(request: Request, body: QueryMultipleDocs): |
657 | 709 | try: |
@@ -690,3 +742,66 @@ async def rerank_documents_by_query(request: Request, body: QueryMultipleDocs): |
690 | 742 | traceback.format_exc(), |
691 | 743 | ) |
692 | 744 | raise HTTPException(status_code=500, detail=str(e)) |
| 745 | + |
| 746 | +@router.post("/text") |
| 747 | +async def extract_text_from_file( |
| 748 | + request: Request, |
| 749 | + file_id: str = Form(...), |
| 750 | + file: UploadFile = File(...), |
| 751 | + entity_id: str = Form(None), |
| 752 | +): |
| 753 | + """ |
| 754 | + Extract text content from an uploaded file without creating embeddings. |
| 755 | + Returns the raw text content for text parsing purposes. |
| 756 | + """ |
| 757 | + user_id = get_user_id(request, entity_id) |
| 758 | + temp_base_path = os.path.join(RAG_UPLOAD_DIR, user_id) |
| 759 | + os.makedirs(temp_base_path, exist_ok=True) |
| 760 | + temp_file_path = os.path.join(RAG_UPLOAD_DIR, user_id, file.filename) |
| 761 | + |
| 762 | + await save_upload_file_async(file, temp_file_path) |
| 763 | + |
| 764 | + try: |
| 765 | + data, known_type, file_ext = await load_file_content( |
| 766 | + file.filename, |
| 767 | + file.content_type, |
| 768 | + temp_file_path, |
| 769 | + request.app.state.thread_pool, |
| 770 | + ) |
| 771 | + |
| 772 | + # Extract text content from loaded documents |
| 773 | + text_content = extract_text_from_documents(data, file_ext) |
| 774 | + |
| 775 | + return { |
| 776 | + "text": text_content, |
| 777 | + "file_id": file_id, |
| 778 | + "filename": file.filename, |
| 779 | + "known_type": known_type, |
| 780 | + } |
| 781 | + |
| 782 | + except HTTPException as http_exc: |
| 783 | + logger.error( |
| 784 | + "HTTP Exception in extract_text_from_file | Status: %d | Detail: %s", |
| 785 | + http_exc.status_code, |
| 786 | + http_exc.detail, |
| 787 | + ) |
| 788 | + raise http_exc |
| 789 | + except Exception as e: |
| 790 | + logger.error( |
| 791 | + "Error during text extraction | File: %s | Error: %s | Traceback: %s", |
| 792 | + file.filename, |
| 793 | + str(e), |
| 794 | + traceback.format_exc(), |
| 795 | + ) |
| 796 | + if "No pandoc was found" in str(e): |
| 797 | + raise HTTPException( |
| 798 | + status_code=status.HTTP_400_BAD_REQUEST, |
| 799 | + detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, |
| 800 | + ) |
| 801 | + else: |
| 802 | + raise HTTPException( |
| 803 | + status_code=status.HTTP_400_BAD_REQUEST, |
| 804 | + detail=f"Error during text extraction: {str(e)}", |
| 805 | + ) |
| 806 | + finally: |
| 807 | + await cleanup_temp_file_async(temp_file_path) |
0 commit comments