Skip to content

Commit 19dccab

Browse files
authored
Merge branch 'main' into conda_pack_build_related_fixes
2 parents f07bf91 + 55740a6 commit 19dccab

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+2924
-473
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ repos:
4545
rev: v8.18.4
4646
hooks:
4747
- id: gitleaks
48-
exclude: .github/workflows/reusable-actions/set-dummy-conf.yml
48+
exclude: .github/workflows/reusable-actions/set-dummy-conf.yml|./tests/operators/common/test_load_data.py
4949
# Oracle copyright checker
5050
- repo: https://github.com/oracle-samples/oci-data-science-ai-samples/
5151
rev: 1bc5270a443b791c62f634233c0f4966dfcc0dd6

CODEOWNERS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
* @darenr @mayoor @mrDzurb @VipulMascarenhas @qiuosier
1+
* @darenr @mayoor @mrDzurb @VipulMascarenhas @qiuosier @ahosler

ads/aqua/app.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

@@ -175,7 +174,7 @@ def create_model_version_set(
175174
f"Invalid model version set name. Please provide a model version set with `{tag}` in tags."
176175
)
177176

178-
except:
177+
except Exception:
179178
logger.debug(
180179
f"Model version set {model_version_set_name} doesn't exist. "
181180
"Creating new model version set."
@@ -254,7 +253,7 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
254253

255254
try:
256255
response = self.ds_client.head_model_artifact(model_id=model_id, **kwargs)
257-
return True if response.status == 200 else False
256+
return response.status == 200
258257
except oci.exceptions.ServiceError as ex:
259258
if ex.status == 404:
260259
logger.info(f"Artifact not found in model {model_id}.")
@@ -302,15 +301,15 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
302301
config_path,
303302
config_file_name=config_file_name,
304303
)
305-
except:
304+
except Exception:
306305
# todo: temp fix for issue related to config load for byom models, update logic to choose the right path
307306
try:
308307
config_path = f"{artifact_path.rstrip('/')}/config/"
309308
config = load_config(
310309
config_path,
311310
config_file_name=config_file_name,
312311
)
313-
except:
312+
except Exception:
314313
pass
315314

316315
if not config:
@@ -343,7 +342,7 @@ def build_cli(self) -> str:
343342
params = [
344343
f"--{field.name} {getattr(self,field.name)}"
345344
for field in fields(self.__class__)
346-
if getattr(self, field.name)
345+
if getattr(self, field.name) is not None
347346
]
348347
cmd = f"{cmd} {' '.join(params)}"
349348
return cmd

ads/aqua/common/entities.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2024 Oracle and/or its affiliates.
3+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4+
5+
6+
class ContainerSpec:
7+
"""
8+
Class to hold to hold keys within the container spec.
9+
"""
10+
11+
CONTAINER_SPEC = "containerSpec"
12+
CLI_PARM = "cliParam"
13+
SERVER_PORT = "serverPort"
14+
HEALTH_CHECK_PORT = "healthCheckPort"
15+
ENV_VARS = "envVars"
16+
RESTRICTED_PARAMS = "restrictedParams"
17+
EVALUATION_CONFIGURATION = "evaluationConfiguration"

ads/aqua/common/enums.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

