Skip to content

Commit 1cdd58e

Browse files
authored
Merge branch 'main' into feature/what-if-analysis
2 parents 898111e + 410dbe0 commit 1cdd58e

40 files changed

+3722
-498
lines changed

.github/workflows/run-forecast-unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,6 @@ jobs:
5656
$CONDA/bin/conda init
5757
source /home/runner/.bashrc
5858
pip install -r test-requirements-operators.txt
59-
pip install "oracle-automlx[forecasting]>=24.4.0"
59+
pip install "oracle-automlx[forecasting]>=24.4.1"
6060
pip install pandas>=2.2.0
6161
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast

ads/aqua/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

65

76
import os
7+
from logging import getLogger
88

9-
from ads import logger, set_auth
9+
from ads import set_auth
1010
from ads.aqua.common.utils import fetch_service_compartment
1111
from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION
1212

@@ -19,6 +19,7 @@ def get_logger_level():
1919
return level
2020

2121

22+
logger = getLogger(__name__)
2223
logger.setLevel(get_logger_level())
2324

2425

@@ -27,7 +28,6 @@ def set_log_level(log_level: str):
2728

2829
log_level = log_level.upper()
2930
logger.setLevel(log_level.upper())
30-
logger.handlers[0].setLevel(log_level)
3131

3232

3333
if OCI_RESOURCE_PRINCIPAL_VERSION:

ads/aqua/common/enums.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5252
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
5353
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
5454
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
55+
56+
57+
class CustomInferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5558
AQUA_TEI_CONTAINER_FAMILY = "odsc-tei-serving"
5659

5760

ads/aqua/common/utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
"""AQUA utils and constants."""
55

@@ -30,6 +30,7 @@
3030
)
3131
from oci.data_science.models import JobRun, Model
3232
from oci.object_storage.models import ObjectSummary
33+
from pydantic import ValidationError
3334

