|
17 | 17 | import abc |
18 | 18 | import json |
19 | 19 | import logging |
20 | | -from typing import Any, Optional |
| 20 | +from typing import Any, Optional, Union |
21 | 21 |
|
22 | 22 | from google.genai import _common |
23 | 23 | from google.genai import types as genai_types |
@@ -103,12 +103,20 @@ def _parse_request( |
103 | 103 | last_message.content.role if last_message.content else "user" |
104 | 104 | ) |
105 | 105 | if last_message_role in ["user", None]: |
106 | | - prompt = last_message.content |
| 106 | + prompt = ( |
| 107 | + last_message.content |
| 108 | + if last_message.content |
| 109 | + else genai_types.Content() |
| 110 | + ) |
107 | 111 | elif last_message_role == "model": |
108 | 112 | reference = types.ResponseCandidate(response=last_message.content) |
109 | 113 | if conversation_history: |
110 | 114 | second_to_last_message = conversation_history.pop() |
111 | | - prompt = second_to_last_message.content |
| 115 | + prompt = ( |
| 116 | + second_to_last_message.content |
| 117 | + if second_to_last_message.content |
| 118 | + else genai_types.Content() |
| 119 | + ) |
112 | 120 | else: |
113 | 121 | prompt = genai_types.Content() |
114 | 122 |
|
@@ -436,7 +444,7 @@ def convert(self, raw_data: list[dict[str, Any]]) -> types.EvaluationDataset: |
436 | 444 |
|
437 | 445 | def auto_detect_dataset_schema( |
438 | 446 | raw_dataset: list[dict[str, Any]], |
439 | | -) -> EvalDatasetSchema: |
| 447 | +) -> Union[EvalDatasetSchema, str]: |
440 | 448 | """Detects the schema of a raw dataset.""" |
441 | 449 | if not raw_dataset: |
442 | 450 | return EvalDatasetSchema.UNKNOWN |
@@ -522,7 +530,7 @@ def _validate_case_consistency( |
522 | 530 | current_case: types.EvalCase, |
523 | 531 | case_idx: int, |
524 | 532 | dataset_idx: int, |
525 | | -): |
| 533 | +) -> None: |
526 | 534 | """Logs warnings if prompt or reference mismatches occur.""" |
527 | 535 | if base_case.prompt != current_case.prompt: |
528 | 536 | base_prompt_text_preview = _get_first_part_text(base_case.prompt)[:50] |
@@ -609,7 +617,11 @@ def merge_response_datasets_into_canonical_format( |
609 | 617 | base_parsed_dataset = parsed_evaluation_datasets[0] |
610 | 618 |
|
611 | 619 | for case_idx in range(num_expected_cases): |
612 | | - base_eval_case: types.EvalCase = base_parsed_dataset.eval_cases[case_idx] |
| 620 | + base_eval_case: types.EvalCase = ( |
| 621 | + base_parsed_dataset.eval_cases[case_idx] |
| 622 | + if base_parsed_dataset.eval_cases |
| 623 | + else types.EvalCase() |
| 624 | + ) |
613 | 625 | candidate_responses: list[types.ResponseCandidate] = [] |
614 | 626 |
|
615 | 627 | if base_eval_case.responses: |
@@ -640,9 +652,11 @@ def merge_response_datasets_into_canonical_format( |
640 | 652 | for dataset_idx_offset, current_parsed_ds in enumerate( |
641 | 653 | parsed_evaluation_datasets[1:], start=1 |
642 | 654 | ): |
643 | | - current_ds_eval_case: types.EvalCase = current_parsed_ds.eval_cases[ |
644 | | - case_idx |
645 | | - ] |
| 655 | + current_ds_eval_case: types.EvalCase = ( |
| 656 | + current_parsed_ds.eval_cases[case_idx] |
| 657 | + if current_parsed_ds.eval_cases |
| 658 | + else types.EvalCase() |
| 659 | + ) |
646 | 660 |
|
647 | 661 | _validate_case_consistency( |
648 | 662 | base_eval_case, current_ds_eval_case, case_idx, dataset_idx_offset |
|
0 commit comments