Skip to content

Commit c4bae9a

Browse files
committed
update: 테이블 정보를 가져오는 방식 수정
1 parent 4af8138 commit c4bae9a

File tree

1 file changed

+8
-25
lines changed

1 file changed

+8
-25
lines changed

evaluation/gen_persona.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,23 @@
1-
from dotenv import load_dotenv
21
import os
32

4-
from datahub_cls.metadata_fetcher import (
5-
DatahubMetadataFetcher,
6-
get_all_tables_info,
7-
)
8-
93
from utils import save_persona_json, pretty_print_persona
104
from persona_class import PersonaList
115

12-
6+
from llm_utils.tools import _get_table_info
137
from langchain_openai.chat_models import ChatOpenAI
148
from langchain_core.prompts import ChatPromptTemplate
159
from argparse import ArgumentParser
1610

17-
load_dotenv()
18-
19-
20-
def drop_empty_tables(tables_df):
21-
drop_empty_tables = tables_df[
22-
tables_df["table_description"].apply(lambda x: x != "")
23-
]
24-
return drop_empty_tables[["table_name", "table_description"]]
25-
2611

27-
def get_table_des_string(tables_df):
12+
def get_table_des_string(tables_desc):
2813
return_string = "table name : table description\n---\n"
29-
for _, row in tables_df.iterrows():
30-
return_string += f"{row['table_name']} : {row['table_description']}\n---\n"
14+
for table_name, table_desc in tables_desc.items():
15+
return_string += f"{table_name} : {table_desc}\n---\n"
3116
return return_string
3217

3318

34-
def generate_persona(tables_df):
35-
tables_df = drop_empty_tables(tables_df)
36-
description_string = get_table_des_string(tables_df)
19+
def generate_persona(tables_desc):
20+
description_string = get_table_des_string(tables_desc)
3721

3822
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
3923
system_prompt = """주어진 Tabel description들을 참고하여 Text2SQL 서비스로 질문을 할만한 패르소나를 생성하세요"""
@@ -50,9 +34,8 @@ def generate_persona(tables_df):
5034

5135
def main(output_path):
5236
# 데이터허브 서버 연결
53-
fetcher = DatahubMetadataFetcher(gms_server=os.getenv("DATAHUB_SERVER"))
54-
tables_df = get_all_tables_info(fetcher)
55-
personas = generate_persona(tables_df)
37+
tables_desc = _get_table_info()
38+
personas = generate_persona(tables_desc)
5639

5740
for persona in personas.personas:
5841
print(pretty_print_persona(persona))

0 commit comments

Comments
 (0)