3435
from ads.aqua.common.enums import (
3536
InferenceContainerParamType,
@@ -788,7 +789,9 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
788789
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
789790

790791

791-
def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None) -> str:
792+
def upload_folder(
793+
os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None
794+
) -> str:
792795
"""Upload the local folder to the object storage
793796
794797
Args:
@@ -1159,3 +1162,15 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
11591162

11601163
combined_cmd_var = cmd_var + overrides
11611164
return combined_cmd_var
1165+
1166+
1167+
def build_pydantic_error_message(ex: ValidationError):
1168+
"""Added to handle error messages from pydantic model validator.
1169+
Combine both loc and msg for errors where loc (field) is present in error details, else only build error
1170+
message using msg field."""
1171+
1172+
return {
1173+
".".join(map(str, e["loc"])): e["msg"]
1174+
for e in ex.errors()
1175+
if "loc" in e and e["loc"]
1176+
} or "; ".join(e["msg"] for e in ex.errors())

ads/aqua/data.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

6-
from dataclasses import dataclass, field
5+
from dataclasses import dataclass
76

87
from ads.common.serializer import DataClassSerializable
98

@@ -13,19 +12,3 @@ class AquaResourceIdentifier(DataClassSerializable):
1312
id: str = ""
1413
name: str = ""
1514
url: str = ""
16-
17-
18-
@dataclass(repr=False)
19-
class AquaJobSummary(DataClassSerializable):
20-
"""Represents an Aqua job summary."""
21-
22-
id: str
23-
name: str
24-
console_url: str
25-
lifecycle_state: str
26-
lifecycle_details: str
27-
time_created: str
28-
tags: dict
29-
experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
30-
source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
31-
job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)

ads/aqua/extension/finetune_handler.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55

@@ -10,9 +10,7 @@
1010
from ads.aqua.common.decorator import handle_exceptions
1111
from ads.aqua.extension.base_handler import AquaAPIhandler
1212
from ads.aqua.extension.errors import Errors
13-
from ads.aqua.extension.utils import validate_function_parameters
1413
from ads.aqua.finetuning import AquaFineTuningApp
15-
from ads.aqua.finetuning.entities import CreateFineTuningDetails
1614

1715

1816
class AquaFineTuneHandler(AquaAPIhandler):
@@ -33,7 +31,7 @@ def get(self, id=""):
3331
raise HTTPError(400, f"The request {self.request.path} is invalid.")
3432

3533
@handle_exceptions
36-
def post(self, *args, **kwargs):
34+
def post(self, *args, **kwargs): # noqa: ARG002
3735
"""Handles post request for the fine-tuning API
3836
3937
Raises
@@ -43,17 +41,13 @@ def post(self, *args, **kwargs):
4341
"""
4442
try:
4543
input_data = self.get_json_body()
46-
except Exception:
47-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
44+
except Exception as ex:
45+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
4846

4947
if not input_data:
5048
raise HTTPError(400, Errors.NO_INPUT_DATA)
5149

52-
validate_function_parameters(
53-
data_class=CreateFineTuningDetails, input_data=input_data
54-
)
55-
56-
self.finish(AquaFineTuningApp().create(CreateFineTuningDetails(**input_data)))
50+
self.finish(AquaFineTuningApp().create(**input_data))
5751

5852
def get_finetuning_config(self, model_id):
5953
"""Gets the finetuning config for Aqua model."""
@@ -71,7 +65,7 @@ def get(self, model_id):
7165
)
7266

7367
@handle_exceptions
74-
def post(self, *args, **kwargs):
68+
def post(self, *args, **kwargs): # noqa: ARG002
7569
"""Handles post request for the finetuning param handler API.
7670
7771
Raises
@@ -81,8 +75,8 @@ def post(self, *args, **kwargs):
8175
"""
8276
try:
8377
input_data = self.get_json_body()
84-
except Exception:
85-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
78+
except Exception as ex:
79+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
8680

8781
if not input_data:
8882
raise HTTPError(400, Errors.NO_INPUT_DATA)

ads/aqua/extension/model_handler.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from tornado.web import HTTPError
99

1010
from ads.aqua.common.decorator import handle_exceptions
11+
from ads.aqua.common.enums import (
12+
CustomInferenceContainerTypeFamily,
13+
)
1114
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1215
from ads.aqua.common.utils import (
1316
get_hf_model_info,
@@ -163,7 +166,9 @@ def put(self, id):
163166
raise HTTPError(400, Errors.NO_INPUT_DATA)
164167

165168
inference_container = input_data.get("inference_container")
169+
inference_container_uri = input_data.get("inference_container_uri")
166170
inference_containers = AquaModelApp.list_valid_inference_containers()
171+
inference_containers.extend(CustomInferenceContainerTypeFamily.values())
167172
if (
168173
inference_container is not None
169174
and inference_container not in inference_containers
@@ -176,7 +181,13 @@ def put(self, id):
176181
task = input_data.get("task")
177182
app = AquaModelApp()
178183
self.finish(
179-
app.edit_registered_model(id, inference_container, enable_finetuning, task)
184+
app.edit_registered_model(
185+
id,
186+
inference_container,
187+
inference_container_uri,
188+
enable_finetuning,
189+
task,
190+
)
180191
)
181192
app.clear_model_details_cache(model_id=id)
182193

ads/aqua/finetuning/constants.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

65
from ads.common.extended_enum import ExtendedEnumMeta
@@ -17,4 +16,8 @@ class FineTuneCustomMetadata(str, metaclass=ExtendedEnumMeta):
1716
SERVICE_MODEL_FINE_TUNE_CONTAINER = "finetune-container"
1817

1918

19+
class FineTuningRestrictedParams(str, metaclass=ExtendedEnumMeta):
20+
OPTIMIZER = "optimizer"
21+
22+
2023
ENV_AQUA_FINE_TUNING_CONTAINER = "AQUA_FINE_TUNING_CONTAINER"

ads/aqua/finetuning/entities.py

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4-
from dataclasses import dataclass, field
5-
from typing import List, Optional
64

7-
from ads.aqua.data import AquaJobSummary
8-
from ads.common.serializer import DataClassSerializable
5+
import json
6+
from typing import List, Literal, Optional, Union
97

8+
from pydantic import Field, model_validator
109

11-
@dataclass(repr=False)
12-
class AquaFineTuningParams(DataClassSerializable):
13-
epochs: int
10+
from ads.aqua.common.errors import AquaValueError
11+
from ads.aqua.config.utils.serializer import Serializable
12+
from ads.aqua.data import AquaResourceIdentifier
13+
from ads.aqua.finetuning.constants import FineTuningRestrictedParams
14+
15+
16+
class AquaFineTuningParams(Serializable):
17+
"""Class for maintaining aqua fine-tuning model parameters"""
18+
19+
epochs: Optional[int] = None
1420
learning_rate: Optional[float] = None
15-
sample_packing: Optional[bool] = "auto"
21+
sample_packing: Union[bool, None, Literal["auto"]] = "auto"
1622
batch_size: Optional[int] = (
1723
None # make it batch_size for user, but internally this is micro_batch_size
1824
)
@@ -22,21 +28,59 @@ class AquaFineTuningParams(DataClassSerializable):
2228
lora_alpha: Optional[int] = None
2329
lora_dropout: Optional[float] = None
2430
lora_target_linear: Optional[bool] = None
25-
lora_target_modules: Optional[List] = None
31+
lora_target_modules: Optional[List[str]] = None
2632
early_stopping_patience: Optional[int] = None
2733
early_stopping_threshold: Optional[float] = None
2834

35+
class Config:
36+
extra = "allow"
37+
38+
def to_dict(self) -> dict:
39+
return json.loads(super().to_json(exclude_none=True))
40+
41+
@model_validator(mode="before")
42+
@classmethod
43+
def validate_restricted_fields(cls, data: dict):
44+
# we may want to skip validation if loading data from config files instead of user entered parameters
45+
validate = data.pop("_validate", True)
46+
if not (validate and isinstance(data, dict)):
47+
return data
48+
restricted_params = [
49+
param for param in data if param in FineTuningRestrictedParams.values()
50+
]
51+
if restricted_params:
52+
raise AquaValueError(
53+
f"Found restricted parameter name: {restricted_params}"
54+
)
55+
return data
2956

30-
@dataclass(repr=False)
31-
class AquaFineTuningSummary(AquaJobSummary, DataClassSerializable):
32-
parameters: AquaFineTuningParams = field(default_factory=AquaFineTuningParams)
3357

58+
class AquaFineTuningSummary(Serializable):
59+
"""Represents a summary of Aqua Finetuning job."""
3460

35-
@dataclass(repr=False)
36-
class CreateFineTuningDetails(DataClassSerializable):
37-
"""Dataclass to create aqua model fine tuning.
61+
id: str
62+
name: str
63+
console_url: str
64+
lifecycle_state: str
65+
lifecycle_details: str
66+
time_created: str
67+
tags: dict
68+
experiment: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
69+
source: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
70+
job: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
71+
parameters: AquaFineTuningParams = Field(default_factory=AquaFineTuningParams)
3872

39-
Fields
73+
class Config:
74+
extra = "ignore"
75+
76+
def to_dict(self) -> dict:
77+
return json.loads(super().to_json(exclude_none=True))
78+
79+
80+
class CreateFineTuningDetails(Serializable):
81+
"""Class to create aqua model fine-tuning instance.
82+
83+
Properties
4084
------
4185
ft_source_id: str
4286
The fine tuning source id. Must be model ocid.
@@ -107,3 +151,6 @@ class CreateFineTuningDetails(DataClassSerializable):
107151
force_overwrite: Optional[bool] = False
108152
freeform_tags: Optional[dict] = None
109153
defined_tags: Optional[dict] = None
154+
155+
class Config:
156+
extra = "ignore"

0 commit comments

Comments
 (0)