Skip to content

Commit 4cbd7e0

Browse files
revert to previous validation
1 parent 5621a06 commit 4cbd7e0

File tree

4 files changed

+40
-64
lines changed

4 files changed

+40
-64
lines changed

ads/aqua/common/utils.py

Lines changed: 1 addition & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
import re
1313
import shlex
1414
import subprocess
15-
from dataclasses import fields
1615
from datetime import datetime, timedelta
1716
from functools import wraps
1817
from pathlib import Path
1918
from string import Template
20-
from typing import Any, List, Optional, Type, TypeVar, Union
19+
from typing import List, Union
2120

2221
import fsspec
2322
import oci
@@ -31,7 +30,6 @@
3130
)
3231
from oci.data_science.models import JobRun, Model
3332
from oci.object_storage.models import ObjectSummary
34-
from pydantic import BaseModel, ValidationError
3533

3634
from ads.aqua.common.enums import (
3735
InferenceContainerParamType,
@@ -76,7 +74,6 @@
7674
from ads.model import DataScienceModel, ModelVersionSet
7775

7876
logger = logging.getLogger("ads.aqua")
79-
T = TypeVar("T", bound=Union[BaseModel, Any])
8077

8178

8279
class LifecycleStatus(str, metaclass=ExtendedEnumMeta):
@@ -1164,49 +1161,3 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
11641161

11651162
combined_cmd_var = cmd_var + overrides
11661163
return combined_cmd_var
1167-
1168-
1169-
def validate_dataclass_params(dataclass_type: Type[T], **kwargs: Any) -> Optional[T]:
1170-
"""This method tries to initialize a dataclass with the provided keyword arguments. It handles
1171-
errors related to missing, unexpected or invalid arguments.
1172-
1173-
Parameters
1174-
----------
1175-
dataclass_type (Type[T]):
1176-
the dataclass type to instantiate.
1177-
kwargs (Any):
1178-
the keyword arguments to initialize the dataclass
1179-
Returns
1180-
-------
1181-
Optional[T]
1182-
instance of dataclass if successfully initialized
1183-
"""
1184-
1185-
try:
1186-
return dataclass_type(**kwargs)
1187-
except TypeError as ex:
1188-
error_message = str(ex)
1189-
allowed_params = ", ".join(
1190-
field.name for field in fields(dataclass_type)
1191-
).rstrip()
1192-
if "__init__() missing" in error_message:
1193-
missing_params = error_message.split("missing ")[1]
1194-
raise AquaValueError(
1195-
"Error: Missing required parameters: "
1196-
f"{missing_params}. Allowable parameters are: {allowed_params}."
1197-
) from ex
1198-
elif "__init__() got an unexpected keyword argument" in error_message:
1199-
unexpected_param = error_message.split("argument '")[1].rstrip("'")
1200-
raise AquaValueError(
1201-
"Error: Unexpected parameter: "
1202-
f"{unexpected_param}. Allowable parameters are: {allowed_params}."
1203-
) from ex
1204-
else:
1205-
raise AquaValueError(
1206-
"Invalid parameters. Allowable parameters are: " f"{allowed_params}."
1207-
) from ex
1208-
except ValidationError as ex:
1209-
custom_errors = {".".join(map(str, e["loc"])): e["msg"] for e in ex.errors()}
1210-
raise AquaValueError(
1211-
f"Invalid parameters. Error details: {custom_errors}."
1212-
) from ex

ads/aqua/evaluation/evaluation.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
get_container_image,
4444
is_valid_ocid,
4545
upload_local_to_os,
46-
validate_dataclass_params,
4746
)
4847
from ads.aqua.config.config import get_evaluation_service_config
4948
from ads.aqua.constants import (
@@ -156,9 +155,16 @@ def create(
156155
The instance of AquaEvaluationSummary.
157156
"""
158157
if not create_aqua_evaluation_details:
159-
create_aqua_evaluation_details = validate_dataclass_params(
160-
CreateAquaEvaluationDetails, **kwargs
161-
)
158+
try:
159+
create_aqua_evaluation_details = CreateAquaEvaluationDetails(**kwargs)
160+
except Exception as ex:
161+
custom_errors = {
162+
".".join(map(str, e["loc"])): e["msg"]
163+
for e in json.loads(ex.json())
164+
}
165+
raise AquaValueError(
166+
f"Invalid create evaluation parameters. Error details: {custom_errors}."
167+
) from ex
162168

163169
if not is_valid_ocid(create_aqua_evaluation_details.evaluation_source_id):
164170
raise AquaValueError(

ads/aqua/extension/finetune_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
8989

9090
params = input_data.get("params", None)
9191
return self.finish(
92-
AquaFineTuningApp.validate_finetuning_params(
92+
AquaFineTuningApp().validate_finetuning_params(
9393
params=params,
9494
)
9595
)

ads/aqua/finetuning/finetuning.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import json
66
import os
7-
from dataclasses import asdict, fields
7+
from dataclasses import MISSING, asdict, fields
88
from typing import Dict
99

1010
from oci.data_science.models import (
@@ -20,7 +20,6 @@
2020
from ads.aqua.common.utils import (
2121
get_container_image,
2222
upload_local_to_os,
23-
validate_dataclass_params,
2423
)
2524
from ads.aqua.constants import (
2625
DEFAULT_FT_BATCH_SIZE,
@@ -103,9 +102,16 @@ def create(
103102
The instance of AquaFineTuningSummary.
104103
"""
105104
if not create_fine_tuning_details:
106-
create_fine_tuning_details = validate_dataclass_params(
107-
CreateFineTuningDetails, **kwargs
108-
)
105+
try:
106+
create_fine_tuning_details = CreateFineTuningDetails(**kwargs)
107+
except Exception as ex:
108+
allowed_create_fine_tuning_details = ", ".join(
109+
field.name for field in fields(CreateFineTuningDetails)
110+
).rstrip()
111+
raise AquaValueError(
112+
"Invalid create fine tuning parameters. Allowable parameters are: "
113+
f"{allowed_create_fine_tuning_details}."
114+
) from ex
109115

110116
source = self.get_source(create_fine_tuning_details.ft_source_id)
111117

@@ -615,8 +621,7 @@ def get_finetuning_default_params(self, model_id: str) -> Dict:
615621

616622
return default_params
617623

618-
@staticmethod
619-
def validate_finetuning_params(params: Dict = None) -> Dict:
624+
def validate_finetuning_params(self, params: Dict = None) -> Dict:
620625
"""Validate if the fine-tuning parameters passed by the user can be overridden. Parameter values are not
621626
validated, only param keys are validated.
622627
@@ -627,7 +632,21 @@ def validate_finetuning_params(params: Dict = None) -> Dict:
627632
628633
Returns
629634
-------
630-
Return a dict with value true if valid, else raises AquaValueError.
635+
Return a list of restricted params.
631636
"""
632-
validate_dataclass_params(AquaFineTuningParams, **(params or {}))
637+
try:
638+
AquaFineTuningParams(
639+
**params,
640+
)
641+
except Exception as e:
642+
logger.debug(str(e))
643+
allowed_fine_tuning_parameters = ", ".join(
644+
f"{field.name} (required)" if field.default is MISSING else field.name
645+
for field in fields(AquaFineTuningParams)
646+
).rstrip()
647+
raise AquaValueError(
648+
f"Invalid fine tuning parameters. Allowable parameters are: "
649+
f"{allowed_fine_tuning_parameters}."
650+
) from e
651+
633652
return {"valid": True}

0 commit comments

Comments
 (0)