Skip to content

Commit 10e0529

Browse files
authored
Merge pull request #63 from CausalInferenceLab/10-보여주는-결과의-확장-필요성
10 보여주는 결과의 확장 필요성
2 parents 9c79b30 + 39a4cfb commit 10e0529

File tree

6 files changed

+95
-18
lines changed

6 files changed

+95
-18
lines changed

interface/lang2sql.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,18 @@
22
from langchain_core.messages import HumanMessage
33
from llm_utils.graph import builder
44
from langchain.chains.sql_database.prompt import SQL_PROMPTS
5+
import os
6+
from typing import Union
7+
import pandas as pd
8+
9+
from clickhouse_driver import Client
10+
from llm_utils.connect_db import ConnectDB
11+
from dotenv import load_dotenv
12+
13+
14+
# Clickhouse 연결
15+
db = ConnectDB()
16+
db.connect_to_clickhouse()
517

618
# Streamlit 앱 제목
719
st.title("Lang2SQL")
@@ -17,6 +29,22 @@
1729
options=SQL_PROMPTS.keys(),
1830
index=0,
1931
)
32+
st.sidebar.title("Output Settings")
33+
st.sidebar.checkbox("Show Total Token Usage", value=True, key="show_total_token_usage")
34+
st.sidebar.checkbox(
35+
"Show Result Description", value=True, key="show_result_description"
36+
)
37+
st.sidebar.checkbox("Show SQL", value=True, key="show_sql")
38+
st.sidebar.checkbox(
39+
"Show User Question Reinterpreted by AI",
40+
value=True,
41+
key="show_question_reinterpreted_by_ai",
42+
)
43+
st.sidebar.checkbox(
44+
"Show List of Referenced Tables", value=True, key="show_referenced_tables"
45+
)
46+
st.sidebar.checkbox("Show Table", value=True, key="show_table")
47+
st.sidebar.checkbox("Show Chart", value=True, key="show_chart")
2048

2149

2250
# Token usage 집계 함수 정의
@@ -43,9 +71,20 @@ def summarize_total_tokens(data):
4371
total_tokens = summarize_total_tokens(res["messages"])
4472

4573
# 결과 출력
46-
st.write("총 토큰 사용량:", total_tokens)
47-
# st.write("결과:", res["generated_query"].content)
48-
st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```")
49-
st.write("결과 설명:\n\n", res["messages"][-1].content)
50-
st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content)
51-
st.write("참고한 테이블 목록:", res["searched_tables"])
74+
if st.session_state.get("show_total_token_usage", True):
75+
st.write("총 토큰 사용량:", total_tokens)
76+
if st.session_state.get("show_sql", True):
77+
st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```")
78+
if st.session_state.get("show_result_description", True):
79+
st.write("결과 설명:\n\n", res["messages"][-1].content)
80+
if st.session_state.get("show_question_reinterpreted_by_ai", True):
81+
st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content)
82+
if st.session_state.get("show_referenced_tables", True):
83+
st.write("참고한 테이블 목록:", res["searched_tables"])
84+
if st.session_state.get("show_table", True):
85+
sql = res["generated_query"]
86+
df = db.run_sql(sql)
87+
if len(df) > 10:
88+
st.dataframe(df.head(10))
89+
else:
90+
st.dataframe(df)

interface/streamlit_app.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import streamlit as st
22

3-
43
pg = st.navigation(
54
[
65
st.Page("lang2sql.py", title="Lang2SQL"),

llm_utils/chains.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,7 @@
1212
else:
1313
print(f"⚠️ 환경변수 파일(.env)이 {os.getcwd()}에 없습니다!")
1414

15-
llm = get_llm(
16-
model_type="openai",
17-
model_name="gpt-4o-mini",
18-
openai_api_key=os.getenv("OPENAI_API_KEY"),
19-
)
15+
llm = get_llm()
2016

2117

2218
def create_query_refiner_chain(llm):

llm_utils/connect_db.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
from typing import Union
3+
import pandas as pd
4+
from clickhouse_driver import Client
5+
from dotenv import load_dotenv
6+
7+
# 환경변수
8+
load_dotenv()
9+
10+
11+
class ConnectDB:
12+
def __init__(self):
13+
self.client = None
14+
self.host = os.getenv("CLICKHOUSE_HOST")
15+
self.dbname = os.getenv("CLICKHOUSE_DATABASE")
16+
self.user = os.getenv("CLICKHOUSE_USER")
17+
self.password = os.getenv("CLICKHOUSE_PASSWORD")
18+
self.port = os.getenv("CLICKHOUSE_PORT")
19+
20+
def connect_to_clickhouse(self):
21+
22+
# ClickHouse 서버 정보
23+
self.client = Client(
24+
host=self.host,
25+
port=self.port,
26+
user=self.user,
27+
password=self.password,
28+
database=self.dbname, # 예: '127.0.0.1' # 기본 TCP 포트
29+
)
30+
31+
def run_sql(self, sql: str) -> Union[pd.DataFrame, None]:
32+
if self.client:
33+
try:
34+
result = self.client.execute(sql, with_column_types=True)
35+
# 결과와 컬럼 정보 분리
36+
rows, columns = result
37+
column_names = [col[0] for col in columns]
38+
39+
# Create a pandas dataframe from the results
40+
df = pd.DataFrame(rows, columns=column_names)
41+
return df
42+
43+
except Exception as e:
44+
raise e

llm_utils/graph.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def get_table_info_node(state: QueryMakerState):
6262
documents = get_info_from_db()
6363
db = FAISS.from_documents(documents, embeddings)
6464
db.save_local(os.getcwd() + "/table_info_db")
65-
print("table_info_db not found")
6665
doc_res = db.similarity_search(state["messages"][-1].content)
6766
documents_dict = {}
6867

@@ -112,11 +111,7 @@ class SQLResult(BaseModel):
112111

113112
def query_maker_node_with_db_guide(state: QueryMakerState):
114113
sql_prompt = SQL_PROMPTS[state["user_database_env"]]
115-
llm = get_llm(
116-
model_type="openai",
117-
model_name="gpt-4o-mini",
118-
openai_api_key=os.getenv("OPENAI_API_KEY"),
119-
)
114+
llm = get_llm()
120115
chain = sql_prompt | llm.with_structured_output(SQLResult)
121116
res = chain.invoke(
122117
input={

setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
"streamlit==1.41.1",
2525
"python-dotenv==1.0.1",
2626
"faiss-cpu==1.10.0",
27+
"langchain-aws>=0.2.21,<0.3.0",
28+
"langchain-google-genai>=2.1.3,<3.0.0",
29+
"langchain-ollama>=0.3.2,<0.4.0",
30+
"langchain-huggingface>=0.1.2,<0.2.0",
2731
],
2832
entry_points={
2933
"console_scripts": [

0 commit comments

Comments
 (0)