@@ -8,6 +7,7 @@
87
~~~~~~~~~~~~~~
98
This module contains the set of enums used in AQUA.
109
"""
10+
1111
from ads.common.extended_enum import ExtendedEnumMeta
1212

1313

@@ -38,21 +38,34 @@ class Tags(str, metaclass=ExtendedEnumMeta):
3838
READY_TO_IMPORT = "ready_to_import"
3939
BASE_MODEL_CUSTOM = "aqua_custom_base_model"
4040
AQUA_EVALUATION_MODEL_ID = "evaluation_model_id"
41+
MODEL_FORMAT = "model_format"
42+
MODEL_ARTIFACT_FILE = "model_file"
4143

4244

4345
class InferenceContainerType(str, metaclass=ExtendedEnumMeta):
4446
CONTAINER_TYPE_VLLM = "vllm"
4547
CONTAINER_TYPE_TGI = "tgi"
48+
CONTAINER_TYPE_LLAMA_CPP = "llama-cpp"
4649

4750

4851
class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
4952
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
5053
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
54+
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
5155

5256

5357
class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):
5458
PARAM_TYPE_VLLM = "VLLM_PARAMS"
5559
PARAM_TYPE_TGI = "TGI_PARAMS"
60+
PARAM_TYPE_LLAMA_CPP = "LLAMA_CPP_PARAMS"
61+
62+
63+
class EvaluationContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
64+
AQUA_EVALUATION_CONTAINER_FAMILY = "odsc-llm-evaluate"
65+
66+
67+
class FineTuningContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
68+
AQUA_FINETUNING_CONTAINER_FAMILY = "odsc-llm-fine-tuning"
5669

5770

5871
class HuggingFaceTags(str, metaclass=ExtendedEnumMeta):

ads/aqua/common/utils.py

Lines changed: 160 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,26 @@
1010
import os
1111
import random
1212
import re
13+
import shlex
14+
import subprocess
15+
from datetime import datetime, timedelta
1316
from functools import wraps
1417
from pathlib import Path
1518
from string import Template
1619
from typing import List, Union
1720

1821
import fsspec
1922
import oci
23+
from cachetools import TTLCache, cached
24+
from huggingface_hub.hf_api import HfApi, ModelInfo
25+
from huggingface_hub.utils import (
26+
GatedRepoError,
27+
HfHubHTTPError,
28+
RepositoryNotFoundError,
29+
RevisionNotFoundError,
30+
)
2031
from oci.data_science.models import JobRun, Model
32+
from oci.object_storage.models import ObjectSummary
2133

2234
from ads.aqua.common.enums import (
2335
InferenceContainerParamType,
@@ -34,6 +46,7 @@
3446
COMPARTMENT_MAPPING_KEY,
3547
CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
3648
CONTAINER_INDEX,
49+
HF_LOGIN_DEFAULT_TIMEOUT,
3750
MAXIMUM_ALLOWED_DATASET_IN_BYTE,
3851
MODEL_BY_REFERENCE_OSS_PATH_KEY,
3952
SERVICE_MANAGED_CONTAINER_URI_SCHEME,
@@ -44,8 +57,7 @@
4457
VLLM_INFERENCE_RESTRICTED_PARAMS,
4558
)
4659
from ads.aqua.data import AquaResourceIdentifier
47-
from ads.common.auth import default_signer
48-
from ads.common.decorator.threaded import threaded
60+
from ads.common.auth import AuthState, default_signer
4961
from ads.common.extended_enum import ExtendedEnumMeta
5062
from ads.common.object_storage_details import ObjectStorageDetails
5163
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
@@ -213,7 +225,6 @@ def read_file(file_path: str, **kwargs) -> str:
213225
return UNKNOWN
214226

215227

216-
@threaded()
217228
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
218229
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
219230
signer = default_signer() if artifact_path.startswith("oci://") else {}
@@ -228,6 +239,32 @@ def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
228239
return config
229240

230241

242+
def list_os_files_with_extension(oss_path: str, extension: str) -> [str]:
243+
"""
244+
List files in the specified directory with the given extension.
245+
246+
Parameters:
247+
- oss_path: The path to the directory where files are located.
248+
- extension: The file extension to filter by (e.g., 'txt' for text files).
249+
250+
Returns:
251+
- A list of file paths matching the specified extension.
252+
"""
253+
254+
oss_client = ObjectStorageDetails.from_path(oss_path)
255+
256+
# Ensure the extension is prefixed with a dot if not already
257+
if not extension.startswith("."):
258+
extension = "." + extension
259+
files: List[ObjectSummary] = oss_client.list_objects().objects
260+
261+
return [
262+
file.name[len(oss_client.filepath) :].lstrip("/")
263+
for file in files
264+
if file.name.endswith(extension)
265+
]
266+
267+
231268
def is_valid_ocid(ocid: str) -> bool:
232269
"""Checks if the given ocid is valid.
233270
@@ -503,6 +540,7 @@ def container_config_path():
503540
return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
504541

505542

