|
4 | 4 | from typing_extensions import TypedDict, Annotated
|
5 | 5 | from langgraph.graph import END, StateGraph
|
6 | 6 | from langgraph.graph.message import add_messages
|
| 7 | +from langchain.chains.sql_database.prompt import SQL_PROMPTS |
| 8 | +from pydantic import BaseModel, Field |
| 9 | +from .llm_factory import get_llm |
7 | 10 |
|
8 | 11 | from llm_utils.chains import (
|
9 | 12 | query_refiner_chain,
|
@@ -102,14 +105,44 @@ def query_maker_node(state: QueryMakerState):
|
102 | 105 | return state
|
103 | 106 |
|
104 | 107 |
|
| 108 | +class SQLResult(BaseModel): |
| 109 | + sql: str = Field(description="SQL 쿼리 문자열") |
| 110 | + explanation: str = Field(description="SQL 쿼리 설명") |
| 111 | + |
| 112 | + |
| 113 | +def query_maker_node_with_db_guide(state: QueryMakerState): |
| 114 | + sql_prompt = SQL_PROMPTS[state["user_database_env"]] |
| 115 | + llm = get_llm( |
| 116 | + model_type="openai", |
| 117 | + model_name="gpt-4o-mini", |
| 118 | + openai_api_key=os.getenv("OPENAI_API_KEY"), |
| 119 | + ) |
| 120 | + chain = sql_prompt | llm.with_structured_output(SQLResult) |
| 121 | + res = chain.invoke( |
| 122 | + input={ |
| 123 | + "input": "\n\n---\n\n".join( |
| 124 | + [state["messages"][0].content] + [state["refined_input"].content] |
| 125 | + ), |
| 126 | + "table_info": [json.dumps(state["searched_tables"])], |
| 127 | + "top_k": 10, |
| 128 | + } |
| 129 | + ) |
| 130 | + state["generated_query"] = res.sql |
| 131 | + state["messages"].append(res.explanation) |
| 132 | + return state |
| 133 | + |
| 134 | + |
105 | 135 | # StateGraph 생성 및 구성
|
106 | 136 | builder = StateGraph(QueryMakerState)
|
107 | 137 | builder.set_entry_point(QUERY_REFINER)
|
108 | 138 |
|
109 | 139 | # 노드 추가
|
110 | 140 | builder.add_node(QUERY_REFINER, query_refiner_node)
|
111 | 141 | builder.add_node(GET_TABLE_INFO, get_table_info_node)
|
112 |
| -builder.add_node(QUERY_MAKER, query_maker_node) |
| 142 | +# builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide |
| 143 | +builder.add_node( |
| 144 | + QUERY_MAKER, query_maker_node_with_db_guide |
| 145 | +) # query_maker_node_with_db_guide |
113 | 146 |
|
114 | 147 | # 기본 엣지 설정
|
115 | 148 | builder.add_edge(QUERY_REFINER, GET_TABLE_INFO)
|
|
0 commit comments