Skip to content

Commit 03c0292

Browse files
committed
style: apply Black formatter to Python files
1 parent 9fd52e3 commit 03c0292

File tree

5 files changed

+24
-27
lines changed

5 files changed

+24
-27
lines changed

interface/streamlit_app.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,3 @@
88
)
99

1010
pg.run()
11-

llm_utils/chains.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import os
2-
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate
2+
from langchain_core.prompts import (
3+
ChatPromptTemplate,
4+
MessagesPlaceholder,
5+
SystemMessagePromptTemplate,
6+
)
37

48
from .llm_factory import get_llm
59

@@ -15,8 +19,9 @@
1519

1620
llm = get_llm()
1721

22+
1823
def create_query_refiner_chain(llm):
19-
prompt = get_prompt_template('query_refiner_prompt')
24+
prompt = get_prompt_template("query_refiner_prompt")
2025
tool_choice_prompt = ChatPromptTemplate.from_messages(
2126
[
2227
SystemMessagePromptTemplate.from_template(prompt),
@@ -38,7 +43,7 @@ def create_query_refiner_chain(llm):
3843
# QueryMakerChain
3944
def create_query_maker_chain(llm):
4045
# SystemPrompt만 yaml 파일로 불러와서 사용
41-
prompt = get_prompt_template('query_maker_prompt')
46+
prompt = get_prompt_template("query_maker_prompt")
4247
query_maker_prompt = ChatPromptTemplate.from_messages(
4348
[
4449
SystemMessagePromptTemplate.from_template(prompt),
@@ -67,4 +72,4 @@ def create_query_maker_chain(llm):
6772
query_maker_chain = create_query_maker_chain(llm)
6873

6974
if __name__ == "__main__":
70-
query_refiner_chain.invoke()
75+
query_refiner_chain.invoke()

llm_utils/graph.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class QueryMakerState(TypedDict):
3535

3636
# 노드 함수: QUERY_REFINER 노드
3737
def query_refiner_node(state: QueryMakerState):
38-
print('query_refiner_node 진입 [md]')
3938
res = query_refiner_chain.invoke(
4039
input={
4140
"user_input": [state["messages"][0].content],
@@ -61,7 +60,6 @@ def get_table_info_node(state: QueryMakerState):
6160
)
6261
except:
6362
documents = get_info_from_db()
64-
print("db_embedding 진입")
6563
db = FAISS.from_documents(documents, embeddings)
6664
db.save_local(os.getcwd() + "/table_info_db")
6765
doc_res = db.similarity_search(state["messages"][-1].content)

llm_utils/prompts_class.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,26 @@
44
from langchain_core.prompts import load_prompt
55

66

7-
class SQLPrompt():
7+
class SQLPrompt:
88
def __init__(self):
9-
# os library를 확인해서 SQL_PROMPTS key에 해당하는ㅁ prompt가 있으면, 이를 교체
9+
# os library를 확인해서 SQL_PROMPTS key에 해당하는 prompt가 있으면, 이를 교체
1010
self.sql_prompts = SQL_PROMPTS
1111
self.target_db_list = list(SQL_PROMPTS.keys())
12-
self.prompt_path = '../prompt'
12+
self.prompt_path = "../prompt"
1313

1414
def update_prompt_from_path(self):
1515
if os.path.exists(self.prompt_path):
1616
path_list = os.listdir(self.prompt_path)
1717
# yaml 파일만 가져옴
18-
file_list = [file for file in path_list if file.endswith('.yaml')]
19-
key_path_dict = {key.split('.')[0]: os.path.join(self.prompt_path, key) for key in file_list if key.split('.')[0] in self.target_db_list}
18+
file_list = [file for file in path_list if file.endswith(".yaml")]
19+
key_path_dict = {
20+
key.split(".")[0]: os.path.join(self.prompt_path, key)
21+
for key in file_list
22+
if key.split(".")[0] in self.target_db_list
23+
}
2024
# file_list에서 sql_prompts의 key에 해당하는 파일이 있는 것만 가져옴
2125
for key, path in key_path_dict.items():
22-
self.sql_prompts[key] = load_prompt(path, encoding='utf-8')
26+
self.sql_prompts[key] = load_prompt(path, encoding="utf-8")
2327
else:
2428
raise FileNotFoundError(f"Prompt file not found in {self.prompt_path}")
2529
return False
26-
27-
if __name__ == '__main__':
28-
sql_prompts_class = SQLPrompt()
29-
print(sql_prompts_class.sql_prompts['mysql'])
30-
print(sql_prompts_class.update_prompt_from_path())
31-
32-
print(sql_prompts_class.sql_prompts['mysql'])
33-
print(sql_prompts_class.sql_prompts)

prompt/template_loader.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import os
22

33

4-
54
def get_prompt_template(prompt_name: str) -> str:
65
try:
7-
with open(os.path.join(os.path.dirname(__file__), f"{prompt_name}.md"), "r", encoding="utf-8") as f:
6+
with open(
7+
os.path.join(os.path.dirname(__file__), f"{prompt_name}.md"),
8+
"r",
9+
encoding="utf-8",
10+
) as f:
811
template = f.read()
912
except FileNotFoundError:
1013
raise FileNotFoundError(f"경고: '{prompt_name}.md' 파일을 찾을 수 없습니다.")
1114
return template
12-
13-
if __name__ == "__main__":
14-
print(get_prompt_template("system_prompt"))
15-
# print(apply_prompt_template("prompt_md_sample", {"messages": []}))

0 commit comments

Comments
 (0)