Skip to content

Commit d9e2042

Browse files
fanzhidongyzbyAries-cktAppointat
authored
fix: fix unit test error (#2085)
Co-authored-by: aries_ckt <[email protected]> Co-authored-by: Appointat <[email protected]>
1 parent 6d66678 commit d9e2042

File tree

11 files changed

+129
-113
lines changed

11 files changed

+129
-113
lines changed

dbgpt/datasource/conn_tugraph.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""TuGraph Connector."""
22

33
import json
4-
from typing import Dict, Generator, List, Tuple, cast
4+
from typing import Dict, Generator, Iterator, List, cast
55

66
from .base import BaseConnector
77

@@ -20,7 +20,7 @@ def __init__(self, driver, graph):
2020
self._graph = graph
2121
self._session = None
2222

23-
def create_graph(self, graph_name: str) -> None:
23+
def create_graph(self, graph_name: str) -> bool:
2424
"""Create a new graph in the database if it doesn't already exist."""
2525
try:
2626
with self._driver.session(database="default") as session:
@@ -33,6 +33,8 @@ def create_graph(self, graph_name: str) -> None:
3333
except Exception as e:
3434
raise Exception(f"Failed to create graph '{graph_name}': {str(e)}") from e
3535

36+
return not exists
37+
3638
def delete_graph(self, graph_name: str) -> None:
3739
"""Delete a graph in the database if it exists."""
3840
with self._driver.session(database="default") as session:
@@ -60,20 +62,18 @@ def from_uri_db(
6062
"`pip install neo4j`"
6163
) from err
6264

63-
def get_table_names(self) -> Tuple[List[str], List[str]]:
65+
def get_table_names(self) -> Iterator[str]:
6466
"""Get all table names from the TuGraph by Neo4j driver."""
6567
with self._driver.session(database=self._graph) as session:
6668
# Run the query to get vertex labels
67-
raw_vertex_labels: Dict[str, str] = session.run(
68-
"CALL db.vertexLabels()"
69-
).data()
69+
raw_vertex_labels = session.run("CALL db.vertexLabels()").data()
7070
vertex_labels = [table_name["label"] for table_name in raw_vertex_labels]
7171

7272
# Run the query to get edge labels
73-
raw_edge_labels: Dict[str, str] = session.run("CALL db.edgeLabels()").data()
73+
raw_edge_labels = session.run("CALL db.edgeLabels()").data()
7474
edge_labels = [table_name["label"] for table_name in raw_edge_labels]
7575

76-
return vertex_labels, edge_labels
76+
return iter(vertex_labels + edge_labels)
7777

7878
def get_grants(self):
7979
"""Get grants."""

dbgpt/rag/summary/gdbms_db_summary.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def _parse_db_summary(
7676
table_info_summaries = None
7777
if isinstance(conn, TuGraphConnector):
7878
table_names = conn.get_table_names()
79-
v_tables = table_names.get("vertex_tables", [])
80-
e_tables = table_names.get("edge_tables", [])
79+
v_tables = table_names.get("vertex_tables", []) # type: ignore
80+
e_tables = table_names.get("edge_tables", []) # type: ignore
8181
table_info_summaries = [
8282
_parse_table_summary(conn, summary_template, table_name, "vertex")
8383
for table_name in v_tables

dbgpt/storage/graph_store/tugraph_store.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,16 @@ def _upload_plugin(self):
141141
if len(missing_plugins):
142142
for name in missing_plugins:
143143
try:
144-
from dbgpt_tugraph_plugins import (
145-
get_plugin_binary_path, # type:ignore[import-untyped]
144+
from dbgpt_tugraph_plugins import ( # type: ignore
145+
get_plugin_binary_path,
146146
)
147147
except ImportError:
148148
logger.error(
149149
"dbgpt-tugraph-plugins is not installed, "
150150
"pip install dbgpt-tugraph-plugins==0.1.0rc1 -U -i "
151151
"https://pypi.org/simple"
152152
)
153-
plugin_path = get_plugin_binary_path("leiden")
153+
plugin_path = get_plugin_binary_path("leiden") # type: ignore
154154
with open(plugin_path, "rb") as f:
155155
content = f.read()
156156
content = base64.b64encode(content).decode()

dbgpt/storage/knowledge_graph/community/base.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
from abc import ABC, abstractmethod
55
from dataclasses import dataclass
6-
from typing import AsyncGenerator, Iterator, List, Optional, Union
6+
from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Union
77

88
from dbgpt.storage.graph_store.base import GraphStoreBase
99
from dbgpt.storage.graph_store.graph import (
@@ -156,7 +156,11 @@ def create_graph(self, graph_name: str) -> None:
156156
"""Create graph."""
157157

158158
@abstractmethod
159-
def create_graph_label(self) -> None:
159+
def create_graph_label(
160+
self,
161+
graph_elem_type: GraphElemType,
162+
graph_properties: List[Dict[str, Union[str, bool]]],
163+
) -> None:
160164
"""Create a graph label.
161165
162166
The graph label is used to identify and distinguish different types of nodes
@@ -176,7 +180,12 @@ def explore(
176180
self,
177181
subs: List[str],
178182
direct: Direction = Direction.BOTH,
179-
depth: Optional[int] = None,
183+
depth: int = 3,
184+
fan: Optional[int] = None,
185+
limit: Optional[int] = None,
186+
search_scope: Optional[
187+
Literal["knowledge_graph", "document_graph"]
188+
] = "knowledge_graph",
180189
) -> MemoryGraph:
181190
"""Explore the graph from given subjects up to a depth."""
182191

dbgpt/storage/knowledge_graph/community/memgraph_store_adapter.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import json
44
import logging
5-
from typing import AsyncGenerator, Iterator, List, Optional, Tuple, Union
5+
from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Tuple, Union
66

77
from dbgpt.storage.graph_store.graph import (
88
Direction,
@@ -173,6 +173,8 @@ def create_graph(self, graph_name: str):
173173

174174
def create_graph_label(
175175
self,
176+
graph_elem_type: GraphElemType,
177+
graph_properties: List[Dict[str, Union[str, bool]]],
176178
) -> None:
177179
"""Create a graph label.
178180
@@ -201,9 +203,12 @@ def explore(
201203
self,
202204
subs: List[str],
203205
direct: Direction = Direction.BOTH,
204-
depth: int | None = None,
205-
fan: int | None = None,
206-
limit: int | None = None,
206+
depth: int = 3,
207+
fan: Optional[int] = None,
208+
limit: Optional[int] = None,
209+
search_scope: Optional[
210+
Literal["knowledge_graph", "document_graph"]
211+
] = "knowledge_graph",
207212
) -> MemoryGraph:
208213
"""Explore the graph from given subjects up to a depth."""
209214
return self._graph_store._graph.search(subs, direct, depth, fan, limit)

dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py

+75-72
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ async def get_community(self, community_id: str) -> Community:
7979
@property
8080
def graph_store(self) -> TuGraphStore:
8181
"""Get the graph store."""
82-
return self._graph_store
82+
return self._graph_store # type: ignore[return-value]
8383

8484
def get_graph_config(self):
8585
"""Get the graph store config."""
@@ -176,29 +176,23 @@ def upsert_edge(
176176
[{self._convert_dict_to_str(edge_list)}])"""
177177
self.graph_store.conn.run(query=relation_query)
178178

179-
def upsert_chunks(
180-
self, chunks: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
181-
) -> None:
179+
def upsert_chunks(self, chunks: Iterator[Union[Vertex, ParagraphChunk]]) -> None:
182180
"""Upsert chunks."""
183-
chunks_list = list(chunks)
184-
if chunks_list and isinstance(chunks_list[0], ParagraphChunk):
185-
chunk_list = [
186-
{
187-
"id": self._escape_quotes(chunk.chunk_id),
188-
"name": self._escape_quotes(chunk.chunk_name),
189-
"content": self._escape_quotes(chunk.content),
190-
}
191-
for chunk in chunks_list
192-
]
193-
else:
194-
chunk_list = [
195-
{
196-
"id": self._escape_quotes(chunk.vid),
197-
"name": self._escape_quotes(chunk.name),
198-
"content": self._escape_quotes(chunk.get_prop("content")),
199-
}
200-
for chunk in chunks_list
201-
]
181+
chunk_list = [
182+
{
183+
"id": self._escape_quotes(chunk.chunk_id),
184+
"name": self._escape_quotes(chunk.chunk_name),
185+
"content": self._escape_quotes(chunk.content),
186+
}
187+
if isinstance(chunk, ParagraphChunk)
188+
else {
189+
"id": self._escape_quotes(chunk.vid),
190+
"name": self._escape_quotes(chunk.name),
191+
"content": self._escape_quotes(chunk.get_prop("content")),
192+
}
193+
for chunk in chunks
194+
]
195+
202196
chunk_query = (
203197
f"CALL db.upsertVertex("
204198
f'"{GraphElemType.CHUNK.value}", '
@@ -207,28 +201,24 @@ def upsert_chunks(
207201
self.graph_store.conn.run(query=chunk_query)
208202

209203
def upsert_documents(
210-
self, documents: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
204+
self, documents: Iterator[Union[Vertex, ParagraphChunk]]
211205
) -> None:
212206
"""Upsert documents."""
213-
documents_list = list(documents)
214-
if documents_list and isinstance(documents_list[0], ParagraphChunk):
215-
document_list = [
216-
{
217-
"id": self._escape_quotes(document.chunk_id),
218-
"name": self._escape_quotes(document.chunk_name),
219-
"content": "",
220-
}
221-
for document in documents_list
222-
]
223-
else:
224-
document_list = [
225-
{
226-
"id": self._escape_quotes(document.vid),
227-
"name": self._escape_quotes(document.name),
228-
"content": self._escape_quotes(document.get_prop("content")) or "",
229-
}
230-
for document in documents_list
231-
]
207+
document_list = [
208+
{
209+
"id": self._escape_quotes(document.chunk_id),
210+
"name": self._escape_quotes(document.chunk_name),
211+
"content": "",
212+
}
213+
if isinstance(document, ParagraphChunk)
214+
else {
215+
"id": self._escape_quotes(document.vid),
216+
"name": self._escape_quotes(document.name),
217+
"content": "",
218+
}
219+
for document in documents
220+
]
221+
232222
document_query = (
233223
"CALL db.upsertVertex("
234224
f'"{GraphElemType.DOCUMENT.value}", '
@@ -258,7 +248,7 @@ def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
258248
self.graph_store.conn.run(query=vertex_query)
259249
self.graph_store.conn.run(query=edge_query)
260250

261-
def upsert_graph(self, graph: MemoryGraph) -> None:
251+
def upsert_graph(self, graph: Graph) -> None:
262252
"""Add graph to the graph store.
263253
264254
Args:
@@ -362,7 +352,8 @@ def drop(self):
362352

363353
def create_graph(self, graph_name: str):
364354
"""Create a graph."""
365-
self.graph_store.conn.create_graph(graph_name=graph_name)
355+
if not self.graph_store.conn.create_graph(graph_name=graph_name):
356+
return
366357

367358
# Create the graph schema
368359
def _format_graph_propertity_schema(
@@ -474,12 +465,14 @@ def create_graph_label(
474465
(vertices) and edges in the graph.
475466
"""
476467
if graph_elem_type.is_vertex(): # vertex
477-
data = json.dumps({
478-
"label": graph_elem_type.value,
479-
"type": "VERTEX",
480-
"primary": "id",
481-
"properties": graph_properties,
482-
})
468+
data = json.dumps(
469+
{
470+
"label": graph_elem_type.value,
471+
"type": "VERTEX",
472+
"primary": "id",
473+
"properties": graph_properties,
474+
}
475+
)
483476
gql = f"""CALL db.createVertexLabelByJson('{data}')"""
484477
else: # edge
485478

@@ -505,12 +498,14 @@ def edge_direction(graph_elem_type: GraphElemType) -> List[List[str]]:
505498
else:
506499
raise ValueError("Invalid graph element type.")
507500

508-
data = json.dumps({
509-
"label": graph_elem_type.value,
510-
"type": "EDGE",
511-
"constraints": edge_direction(graph_elem_type),
512-
"properties": graph_properties,
513-
})
501+
data = json.dumps(
502+
{
503+
"label": graph_elem_type.value,
504+
"type": "EDGE",
505+
"constraints": edge_direction(graph_elem_type),
506+
"properties": graph_properties,
507+
}
508+
)
514509
gql = f"""CALL db.createEdgeLabelByJson('{data}')"""
515510

516511
self.graph_store.conn.run(gql)
@@ -530,18 +525,16 @@ def check_label(self, graph_elem_type: GraphElemType) -> bool:
530525
True if the label exists in the specified graph element type, otherwise
531526
False.
532527
"""
533-
vertex_tables, edge_tables = self.graph_store.conn.get_table_names()
528+
tables = self.graph_store.conn.get_table_names()
534529

535-
if graph_elem_type.is_vertex():
536-
return graph_elem_type in vertex_tables
537-
else:
538-
return graph_elem_type in edge_tables
530+
return graph_elem_type.value in tables
539531

540532
def explore(
541533
self,
542534
subs: List[str],
543535
direct: Direction = Direction.BOTH,
544536
depth: int = 3,
537+
fan: Optional[int] = None,
545538
limit: Optional[int] = None,
546539
search_scope: Optional[
547540
Literal["knowledge_graph", "document_graph"]
@@ -621,11 +614,17 @@ def query(self, query: str, **kwargs) -> MemoryGraph:
621614
mg.append_edge(edge)
622615
return mg
623616

624-
async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]:
617+
# type: ignore[override]
618+
# mypy: ignore-errors
619+
async def stream_query( # type: ignore[override]
620+
self,
621+
query: str,
622+
**kwargs,
623+
) -> AsyncGenerator[Graph, None]:
625624
"""Execute a stream query."""
626625
from neo4j import graph
627626

628-
async for record in self.graph_store.conn.run_stream(query):
627+
async for record in self.graph_store.conn.run_stream(query): # type: ignore
629628
mg = MemoryGraph()
630629
for key in record.keys():
631630
value = record[key]
@@ -650,15 +649,19 @@ async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None
650649
rels = list(record["p"].relationships)
651650
formatted_path = []
652651
for i in range(len(nodes)):
653-
formatted_path.append({
654-
"id": nodes[i]._properties["id"],
655-
"description": nodes[i]._properties["description"],
656-
})
652+
formatted_path.append(
653+
{
654+
"id": nodes[i]._properties["id"],
655+
"description": nodes[i]._properties["description"],
656+
}
657+
)
657658
if i < len(rels):
658-
formatted_path.append({
659-
"id": rels[i]._properties["id"],
660-
"description": rels[i]._properties["description"],
661-
})
659+
formatted_path.append(
660+
{
661+
"id": rels[i]._properties["id"],
662+
"description": rels[i]._properties["description"],
663+
}
664+
)
662665
for i in range(0, len(formatted_path), 2):
663666
mg.upsert_vertex(
664667
Vertex(

0 commit comments

Comments
 (0)