Skip to content

Commit 82a9a95

Browse files
authored
Removed DETA db support, due to shutdown of the app (#12)
* Fixed bug * enh txt * Update app.py * Update app.py * Update app.py * Update model.py * Update app.py * Refactor app.py and frontend.py for improved code organization and readability
1 parent e7218c5 commit 82a9a95

File tree

4 files changed

+119
-157
lines changed

4 files changed

+119
-157
lines changed

requirements.txt

+5-5
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ h11==0.14.0
2525
httpcore==0.17.3
2626
httpx==0.24.1
2727
httpx-oauth==0.13.0
28-
huggingface-hub==0.24.2
28+
huggingface-hub==0.23.4
2929
idna==3.7
3030
importlib-metadata==6.11.0
3131
iniconfig==2.0.0
@@ -35,9 +35,9 @@ jsonpatch==1.33
3535
jsonpointer==3.0.0
3636
jsonschema==4.23.0
3737
jsonschema-specifications==2023.12.1
38-
langchain==0.2.11
39-
langchain-community==0.2.10
40-
langchain-core==0.2.24
38+
langchain==0.2.14
39+
langchain-core==0.2.32
40+
langchain-community>=0.0.37
4141
langchain-huggingface==0.0.3
4242
langchain-text-splitters==0.2.2
4343
langsmith==0.1.93
@@ -93,7 +93,7 @@ six==1.16.0
9393
smmap==5.0.1
9494
sniffio==1.3.1
9595
SQLAlchemy==2.0.31
96-
streamlit==1.28.0
96+
streamlit==1.36.0
9797
streamlit-oauth==0.1.5
9898
sympy==1.13.1
9999
tenacity==8.5.0

src/app.py

+75-115
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from langchain_core.output_parsers import StrOutputParser
22
from langchain_core.prompts import PromptTemplate
33
import streamlit as st
4-
from deta import Deta
54
import sys
65
import os
6+
import json
77
from backend import (
88
configure_page_styles,
9-
create_oauth2_component,
109
display_github_badge,
11-
handle_google_login_if_needed,
1210
hide_main_menu_and_footer,
1311
)
1412
from frontend import (
@@ -19,129 +17,91 @@
1917
handle_new_chat,
2018
)
2119
from model import create_huggingface_hub
22-
23-
2420
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
25-
2621
from src.auth import *
2722
from src.constant import *
2823

24+
def format_chat_history(messages):
25+
"""Format the chat history as a structured JSON string."""
26+
history = []
27+
for msg in messages[1:]:
28+
content = msg['content']
29+
if '```sql' in content:
30+
content = content.replace('```sql\n', '').replace('\n```', '').strip()
31+
32+
history.append({
33+
"role": msg['role'],
34+
"query" if msg['role'] == 'user' else "response": content
35+
})
36+
37+
formatted_history = json.dumps(history, indent=2)
38+
print("Formatted history:", formatted_history)
39+
return formatted_history
40+
41+
def extract_sql_code(response):
42+
"""Extract clean SQL code from the response."""
43+
sql_code_start = response.find("```sql")
44+
if sql_code_start != -1:
45+
sql_code_end = response.find("```", sql_code_start + 5)
46+
if sql_code_end != -1:
47+
sql_code = response[sql_code_start + 6:sql_code_end].strip()
48+
return f"```sql\n{sql_code}\n```"
49+
return response
2950

3051
def main():
3152
"""Main function to configure and run the Querypls application."""
3253
configure_page_styles("static/css/styles.css")
33-
deta = Deta(DETA_PROJECT_KEY)
54+
3455
if "model" not in st.session_state:
3556
llm = create_huggingface_hub()
3657
st.session_state["model"] = llm
37-
db = deta.Base("users")
38-
oauth2 = create_oauth2_component()
39-
40-
if "code" not in st.session_state or not st.session_state.code:
41-
st.session_state.code = False
42-
43-
if "code" not in st.session_state:
44-
st.session_state.code = False
45-
58+
59+
if "messages" not in st.session_state:
60+
create_message()
61+
4662
hide_main_menu_and_footer()
47-
if st.session_state.code == False:
48-
col1, col2, col3 = st.columns(3)
49-
with col1:
50-
pass
51-
with col2:
52-
with st.container():
53-
54-
display_github_badge()
55-
display_logo_and_heading()
56-
57-
st.markdown("`Made with 🤍`")
58-
if "token" not in st.session_state:
59-
result = oauth2.authorize_button(
60-
"Connect with Google",
61-
REDIRECT_URI,
62-
SCOPE,
63-
icon="data:image/svg+xml;charset=utf-8,%3Csvg \
64-
xmlns='http://www.w3.org/2000/svg' \
65-
xmlns:xlink='http://www.w3.org/1999/xlink' \
66-
viewBox='0 0 48 48'%3E%3Cdefs%3E%3Cpath id='a' \
67-
d='M44.5 20H24v8.5h11.8C34.7 33.9 30.1 37 24 37c-7.2 \
68-
0-13-5.8-13-13s5.8-13 13-13c3.1 0 5.9 1.1 8.1 \
69-
2.9l6.4-6.4C34.6 4.1 29.6 2 24 2 11.8 2 2 11.8 2 \
70-
24s9.8 22 22 22c11 0 21-8 21-22 \
71-
0-1.3-.2-2.7-.5-4z'/%3E%3C/defs%3E%3CclipPath \
72-
id='b'%3E%3Cuse xlink:href='%23a' \
73-
overflow='visible'/%3E%3C/clipPath%3E%3Cpath \
74-
clip-path='url(%23b)' fill='%23FBBC05' \
75-
d='M0 37V11l17 13z'/%3E%3Cpath clip-path='url(%23b)' \
76-
fill='%23EA4335' d='M0 11l17 13 7-6.1L48 \
77-
14V0H0z'/%3E%3Cpath clip-path='url(%23b)' \
78-
fill='%2334A853' d='M0 37l30-23 7.9 1L48 \
79-
0v48H0z'/%3E%3Cpath clip-path='url(%23b)' \
80-
fill='%234285F4' d='M48 48L17 24l-4-3 \
81-
35-10z'/%3E%3C/svg%3E",
82-
use_container_width=True,
83-
)
84-
handle_google_login_if_needed(result)
85-
if st.session_state.code:
86-
st.rerun()
87-
with col3:
88-
pass
89-
else:
90-
with st.sidebar:
91-
display_github_badge()
92-
display_logo_and_heading()
93-
st.markdown("`Made with 🤍`")
94-
if st.session_state.code:
95-
handle_new_chat(db)
96-
if st.session_state.code:
97-
display_previous_chats(db)
98-
99-
if "messages" not in st.session_state:
100-
create_message()
101-
display_welcome_message()
102-
103-
for message in st.session_state.messages:
104-
with st.chat_message(message["role"]):
105-
st.markdown(message["content"], unsafe_allow_html=True)
106-
107-
if prompt := st.chat_input(disabled=(st.session_state.code is False)):
108-
st.session_state.messages.append(
109-
{"role": "user", "content": prompt}
110-
)
111-
with st.chat_message("user"):
112-
st.write(prompt)
113-
114-
prompt_template = PromptTemplate(
115-
template=TEMPLATE, input_variables=["question"]
116-
)
117-
118-
if "model" in st.session_state:
119-
llm_chain = (
120-
prompt_template
121-
| st.session_state.model
122-
| StrOutputParser()
123-
)
124-
if st.session_state.messages[-1]["role"] != "assistant":
125-
with st.chat_message("assistant"):
126-
with st.spinner("Generating..."):
127-
response = llm_chain.invoke(prompt)
128-
import re
129-
130-
code_block_match = re.search(
131-
r"```sql(.*?)```", response, re.DOTALL
132-
)
133-
if code_block_match:
134-
code_block = code_block_match.group(1)
135-
st.markdown(
136-
f"```sql\n{code_block}\n```",
137-
unsafe_allow_html=True,
138-
)
139-
message = {
140-
"role": "assistant",
141-
"content": f"```sql\n{code_block}\n```",
142-
}
143-
st.session_state.messages.append(message)
144-
63+
64+
with st.sidebar:
65+
display_github_badge()
66+
display_logo_and_heading()
67+
st.markdown("`Made with 🤍`")
68+
handle_new_chat()
69+
70+
display_welcome_message()
71+
for message in st.session_state.messages:
72+
with st.chat_message(message["role"]):
73+
st.markdown(message["content"])
74+
75+
if prompt := st.chat_input():
76+
st.session_state.messages.append({"role": "user", "content": prompt})
77+
with st.chat_message("user"):
78+
st.markdown(prompt)
79+
80+
conversation_history = format_chat_history(st.session_state.messages)
81+
prompt_template = PromptTemplate(
82+
template=TEMPLATE,
83+
input_variables=["input", "conversation_history"]
84+
)
85+
86+
if "model" in st.session_state:
87+
llm_chain = prompt_template | st.session_state.model | StrOutputParser()
88+
89+
with st.chat_message("assistant"):
90+
with st.spinner("Generating..."):
91+
response = llm_chain.invoke({
92+
"input": prompt,
93+
"conversation_history": conversation_history
94+
})
95+
96+
# Clean and format the response
97+
formatted_response = extract_sql_code(response)
98+
st.markdown(formatted_response)
99+
100+
# Add to chat history
101+
st.session_state.messages.append({
102+
"role": "assistant",
103+
"content": formatted_response
104+
})
145105

146106
if __name__ == "__main__":
147-
main()
107+
main()

src/frontend.py

+38-36
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import streamlit as st
2-
from src.database import database, get_previous_chats
3-
42

53
def display_logo_and_heading():
64
"""Displays the Querypls logo."""
@@ -14,77 +12,81 @@ def display_welcome_message():
1412
st.markdown(f"#### Welcome to \n ## 🗃️💬Querypls - Prompt to SQL")
1513

1614

17-
def handle_new_chat(db, max_chat_histories=5):
15+
def handle_new_chat(max_chat_histories=5):
1816
"""Handles the initiation of a new chat session.
1917
2018
Displays the remaining chat history count and provides a button to start a new chat.
2119
2220
Args:
23-
db: Deta Base instance.
2421
max_chat_histories (int, optional): Maximum number of chat histories to retain.
2522
2623
Returns:
2724
None
2825
"""
29-
remaining_chats = max_chat_histories - len(
30-
get_previous_chats(db, st.session_state.user_email)
31-
)
26+
remaining_chats = max_chat_histories - len(st.session_state.get("previous_chats", []))
3227
st.markdown(
33-
f" #### Remaining Chat Histories: \
34-
`{remaining_chats}/{max_chat_histories}`"
28+
f" #### Remaining Chat Histories: `{remaining_chats}/{max_chat_histories}`"
3529
)
3630
st.markdown(
37-
"You can create up to 5 chat histories. Each history \
38-
can contain unlimited messages."
31+
"You can create up to 5 chat histories. Each history can contain unlimited messages."
3932
)
4033

4134
if st.button("➕ New chat"):
42-
database(db, previous_key=st.session_state.key)
35+
save_chat_history() # Save current chat before creating a new one
4336
create_message()
4437

4538

46-
def display_previous_chats(db):
47-
"""Displays previous chat records.
39+
def display_previous_chats():
40+
"""Displays previous chat records stored in session state.
4841
49-
Retrieves and displays a list of previous chat records for the user.
5042
Allows the user to select a chat to view.
51-
52-
Args:
53-
db: Deta Base instance.
54-
55-
Returns:
56-
None
5743
"""
58-
previous_chats = get_previous_chats(db, st.session_state.user_email)
59-
reversed_chats = reversed(previous_chats)
44+
if "previous_chats" in st.session_state:
45+
reversed_chats = reversed(st.session_state["previous_chats"])
6046

61-
for chat in reversed_chats:
62-
if st.button(chat["title"], key=chat["key"]):
63-
update_session_state(db, chat)
47+
for chat in reversed_chats:
48+
if st.button(chat["title"], key=chat["key"]):
49+
update_session_state(chat)
6450

6551

6652
def create_message():
6753
"""Creates a default assistant message and initializes a session key."""
68-
6954
st.session_state["messages"] = [
7055
{"role": "assistant", "content": "How may I help you?"}
7156
]
7257
st.session_state["key"] = "key"
73-
return
7458

7559

76-
def update_session_state(db, chat):
60+
def update_session_state(chat):
7761
"""Updates the session state with selected chat information.
7862
7963
Args:
80-
db: Deta Base instance.
8164
chat (dict): Selected chat information.
82-
83-
Returns:
84-
None
8565
"""
86-
previous_chat = st.session_state["messages"]
87-
previous_key = st.session_state["key"]
8866
st.session_state["messages"] = chat["chat"]
8967
st.session_state["key"] = chat["key"]
90-
database(db, previous_key, previous_chat)
68+
69+
70+
def save_chat_history():
71+
"""Saves the current chat to session state if it contains messages."""
72+
if "messages" in st.session_state and len(st.session_state["messages"]) > 1:
73+
# Initialize previous chats list if it doesn't exist
74+
if "previous_chats" not in st.session_state:
75+
st.session_state["previous_chats"] = []
76+
77+
# Create a chat summary to store in session
78+
title = st.session_state["messages"][1]["content"]
79+
chat_summary = {
80+
"title": title[:25] + "....." if len(title) > 25 else title,
81+
"chat": st.session_state["messages"],
82+
"key": f"chat_{len(st.session_state['previous_chats']) + 1}"
83+
}
84+
85+
st.session_state["previous_chats"].append(chat_summary)
86+
87+
# Limit chat histories to a maximum number
88+
if len(st.session_state["previous_chats"]) > 5:
89+
st.session_state["previous_chats"].pop(0) # Remove oldest chat
90+
st.warning(
91+
f"The oldest chat history has been removed as you reached the limit of 5 chat histories."
92+
)

src/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,5 @@ def create_huggingface_hub():
1818
return HuggingFaceHub(
1919
huggingfacehub_api_token=HUGGINGFACE_API_TOKEN,
2020
repo_id=REPO_ID,
21-
model_kwargs={"temperature": 0.2, "max_new_tokens": 180},
21+
model_kwargs={"temperature": 0.7, "max_new_tokens": 180},
2222
)

0 commit comments

Comments
 (0)