Skip to content

123 cli output format 지원 #125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,91 @@ def run_streamlit_cli_command(port: int) -> None:

logger.info("Executing 'run-streamlit' command on port %d...", port)
run_streamlit_command(port)


@cli.command(name="query")
@click.argument("question", type=str)
@click.option(
"--database-env",
default="clickhouse",
help="사용할 데이터베이스 환경 (기본값: clickhouse)",
)
@click.option(
"--retriever-name",
default="기본",
help="테이블 검색기 이름 (기본값: 기본)",
)
@click.option(
"--top-n",
type=int,
default=5,
help="검색된 상위 테이블 수 제한 (기본값: 5)",
)
@click.option(
"--device",
default="cpu",
help="LLM 실행에 사용할 디바이스 (기본값: cpu)",
)
@click.option(
"--use-enriched-graph",
is_flag=True,
help="확장된 그래프(프로파일 추출 + 컨텍스트 보강) 사용 여부",
)
def query_command(
question: str,
database_env: str,
retriever_name: str,
top_n: int,
device: str,
use_enriched_graph: bool,
) -> None:
"""
자연어 질문을 SQL 쿼리로 변환하여 출력하는 명령어입니다.

이 명령은 사용자가 입력한 자연어 질문을 받아서 SQL 쿼리로 변환하고,
생성된 SQL 쿼리만을 표준 출력으로 출력합니다.

매개변수:
question (str): SQL로 변환할 자연어 질문
database_env (str): 사용할 데이터베이스 환경
retriever_name (str): 테이블 검색기 이름
top_n (int): 검색된 상위 테이블 수 제한
device (str): LLM 실행에 사용할 디바이스
use_enriched_graph (bool): 확장된 그래프 사용 여부

예시:
lang2sql query "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
lang2sql query "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" --use-enriched-graph
"""

try:
from llm_utils.query_executor import execute_query, extract_sql_from_result

# 공용 함수를 사용하여 쿼리 실행
res = execute_query(
query=question,
database_env=database_env,
retriever_name=retriever_name,
top_n=top_n,
device=device,
use_enriched_graph=use_enriched_graph,
)

# SQL 추출 및 출력
sql = extract_sql_from_result(res)
if sql:
print(sql)
else:
# SQL 추출 실패 시 원본 쿼리 텍스트 출력
generated_query = res.get("generated_query")
if generated_query:
query_text = (
generated_query.content
if hasattr(generated_query, "content")
else str(generated_query)
)
print(query_text)

except Exception as e:
logger.error("쿼리 처리 중 오류 발생: %s", e)
raise
2 changes: 1 addition & 1 deletion evaluation/gen_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tqdm import tqdm
import uuid

from llm_utils.graph import builder
from llm_utils.graph_utils.basic_graph import builder


