Skip to content

Commit 3bc31e5

Browse files
Resolve merge conflicts
2 parents 4cbd7e0 + 21ba00b commit 3bc31e5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+4605
-532
lines changed

.github/workflows/run-forecast-unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,6 @@ jobs:
5656
$CONDA/bin/conda init
5757
source /home/runner/.bashrc
5858
pip install -r test-requirements-operators.txt
59-
pip install "oracle-automlx[forecasting]>=24.4.0"
59+
pip install "oracle-automlx[forecasting]>=24.4.1"
6060
pip install pandas>=2.2.0
6161
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast

ads/aqua/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

65

76
import os
7+
from logging import getLogger
88

9-
from ads import logger, set_auth
9+
from ads import set_auth
1010
from ads.aqua.common.utils import fetch_service_compartment
1111
from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION
1212

@@ -19,6 +19,7 @@ def get_logger_level():
1919
return level
2020

2121

22+
logger = getLogger(__name__)
2223
logger.setLevel(get_logger_level())
2324

2425

@@ -27,7 +28,6 @@ def set_log_level(log_level: str):
2728

2829
log_level = log_level.upper()
2930
logger.setLevel(log_level.upper())
30-
logger.handlers[0].setLevel(log_level)
3131

3232

3333
if OCI_RESOURCE_PRINCIPAL_VERSION:

ads/aqua/common/enums.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5252
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
5353
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
5454
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
55+
56+
57+
class CustomInferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5558
AQUA_TEI_CONTAINER_FAMILY = "odsc-tei-serving"
5659

5760

ads/aqua/common/utils.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import random
1212
import re
1313
import shlex
14+
import shutil
1415
import subprocess
1516
from datetime import datetime, timedelta
1617
from functools import wraps
@@ -21,6 +22,8 @@
2122
import fsspec
2223
import oci
2324
from cachetools import TTLCache, cached
25+
from huggingface_hub.constants import HF_HUB_CACHE
26+
from huggingface_hub.file_download import repo_folder_name
2427
from huggingface_hub.hf_api import HfApi, ModelInfo
2528
from huggingface_hub.utils import (
2629
GatedRepoError,
@@ -30,6 +33,7 @@
3033
)
3134
from oci.data_science.models import JobRun, Model
3235
from oci.object_storage.models import ObjectSummary
36+
from pydantic import ValidationError
3337

