Skip to content

Commit 5079c7f

Browse files
authored
Merge pull request #129 from #126
feat: 단순화된 그래프 워크플로우 추가 및 관련 코드 수정
2 parents 24b1c3d + 230293e commit 5079c7f

File tree

6 files changed

+139
-13
lines changed

6 files changed

+139
-13
lines changed

cli/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,13 +229,19 @@ def run_streamlit_cli_command(port: int) -> None:
229229
is_flag=True,
230230
help="확장된 그래프(프로파일 추출 + 컨텍스트 보강) 사용 여부",
231231
)
232+
@click.option(
233+
"--use-simplified-graph",
234+
is_flag=True,
235+
help="단순화된 그래프(QUERY_REFINER 제거) 사용 여부",
236+
)
232237
def query_command(
233238
question: str,
234239
database_env: str,
235240
retriever_name: str,
236241
top_n: int,
237242
device: str,
238243
use_enriched_graph: bool,
244+
use_simplified_graph: bool,
239245
) -> None:
240246
"""
241247
자연어 질문을 SQL 쿼리로 변환하여 출력하는 명령어입니다.
@@ -267,6 +273,7 @@ def query_command(
267273
top_n=top_n,
268274
device=device,
269275
use_enriched_graph=use_enriched_graph,
276+
use_simplified_graph=use_simplified_graph,
270277
)
271278

272279
# SQL 추출 및 출력

interface/lang2sql.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from llm_utils.token_utils import TokenUtils
2121
from llm_utils.graph_utils.enriched_graph import builder as enriched_builder
2222
from llm_utils.graph_utils.basic_graph import builder
23+
from llm_utils.graph_utils.simplified_graph import builder as simplified_builder
2324

2425

2526
TITLE = "Lang2SQL"
@@ -71,6 +72,7 @@ def execute_query(
7172
top_n=top_n,
7273
device=device,
7374
use_enriched_graph=st.session_state.get("use_enriched", False),
75+
use_simplified_graph=st.session_state.get("use_simplified", False),
7476
session_state=st.session_state,
7577
)
7678

@@ -219,29 +221,53 @@ def should_show(_key: str) -> bool:
219221
st.title(TITLE)
220222

221223
# 워크플로우 선택(UI)
224+
st.sidebar.markdown("### 워크플로우 선택")
222225
use_enriched = st.sidebar.checkbox(
223226
"프로파일 추출 & 컨텍스트 보강 워크플로우 사용", value=False
224227
)
228+
use_simplified = st.sidebar.checkbox(
229+
"단순화된 워크플로우 사용 (QUERY_REFINER 제거)", value=False
230+
)
225231

226232
# 세션 상태 초기화
227233
if (
228234
"graph" not in st.session_state
229235
or st.session_state.get("use_enriched") != use_enriched
236+
or st.session_state.get("use_simplified") != use_simplified
230237
):
231-
graph_builder = enriched_builder if use_enriched else builder
232-
st.session_state["graph"] = graph_builder.compile()
238+
# 그래프 선택 로직
239+
if use_simplified:
240+
graph_builder = simplified_builder
241+
graph_type = "단순화된"
242+
elif use_enriched:
243+
graph_builder = enriched_builder
244+
graph_type = "확장된"
245+
else:
246+
graph_builder = builder
247+
graph_type = "기본"
233248

234-
# 프로파일 추출 & 컨텍스트 보강 그래프
249+
st.session_state["graph"] = graph_builder.compile()
235250
st.session_state["use_enriched"] = use_enriched
236-
st.info("Lang2SQL이 성공적으로 시작되었습니다.")
251+
st.session_state["use_simplified"] = use_simplified
252+
st.info(f"Lang2SQL이 성공적으로 시작되었습니다. ({graph_type} 워크플로우)")
237253

238254
# 새로고침 버튼 추가
239255
if st.sidebar.button("Lang2SQL 새로고침"):
240-
graph_builder = (
241-
enriched_builder if st.session_state.get("use_enriched") else builder
242-
)
256+
# 그래프 선택 로직
257+
if st.session_state.get("use_simplified"):
258+
graph_builder = simplified_builder
259+
graph_type = "단순화된"
260+
elif st.session_state.get("use_enriched"):
261+
graph_builder = enriched_builder
262+
graph_type = "확장된"
263+
else:
264+
graph_builder = builder
265+
graph_type = "기본"
266+
243267
st.session_state["graph"] = graph_builder.compile()
244-
st.sidebar.success("Lang2SQL이 성공적으로 새로고침되었습니다.")
268+
st.sidebar.success(
269+
f"Lang2SQL이 성공적으로 새로고침되었습니다. ({graph_type} 워크플로우)"
270+
)
245271

246272
user_query = st.text_area(
247273
"쿼리를 입력하세요:",

llm_utils/graph_utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
query_refiner_with_profile_node,
1919
context_enrichment_node,
2020
query_maker_node_with_db_guide,
21+
query_maker_node_without_refiner,
2122
)
2223

2324
from .basic_graph import builder as basic_builder
2425
from .enriched_graph import builder as enriched_builder
26+
from .simplified_graph import builder as simplified_builder
2527

2628
__all__ = [
2729
# 상태 및 노드 식별자
@@ -39,7 +41,9 @@
3941
"query_refiner_with_profile_node",
4042
"context_enrichment_node",
4143
"query_maker_node_with_db_guide",
44+
"query_maker_node_without_refiner",
4245
# 그래프 빌더들
4346
"basic_builder",
4447
"enriched_builder",
48+
"simplified_builder",
4549
]

llm_utils/graph_utils/base.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,21 @@ def context_enrichment_node(state: QueryMakerState):
140140
searched_tables = state["searched_tables"]
141141
searched_tables_json = json.dumps(searched_tables, ensure_ascii=False, indent=2)
142142

143-
question_profile = state["question_profile"].model_dump()
143+
# question_profile이 BaseModel인 경우 model_dump() 사용, dict인 경우 그대로 사용
144+
if hasattr(state["question_profile"], "model_dump"):
145+
question_profile = state["question_profile"].model_dump()
146+
else:
147+
question_profile = state["question_profile"]
144148
question_profile_json = json.dumps(question_profile, ensure_ascii=False, indent=2)
145149

150+
# refined_input이 없는 경우 초기 사용자 입력 사용
151+
refined_question = state.get("refined_input", state["messages"][0].content)
152+
if hasattr(refined_question, "content"):
153+
refined_question = refined_question.content
154+
146155
enriched_text = query_enrichment_chain.invoke(
147156
input={
148-
"refined_question": state["refined_input"],
157+
"refined_question": refined_question,
149158
"profiles": question_profile_json,
150159
"related_tables": searched_tables_json,
151160
}
@@ -207,3 +216,33 @@ def query_maker_node_with_db_guide(state: QueryMakerState):
207216
state["generated_query"] = res.sql
208217
state["messages"].append(res.explanation)
209218
return state
219+
220+
221+
# 노드 함수: QUERY_MAKER 노드 (refined_input 없이)
222+
def query_maker_node_without_refiner(state: QueryMakerState):
223+
"""
224+
refined_input 없이 초기 사용자 입력만을 사용하여 SQL을 생성하는 노드입니다.
225+
226+
이 노드는 QUERY_REFINER 단계를 건너뛰고, 초기 사용자 입력, 프로파일 정보,
227+
컨텍스트 보강 정보를 모두 활용하여 SQL을 생성합니다.
228+
"""
229+
# 컨텍스트 보강된 질문 (refined_input이 없는 경우 초기 입력 사용)
230+
enriched_question = state.get("refined_input", state["messages"][0])
231+
232+
# enriched_question이 AIMessage인 경우 content 추출, 문자열인 경우 그대로 사용
233+
if hasattr(enriched_question, "content"):
234+
enriched_question_content = enriched_question.content
235+
else:
236+
enriched_question_content = str(enriched_question)
237+
238+
res = query_maker_chain.invoke(
239+
input={
240+
"user_input": [state["messages"][0].content],
241+
"refined_input": [enriched_question_content],
242+
"searched_tables": [json.dumps(state["searched_tables"])],
243+
"user_database_env": [state["user_database_env"]],
244+
}
245+
)
246+
state["generated_query"] = res
247+
state["messages"].append(res)
248+
return state
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import json
2+
3+
from langgraph.graph import StateGraph, END
4+
from llm_utils.graph_utils.base import (
5+
QueryMakerState,
6+
GET_TABLE_INFO,
7+
PROFILE_EXTRACTION,
8+
CONTEXT_ENRICHMENT,
9+
QUERY_MAKER,
10+
get_table_info_node,
11+
profile_extraction_node,
12+
context_enrichment_node,
13+
query_maker_node_without_refiner,
14+
)
15+
16+
"""
17+
QUERY_REFINER 단계를 제거한 단순화된 워크플로우입니다.
18+
GET_TABLE_INFO → PROFILE_EXTRACTION → CONTEXT_ENRICHMENT → QUERY_MAKER 순서로 실행됩니다.
19+
초기 사용자 입력만을 사용하여 더 정확한 쿼리를 생성합니다.
20+
"""
21+
22+
# StateGraph 생성 및 구성
23+
builder = StateGraph(QueryMakerState)
24+
builder.set_entry_point(GET_TABLE_INFO)
25+
26+
# 노드 추가
27+
builder.add_node(GET_TABLE_INFO, get_table_info_node)
28+
builder.add_node(PROFILE_EXTRACTION, profile_extraction_node)
29+
builder.add_node(CONTEXT_ENRICHMENT, context_enrichment_node)
30+
builder.add_node(QUERY_MAKER, query_maker_node_without_refiner)
31+
32+
# 기본 엣지 설정
33+
builder.add_edge(GET_TABLE_INFO, PROFILE_EXTRACTION)
34+
builder.add_edge(PROFILE_EXTRACTION, CONTEXT_ENRICHMENT)
35+
builder.add_edge(CONTEXT_ENRICHMENT, QUERY_MAKER)
36+
37+
# QUERY_MAKER 노드 후 종료
38+
builder.add_edge(QUERY_MAKER, END)

llm_utils/query_executor.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from llm_utils.graph_utils.enriched_graph import builder as enriched_builder
1414
from llm_utils.graph_utils.basic_graph import builder as basic_builder
15+
from llm_utils.graph_utils.simplified_graph import builder as simplified_builder
1516
from llm_utils.llm_response_parser import LLMResponseParser
1617

1718
logger = logging.getLogger(__name__)
@@ -25,6 +26,7 @@ def execute_query(
2526
top_n: int = 5,
2627
device: str = "cpu",
2728
use_enriched_graph: bool = False,
29+
use_simplified_graph: bool = False,
2830
session_state: Optional[Union[Dict[str, Any], Any]] = None,
2931
) -> Dict[str, Any]:
3032
"""
@@ -52,19 +54,29 @@ def execute_query(
5254
"""
5355

5456
logger.info("Processing query: %s", query)
55-
logger.info("Using %s graph", "enriched" if use_enriched_graph else "basic")
57+
58+
# 그래프 선택
59+
if use_simplified_graph:
60+
graph_type = "simplified"
61+
graph_builder = simplified_builder
62+
elif use_enriched_graph:
63+
graph_type = "enriched"
64+
graph_builder = enriched_builder
65+
else:
66+
graph_type = "basic"
67+
graph_builder = basic_builder
68+
69+
logger.info("Using %s graph", graph_type)
5670

5771
# 그래프 선택 및 컴파일
5872
if session_state is not None:
5973
# Streamlit 환경: 세션 상태에서 그래프 재사용
6074
graph = session_state.get("graph")
6175
if graph is None:
62-
graph_builder = enriched_builder if use_enriched_graph else basic_builder
6376
graph = graph_builder.compile()
6477
session_state["graph"] = graph
6578
else:
6679
# CLI 환경: 매번 새로운 그래프 컴파일
67-
graph_builder = enriched_builder if use_enriched_graph else basic_builder
6880
graph = graph_builder.compile()
6981

7082
# 그래프 실행

0 commit comments

Comments
 (0)