Skip to content

Commit 8dd00ee

Browse files
authored
Merge pull request #14 Resolve #6
#6 이슈 해결 #11 이슈에서 `쿼리 설명이 영어로 나오는` 부분 해결 기대
2 parents 4d186ba + 36af176 commit 8dd00ee

File tree

3 files changed

+42
-5
lines changed

3 files changed

+42
-5
lines changed

interface/streamlit_app.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import streamlit as st
22
from langchain_core.messages import HumanMessage
33
from llm_utils.graph import builder
4+
from langchain.chains.sql_database.prompt import SQL_PROMPTS
45

56
# Streamlit 앱 제목
67
st.title("Lang2SQL")
@@ -11,9 +12,10 @@
1112
value="고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리",
1213
)
1314

14-
user_database_env = st.text_area(
15+
user_database_env = st.selectbox(
1516
"db 환경정보를 입력하세요:",
16-
value="duckdb",
17+
options=SQL_PROMPTS.keys(),
18+
index=0,
1719
)
1820

1921

@@ -42,6 +44,8 @@ def summarize_total_tokens(data):
4244

4345
# 결과 출력
4446
st.write("총 토큰 사용량:", total_tokens)
45-
st.write("결과:", res["generated_query"].content)
47+
# st.write("결과:", res["generated_query"].content)
48+
st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```")
49+
st.write("결과 설명:\n\n", res["messages"][-1].content)
4650
st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content)
4751
st.write("참고한 테이블 목록:", res["searched_tables"])

llm_utils/graph.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from typing_extensions import TypedDict, Annotated
55
from langgraph.graph import END, StateGraph
66
from langgraph.graph.message import add_messages
7+
from langchain.chains.sql_database.prompt import SQL_PROMPTS
8+
from pydantic import BaseModel, Field
9+
from .llm_factory import get_llm
710

811
from llm_utils.chains import (
912
query_refiner_chain,
@@ -102,14 +105,44 @@ def query_maker_node(state: QueryMakerState):
102105
return state
103106

104107

108+
class SQLResult(BaseModel):
109+
sql: str = Field(description="SQL 쿼리 문자열")
110+
explanation: str = Field(description="SQL 쿼리 설명")
111+
112+
113+
def query_maker_node_with_db_guide(state: QueryMakerState):
114+
sql_prompt = SQL_PROMPTS[state["user_database_env"]]
115+
llm = get_llm(
116+
model_type="openai",
117+
model_name="gpt-4o-mini",
118+
openai_api_key=os.getenv("OPENAI_API_KEY"),
119+
)
120+
chain = sql_prompt | llm.with_structured_output(SQLResult)
121+
res = chain.invoke(
122+
input={
123+
"input": "\n\n---\n\n".join(
124+
[state["messages"][0].content] + [state["refined_input"].content]
125+
),
126+
"table_info": [json.dumps(state["searched_tables"])],
127+
"top_k": 10,
128+
}
129+
)
130+
state["generated_query"] = res.sql
131+
state["messages"].append(res.explanation)
132+
return state
133+
134+
105135
# StateGraph 생성 및 구성
106136
builder = StateGraph(QueryMakerState)
107137
builder.set_entry_point(QUERY_REFINER)
108138

109139
# 노드 추가
110140
builder.add_node(QUERY_REFINER, query_refiner_node)
111141
builder.add_node(GET_TABLE_INFO, get_table_info_node)
112-
builder.add_node(QUERY_MAKER, query_maker_node)
142+
# builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide
143+
builder.add_node(
144+
QUERY_MAKER, query_maker_node_with_db_guide
145+
) # query_maker_node_with_db_guide
113146

114147
# 기본 엣지 설정
115148
builder.add_edge(QUERY_REFINER, GET_TABLE_INFO)

llm_utils/llm_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def get_llm(
1717
if model_type == "openai":
1818
return ChatOpenAI(
1919
model=model_name,
20-
openai_api_key=openai_api_key,
20+
api_key=openai_api_key,
2121
**kwargs,
2222
)
2323

0 commit comments

Comments
 (0)