@@ -79,7 +79,7 @@ async def get_community(self, community_id: str) -> Community:
79
79
@property
80
80
def graph_store (self ) -> TuGraphStore :
81
81
"""Get the graph store."""
82
- return self ._graph_store
82
+ return self ._graph_store # type: ignore[return-value]
83
83
84
84
def get_graph_config (self ):
85
85
"""Get the graph store config."""
@@ -176,29 +176,23 @@ def upsert_edge(
176
176
[{ self ._convert_dict_to_str (edge_list )} ])"""
177
177
self .graph_store .conn .run (query = relation_query )
178
178
179
- def upsert_chunks (
180
- self , chunks : Union [Iterator [Vertex ], Iterator [ParagraphChunk ]]
181
- ) -> None :
179
+ def upsert_chunks (self , chunks : Iterator [Union [Vertex , ParagraphChunk ]]) -> None :
182
180
"""Upsert chunks."""
183
- chunks_list = list (chunks )
184
- if chunks_list and isinstance (chunks_list [0 ], ParagraphChunk ):
185
- chunk_list = [
186
- {
187
- "id" : self ._escape_quotes (chunk .chunk_id ),
188
- "name" : self ._escape_quotes (chunk .chunk_name ),
189
- "content" : self ._escape_quotes (chunk .content ),
190
- }
191
- for chunk in chunks_list
192
- ]
193
- else :
194
- chunk_list = [
195
- {
196
- "id" : self ._escape_quotes (chunk .vid ),
197
- "name" : self ._escape_quotes (chunk .name ),
198
- "content" : self ._escape_quotes (chunk .get_prop ("content" )),
199
- }
200
- for chunk in chunks_list
201
- ]
181
+ chunk_list = [
182
+ {
183
+ "id" : self ._escape_quotes (chunk .chunk_id ),
184
+ "name" : self ._escape_quotes (chunk .chunk_name ),
185
+ "content" : self ._escape_quotes (chunk .content ),
186
+ }
187
+ if isinstance (chunk , ParagraphChunk )
188
+ else {
189
+ "id" : self ._escape_quotes (chunk .vid ),
190
+ "name" : self ._escape_quotes (chunk .name ),
191
+ "content" : self ._escape_quotes (chunk .get_prop ("content" )),
192
+ }
193
+ for chunk in chunks
194
+ ]
195
+
202
196
chunk_query = (
203
197
f"CALL db.upsertVertex("
204
198
f'"{ GraphElemType .CHUNK .value } ", '
@@ -207,28 +201,24 @@ def upsert_chunks(
207
201
self .graph_store .conn .run (query = chunk_query )
208
202
209
203
def upsert_documents (
210
- self , documents : Union [ Iterator [Vertex ], Iterator [ ParagraphChunk ]]
204
+ self , documents : Iterator [Union [ Vertex , ParagraphChunk ]]
211
205
) -> None :
212
206
"""Upsert documents."""
213
- documents_list = list (documents )
214
- if documents_list and isinstance (documents_list [0 ], ParagraphChunk ):
215
- document_list = [
216
- {
217
- "id" : self ._escape_quotes (document .chunk_id ),
218
- "name" : self ._escape_quotes (document .chunk_name ),
219
- "content" : "" ,
220
- }
221
- for document in documents_list
222
- ]
223
- else :
224
- document_list = [
225
- {
226
- "id" : self ._escape_quotes (document .vid ),
227
- "name" : self ._escape_quotes (document .name ),
228
- "content" : self ._escape_quotes (document .get_prop ("content" )) or "" ,
229
- }
230
- for document in documents_list
231
- ]
207
+ document_list = [
208
+ {
209
+ "id" : self ._escape_quotes (document .chunk_id ),
210
+ "name" : self ._escape_quotes (document .chunk_name ),
211
+ "content" : "" ,
212
+ }
213
+ if isinstance (document , ParagraphChunk )
214
+ else {
215
+ "id" : self ._escape_quotes (document .vid ),
216
+ "name" : self ._escape_quotes (document .name ),
217
+ "content" : "" ,
218
+ }
219
+ for document in documents
220
+ ]
221
+
232
222
document_query = (
233
223
"CALL db.upsertVertex("
234
224
f'"{ GraphElemType .DOCUMENT .value } ", '
@@ -258,7 +248,7 @@ def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
258
248
self .graph_store .conn .run (query = vertex_query )
259
249
self .graph_store .conn .run (query = edge_query )
260
250
261
- def upsert_graph (self , graph : MemoryGraph ) -> None :
251
+ def upsert_graph (self , graph : Graph ) -> None :
262
252
"""Add graph to the graph store.
263
253
264
254
Args:
@@ -362,7 +352,8 @@ def drop(self):
362
352
363
353
def create_graph (self , graph_name : str ):
364
354
"""Create a graph."""
365
- self .graph_store .conn .create_graph (graph_name = graph_name )
355
+ if not self .graph_store .conn .create_graph (graph_name = graph_name ):
356
+ return
366
357
367
358
# Create the graph schema
368
359
def _format_graph_propertity_schema (
@@ -474,12 +465,14 @@ def create_graph_label(
474
465
(vertices) and edges in the graph.
475
466
"""
476
467
if graph_elem_type .is_vertex (): # vertex
477
- data = json .dumps ({
478
- "label" : graph_elem_type .value ,
479
- "type" : "VERTEX" ,
480
- "primary" : "id" ,
481
- "properties" : graph_properties ,
482
- })
468
+ data = json .dumps (
469
+ {
470
+ "label" : graph_elem_type .value ,
471
+ "type" : "VERTEX" ,
472
+ "primary" : "id" ,
473
+ "properties" : graph_properties ,
474
+ }
475
+ )
483
476
gql = f"""CALL db.createVertexLabelByJson('{ data } ')"""
484
477
else : # edge
485
478
@@ -505,12 +498,14 @@ def edge_direction(graph_elem_type: GraphElemType) -> List[List[str]]:
505
498
else :
506
499
raise ValueError ("Invalid graph element type." )
507
500
508
- data = json .dumps ({
509
- "label" : graph_elem_type .value ,
510
- "type" : "EDGE" ,
511
- "constraints" : edge_direction (graph_elem_type ),
512
- "properties" : graph_properties ,
513
- })
501
+ data = json .dumps (
502
+ {
503
+ "label" : graph_elem_type .value ,
504
+ "type" : "EDGE" ,
505
+ "constraints" : edge_direction (graph_elem_type ),
506
+ "properties" : graph_properties ,
507
+ }
508
+ )
514
509
gql = f"""CALL db.createEdgeLabelByJson('{ data } ')"""
515
510
516
511
self .graph_store .conn .run (gql )
@@ -530,18 +525,16 @@ def check_label(self, graph_elem_type: GraphElemType) -> bool:
530
525
True if the label exists in the specified graph element type, otherwise
531
526
False.
532
527
"""
533
- vertex_tables , edge_tables = self .graph_store .conn .get_table_names ()
528
+ tables = self .graph_store .conn .get_table_names ()
534
529
535
- if graph_elem_type .is_vertex ():
536
- return graph_elem_type in vertex_tables
537
- else :
538
- return graph_elem_type in edge_tables
530
+ return graph_elem_type .value in tables
539
531
540
532
def explore (
541
533
self ,
542
534
subs : List [str ],
543
535
direct : Direction = Direction .BOTH ,
544
536
depth : int = 3 ,
537
+ fan : Optional [int ] = None ,
545
538
limit : Optional [int ] = None ,
546
539
search_scope : Optional [
547
540
Literal ["knowledge_graph" , "document_graph" ]
@@ -621,11 +614,17 @@ def query(self, query: str, **kwargs) -> MemoryGraph:
621
614
mg .append_edge (edge )
622
615
return mg
623
616
624
- async def stream_query (self , query : str , ** kwargs ) -> AsyncGenerator [Graph , None ]:
617
+ # type: ignore[override]
618
+ # mypy: ignore-errors
619
+ async def stream_query ( # type: ignore[override]
620
+ self ,
621
+ query : str ,
622
+ ** kwargs ,
623
+ ) -> AsyncGenerator [Graph , None ]:
625
624
"""Execute a stream query."""
626
625
from neo4j import graph
627
626
628
- async for record in self .graph_store .conn .run_stream (query ):
627
+ async for record in self .graph_store .conn .run_stream (query ): # type: ignore
629
628
mg = MemoryGraph ()
630
629
for key in record .keys ():
631
630
value = record [key ]
@@ -650,15 +649,19 @@ async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None
650
649
rels = list (record ["p" ].relationships )
651
650
formatted_path = []
652
651
for i in range (len (nodes )):
653
- formatted_path .append ({
654
- "id" : nodes [i ]._properties ["id" ],
655
- "description" : nodes [i ]._properties ["description" ],
656
- })
652
+ formatted_path .append (
653
+ {
654
+ "id" : nodes [i ]._properties ["id" ],
655
+ "description" : nodes [i ]._properties ["description" ],
656
+ }
657
+ )
657
658
if i < len (rels ):
658
- formatted_path .append ({
659
- "id" : rels [i ]._properties ["id" ],
660
- "description" : rels [i ]._properties ["description" ],
661
- })
659
+ formatted_path .append (
660
+ {
661
+ "id" : rels [i ]._properties ["id" ],
662
+ "description" : rels [i ]._properties ["description" ],
663
+ }
664
+ )
662
665
for i in range (0 , len (formatted_path ), 2 ):
663
666
mg .upsert_vertex (
664
667
Vertex (
0 commit comments