Skip to content

Commit 160ec4e

Browse files
committed
refactor: reseting values before every validate call in validators
added unit test for that case
1 parent dc5d603 commit 160ec4e

File tree

3 files changed

+82
-4
lines changed

3 files changed

+82
-4
lines changed

src/rai_bench/rai_bench/tool_calling_agent/interfaces.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -410,15 +410,26 @@ def add_subtask_errors(self, idx: int, msgs: List[str]):
410410

411411
def reset(self):
412412
"""
413-
reset all values refering previous validation
414-
before next validation
413+
resets all values refering previous validation.
414+
Use it before next validation.
415415
"""
416416
self.subtasks_errors = [[] for _ in range(len(self.subtasks))]
417-
self.subtasks_passed: List[bool] = [False for _ in range(len(self.subtasks))]
417+
self.subtasks_passed = [False for _ in range(len(self.subtasks))]
418418
self.extra_calls_used = 0
419419
self.passed = None
420420

421421
def dump_results(self) -> ValidatorResult:
422+
"""Get results for last validate() call
423+
424+
Returns
425+
-------
426+
ValidatorResult
427+
428+
Raises
429+
------
430+
ValueError
431+
When called before validate()
432+
"""
422433
if self.passed is None:
423434
raise ValueError("Run validator validation before dumping results")
424435
subtasks_results: List[SubTaskResult] = []
@@ -436,7 +447,6 @@ def dump_results(self) -> ValidatorResult:
436447
extra_tool_calls_used=self.extra_calls_used,
437448
passed=self.passed,
438449
)
439-
self.reset()
440450
return result
441451

442452
@abstractmethod

src/rai_bench/rai_bench/tool_calling_agent/validators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def type(self) -> str:
4444
return "ordered"
4545

4646
def validate(self, tool_calls: List[ToolCall]) -> Tuple[bool, List[ToolCall]]:
47+
self.reset()
4748
# Before validation create new iterator, in case validator
4849
# was used before in other task
4950
subtask_iter = iter(enumerate(self.subtasks))
@@ -93,6 +94,7 @@ def type(self) -> str:
9394
return "not ordered"
9495

9596
def validate(self, tool_calls: List[ToolCall]) -> Tuple[bool, List[ToolCall]]:
97+
self.reset()
9698
if len(tool_calls) < 1:
9799
self.logger.error("Not a single tool call to validate")
98100
self.passed = False

tests/rai_bench/tool_calling_agent/test_validators.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,43 @@ def test_validate_extra_calls_when_subtask_eventually_passes(self):
443443
expected_errors_counts=[0, 5],
444444
)
445445

446+
def test_validate_reset(self):
447+
subtasks = [
448+
DummySubTask("task1"),
449+
DummySubTask("task2", outcomes=10 * [False]),
450+
]
451+
validator = OrderedCallsValidator(subtasks=subtasks)
452+
453+
tool_calls = [
454+
ToolCall(name="tool1"),
455+
ToolCall(name="tool2"),
456+
ToolCall(name="tool2"),
457+
ToolCall(name="tool2"),
458+
ToolCall(name="tool2"),
459+
ToolCall(name="tool2"),
460+
]
461+
# additional call
462+
validator.validate(tool_calls=tool_calls)
463+
success, remaining = validator.validate(tool_calls=tool_calls)
464+
465+
assert not success
466+
assert remaining == []
467+
assert validator.subtasks_passed[0] is True
468+
assert validator.subtasks_passed[1] is False
469+
assert len(validator.subtasks_errors[1]) == 5
470+
assert "error in task2" in validator.subtasks_errors[1][0]
471+
assert validator.passed is False
472+
assert validator.extra_calls_used == 4
473+
474+
assert_dumped(
475+
validator,
476+
expected_type="ordered",
477+
expected_passed=False,
478+
expected_extra_calls=4,
479+
expected_subtasks_passed=[True, False],
480+
expected_errors_counts=[0, 5],
481+
)
482+
446483

447484
class TestNotOrderedCallsValidator:
448485
def test_init_with_empty_subtasks(self):
@@ -608,3 +645,32 @@ def test_validate_all_subtasks_fail(self):
608645
expected_subtasks_passed=[False, False],
609646
expected_errors_counts=[2, 2],
610647
)
648+
649+
def test_validate_reset(self):
650+
subtasks = [
651+
DummySubTask("task1", outcomes=4 * [False]),
652+
DummySubTask("task2", outcomes=4 * [False]),
653+
]
654+
validator = NotOrderedCallsValidator(subtasks=subtasks)
655+
tool_calls = [ToolCall(), ToolCall()]
656+
657+
# additional call
658+
validator.validate(tool_calls=tool_calls)
659+
success, remaining = validator.validate(tool_calls=tool_calls)
660+
661+
assert not success
662+
assert remaining == []
663+
assert all(not passed for passed in validator.subtasks_passed)
664+
assert len(validator.subtasks_errors[0]) == 2
665+
assert len(validator.subtasks_errors[1]) == 2
666+
assert validator.passed is False
667+
assert validator.extra_calls_used == 0
668+
669+
assert_dumped(
670+
validator,
671+
expected_type="not ordered",
672+
expected_passed=False,
673+
expected_extra_calls=0,
674+
expected_subtasks_passed=[False, False],
675+
expected_errors_counts=[2, 2],
676+
)

0 commit comments

Comments
 (0)