543+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
506544
def get_container_config():
507545
config = load_config(
508546
file_path=container_config_path(),
@@ -743,6 +781,33 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
743781
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
744782

745783

784+
def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
785+
"""Upload the local folder to the object storage
786+
787+
Args:
788+
os_path (str): object storage URI with prefix. This is the path to upload
789+
local_dir (str): Local directory where the object is downloaded
790+
model_name (str): Name of the huggingface model
791+
Retuns:
792+
str: Object name inside the bucket
793+
"""
794+
os_details: ObjectStorageDetails = ObjectStorageDetails.from_path(os_path)
795+
if not os_details.is_bucket_versioned():
796+
raise ValueError(f"Version is not enabled at object storage location {os_path}")
797+
auth_state = AuthState()
798+
object_path = os_details.filepath.rstrip("/") + "/" + model_name + "/"
799+
command = f"oci os object bulk-upload --src-dir {local_dir} --prefix {object_path} -bn {os_details.bucket} -ns {os_details.namespace} --auth {auth_state.oci_iam_type} --profile {auth_state.oci_key_profile} --no-overwrite"
800+
try:
801+
logger.info(f"Running: {command}")
802+
subprocess.check_call(shlex.split(command))
803+
except subprocess.CalledProcessError as e:
804+
logger.error(
805+
f"Error uploading the object. Exit code: {e.returncode} with error {e.stdout}"
806+
)
807+
808+
return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path
809+
810+
746811
def is_service_managed_container(container):
747812
return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
748813

@@ -881,6 +946,8 @@ def get_container_params_type(container_type_name: str) -> str:
881946
return InferenceContainerParamType.PARAM_TYPE_VLLM
882947
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name.lower():
883948
return InferenceContainerParamType.PARAM_TYPE_TGI
949+
elif InferenceContainerType.CONTAINER_TYPE_LLAMA_CPP in container_type_name.lower():
950+
return InferenceContainerParamType.PARAM_TYPE_LLAMA_CPP
884951
else:
885952
return UNKNOWN
886953

@@ -905,3 +972,93 @@ def get_restricted_params_by_container(container_type_name: str) -> set:
905972
return TGI_INFERENCE_RESTRICTED_PARAMS
906973
else:
907974
return set()
975+
976+
977+
def get_huggingface_login_timeout() -> int:
978+
"""This helper function returns the huggingface login timeout, returns default if not set via
979+
env var.
980+
Returns
981+
-------
982+
timeout: int
983+
huggingface login timeout.
984+
985+
"""
986+
timeout = HF_LOGIN_DEFAULT_TIMEOUT
987+
try:
988+
timeout = int(
989+
os.environ.get("HF_LOGIN_DEFAULT_TIMEOUT", HF_LOGIN_DEFAULT_TIMEOUT)
990+
)
991+
except ValueError:
992+
pass
993+
return timeout
994+
995+
996+
def format_hf_custom_error_message(error: HfHubHTTPError):
997+
"""
998+
Formats a custom error message based on the Hugging Face error response.
999+
1000+
Parameters
1001+
----------
1002+
error (HfHubHTTPError): The caught exception.
1003+
1004+
Raises
1005+
------
1006+
AquaRuntimeError: A user-friendly error message.
1007+
"""
1008+
# Extract the repository URL from the error message if present
1009+
match = re.search(r"(https://huggingface.co/[^\s]+)", str(error))
1010+
url = match.group(1) if match else "the requested Hugging Face URL."
1011+
1012+
if isinstance(error, RepositoryNotFoundError):
1013+
raise AquaRuntimeError(
1014+
reason=f"Failed to access `{url}`. Please check if the provided repository name is correct. "
1015+
"If the repo is private, make sure you are authenticated and have a valid HF token registered. "
1016+
"To register your token, run this command in your terminal: `huggingface-cli login`",
1017+
service_payload={"error": "RepositoryNotFoundError"},
1018+
)
1019+
1020+
if isinstance(error, GatedRepoError):
1021+
raise AquaRuntimeError(
1022+
reason=f"Access denied to `{url}` "
1023+
"This repository is gated. Access is restricted to authorized users. "
1024+
"Please request access or check with the repository administrator. "
1025+
"If you are trying to access a gated repository, ensure you have a valid HF token registered. "
1026+
"To register your token, run this command in your terminal: `huggingface-cli login`",
1027+
service_payload={"error": "GatedRepoError"},
1028+
)
1029+
1030+
if isinstance(error, RevisionNotFoundError):
1031+
raise AquaRuntimeError(
1032+
reason=f"The specified revision could not be found at `{url}` "
1033+
"Please check the revision identifier and try again.",
1034+
service_payload={"error": "RevisionNotFoundError"},
1035+
)
1036+
1037+
raise AquaRuntimeError(
1038+
reason=f"An error occurred while accessing `{url}` "
1039+
"Please check your network connection and try again. "
1040+
"If you are trying to access a gated repository, ensure you have a valid HF token registered. "
1041+
"To register your token, run this command in your terminal: `huggingface-cli login`",
1042+
service_payload={"error": "Error"},
1043+
)
1044+
1045+
1046+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
1047+
def get_hf_model_info(repo_id: str) -> ModelInfo:
1048+
"""Gets the model information object for the given model repository name. For models that requires a token,
1049+
this method assumes that the token validation is already done.
1050+
1051+
Parameters
1052+
----------
1053+
repo_id: str
1054+
hugging face model repository name
1055+
1056+
Returns
1057+
-------
1058+
instance of ModelInfo object
1059+
1060+
"""
1061+
try:
1062+
return HfApi().model_info(repo_id=repo_id)
1063+
except HfHubHTTPError as err:
1064+
raise format_hf_custom_error_message(err) from err

ads/aqua/config/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

@@ -14,5 +13,6 @@ def get_finetuning_config_defaults():
1413
"BM.GPU.A10.4": {"batch_size": 1, "replica": 1},
1514
"BM.GPU4.8": {"batch_size": 4, "replica": 1},
1615
"BM.GPU.A100-v2.8": {"batch_size": 6, "replica": 1},
16+
"BM.GPU.H100.8": {"batch_size": 6, "replica": 1},
1717
}
1818
}

0 commit comments

Comments
 (0)