Skip to content

Commit 4af8138

Browse files
committed
feat: 데이터허브에서 테이블 정보를 가져와 패르소나를 생성하는 기능 추가
1 parent ae1589e commit 4af8138

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

evaluation/gen_persona.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from dotenv import load_dotenv
2+
import os
3+
4+
from datahub_cls.metadata_fetcher import (
5+
DatahubMetadataFetcher,
6+
get_all_tables_info,
7+
)
8+
9+
from utils import save_persona_json, pretty_print_persona
10+
from persona_class import PersonaList
11+
12+
13+
from langchain_openai.chat_models import ChatOpenAI
14+
from langchain_core.prompts import ChatPromptTemplate
15+
from argparse import ArgumentParser
16+
17+
load_dotenv()
18+
19+
20+
def drop_empty_tables(tables_df):
21+
drop_empty_tables = tables_df[
22+
tables_df["table_description"].apply(lambda x: x != "")
23+
]
24+
return drop_empty_tables[["table_name", "table_description"]]
25+
26+
27+
def get_table_des_string(tables_df):
28+
return_string = "table name : table description\n---\n"
29+
for _, row in tables_df.iterrows():
30+
return_string += f"{row['table_name']} : {row['table_description']}\n---\n"
31+
return return_string
32+
33+
34+
def generate_persona(tables_df):
35+
tables_df = drop_empty_tables(tables_df)
36+
description_string = get_table_des_string(tables_df)
37+
38+
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
39+
system_prompt = """주어진 Tabel description들을 참고하여 Text2SQL 서비스로 질문을 할만한 패르소나를 생성하세요"""
40+
41+
prompt = ChatPromptTemplate.from_messages(
42+
[
43+
("system", system_prompt),
44+
]
45+
)
46+
47+
chain = prompt | llm.with_structured_output(PersonaList)
48+
return chain.invoke({"input": description_string})
49+
50+
51+
def main(output_path):
52+
# 데이터허브 서버 연결
53+
fetcher = DatahubMetadataFetcher(gms_server=os.getenv("DATAHUB_SERVER"))
54+
tables_df = get_all_tables_info(fetcher)
55+
personas = generate_persona(tables_df)
56+
57+
for persona in personas.personas:
58+
print(pretty_print_persona(persona))
59+
save_persona_json(personas, output_path)
60+
61+
62+
if __name__ == "__main__":
63+
parser = ArgumentParser()
64+
parser.add_argument("--output_path", type=str, default="data/persona/personas.json")
65+
args = parser.parse_args()
66+
main(args.output_path)

0 commit comments

Comments
 (0)