1
1
from langchain_core .output_parsers import StrOutputParser
2
2
from langchain_core .prompts import PromptTemplate
3
3
import streamlit as st
4
- from deta import Deta
5
4
import sys
6
5
import os
6
+ import json
7
7
from backend import (
8
8
configure_page_styles ,
9
- create_oauth2_component ,
10
9
display_github_badge ,
11
- handle_google_login_if_needed ,
12
10
hide_main_menu_and_footer ,
13
11
)
14
12
from frontend import (
19
17
handle_new_chat ,
20
18
)
21
19
from model import create_huggingface_hub
22
-
23
-
24
20
sys .path .append (os .path .abspath (os .path .join (os .path .dirname (__file__ ), ".." )))
25
-
26
21
from src .auth import *
27
22
from src .constant import *
28
23
24
+ def format_chat_history (messages ):
25
+ """Format the chat history as a structured JSON string."""
26
+ history = []
27
+ for msg in messages [1 :]:
28
+ content = msg ['content' ]
29
+ if '```sql' in content :
30
+ content = content .replace ('```sql\n ' , '' ).replace ('\n ```' , '' ).strip ()
31
+
32
+ history .append ({
33
+ "role" : msg ['role' ],
34
+ "query" if msg ['role' ] == 'user' else "response" : content
35
+ })
36
+
37
+ formatted_history = json .dumps (history , indent = 2 )
38
+ print ("Formatted history:" , formatted_history )
39
+ return formatted_history
40
+
41
+ def extract_sql_code (response ):
42
+ """Extract clean SQL code from the response."""
43
+ sql_code_start = response .find ("```sql" )
44
+ if sql_code_start != - 1 :
45
+ sql_code_end = response .find ("```" , sql_code_start + 5 )
46
+ if sql_code_end != - 1 :
47
+ sql_code = response [sql_code_start + 6 :sql_code_end ].strip ()
48
+ return f"```sql\n { sql_code } \n ```"
49
+ return response
29
50
30
51
def main ():
31
52
"""Main function to configure and run the Querypls application."""
32
53
configure_page_styles ("static/css/styles.css" )
33
- deta = Deta ( DETA_PROJECT_KEY )
54
+
34
55
if "model" not in st .session_state :
35
56
llm = create_huggingface_hub ()
36
57
st .session_state ["model" ] = llm
37
- db = deta .Base ("users" )
38
- oauth2 = create_oauth2_component ()
39
-
40
- if "code" not in st .session_state or not st .session_state .code :
41
- st .session_state .code = False
42
-
43
- if "code" not in st .session_state :
44
- st .session_state .code = False
45
-
58
+
59
+ if "messages" not in st .session_state :
60
+ create_message ()
61
+
46
62
hide_main_menu_and_footer ()
47
- if st .session_state .code == False :
48
- col1 , col2 , col3 = st .columns (3 )
49
- with col1 :
50
- pass
51
- with col2 :
52
- with st .container ():
53
-
54
- display_github_badge ()
55
- display_logo_and_heading ()
56
-
57
- st .markdown ("`Made with 🤍`" )
58
- if "token" not in st .session_state :
59
- result = oauth2 .authorize_button (
60
- "Connect with Google" ,
61
- REDIRECT_URI ,
62
- SCOPE ,
63
- icon = "data:image/svg+xml;charset=utf-8,%3Csvg \
64
- xmlns='http://www.w3.org/2000/svg' \
65
- xmlns:xlink='http://www.w3.org/1999/xlink' \
66
- viewBox='0 0 48 48'%3E%3Cdefs%3E%3Cpath id='a' \
67
- d='M44.5 20H24v8.5h11.8C34.7 33.9 30.1 37 24 37c-7.2 \
68
- 0-13-5.8-13-13s5.8-13 13-13c3.1 0 5.9 1.1 8.1 \
69
- 2.9l6.4-6.4C34.6 4.1 29.6 2 24 2 11.8 2 2 11.8 2 \
70
- 24s9.8 22 22 22c11 0 21-8 21-22 \
71
- 0-1.3-.2-2.7-.5-4z'/%3E%3C/defs%3E%3CclipPath \
72
- id='b'%3E%3Cuse xlink:href='%23a' \
73
- overflow='visible'/%3E%3C/clipPath%3E%3Cpath \
74
- clip-path='url(%23b)' fill='%23FBBC05' \
75
- d='M0 37V11l17 13z'/%3E%3Cpath clip-path='url(%23b)' \
76
- fill='%23EA4335' d='M0 11l17 13 7-6.1L48 \
77
- 14V0H0z'/%3E%3Cpath clip-path='url(%23b)' \
78
- fill='%2334A853' d='M0 37l30-23 7.9 1L48 \
79
- 0v48H0z'/%3E%3Cpath clip-path='url(%23b)' \
80
- fill='%234285F4' d='M48 48L17 24l-4-3 \
81
- 35-10z'/%3E%3C/svg%3E" ,
82
- use_container_width = True ,
83
- )
84
- handle_google_login_if_needed (result )
85
- if st .session_state .code :
86
- st .rerun ()
87
- with col3 :
88
- pass
89
- else :
90
- with st .sidebar :
91
- display_github_badge ()
92
- display_logo_and_heading ()
93
- st .markdown ("`Made with 🤍`" )
94
- if st .session_state .code :
95
- handle_new_chat (db )
96
- if st .session_state .code :
97
- display_previous_chats (db )
98
-
99
- if "messages" not in st .session_state :
100
- create_message ()
101
- display_welcome_message ()
102
-
103
- for message in st .session_state .messages :
104
- with st .chat_message (message ["role" ]):
105
- st .markdown (message ["content" ], unsafe_allow_html = True )
106
-
107
- if prompt := st .chat_input (disabled = (st .session_state .code is False )):
108
- st .session_state .messages .append (
109
- {"role" : "user" , "content" : prompt }
110
- )
111
- with st .chat_message ("user" ):
112
- st .write (prompt )
113
-
114
- prompt_template = PromptTemplate (
115
- template = TEMPLATE , input_variables = ["question" ]
116
- )
117
-
118
- if "model" in st .session_state :
119
- llm_chain = (
120
- prompt_template
121
- | st .session_state .model
122
- | StrOutputParser ()
123
- )
124
- if st .session_state .messages [- 1 ]["role" ] != "assistant" :
125
- with st .chat_message ("assistant" ):
126
- with st .spinner ("Generating..." ):
127
- response = llm_chain .invoke (prompt )
128
- import re
129
-
130
- code_block_match = re .search (
131
- r"```sql(.*?)```" , response , re .DOTALL
132
- )
133
- if code_block_match :
134
- code_block = code_block_match .group (1 )
135
- st .markdown (
136
- f"```sql\n { code_block } \n ```" ,
137
- unsafe_allow_html = True ,
138
- )
139
- message = {
140
- "role" : "assistant" ,
141
- "content" : f"```sql\n { code_block } \n ```" ,
142
- }
143
- st .session_state .messages .append (message )
144
-
63
+
64
+ with st .sidebar :
65
+ display_github_badge ()
66
+ display_logo_and_heading ()
67
+ st .markdown ("`Made with 🤍`" )
68
+ handle_new_chat ()
69
+
70
+ display_welcome_message ()
71
+ for message in st .session_state .messages :
72
+ with st .chat_message (message ["role" ]):
73
+ st .markdown (message ["content" ])
74
+
75
+ if prompt := st .chat_input ():
76
+ st .session_state .messages .append ({"role" : "user" , "content" : prompt })
77
+ with st .chat_message ("user" ):
78
+ st .markdown (prompt )
79
+
80
+ conversation_history = format_chat_history (st .session_state .messages )
81
+ prompt_template = PromptTemplate (
82
+ template = TEMPLATE ,
83
+ input_variables = ["input" , "conversation_history" ]
84
+ )
85
+
86
+ if "model" in st .session_state :
87
+ llm_chain = prompt_template | st .session_state .model | StrOutputParser ()
88
+
89
+ with st .chat_message ("assistant" ):
90
+ with st .spinner ("Generating..." ):
91
+ response = llm_chain .invoke ({
92
+ "input" : prompt ,
93
+ "conversation_history" : conversation_history
94
+ })
95
+
96
+ # Clean and format the response
97
+ formatted_response = extract_sql_code (response )
98
+ st .markdown (formatted_response )
99
+
100
+ # Add to chat history
101
+ st .session_state .messages .append ({
102
+ "role" : "assistant" ,
103
+ "content" : formatted_response
104
+ })
145
105
146
106
if __name__ == "__main__" :
147
- main ()
107
+ main ()
0 commit comments