Skip to content

Commit 8bb9907

Browse files
authored
Merge pull request #50
feat: 테이블 및 컬럼 정보를 병렬로 수집하도록 개선하여 성능 향상
2 parents 2a1ba84 + 6026a1d commit 8bb9907

File tree

1 file changed

+106
-23
lines changed

1 file changed

+106
-23
lines changed

llm_utils/tools.py

Lines changed: 106 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,47 @@
11
import os
2-
from typing import List, Dict
2+
from typing import List, Dict, Optional, TypeVar, Callable, Iterable, Any
33

44
from langchain.schema import Document
55

66
from data_utils.datahub_source import DatahubMetadataFetcher
7+
from tqdm import tqdm
8+
from concurrent.futures import ThreadPoolExecutor
9+
10+
T = TypeVar("T")
11+
R = TypeVar("R")
12+
13+
14+
def parallel_process[T, R](
15+
items: Iterable[T],
16+
process_fn: Callable[[T], R],
17+
max_workers: int = 8,
18+
desc: Optional[str] = None,
19+
show_progress: bool = True,
20+
) -> List[R]:
21+
"""병렬 처리를 위한 유틸리티 함수
22+
23+
Args:
24+
items (Iterable[T]): 처리할 아이템들
25+
process_fn (Callable[[T], R]): 각 아이템을 처리할 함수
26+
max_workers (int, optional): 최대 쓰레드 수. Defaults to 8.
27+
desc (Optional[str], optional): 진행 상태 메시지. Defaults to None.
28+
show_progress (bool, optional): 진행 상태 표시 여부. Defaults to True.
29+
30+
Returns:
31+
List[R]: 처리 결과 리스트
32+
"""
33+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
34+
futures = [executor.submit(process_fn, item) for item in items]
35+
if show_progress:
36+
futures = tqdm(futures, desc=desc)
37+
return [future.result() for future in futures]
738

839

940
def set_gms_server(gms_server: str):
1041
try:
1142
os.environ["DATAHUB_SERVER"] = gms_server
1243
fetcher = DatahubMetadataFetcher(gms_server=gms_server)
1344
except ValueError as e:
14-
# 유효하지 않은 GMS 서버 주소일 경우 예외를 발생시킴
1545
raise ValueError(f"GMS 서버 설정 실패: {str(e)}")
1646

1747

@@ -22,49 +52,102 @@ def _get_fetcher():
2252
return DatahubMetadataFetcher(gms_server=gms_server)
2353

2454

25-
def _get_table_info() -> Dict[str, str]:
26-
"""전체 테이블 이름과 설명을 가져오는 함수"""
55+
def _process_urn(urn: str, fetcher: DatahubMetadataFetcher) -> tuple[str, str]:
56+
table_name = fetcher.get_table_name(urn)
57+
table_description = fetcher.get_table_description(urn)
58+
return (table_name, table_description)
59+
60+
61+
def _process_column_info(
62+
urn: str, table_name: str, fetcher: DatahubMetadataFetcher
63+
) -> Optional[List[Dict[str, str]]]:
64+
if fetcher.get_table_name(urn) == table_name:
65+
return fetcher.get_column_names_and_descriptions(urn)
66+
return None
67+
68+
69+
def _get_table_info(max_workers: int = 8) -> Dict[str, str]:
70+
"""전체 테이블 이름과 설명을 가져오는 함수
71+
72+
Args:
73+
max_workers (int, optional): 병렬 처리에 사용할 최대 쓰레드 수. Defaults to 8.
74+
75+
Returns:
76+
Dict[str, str]: 테이블 이름과 설명을 담은 딕셔너리
77+
"""
2778
fetcher = _get_fetcher()
2879
urns = fetcher.get_urns()
2980
table_info = {}
30-
for urn in urns:
31-
table_name = fetcher.get_table_name(urn)
32-
table_description = fetcher.get_table_description(urn)
81+
82+
results = parallel_process(
83+
urns,
84+
lambda urn: _process_urn(urn, fetcher),
85+
max_workers=max_workers,
86+
desc="테이블 정보 수집 중",
87+
)
88+
89+
for table_name, table_description in results:
3390
if table_name and table_description:
3491
table_info[table_name] = table_description
92+
3593
return table_info
3694

3795

38-
def _get_column_info(table_name: str) -> List[Dict[str, str]]:
39-
"""table_name에 해당하는 컬럼 이름과 설명을 가져오는 함수"""
96+
def _get_column_info(table_name: str, max_workers: int = 8) -> List[Dict[str, str]]:
97+
"""table_name에 해당하는 컬럼 이름과 설명을 가져오는 함수
98+
99+
Args:
100+
table_name (str): 테이블 이름
101+
max_workers (int, optional): 병렬 처리에 사용할 최대 쓰레드 수. Defaults to 8.
102+
103+
Returns:
104+
List[Dict[str, str]]: 컬럼 정보 리스트
105+
"""
40106
fetcher = _get_fetcher()
41107
urns = fetcher.get_urns()
42-
for urn in urns:
43-
if fetcher.get_table_name(urn) == table_name:
44-
return fetcher.get_column_names_and_descriptions(urn)
108+
109+
results = parallel_process(
110+
urns,
111+
lambda urn: _process_column_info(urn, table_name, fetcher),
112+
max_workers=max_workers,
113+
show_progress=False,
114+
)
115+
116+
for result in results:
117+
if result:
118+
return result
45119
return []
46120

47121

48-
def get_info_from_db() -> List[Document]:
49-
"""
50-
전체 테이블 이름과 설명, 컬럼 이름과 설명을 가져오는 함수
122+
def get_info_from_db(max_workers: int = 8) -> List[Document]:
123+
"""전체 테이블 이름과 설명, 컬럼 이름과 설명을 가져오는 함수
124+
125+
Args:
126+
max_workers (int, optional): 병렬 처리에 사용할 최대 쓰레드 수. Defaults to 8.
127+
128+
Returns:
129+
List[Document]: 테이블과 컬럼 정보를 담은 Document 객체 리스트
51130
"""
131+
table_info = _get_table_info(max_workers=max_workers)
52132

53-
table_info_str_list = []
54-
table_info = _get_table_info()
55-
for table_name, table_description in table_info.items():
56-
column_info = _get_column_info(table_name)
133+
def process_table_info(item: tuple[str, str]) -> str:
134+
table_name, table_description = item
135+
column_info = _get_column_info(table_name, max_workers=max_workers)
57136
column_info_str = "\n".join(
58137
[
59138
f"{col['column_name']}: {col['column_description']}"
60139
for col in column_info
61140
]
62141
)
63-
table_info_str_list.append(
64-
f"{table_name}: {table_description}\nColumns:\n {column_info_str}"
65-
)
142+
return f"{table_name}: {table_description}\nColumns:\n {column_info_str}"
143+
144+
table_info_str_list = parallel_process(
145+
table_info.items(),
146+
process_table_info,
147+
max_workers=max_workers,
148+
desc="컬럼 정보 수집 중",
149+
)
66150

67-
# table_info_str_list를 Document 객체 리스트로 변환
68151
return [Document(page_content=info) for info in table_info_str_list]
69152

70153

0 commit comments

Comments
 (0)