3438
from ads.aqua.common.enums import (
3539
InferenceContainerParamType,
@@ -820,6 +824,48 @@ def upload_folder(
820824
return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path
821825

822826

827+
def cleanup_local_hf_model_artifact(
828+
model_name: str,
829+
local_dir: str = None,
830+
):
831+
"""
832+
Helper function that deletes local artifacts downloaded from Hugging Face to free up disk space.
833+
Parameters
834+
----------
835+
model_name (str): Name of the huggingface model
836+
local_dir (str): Local directory where the object is downloaded
837+
838+
"""
839+
if local_dir and os.path.exists(local_dir):
840+
model_dir = os.path.join(local_dir, model_name)
841+
model_dir = (
842+
os.path.dirname(model_dir)
843+
if "/" in model_name or os.sep in model_name
844+
else model_dir
845+
)
846+
shutil.rmtree(model_dir, ignore_errors=True)
847+
if os.path.exists(model_dir):
848+
logger.debug(
849+
f"Could not delete local model artifact directory: {model_dir}"
850+
)
851+
else:
852+
logger.debug(f"Deleted local model artifact directory: {model_dir}.")
853+
854+
hf_local_path = os.path.join(
855+
HF_HUB_CACHE, repo_folder_name(repo_id=model_name, repo_type="model")
856+
)
857+
shutil.rmtree(hf_local_path, ignore_errors=True)
858+
859+
if os.path.exists(hf_local_path):
860+
logger.debug(
861+
f"Could not clear the local Hugging Face cache directory {hf_local_path} for the model {model_name}."
862+
)
863+
else:
864+
logger.debug(
865+
f"Cleared contents of local Hugging Face cache directory {hf_local_path} for the model {model_name}."
866+
)
867+
868+
823869
def is_service_managed_container(container):
824870
return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
825871

@@ -1161,3 +1207,15 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
11611207

11621208
combined_cmd_var = cmd_var + overrides
11631209
return combined_cmd_var
1210+
1211+
1212+
def build_pydantic_error_message(ex: ValidationError):
1213+
"""Added to handle error messages from pydantic model validator.
1214+
Combine both loc and msg for errors where loc (field) is present in error details, else only build error
1215+
message using msg field."""
1216+
1217+
return {
1218+
".".join(map(str, e["loc"])): e["msg"]
1219+
for e in ex.errors()
1220+
if "loc" in e and e["loc"]
1221+
} or "; ".join(e["msg"] for e in ex.errors())

ads/aqua/data.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

6-
from dataclasses import dataclass, field
5+
from dataclasses import dataclass
76

87
from ads.common.serializer import DataClassSerializable
98

@@ -13,19 +12,3 @@ class AquaResourceIdentifier(DataClassSerializable):
1312
id: str = ""
1413
name: str = ""
1514
url: str = ""
16-
17-
18-
@dataclass(repr=False)
19-
class AquaJobSummary(DataClassSerializable):
20-
"""Represents an Aqua job summary."""
21-
22-
id: str
23-
name: str
24-
console_url: str
25-
lifecycle_state: str
26-
lifecycle_details: str
27-
time_created: str
28-
tags: dict
29-
experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
30-
source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
31-
job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)

ads/aqua/extension/finetune_handler.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
from ads.aqua.common.decorator import handle_exceptions
1111
from ads.aqua.extension.base_handler import AquaAPIhandler
1212
from ads.aqua.extension.errors import Errors
13-
from ads.aqua.extension.utils import validate_function_parameters
1413
from ads.aqua.finetuning import AquaFineTuningApp
15-
from ads.aqua.finetuning.entities import CreateFineTuningDetails
1614

1715

1816
class AquaFineTuneHandler(AquaAPIhandler):
@@ -49,11 +47,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
4947
if not input_data:
5048
raise HTTPError(400, Errors.NO_INPUT_DATA)
5149

52-
validate_function_parameters(
53-
data_class=CreateFineTuningDetails, input_data=input_data
54-
)
55-
56-
self.finish(AquaFineTuningApp().create(CreateFineTuningDetails(**input_data)))
50+
self.finish(AquaFineTuningApp().create(**input_data))
5751

5852
def get_finetuning_config(self, model_id):
5953
"""Gets the finetuning config for Aqua model."""

ads/aqua/extension/model_handler.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
from typing import Optional
@@ -8,6 +8,9 @@
88
from tornado.web import HTTPError
99

1010
from ads.aqua.common.decorator import handle_exceptions
11+
from ads.aqua.common.enums import (
12+
CustomInferenceContainerTypeFamily,
13+
)
1114
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1215
from ads.aqua.common.utils import (
1316
get_hf_model_info,
@@ -128,6 +131,10 @@ def post(self, *args, **kwargs): # noqa: ARG002
128131
download_from_hf = (
129132
str(input_data.get("download_from_hf", "false")).lower() == "true"
130133
)
134+
local_dir = input_data.get("local_dir")
135+
cleanup_model_cache = (
136+
str(input_data.get("cleanup_model_cache", "true")).lower() == "true"
137+
)
131138
inference_container_uri = input_data.get("inference_container_uri")
132139
allow_patterns = input_data.get("allow_patterns")
133140
ignore_patterns = input_data.get("ignore_patterns")
@@ -143,6 +150,8 @@ def post(self, *args, **kwargs): # noqa: ARG002
143150
model=model,
144151
os_path=os_path,
145152
download_from_hf=download_from_hf,
153+
local_dir=local_dir,
154+
cleanup_model_cache=cleanup_model_cache,
146155
inference_container=inference_container,
147156
finetuning_container=finetuning_container,
148157
compartment_id=compartment_id,
@@ -168,7 +177,9 @@ def put(self, id):
168177
raise HTTPError(400, Errors.NO_INPUT_DATA)
169178

170179
inference_container = input_data.get("inference_container")
180+
inference_container_uri = input_data.get("inference_container_uri")
171181
inference_containers = AquaModelApp.list_valid_inference_containers()
182+
inference_containers.extend(CustomInferenceContainerTypeFamily.values())
172183
if (
173184
inference_container is not None
174185
and inference_container not in inference_containers
@@ -181,7 +192,13 @@ def put(self, id):
181192
task = input_data.get("task")
182193
app = AquaModelApp()
183194
self.finish(
184-
app.edit_registered_model(id, inference_container, enable_finetuning, task)
195+
app.edit_registered_model(
196+
id,
197+
inference_container,
198+
inference_container_uri,
199+
enable_finetuning,
200+
task,
201+
)
185202
)
186203
app.clear_model_details_cache(model_id=id)
187204

ads/aqua/finetuning/constants.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

65
from ads.common.extended_enum import ExtendedEnumMeta
@@ -17,4 +16,8 @@ class FineTuneCustomMetadata(str, metaclass=ExtendedEnumMeta):
1716
SERVICE_MODEL_FINE_TUNE_CONTAINER = "finetune-container"
1817

1918

19+
class FineTuningRestrictedParams(str, metaclass=ExtendedEnumMeta):
20+
OPTIMIZER = "optimizer"
21+
22+
2023
ENV_AQUA_FINE_TUNING_CONTAINER = "AQUA_FINE_TUNING_CONTAINER"

ads/aqua/finetuning/entities.py

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4-
from dataclasses import dataclass, field
5-
from typing import List, Optional
64

7-
from ads.aqua.data import AquaJobSummary
8-
from ads.common.serializer import DataClassSerializable
5+
import json
6+
from typing import List, Literal, Optional, Union
97

8+
from pydantic import Field, model_validator
109

11-
@dataclass(repr=False)
12-
class AquaFineTuningParams(DataClassSerializable):
13-
epochs: int
10+
from ads.aqua.common.errors import AquaValueError
11+
from ads.aqua.config.utils.serializer import Serializable
12+
from ads.aqua.data import AquaResourceIdentifier
13+
from ads.aqua.finetuning.constants import FineTuningRestrictedParams
14+
15+
16+
class AquaFineTuningParams(Serializable):
17+
"""Class for maintaining aqua fine-tuning model parameters"""
18+
19+
epochs: Optional[int] = None
1420
learning_rate: Optional[float] = None
15-
sample_packing: Optional[bool] = "auto"
21+
sample_packing: Union[bool, None, Literal["auto"]] = "auto"
1622
batch_size: Optional[int] = (
1723
None # make it batch_size for user, but internally this is micro_batch_size
1824
)
@@ -22,21 +28,59 @@ class AquaFineTuningParams(DataClassSerializable):
2228
lora_alpha: Optional[int] = None
2329
lora_dropout: Optional[float] = None
2430
lora_target_linear: Optional[bool] = None
25-
lora_target_modules: Optional[List] = None
31+
lora_target_modules: Optional[List[str]] = None
2632
early_stopping_patience: Optional[int] = None
2733
early_stopping_threshold: Optional[float] = None
2834

35+
class Config:
36+
extra = "allow"
37+
38+
def to_dict(self) -> dict:
39+
return json.loads(super().to_json(exclude_none=True))
40+
41+
@model_validator(mode="before")
42+
@classmethod
43+
def validate_restricted_fields(cls, data: dict):
44+
# we may want to skip validation if loading data from config files instead of user entered parameters
45+
validate = data.pop("_validate", True)
46+
if not (validate and isinstance(data, dict)):
47+
return data
48+
restricted_params = [
49+
param for param in data if param in FineTuningRestrictedParams.values()
50+
]
51+
if restricted_params:
52+
raise AquaValueError(
53+
f"Found restricted parameter name: {restricted_params}"
54+
)
55+
return data
2956

30-
@dataclass(repr=False)
31-
class AquaFineTuningSummary(AquaJobSummary, DataClassSerializable):
32-
parameters: AquaFineTuningParams = field(default_factory=AquaFineTuningParams)
3357

58+
class AquaFineTuningSummary(Serializable):
59+
"""Represents a summary of Aqua Finetuning job."""
3460

35-
@dataclass(repr=False)
36-
class CreateFineTuningDetails(DataClassSerializable):
37-
"""Dataclass to create aqua model fine tuning.
61+
id: str
62+
name: str
63+
console_url: str
64+
lifecycle_state: str
65+
lifecycle_details: str
66+
time_created: str
67+
tags: dict
68+
experiment: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
69+
source: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
70+
job: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
71+
parameters: AquaFineTuningParams = Field(default_factory=AquaFineTuningParams)
3872

39-
Fields
73+
class Config:
74+
extra = "ignore"
75+
76+
def to_dict(self) -> dict:
77+
return json.loads(super().to_json(exclude_none=True))
78+
79+
80+
class CreateFineTuningDetails(Serializable):
81+
"""Class to create aqua model fine-tuning instance.
82+
83+
Properties
4084
------
4185
ft_source_id: str
4286
The fine tuning source id. Must be model ocid.
@@ -107,3 +151,6 @@ class CreateFineTuningDetails(DataClassSerializable):
107151
force_overwrite: Optional[bool] = False
108152
freeform_tags: Optional[dict] = None
109153
defined_tags: Optional[dict] = None
154+
155+
class Config:
156+
extra = "ignore"

0 commit comments

Comments
 (0)