|
20 | 20 | from aiperf.common.config.output_config import OutputConfig |
21 | 21 | from aiperf.common.config.tokenizer_config import TokenizerConfig |
22 | 22 | from aiperf.common.enums import CustomDatasetType, GPUTelemetryMode |
| 23 | +from aiperf.common.enums.plugin_enums import EndpointType |
23 | 24 | from aiperf.common.enums.timing_enums import RequestRateMode, TimingMode |
24 | 25 | from aiperf.common.utils import load_json_str |
25 | 26 |
|
@@ -398,3 +399,68 @@ def validate_concurrency_limits(self) -> Self: |
398 | 399 | ) |
399 | 400 |
|
400 | 401 | return self |
| 402 | + |
| 403 | + @model_validator(mode="after") |
| 404 | + def validate_rankings_token_options(self) -> Self: |
| 405 | + """Validate rankings token options usage.""" |
| 406 | + |
| 407 | + # Check if prompt input tokens have been changed from defaults |
| 408 | + prompt_tokens_modified = any( |
| 409 | + field in self.input.prompt.input_tokens.model_fields_set |
| 410 | + for field in ["mean", "stddev"] |
| 411 | + ) |
| 412 | + |
| 413 | + # Check if any rankings-specific token options have been changed from defaults |
| 414 | + rankings_token_fields = [ |
| 415 | + "rankings_passages_prompt_token_mean", |
| 416 | + "rankings_passages_prompt_token_stddev", |
| 417 | + "rankings_query_prompt_token_mean", |
| 418 | + "rankings_query_prompt_token_stddev", |
| 419 | + ] |
| 420 | + rankings_tokens_modified = any( |
| 421 | + field in self.input.model_fields_set for field in rankings_token_fields |
| 422 | + ) |
| 423 | + |
| 424 | + # Check if any rankings-specific passage options have been changed from defaults |
| 425 | + rankings_passages_fields = [ |
| 426 | + "rankings_passages_mean", |
| 427 | + "rankings_passages_stddev", |
| 428 | + ] |
| 429 | + rankings_passages_modified = any( |
| 430 | + field in self.input.model_fields_set for field in rankings_passages_fields |
| 431 | + ) |
| 432 | + |
| 433 | + rankings_options_modified = ( |
| 434 | + rankings_tokens_modified or rankings_passages_modified |
| 435 | + ) |
| 436 | + |
| 437 | + endpoint_type_is_rankings = "rankings" in self.endpoint.type.lower() |
| 438 | + |
| 439 | + # Validate that rankings options are only used with rankings endpoints |
| 440 | + rankings_endpoints = [ |
| 441 | + endpoint_type |
| 442 | + for endpoint_type in EndpointType |
| 443 | + if "rankings" in endpoint_type.lower() |
| 444 | + ] |
| 445 | + if rankings_options_modified and not endpoint_type_is_rankings: |
| 446 | + raise ValueError( |
| 447 | + f"Rankings-specific options (--rankings-passages-mean, --rankings-passages-stddev, " |
| 448 | + "--rankings-passages-prompt-token-mean, --rankings-passages-prompt-token-stddev, " |
| 449 | + "--rankings-query-prompt-token-mean, --rankings-query-prompt-token-stddev) " |
| 450 | + "can only be used with rankings endpoint types " |
| 451 | + f"Rankings endpoints: ({', '.join(rankings_endpoints)})." |
| 452 | + ) |
| 453 | + |
| 454 | + # Validate that prompt tokens and rankings tokens are not both set |
| 455 | + if prompt_tokens_modified and ( |
| 456 | + rankings_tokens_modified or endpoint_type_is_rankings |
| 457 | + ): |
| 458 | + raise ValueError( |
| 459 | + "The --prompt-input-tokens-mean/--prompt-input-tokens-stddev options " |
| 460 | + "cannot be used together with rankings-specific token options or the rankings endpoints" |
| 461 | + "Ranking options: (--rankings-passages-prompt-token-mean, --rankings-passages-prompt-token-stddev, " |
| 462 | + "--rankings-query-prompt-token-mean, --rankings-query-prompt-token-stddev, ). " |
| 463 | + f"Rankings endpoints: ({', '.join(rankings_endpoints)})." |
| 464 | + "Please use only one set of options." |
| 465 | + ) |
| 466 | + return self |
0 commit comments