Skip to content

Commit a733c59

Browse files
committed
feat: Refactor LLM workflow and Streamlit interface for Lang2SQL
- Enhance query generation process with refined input and table info retrieval - Update Streamlit app to support database environment input - Modify graph and chain logic to improve query generation - Add DataHub server configuration to README
1 parent 5140536 commit a733c59

File tree

5 files changed

+214
-127
lines changed

5 files changed

+214
-127
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ LANGCHAIN_TRACING_V2=true
6262
LANGCHAIN_PROJECT=autosql
6363
LANGCHAIN_ENDPOINT=https://api.smith.langchain.com
6464
LANGCHAIN_API_KEY=your-langchain-api-key
65+
DATAHUB_SERVER=http://localhost:8080
6566
```
6667

6768
---

interface/streamlit_app.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,19 @@
33
from llm_utils.graph import builder
44

55
# Streamlit 앱 제목
6-
st.title("AutoSQL")
6+
st.title("Lang2SQL")
77

88
# 사용자 입력 받기
99
user_query = st.text_area(
1010
"쿼리를 입력하세요:",
1111
value="고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리",
1212
)
1313

14+
user_database_env = st.text_area(
15+
"db 환경정보를 입력하세요:",
16+
value="duckdb",
17+
)
18+
1419

1520
# Token usage 집계 함수 정의
1621
def summarize_total_tokens(data):
@@ -25,10 +30,18 @@ def summarize_total_tokens(data):
2530
if st.button("쿼리 실행"):
2631
# 그래프 컴파일 및 쿼리 실행
2732
graph = builder.compile()
28-
human_message = HumanMessage(content=user_query)
29-
res = graph.invoke(input=human_message)
30-
total_tokens = summarize_total_tokens(res)
33+
34+
res = graph.invoke(
35+
input={
36+
"messages": [HumanMessage(content=user_query)],
37+
"user_database_env": user_database_env,
38+
"best_practice_query": "",
39+
}
40+
)
41+
total_tokens = summarize_total_tokens(res["messages"])
3142

3243
# 결과 출력
3344
st.write("총 토큰 사용량:", total_tokens)
34-
st.write("결과:", res[-1].content)
45+
st.write("결과:", res["generated_query"].content)
46+
st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content)
47+
st.write("참고한 테이블 목록:", res["searched_tables"])

llm_utils/chains.py

Lines changed: 76 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,120 @@
11
import os
22
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
33

4-
from .tools import get_table_info, get_column_info
54
from .llm_factory import get_llm
65

76
llm = get_llm(
8-
model_type="openai", model_name="gpt-4o", openai_api_key=os.getenv("OPENAI_API_KEY")
7+
model_type="openai",
8+
model_name="gpt-4o-mini",
9+
openai_api_key=os.getenv("OPENAI_API_KEY"),
910
)
1011

1112

12-
# ToolChoiceChain
13-
def create_tool_choice_chain(llm):
13+
def create_query_refiner_chain(llm):
1414
tool_choice_prompt = ChatPromptTemplate.from_messages(
1515
[
1616
(
1717
"system",
1818
"""
19-
너는 User Input에 대해 관련된 테이블, 컬럼을 찾는 Assistance이다.
19+
당신은 데이터 분석 전문가(데이터 분석가 페르소나)입니다.
20+
현재 subscription_activities, contract_activities, marketing_activities,
21+
sales_activities, success_activities, support_activities, trial_activities 데이터를
22+
보유하고 있으며, 사용자의 질문이 모호할 경우에도 우리가 가진 데이터를 기반으로
23+
충분히 답변 가능한 형태로 질문을 구체화해 주세요.
24+
25+
주의:
26+
- 사용자에게 추가 정보를 요구하는 ‘재질문(추가 질문)’을 하지 마세요.
27+
- 질문에 포함해야 할 요소(예: 특정 기간, 대상 유저 그룹, 분석 대상 로그 종류 등)가
28+
불충분하더라도, 합리적으로 추론해 가정한 뒤
29+
답변에 충분한 질문 형태로 완성해 주세요.
30+
31+
예시:
32+
사용자가 "유저 이탈 원인이 궁금해요"라고 했다면,
33+
재질문 형식이 아니라
34+
"최근 1개월 간의 접속·결제 로그를 기준으로,
35+
주로 어떤 사용자가 어떤 과정을 거쳐 이탈하는지를 분석해야 한다"처럼
36+
분석 방향이 명확해진 질문 한 문장(또는 한 문단)으로 정리해 주세요.
37+
38+
최종 출력 형식 예시:
39+
------------------------------
40+
구체화된 질문:
41+
"최근 1개월 동안 고액 결제 경험이 있는 유저가
42+
행동 로그에서 이탈 전 어떤 패턴을 보였는지 분석"
43+
44+
가정한 조건:
45+
- 최근 1개월치 행동 로그와 결제 로그 중심
46+
- 고액 결제자(월 결제액 10만 원 이상) 그룹 대상으로 한정
47+
------------------------------
2048
""",
2149
),
2250
MessagesPlaceholder(variable_name="user_input"),
2351
(
2452
"system",
2553
"""
26-
위의 질의와 관련된 테이블을 찾아주세요
27-
다음 tool을 사용할 수 있습니다:
28-
get_table_info - 전체 table_name과 table_description을 가져옵니다.
29-
get_column_info - table_name을 input으로 넣으면 column_name과 column description을 가져옵니다.
30-
아래 툴을 사용해주세요
54+
위 사용자의 입력을 바탕으로
55+
분석 관점에서 **충분히 답변 가능한 형태**로
56+
"구체화된 질문"을 작성하고,
57+
필요한 경우 가정이나 전제 조건을 함께 제시해 주세요.
3158
""",
3259
),
33-
MessagesPlaceholder(variable_name="tool_choice"),
3460
]
3561
)
3662

37-
tools = [get_table_info, get_column_info]
63+
return tool_choice_prompt | llm
3864

39-
return tool_choice_prompt | llm.bind_tools(tools)
4065

41-
42-
# TableFilterChain
43-
def create_table_filter_chain(llm):
44-
table_filter_prompt = ChatPromptTemplate.from_messages(
66+
# QueryMakerChain
67+
def create_query_maker_chain(llm):
68+
query_maker_prompt = ChatPromptTemplate.from_messages(
4569
[
46-
MessagesPlaceholder(variable_name="user_input"),
4770
(
4871
"system",
4972
"""
50-
너는 위의 User Input에 대해 관련된 테이블을 찾는 Assistance이다.
51-
아래 테이블 목록을 참고해서 관련된 테이블을 찾아주세요.
52-
참고사항은
53-
dim_~: 테이블은 metadata 테이블임
54-
fact_~: 테이블은 실제 데이터가 저장된 테이블임
55-
stg_~: 테이블은 데이터 적재 테이블임
56-
응답형태는 'table_name - table_description' 형태로 출력해주세요.
57-
최소 2개 이상의 테이블을 출력해주세요.
58-
테이블 목록은 다음과 같습니다:
73+
당신은 데이터 분석 전문가(데이터 분석가 페르소나)입니다.
74+
사용자의 질문을 기반으로, 주어진 테이블과 컬럼 정보를 활용하여 적절한 SQL 쿼리를 생성하세요.
75+
76+
주의사항:
77+
- 사용자의 질문이 다소 모호하더라도, 주어진 데이터를 참고하여 합리적인 가정을 통해 SQL 쿼리를 완성하세요.
78+
- 불필요한 재질문 없이, 가능한 가장 명확한 분석 쿼리를 만들어 주세요.
79+
- 최종 출력 형식은 반드시 아래와 같아야 합니다.
80+
81+
최종 형태 예시:
82+
83+
<SQL>
84+
```sql
85+
SELECT COUNT(DISTINCT user_id)
86+
FROM stg_users
87+
```
88+
89+
<해석>
90+
```plaintext (max_length_per_line=100)
91+
이 쿼리는 stg_users 테이블에서 고유한 사용자의 수를 계산합니다.
92+
사용자는 유니크한 user_id를 가지고 있으며
93+
중복을 제거하기 위해 COUNT(DISTINCT user_id)를 사용했습니다.
94+
```
95+
5996
""",
6097
),
61-
MessagesPlaceholder(variable_name="searched_tables"),
62-
]
63-
)
64-
return table_filter_prompt | llm
65-
66-
67-
# QueryMakerChain
68-
def create_query_maker_chain(llm):
69-
query_maker_prompt = ChatPromptTemplate.from_messages(
70-
[
98+
(
99+
"system",
100+
"아래는 사용자의 질문 및 구체화된 질문입니다:",
101+
),
71102
MessagesPlaceholder(variable_name="user_input"),
72-
("system", "너는 위의 User Input에 대해 쿼리를 작성하는 Assistance이다."),
73-
("system", "다음 테이블과 컬럼을 참고해서 쿼리를 작성해주세요."),
74-
("system", "테이블 목록은 다음과 같습니다:"),
103+
MessagesPlaceholder(variable_name="refined_input"),
104+
(
105+
"system",
106+
"다음은 사용자의 db 환경정보와 사용 가능한 테이블 및 컬럼 정보입니다:",
107+
),
108+
MessagesPlaceholder(variable_name="user_database_env"),
75109
MessagesPlaceholder(variable_name="searched_tables"),
76-
("system", "컬럼 목록은 다음과 같습니다:"),
77-
MessagesPlaceholder(variable_name="searched_columns"),
78110
(
79111
"system",
80-
"""최종 형태는 반드시 아래와 같아야합니다.
81-
최종 쿼리:
82-
SELECT COUNT(DISTINCT user_id) FROM stg_users WHERE user_id = 1
83-
참고한 테이블 목록:
84-
stg_users, dim_users
85-
참고한 컬럼 목록:
86-
stg_users.user_id, dim_users.user_id
87-
""",
112+
"위 정보를 바탕으로 사용자 질문에 대한 최적의 SQL 쿼리를 최종 형태 예시와 같은 형태로 생성하세요.",
88113
),
89114
]
90115
)
91116
return query_maker_prompt | llm
92117

93118

94-
tool_choice_chain = create_tool_choice_chain(llm)
95-
table_filter_chain = create_table_filter_chain(llm)
119+
query_refiner_chain = create_query_refiner_chain(llm)
96120
query_maker_chain = create_query_maker_chain(llm)

0 commit comments

Comments
 (0)