-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfastapi.py
222 lines (188 loc) · 7.76 KB
/
fastapi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Depends, Security
from fastapi.responses import JSONResponse
from typing import List
from backend.utils.helper import process_images, save_btn
from backend.utils.web_agent import web_agent_flow
from backend.utils.llms import (
get_bot_response, generate_prompt,
generate_link_prompt, display_model_mapping,
get_model, get_provider, follow_up_Q
)
from backend.utils.db import init_db
from PIL import Image
from fastapi import UploadFile
import datetime
import os
import io
import base64
from dotenv import load_dotenv
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security.api_key import APIKeyHeader
# Load environment variables from api.env
load_dotenv("api.env")
app = FastAPI()
# Configure CORS
origins = [
"http://localhost",
"http://localhost:8501", # Default Streamlit port
"*", # Allow all origins
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize database
conn, c = init_db()
# API Key security
API_KEY = os.getenv("ApiKey")
api_key_header = APIKeyHeader(name="X-API-Key")
def get_api_key(api_key: str = Security(api_key_header)):
if api_key == API_KEY:
return api_key
raise HTTPException(
status_code=401,
detail="Invalid API Key"
)
@app.on_event("shutdown")
async def shutdown_event():
conn.close()
@app.post("/generate-from-images/")
async def generate_from_images(
files: List[UploadFile] = File(...),
tags: List[str] = Form(default=["General"]),
model: str = Form(default="gpt-4o"),
api_key: str = Depends(get_api_key)
):
try:
# Process uploaded files
all_files = []
for file in files:
try:
contents = await file.read()
image_stream = io.BytesIO(contents)
all_files.append(image_stream)
except Exception as file_error:
raise HTTPException(
status_code=400,
detail=f"Error processing file {file.filename}: {str(file_error)}"
)
if not all_files:
raise HTTPException(status_code=400, detail="No valid images provided")
try:
# Get model and provider information
internal_model = get_model(model) or model
provider_name = get_provider(internal_model)
# Validate and normalize tags
valid_tags = ["General", "Coding", "Math", "Student Notes"]
normalized_tags = [tag if tag in valid_tags else "General" for tag in tags]
# Process the images
text_list, processed_images = process_images(all_files)
if not text_list:
raise HTTPException(status_code=400, detail="Could not extract text from images")
# Clean and join the text
text_list = [text for text in text_list if text and text.strip()]
if not text_list:
raise HTTPException(status_code=400, detail="No valid text extracted from images")
combined_text = "\n\n".join(text_list)
prompt = generate_prompt(combined_text, normalized_tags)
bot_response = get_bot_response(prompt, internal_model, provider_name)
if not bot_response:
raise HTTPException(status_code=500, detail="Failed to generate response")
# Store in database
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
base64_images = []
for img in processed_images:
buffered = io.BytesIO()
img.save(buffered, format="JPEG")
img_str = buffered.getvalue()
buffered.close()
# Convert image bytes to Base64
img_base64 = base64.b64encode(img_str).decode('utf-8')
base64_images.append(img_base64)
c.execute(
"INSERT INTO notes (content, image, timestamp) VALUES (?, ?, ?)",
(bot_response, img_str, timestamp)
)
conn.commit()
return JSONResponse(
content={
"message": "Notes generated and saved successfully",
"response": bot_response,
"timestamp": timestamp,
"image_count": len(processed_images),
"images": base64_images # Include Base64 images in response
},
status_code=200
)
except Exception as process_error:
raise HTTPException(
status_code=500,
detail=f"Error processing images: {str(process_error)}"
)
except HTTPException as he:
raise he
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred: {str(e)}"
)
@app.post("/generate-from-link/")
async def generate_from_link(url: str = Form(...), api_key: str = Depends(get_api_key)):
try:
bot_response = web_agent_flow(url)
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
c.execute("INSERT INTO notes (content, timestamp) VALUES (?, ?)", (bot_response, timestamp))
conn.commit()
return JSONResponse(content={"message": "Notes generated and saved.", "response": bot_response})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
@app.get("/notes/")
async def get_notes(api_key: str = Depends(get_api_key)):
try:
c.execute("SELECT id, content, image, timestamp FROM notes")
rows = c.fetchall()
# Convert images to Base64
notes = []
for row in rows:
note_id, content, image_bytes, timestamp = row
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
notes.append({
"id": note_id,
"content": content,
"image": image_base64,
"timestamp": timestamp
})
return JSONResponse(content={"notes": notes})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
@app.delete("/notes/{note_id}")
async def delete_note(note_id: int, api_key: str = Depends(get_api_key)):
try:
c.execute("DELETE FROM notes WHERE id=?", (note_id,))
conn.commit()
return JSONResponse(content={"message": "Note deleted."})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
@app.post("/follow-up-question/")
async def follow_up_question(note_id: int = Form(...), user_prompt: str = Form(...), api_key: str = Depends(get_api_key)):
try:
# Fetch the specific note from the database
c.execute("SELECT content FROM notes WHERE id=?", (note_id,))
row = c.fetchone()
if not row:
raise HTTPException(status_code=404, detail="Note not found.")
note_content = row[0]
# Generate follow-up response using follow_up_Q function
prompt = follow_up_Q(user_prompt, note_content)
follow_up_response = get_bot_response(prompt)
if not follow_up_response:
raise HTTPException(status_code=500, detail="Failed to generate follow-up response.")
# Update the note in the database by appending new content to the old content
c.execute("UPDATE notes SET content = content || ? WHERE id=?", (f"### follow-up response:\n{follow_up_response}", note_id))
conn.commit()
return JSONResponse(content={"follow_up_response": follow_up_response})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {str(e)}")