@@ -332,7 +332,7 @@ def test_vectorizer_dtype_mismatch(routes, redis_url):
332
332
)
333
333
334
334
335
- def test_invalid_vectorizer (routes , redis_url ):
335
+ def test_invalid_vectorizer (redis_url ):
336
336
with pytest .raises (TypeError ):
337
337
SemanticRouter (
338
338
name = "test_invalid_vectorizer" ,
@@ -473,53 +473,69 @@ def test_add_route_references_cls(redis_url, routes):
473
473
474
474
def test_add_route_references_cls_missing_inputs (redis_url ):
475
475
476
- # Add new references to an existing route
477
476
with pytest .raises (ValueError ):
478
477
SemanticRouter .add_route_references (
479
478
route_name = "farewell" ,
480
479
references = ["peace out" ],
481
- router_name = "test-router" ,
482
480
vectorizer = HFTextVectorizer (),
483
481
)
484
482
485
- with pytest .raises (ValueError ):
486
- SemanticRouter .add_route_references (
487
- route_name = "farewell" ,
488
- references = ["peace out" ],
489
- redis_url = redis_url ,
490
- vectorizer = HFTextVectorizer (),
491
- )
492
483
484
+ def test_get_route_references (redis_url ):
485
+ routes = [
486
+ Route (
487
+ name = "test" ,
488
+ references = ["hello" , "hi" ],
489
+ metadata = {"type" : "test" },
490
+ distance_threshold = 0.3 ,
491
+ ),
492
+ ]
493
+
494
+ # Get references for a specific route
495
+ router = SemanticRouter (
496
+ name = "get-router" ,
497
+ routes = routes ,
498
+ routing_config = RoutingConfig (max_k = 2 ),
499
+ redis_url = redis_url ,
500
+ )
493
501
494
- def test_get_route_references (semantic_router ):
495
502
# Get references for a specific route
496
- refs = semantic_router .get_route_references (route_name = "greeting " )
503
+ refs = router .get_route_references (route_name = "test " )
497
504
498
505
# Should return at least the initial references
499
506
assert len (refs ) >= 2
500
507
501
508
# Reference IDs should be present
502
509
reference_id = refs [0 ]["reference_id" ]
503
510
# Get references by ID
504
- id_refs = semantic_router .get_route_references (reference_ids = [reference_id ])
511
+ id_refs = router .get_route_references (reference_ids = [reference_id ])
505
512
assert len (id_refs ) == 1
506
513
507
514
with pytest .raises (ValueError ):
508
- semantic_router .get_route_references ()
515
+ router .get_route_references ()
509
516
510
517
511
- def test_get_route_references_cls (routes , redis_url ):
518
+ def test_get_route_references_cls (redis_url ):
519
+ routes = [
520
+ Route (
521
+ name = "test" ,
522
+ references = ["hello" , "hi" ],
523
+ metadata = {"type" : "test" },
524
+ distance_threshold = 0.3 ,
525
+ ),
526
+ ]
527
+
512
528
# Get references for a specific route
513
529
_ = SemanticRouter (
514
- name = "new -router" ,
530
+ name = "get -router" ,
515
531
routes = routes ,
516
532
routing_config = RoutingConfig (max_k = 2 ),
517
533
redis_url = redis_url ,
518
534
)
519
535
520
536
refs = SemanticRouter .get_route_references (
521
- route_name = "greeting " ,
522
- router_name = "new -router" ,
537
+ route_name = "test " ,
538
+ router_name = "get -router" ,
523
539
redis_url = redis_url ,
524
540
)
525
541
@@ -536,52 +552,39 @@ def test_get_route_references_cls(routes, redis_url):
536
552
SemanticRouter .get_route_references ()
537
553
538
554
539
- def test_delete_route_references (semantic_router ):
540
- redis_version = semantic_router ._index .client .info ()["redis_version" ]
541
- if not compare_versions (redis_version , "7.0.0" ):
542
- pytest .skip ("Not using a late enough version of Redis" )
543
-
544
- # Delete specific reference
545
- deleted_count = semantic_router .delete_route_references (
546
- route_name = "greeting" ,
547
- )
548
-
549
- assert deleted_count == 2
550
-
551
- # Verify the reference is gone
552
- refs = semantic_router .get_route_references (route_name = "farewell" )
553
- ref_id = refs [0 ]["reference_id" ]
554
- deleted = semantic_router .delete_route_references (
555
- route_name = "farewell" , reference_ids = [ref_id ]
556
- )
557
- assert deleted == 1
558
-
555
+ def test_delete_route_references (redis_url ):
556
+ routes = [
557
+ Route (
558
+ name = "test" ,
559
+ references = ["hello" , "hi" ],
560
+ metadata = {"type" : "test" },
561
+ distance_threshold = 0.3 ,
562
+ ),
563
+ Route (
564
+ name = "test2" ,
565
+ references = ["by" , "boy" ],
566
+ metadata = {"type" : "test" },
567
+ distance_threshold = 0.3 ,
568
+ ),
569
+ ]
559
570
560
- def test_delete_route_references_cls (routes , redis_url ):
561
571
# Get references for a specific route
562
- _ = SemanticRouter (
563
- name = "new -router" ,
572
+ router = SemanticRouter (
573
+ name = "delete -router" ,
564
574
routes = routes ,
565
575
routing_config = RoutingConfig (max_k = 2 ),
566
576
redis_url = redis_url ,
567
577
)
568
578
569
579
# Delete specific reference
570
- deleted_count = SemanticRouter .delete_route_references (
571
- route_name = "greeting" ,
572
- router_name = "new-router" ,
573
- redis_url = redis_url ,
580
+ deleted_count = router .delete_route_references (
581
+ route_name = "test" ,
574
582
)
575
583
576
584
assert deleted_count == 2
577
585
578
586
# Verify the reference is gone
579
- refs = SemanticRouter .get_route_references (route_name = "farewell " )
587
+ refs = router .get_route_references (route_name = "test2 " )
580
588
ref_id = refs [0 ]["reference_id" ]
581
- deleted = SemanticRouter .delete_route_references (
582
- route_name = "farewell" ,
583
- reference_ids = [ref_id ],
584
- router_name = "new-router" ,
585
- redis_url = redis_url ,
586
- )
589
+ deleted = router .delete_route_references (route_name = "test2" , reference_ids = [ref_id ])
587
590
assert deleted == 1
0 commit comments