162
162
start = "start" ,
163
163
)
164
164
165
- __version__ = "0.9 .0"
165
+ __version__ = "0.10 .0"
166
166
167
167
168
168
_ALPHABET = string .ascii_lowercase + string .digits
@@ -314,17 +314,17 @@ def _get_edge(host: nx.DiGraph, mapping, match_path, u, v):
314
314
315
315
def and_ (cond_a , cond_b ) -> CONDITION :
316
316
def inner (match : dict , host : nx .DiGraph , return_endges : list ) -> bool :
317
- condition_a , where_a = cond_a (match , host , return_endges )
317
+ condition_a , where_a = cond_a (match , host , return_endges )
318
318
condition_b , where_b = cond_b (match , host , return_endges )
319
319
where_result = [a and b for a , b in zip (where_a , where_b )]
320
320
return (condition_a and condition_b ), where_result
321
-
321
+
322
322
return inner
323
323
324
324
325
325
def or_ (cond_a , cond_b ):
326
326
def inner (match : dict , host : nx .DiGraph , return_endges : list ) -> bool :
327
- condition_a , where_a = cond_a (match , host , return_endges )
327
+ condition_a , where_a = cond_a (match , host , return_endges )
328
328
condition_b , where_b = cond_b (match , host , return_endges )
329
329
where_result = [a or b for a , b in zip (where_a , where_b )]
330
330
return (condition_a or condition_b ), where_result
@@ -419,7 +419,7 @@ def _filter_edge(edge, where_results):
419
419
# exclude edge(s) from multiedge that don't satisfy the where condition
420
420
edge = {k : v for k , v in edge [0 ].items () if where_results [k ] is True }
421
421
return [edge ]
422
-
422
+
423
423
if not data_paths :
424
424
return {}
425
425
@@ -480,8 +480,13 @@ def _filter_edge(edge, where_results):
480
480
ret = (
481
481
_filter_edge (
482
482
_get_edge (
483
- self ._target_graph , mapping [0 ], match_path , mapping_u , mapping_v
484
- ), mapping [1 ]
483
+ self ._target_graph ,
484
+ mapping [0 ],
485
+ match_path ,
486
+ mapping_u ,
487
+ mapping_v ,
488
+ ),
489
+ mapping [1 ],
485
490
)
486
491
for mapping , match_path in true_matches
487
492
)
@@ -506,7 +511,11 @@ def _filter_edge(edge, where_results):
506
511
for r in ret :
507
512
508
513
r = {
509
- k : v for k , v in r .items () if v .get ("__labels__" , None ).intersection (motif_edge_labels )
514
+ k : v
515
+ for k , v in r .items ()
516
+ if v .get ("__labels__" , None ).intersection (
517
+ motif_edge_labels
518
+ )
510
519
}
511
520
if len (r ) > 0 :
512
521
filtered_ret .append (r )
@@ -542,7 +551,9 @@ def return_clause(self, clause):
542
551
if isinstance (item , Tree ) and item .data == "aggregation_function" :
543
552
func , entity = self ._parse_aggregation_token (item )
544
553
if alias :
545
- self ._entity2alias [self ._format_aggregation_key (func , entity )] = alias
554
+ self ._entity2alias [
555
+ self ._format_aggregation_key (func , entity )
556
+ ] = alias
546
557
self ._aggregation_attributes .add (entity )
547
558
self ._aggregate_functions .append ((func , entity ))
548
559
else :
@@ -556,26 +567,26 @@ def return_clause(self, clause):
556
567
self ._alias2entity .update ({v : k for k , v in self ._entity2alias .items ()})
557
568
558
569
def _extract_alias (self , item : Tree ):
559
- '''
570
+ """
560
571
Extract the alias from the return item (if it exists)
561
- '''
572
+ """
562
573
563
574
if len (item .children ) == 1 :
564
575
return None
565
576
item_keys = [it .data if isinstance (it , Tree ) else None for it in item .children ]
566
- if any (k == ' alias' for k in item_keys ):
567
- # get the index of the alias
568
- alias_index = item_keys .index (' alias' )
577
+ if any (k == " alias" for k in item_keys ):
578
+ # get the index of the alias
579
+ alias_index = item_keys .index (" alias" )
569
580
return str (item .children [alias_index ].children [0 ].value )
570
-
581
+
571
582
return None
572
-
583
+
573
584
def _parse_aggregation_token (self , item : Tree ):
574
- '''
585
+ """
575
586
Parse the aggregation function token and return the function and entity
576
587
input: Tree('aggregation_function', [Token('AGGREGATE_FUNC', 'SUM'), Token('CNAME', 'r'), Tree('attribute_id', [Token('CNAME', 'value')])])
577
588
output: ('SUM', 'r.value')
578
- '''
589
+ """
579
590
func = str (item .children [0 ].value ) # AGGREGATE_FUNC
580
591
entity = str (item .children [1 ].value )
581
592
if len (item .children ) > 2 :
@@ -589,12 +600,17 @@ def _format_aggregation_key(self, func, entity):
589
600
def order_clause (self , order_clause ):
590
601
self ._order_by = []
591
602
for item in order_clause [0 ].children :
592
- if isinstance (item .children [0 ], Tree ) and item .children [0 ].data == "aggregation_function" :
603
+ if (
604
+ isinstance (item .children [0 ], Tree )
605
+ and item .children [0 ].data == "aggregation_function"
606
+ ):
593
607
func , entity = self ._parse_aggregation_token (item .children [0 ])
594
608
field = self ._format_aggregation_key (func , entity )
595
609
self ._order_by_attributes .add (entity )
596
610
else :
597
- field = str (item .children [0 ]) # assuming the field name is the first child
611
+ field = str (
612
+ item .children [0 ]
613
+ ) # assuming the field name is the first child
598
614
self ._order_by_attributes .add (field )
599
615
600
616
# Default to 'ASC' if not specified
@@ -687,8 +703,12 @@ def _collate_data(data, unique_labels, func):
687
703
688
704
def returns (self , ignore_limit = False ):
689
705
690
- data_paths = self ._return_requests + list (self ._order_by_attributes ) + list (self ._aggregation_attributes )
691
- # aliases should already be requested in their original form, so we will remove them for lookup
706
+ data_paths = (
707
+ self ._return_requests
708
+ + list (self ._order_by_attributes )
709
+ + list (self ._aggregation_attributes )
710
+ )
711
+ # aliases should already be requested in their original form, so we will remove them for lookup
692
712
data_paths = [d for d in data_paths if d not in self ._alias2entity ]
693
713
results = self ._lookup (
694
714
data_paths ,
@@ -739,10 +759,10 @@ def _apply_order_by(self, results):
739
759
indices = range (
740
760
len (next (iter (results .values ())))
741
761
) # Safe because all lists are assumed to be of the same length
742
- for ( sort_list , field , direction ) in reversed (
762
+ for sort_list , field , direction in reversed (
743
763
sort_lists
744
764
): # reverse to ensure the first sort key is primary
745
-
765
+
746
766
if all (isinstance (item , dict ) for item in sort_list ):
747
767
# (for edge attributes) If all items in sort_list are dictionaries
748
768
# example: ([{(0, 'paid'): 9, (1, 'paid'): 40}, {(0, 'paid'): 14}], 'DESC')
@@ -761,7 +781,8 @@ def _apply_order_by(self, results):
761
781
# then sort the indices based on the sorted sublists
762
782
indices = sorted (
763
783
indices ,
764
- key = lambda i : list (sort_list [i ].values ())[0 ] or 0 , # 0 if `None`
784
+ key = lambda i : list (sort_list [i ].values ())[0 ]
785
+ or 0 , # 0 if `None`
765
786
reverse = (direction == "DESC" ),
766
787
)
767
788
# update results with sorted edge attributes list
0 commit comments