def get_eval_result(
Expand Down
148 changes: 81 additions & 67 deletions interface/lang2sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

import streamlit as st
from langchain.chains.sql_database.prompt import SQL_PROMPTS
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage

from llm_utils.connect_db import ConnectDB
from llm_utils.display_chart import DisplayChart
from llm_utils.enriched_graph import builder as enriched_builder
from llm_utils.graph import builder
from llm_utils.query_executor import execute_query as execute_query_common
from llm_utils.llm_response_parser import LLMResponseParser
from llm_utils.token_utils import TokenUtils
from llm_utils.graph_utils.enriched_graph import builder as enriched_builder
from llm_utils.graph_utils.basic_graph import builder

TITLE = "Lang2SQL"
DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
Expand All @@ -40,9 +41,8 @@ def execute_query(
"""
자연어 쿼리를 SQL로 변환하고 실행 결과를 반환하는 Lang2SQL 그래프 인터페이스 함수입니다.

이 함수는 Lang2SQL 파이프라인(graph)을 세션 상태에서 가져오거나 새로 컴파일한 뒤,
사용자의 자연어 질문을 SQL 쿼리로 변환하고 관련 메타데이터와 함께 결과를 반환합니다.
내부적으로 LangChain의 `graph.invoke` 메서드를 호출합니다.
이 함수는 공용 execute_query 함수를 호출하여 Lang2SQL 파이프라인을 실행합니다.
Streamlit 세션 상태를 활용하여 그래프를 재사용합니다.

Args:
query (str): 사용자가 입력한 자연어 기반 질문.
Expand All @@ -59,27 +59,16 @@ def execute_query(
- "searched_tables": 참조된 테이블 목록 등 추가 정보
"""

graph = st.session_state.get("graph")
if graph is None:
graph_builder = (
enriched_builder if st.session_state.get("use_enriched") else builder
)
graph = graph_builder.compile()
st.session_state["graph"] = graph

res = graph.invoke(
input={
"messages": [HumanMessage(content=query)],
"user_database_env": database_env,
"best_practice_query": "",
"retriever_name": retriever_name,
"top_n": top_n,
"device": device,
}
return execute_query_common(
query=query,
database_env=database_env,
retriever_name=retriever_name,
top_n=top_n,
device=device,
use_enriched_graph=st.session_state.get("use_enriched", False),
session_state=st.session_state,
)

return res


def display_result(
*,
Expand Down Expand Up @@ -120,40 +109,50 @@ def should_show(_key: str) -> bool:
if should_show("show_sql"):
st.markdown("---")
generated_query = res.get("generated_query")
query_text = (
generated_query.content
if isinstance(generated_query, AIMessage)
else str(generated_query)
)
if generated_query:
query_text = (
generated_query.content
if isinstance(generated_query, AIMessage)
else str(generated_query)
)

try:
sql = LLMResponseParser.extract_sql(query_text)
st.markdown("**생성된 SQL 쿼리:**")
st.code(sql, language="sql")
except ValueError:
st.warning("SQL 블록을 추출할 수 없습니다.")
st.text(query_text)

interpretation = LLMResponseParser.extract_interpretation(query_text)
if interpretation:
st.markdown("**결과 해석:**")
st.code(interpretation)
# query_text가 문자열인지 확인
if isinstance(query_text, str):
try:
sql = LLMResponseParser.extract_sql(query_text)
st.markdown("**생성된 SQL 쿼리:**")
st.code(sql, language="sql")
except ValueError:
st.warning("SQL 블록을 추출할 수 없습니다.")
st.text(query_text)

interpretation = LLMResponseParser.extract_interpretation(query_text)
if interpretation:
st.markdown("**결과 해석:**")
st.code(interpretation)
else:
st.warning("쿼리 텍스트가 문자열이 아닙니다.")
st.text(str(query_text))

if should_show("show_result_description"):
st.markdown("---")
st.markdown("**결과 설명:**")
result_message = res["messages"][-1].content

try:
sql = LLMResponseParser.extract_sql(result_message)
st.code(sql, language="sql")
except ValueError:
st.warning("SQL 블록을 추출할 수 없습니다.")
st.text(result_message)

interpretation = LLMResponseParser.extract_interpretation(result_message)
if interpretation:
st.code(interpretation, language="plaintext")
if isinstance(result_message, str):
try:
sql = LLMResponseParser.extract_sql(result_message)
st.code(sql, language="sql")
except ValueError:
st.warning("SQL 블록을 추출할 수 없습니다.")
st.text(result_message)

interpretation = LLMResponseParser.extract_interpretation(result_message)
if interpretation:
st.code(interpretation, language="plaintext")
else:
st.warning("결과 메시지가 문자열이 아닙니다.")
st.text(str(result_message))

if should_show("show_question_reinterpreted_by_ai"):
st.markdown("---")
Expand All @@ -173,26 +172,41 @@ def should_show(_key: str) -> bool:
if isinstance(res["generated_query"], AIMessage)
else str(res["generated_query"])
)
sql = LLMResponseParser.extract_sql(sql_raw)
df = database.run_sql(sql)
st.dataframe(df.head(10) if len(df) > 10 else df)
if isinstance(sql_raw, str):
sql = LLMResponseParser.extract_sql(sql_raw)
df = database.run_sql(sql)
st.dataframe(df.head(10) if len(df) > 10 else df)
else:
st.error("SQL 원본이 문자열이 아닙니다.")
except Exception as e:
st.error(f"쿼리 실행 중 오류 발생: {e}")

if should_show("show_chart"):
st.markdown("---")
df = database.run_sql(sql)
st.markdown("**쿼리 결과 시각화:**")
display_code = DisplayChart(
question=res["refined_input"].content,
sql=sql,
df_metadata=f"Running df.dtypes gives:\n{df.dtypes}",
)
# plotly_code 변수도 따로 보관할 필요 없이 바로 그려도 됩니다
fig = display_code.get_plotly_figure(
plotly_code=display_code.generate_plotly_code(), df=df
)
st.plotly_chart(fig)
try:
sql_raw = (
res["generated_query"].content
if isinstance(res["generated_query"], AIMessage)
else str(res["generated_query"])
)
if isinstance(sql_raw, str):
sql = LLMResponseParser.extract_sql(sql_raw)
df = database.run_sql(sql)
st.markdown("**쿼리 결과 시각화:**")
display_code = DisplayChart(
question=res["refined_input"].content,
sql=sql,
df_metadata=f"Running df.dtypes gives:\n{df.dtypes}",
)
# plotly_code 변수도 따로 보관할 필요 없이 바로 그려도 됩니다
fig = display_code.get_plotly_figure(
plotly_code=display_code.generate_plotly_code(), df=df
)
st.plotly_chart(fig)
else:
st.error("SQL 원본이 문자열이 아닙니다.")
except Exception as e:
st.error(f"차트 생성 중 오류 발생: {e}")


db = ConnectDB()
Expand Down
45 changes: 45 additions & 0 deletions llm_utils/graph_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
그래프 관련 유틸리티 모듈입니다.

이 패키지는 Lang2SQL의 워크플로우 그래프 구성과 관련된 모듈들을 포함합니다.
"""

from .base import (
QueryMakerState,
GET_TABLE_INFO,
QUERY_REFINER,
QUERY_MAKER,
PROFILE_EXTRACTION,
CONTEXT_ENRICHMENT,
get_table_info_node,
query_refiner_node,
query_maker_node,
profile_extraction_node,
query_refiner_with_profile_node,
context_enrichment_node,
query_maker_node_with_db_guide,
)

from .basic_graph import builder as basic_builder
from .enriched_graph import builder as enriched_builder

__all__ = [
# 상태 및 노드 식별자
"QueryMakerState",
"GET_TABLE_INFO",
"QUERY_REFINER",
"QUERY_MAKER",
"PROFILE_EXTRACTION",
"CONTEXT_ENRICHMENT",
# 노드 함수들
"get_table_info_node",
"query_refiner_node",
"query_maker_node",
"profile_extraction_node",
"query_refiner_with_profile_node",
"context_enrichment_node",
"query_maker_node_with_db_guide",
# 그래프 빌더들
"basic_builder",
"enriched_builder",
]
24 changes: 2 additions & 22 deletions llm_utils/graph.py → llm_utils/graph_utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from langgraph.graph.message import add_messages
from langchain.chains.sql_database.prompt import SQL_PROMPTS
from pydantic import BaseModel, Field
from .llm_factory import get_llm
from llm_utils.llm_factory import get_llm

from llm_utils.chains import (
query_refiner_chain,
Expand Down Expand Up @@ -119,7 +119,7 @@ def context_enrichment_node(state: QueryMakerState):
주요 작업:
- 주어진 질문의 메타데이터 (`question_profile` 및 `searched_tables`)를 활용하여, 질문을 수정하거나 추가 정보를 삽입합니다.
- 질문이 시계열 분석 또는 집계 함수 관련인 경우, 이를 명시적으로 강조합니다 (예: "지난 30일 동안").
- 자연어에서 실제 열 이름 또는 값으로 잘못 매칭된 용어를 수정합니다 (예: ‘미국’USA).
- 자연어에서 실제 열 이름 또는 값으로 잘못 매칭된 용어를 수정합니다 (예: '미국''USA').
- 보강된 질문을 출력합니다.

Args:
Expand Down Expand Up @@ -207,23 +207,3 @@ def query_maker_node_with_db_guide(state: QueryMakerState):
state["generated_query"] = res.sql
state["messages"].append(res.explanation)
return state


# StateGraph 생성 및 구성
builder = StateGraph(QueryMakerState)
builder.set_entry_point(GET_TABLE_INFO)

# 노드 추가
builder.add_node(GET_TABLE_INFO, get_table_info_node)
builder.add_node(QUERY_REFINER, query_refiner_node)
builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide
# builder.add_node(
# QUERY_MAKER, query_maker_node_with_db_guide
# ) # query_maker_node_with_db_guide

# 기본 엣지 설정
builder.add_edge(GET_TABLE_INFO, QUERY_REFINER)
builder.add_edge(QUERY_REFINER, QUERY_MAKER)

# QUERY_MAKER 노드 후 종료
builder.add_edge(QUERY_MAKER, END)
Loading