Skip to content

Commit 95cba5b

Browse files
authored
Merge branch 'main' into reranker
2 parents c1267d3 + 65c64ed commit 95cba5b

File tree

9 files changed

+400
-105
lines changed

9 files changed

+400
-105
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
.git
12
.idea
23
.venv
34
.env

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ The following environment variables are required to run the application:
6464
- `DEBUG_RAG_API`: (Optional) Set to "True" to show more verbose logging output in the server console, and to enable postgresql database routes
6565
- `DEBUG_PGVECTOR_QUERIES`: (Optional) Set to "True" to enable detailed PostgreSQL query logging for pgvector operations. Useful for debugging performance issues with vector database queries.
6666
- `CONSOLE_JSON`: (Optional) Set to "True" to log as json for Cloud Logging aggregations
67-
- `EMBEDDINGS_PROVIDER`: (Optional) either "openai", "bedrock", "azure", "huggingface", "huggingfacetei", "vertexai", or "ollama", where "huggingface" uses sentence_transformers; defaults to "openai"
67+
- `EMBEDDINGS_PROVIDER`: (Optional) either "openai", "bedrock", "azure", "huggingface", "huggingfacetei", "google_genai", "vertexai", or "ollama", where "huggingface" uses sentence_transformers; defaults to "openai"
6868
- `EMBEDDINGS_MODEL`: (Optional) Set a valid embeddings model to use from the configured provider.
6969
- **Defaults**
7070
- openai: "text-embedding-3-small"
@@ -74,6 +74,7 @@ The following environment variables are required to run the application:
7474
- vertexai: "text-embedding-004"
7575
- ollama: "nomic-embed-text"
7676
- bedrock: "amazon.titan-embed-text-v1"
77+
- google_genai: "gemini-embedding-001"
7778
- `RAG_AZURE_OPENAI_API_VERSION`: (Optional) Default is `2023-05-15`. The version of the Azure OpenAI API.
7879
- `RAG_AZURE_OPENAI_API_KEY`: (Optional) The API key for Azure OpenAI service.
7980
- Note: `AZURE_OPENAI_API_KEY` will work but `RAG_AZURE_OPENAI_API_KEY` will override it in order to not conflict with LibreChat setting.
@@ -87,6 +88,7 @@ The following environment variables are required to run the application:
8788
- `AWS_DEFAULT_REGION`: (Optional) defaults to `us-east-1`
8889
- `AWS_ACCESS_KEY_ID`: (Optional) needed for bedrock embeddings
8990
- `AWS_SECRET_ACCESS_KEY`: (Optional) needed for bedrock embeddings
91+
- `GOOGLE_API_KEY`, `GOOGLE_KEY`, `RAG_GOOGLE_API_KEY`: (Optional) Google API key for Google GenAI embeddings. Priority order: RAG_GOOGLE_API_KEY > GOOGLE_KEY > GOOGLE_API_KEY
9092
- `AWS_SESSION_TOKEN`: (Optional) may be needed for bedrock embeddings
9193
- `GOOGLE_APPLICATION_CREDENTIALS`: (Optional) needed for Google VertexAI embeddings. This should be a path to a service account credential file in JSON format, as accepted by [langchain](https://python.langchain.com/api_reference/google_vertexai/index.html)
9294
- `RAG_CHECK_EMBEDDING_CTX_LENGTH` (Optional) Default is true, disabling this will send raw input to the embedder, use this for custom embedding models.
@@ -175,3 +177,4 @@ Run the following commands to install pre-commit formatter, which uses [black](h
175177
pip install pre-commit
176178
pre-commit install
177179
```
180+

app/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class EmbeddingsProvider(Enum):
2626
HUGGINGFACETEI = "huggingfacetei"
2727
OLLAMA = "ollama"
2828
BEDROCK = "bedrock"
29+
GOOGLE_GENAI = "google_genai"
2930
GOOGLE_VERTEXAI = "vertexai"
3031

3132

@@ -186,6 +187,9 @@ async def dispatch(self, request, call_next):
186187
OLLAMA_BASE_URL = get_env_variable("OLLAMA_BASE_URL", "http://ollama:11434")
187188
AWS_ACCESS_KEY_ID = get_env_variable("AWS_ACCESS_KEY_ID", "")
188189
AWS_SECRET_ACCESS_KEY = get_env_variable("AWS_SECRET_ACCESS_KEY", "")
190+
GOOGLE_API_KEY = get_env_variable("GOOGLE_API_KEY", "")
191+
GOOGLE_KEY = get_env_variable("GOOGLE_KEY", GOOGLE_API_KEY)
192+
RAG_GOOGLE_API_KEY = get_env_variable("RAG_GOOGLE_API_KEY", GOOGLE_KEY)
189193
AWS_SESSION_TOKEN = get_env_variable("AWS_SESSION_TOKEN", "")
190194
GOOGLE_APPLICATION_CREDENTIALS = get_env_variable("GOOGLE_APPLICATION_CREDENTIALS", "")
191195
env_value = get_env_variable("RAG_CHECK_EMBEDDING_CTX_LENGTH", "True").lower()
@@ -231,6 +235,13 @@ def init_embeddings(provider, model):
231235
from langchain_ollama import OllamaEmbeddings
232236

233237
return OllamaEmbeddings(model=model, base_url=OLLAMA_BASE_URL)
238+
elif provider == EmbeddingsProvider.GOOGLE_GENAI:
239+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
240+
241+
return GoogleGenerativeAIEmbeddings(
242+
model=model,
243+
google_api_key=RAG_GOOGLE_API_KEY,
244+
)
234245
elif provider == EmbeddingsProvider.GOOGLE_VERTEXAI:
235246
from langchain_google_vertexai import VertexAIEmbeddings
236247

@@ -281,6 +292,8 @@ def init_embeddings(provider, model):
281292
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-004")
282293
elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.OLLAMA:
283294
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "nomic-embed-text")
295+
elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.GOOGLE_GENAI:
296+
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "gemini-embedding-001")
284297
elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.BEDROCK:
285298
EMBEDDINGS_MODEL = get_env_variable(
286299
"EMBEDDINGS_MODEL", "amazon.titan-embed-text-v1"

app/routes/document_routes.py

Lines changed: 172 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,94 @@
4444
router = APIRouter()
4545

4646

47+
def get_user_id(request: Request, entity_id: str = None) -> str:
48+
"""Extract user ID from request or entity_id."""
49+
if not hasattr(request.state, "user"):
50+
return entity_id if entity_id else "public"
51+
else:
52+
return entity_id if entity_id else request.state.user.get("id")
53+
54+
55+
async def save_upload_file_async(file: UploadFile, temp_file_path: str) -> None:
56+
"""Save uploaded file asynchronously."""
57+
try:
58+
async with aiofiles.open(temp_file_path, "wb") as temp_file:
59+
chunk_size = 64 * 1024 # 64 KB
60+
while content := await file.read(chunk_size):
61+
await temp_file.write(content)
62+
except Exception as e:
63+
logger.error(
64+
"Failed to save uploaded file | Path: %s | Error: %s | Traceback: %s",
65+
temp_file_path,
66+
str(e),
67+
traceback.format_exc(),
68+
)
69+
raise HTTPException(
70+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
71+
detail=f"Failed to save the uploaded file. Error: {str(e)}",
72+
)
73+
74+
75+
def save_upload_file_sync(file: UploadFile, temp_file_path: str) -> None:
76+
"""Save uploaded file synchronously."""
77+
try:
78+
with open(temp_file_path, "wb") as temp_file:
79+
copyfileobj(file.file, temp_file)
80+
except Exception as e:
81+
logger.error(
82+
"Failed to save uploaded file | Path: %s | Error: %s | Traceback: %s",
83+
temp_file_path,
84+
str(e),
85+
traceback.format_exc(),
86+
)
87+
raise HTTPException(
88+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
89+
detail=f"Failed to save the uploaded file. Error: {str(e)}",
90+
)
91+
92+
93+
async def load_file_content(
94+
filename: str, content_type: str, file_path: str, executor
95+
) -> tuple:
96+
"""Load file content using appropriate loader."""
97+
loader, known_type, file_ext = get_loader(filename, content_type, file_path)
98+
data = await run_in_executor(executor, loader.load)
99+
100+
# Clean up temporary UTF-8 file if it was created for encoding conversion
101+
cleanup_temp_encoding_file(loader)
102+
103+
return data, known_type, file_ext
104+
105+
106+
def extract_text_from_documents(documents: List[Document], file_ext: str) -> str:
107+
"""Extract text content from loaded documents."""
108+
text_content = ""
109+
if documents:
110+
for doc in documents:
111+
if hasattr(doc, "page_content"):
112+
# Clean text if it's a PDF
113+
if file_ext == "pdf":
114+
text_content += clean_text(doc.page_content) + "\n"
115+
else:
116+
text_content += doc.page_content + "\n"
117+
118+
# Remove trailing newline
119+
return text_content.rstrip("\n")
120+
121+
122+
async def cleanup_temp_file_async(file_path: str) -> None:
123+
"""Clean up temporary file asynchronously."""
124+
try:
125+
await aiofiles.os.remove(file_path)
126+
except Exception as e:
127+
logger.error(
128+
"Failed to remove temporary file | Path: %s | Error: %s | Traceback: %s",
129+
file_path,
130+
str(e),
131+
traceback.format_exc(),
132+
)
133+
134+
47135
@router.get("/ids")
48136
async def get_all_ids(request: Request):
49137
try:
@@ -251,7 +339,12 @@ async def query_embeddings_by_file_id(
251339

252340

253341
def generate_digest(page_content: str):
254-
hash_obj = hashlib.md5(page_content.encode())
342+
try:
343+
hash_obj = hashlib.md5(page_content.encode("utf-8"))
344+
except UnicodeEncodeError:
345+
hash_obj = hashlib.md5(
346+
page_content.encode("utf-8", "ignore").decode("utf-8").encode("utf-8")
347+
)
255348
return hash_obj.hexdigest()
256349

257350

@@ -383,40 +476,21 @@ async def embed_file(
383476
response_status = True
384477
response_message = "File processed successfully."
385478
known_type = None
386-
if not hasattr(request.state, "user"):
387-
user_id = entity_id if entity_id else "public"
388-
else:
389-
user_id = entity_id if entity_id else request.state.user.get("id")
390479

480+
user_id = get_user_id(request, entity_id)
391481
temp_base_path = os.path.join(RAG_UPLOAD_DIR, user_id)
392482
os.makedirs(temp_base_path, exist_ok=True)
393483
temp_file_path = os.path.join(RAG_UPLOAD_DIR, user_id, file.filename)
394484

395-
try:
396-
async with aiofiles.open(temp_file_path, "wb") as temp_file:
397-
chunk_size = 64 * 1024 # 64 KB
398-
while content := await file.read(chunk_size):
399-
await temp_file.write(content)
400-
except Exception as e:
401-
logger.error(
402-
"Failed to save uploaded file | Path: %s | Error: %s | Traceback: %s",
403-
temp_file_path,
404-
str(e),
405-
traceback.format_exc(),
406-
)
407-
raise HTTPException(
408-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
409-
detail=f"Failed to save the uploaded file. Error: {str(e)}",
410-
)
485+
await save_upload_file_async(file, temp_file_path)
411486

412487
try:
413-
loader, known_type, file_ext = get_loader(
414-
file.filename, file.content_type, temp_file_path
488+
data, known_type, file_ext = await load_file_content(
489+
file.filename,
490+
file.content_type,
491+
temp_file_path,
492+
request.app.state.thread_pool,
415493
)
416-
data = await run_in_executor(request.app.state.thread_pool, loader.load)
417-
418-
# Clean up temporary UTF-8 file if it was created for encoding conversion
419-
cleanup_temp_encoding_file(loader)
420494

421495
result = await store_data_in_vector_db(
422496
data=data,
@@ -465,15 +539,7 @@ async def embed_file(
465539
detail=f"Error during file processing: {str(e)}",
466540
)
467541
finally:
468-
try:
469-
await aiofiles.os.remove(temp_file_path)
470-
except Exception as e:
471-
logger.error(
472-
"Failed to remove temporary file | Path: %s | Error: %s | Traceback: %s",
473-
temp_file_path,
474-
str(e),
475-
traceback.format_exc(),
476-
)
542+
await cleanup_temp_file_async(temp_file_path)
477543

478544
return {
479545
"status": response_status,
@@ -539,32 +605,19 @@ async def embed_file_upload(
539605
uploaded_file: UploadFile = File(...),
540606
entity_id: str = Form(None),
541607
):
608+
user_id = get_user_id(request, entity_id)
542609
temp_file_path = os.path.join(RAG_UPLOAD_DIR, uploaded_file.filename)
543610

544-
if not hasattr(request.state, "user"):
545-
user_id = entity_id if entity_id else "public"
546-
else:
547-
user_id = entity_id if entity_id else request.state.user.get("id")
548-
549-
try:
550-
with open(temp_file_path, "wb") as temp_file:
551-
copyfileobj(uploaded_file.file, temp_file)
552-
except Exception as e:
553-
raise HTTPException(
554-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
555-
detail=f"Failed to save the uploaded file. Error: {str(e)}",
556-
)
611+
save_upload_file_sync(uploaded_file, temp_file_path)
557612

558613
try:
559-
loader, known_type, file_ext = get_loader(
560-
uploaded_file.filename, uploaded_file.content_type, temp_file_path
614+
data, known_type, file_ext = await load_file_content(
615+
uploaded_file.filename,
616+
uploaded_file.content_type,
617+
temp_file_path,
618+
request.app.state.thread_pool,
561619
)
562620

563-
data = await run_in_executor(request.app.state.thread_pool, loader.load)
564-
565-
# Clean up temporary UTF-8 file if it was created for encoding conversion
566-
cleanup_temp_encoding_file(loader)
567-
568621
result = await store_data_in_vector_db(
569622
data,
570623
file_id,
@@ -651,7 +704,6 @@ async def query_embeddings_by_file_ids(request: Request, body: QueryMultipleBody
651704
)
652705
raise HTTPException(status_code=500, detail=str(e))
653706

654-
655707
@router.post("/rerank")
656708
async def rerank_documents_by_query(request: Request, body: QueryMultipleDocs):
657709
try:
@@ -690,3 +742,66 @@ async def rerank_documents_by_query(request: Request, body: QueryMultipleDocs):
690742
traceback.format_exc(),
691743
)
692744
raise HTTPException(status_code=500, detail=str(e))
745+
746+
@router.post("/text")
747+
async def extract_text_from_file(
748+
request: Request,
749+
file_id: str = Form(...),
750+
file: UploadFile = File(...),
751+
entity_id: str = Form(None),
752+
):
753+
"""
754+
Extract text content from an uploaded file without creating embeddings.
755+
Returns the raw text content for text parsing purposes.
756+
"""
757+
user_id = get_user_id(request, entity_id)
758+
temp_base_path = os.path.join(RAG_UPLOAD_DIR, user_id)
759+
os.makedirs(temp_base_path, exist_ok=True)
760+
temp_file_path = os.path.join(RAG_UPLOAD_DIR, user_id, file.filename)
761+
762+
await save_upload_file_async(file, temp_file_path)
763+
764+
try:
765+
data, known_type, file_ext = await load_file_content(
766+
file.filename,
767+
file.content_type,
768+
temp_file_path,
769+
request.app.state.thread_pool,
770+
)
771+
772+
# Extract text content from loaded documents
773+
text_content = extract_text_from_documents(data, file_ext)
774+
775+
return {
776+
"text": text_content,
777+
"file_id": file_id,
778+
"filename": file.filename,
779+
"known_type": known_type,
780+
}
781+
782+
except HTTPException as http_exc:
783+
logger.error(
784+
"HTTP Exception in extract_text_from_file | Status: %d | Detail: %s",
785+
http_exc.status_code,
786+
http_exc.detail,
787+
)
788+
raise http_exc
789+
except Exception as e:
790+
logger.error(
791+
"Error during text extraction | File: %s | Error: %s | Traceback: %s",
792+
file.filename,
793+
str(e),
794+
traceback.format_exc(),
795+
)
796+
if "No pandoc was found" in str(e):
797+
raise HTTPException(
798+
status_code=status.HTTP_400_BAD_REQUEST,
799+
detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
800+
)
801+
else:
802+
raise HTTPException(
803+
status_code=status.HTTP_400_BAD_REQUEST,
804+
detail=f"Error during text extraction: {str(e)}",
805+
)
806+
finally:
807+
await cleanup_temp_file_async(temp_file_path)

0 commit comments

Comments
 (0)