10
10
import os
11
11
import random
12
12
import re
13
+ import shlex
14
+ import subprocess
15
+ from datetime import datetime , timedelta
13
16
from functools import wraps
14
17
from pathlib import Path
15
18
from string import Template
16
19
from typing import List , Union
17
20
18
21
import fsspec
19
22
import oci
23
+ from cachetools import TTLCache , cached
24
+ from huggingface_hub .hf_api import HfApi , ModelInfo
25
+ from huggingface_hub .utils import (
26
+ GatedRepoError ,
27
+ HfHubHTTPError ,
28
+ RepositoryNotFoundError ,
29
+ RevisionNotFoundError ,
30
+ )
20
31
from oci .data_science .models import JobRun , Model
32
+ from oci .object_storage .models import ObjectSummary
21
33
22
34
from ads .aqua .common .enums import (
23
35
InferenceContainerParamType ,
34
46
COMPARTMENT_MAPPING_KEY ,
35
47
CONSOLE_LINK_RESOURCE_TYPE_MAPPING ,
36
48
CONTAINER_INDEX ,
49
+ HF_LOGIN_DEFAULT_TIMEOUT ,
37
50
MAXIMUM_ALLOWED_DATASET_IN_BYTE ,
38
51
MODEL_BY_REFERENCE_OSS_PATH_KEY ,
39
52
SERVICE_MANAGED_CONTAINER_URI_SCHEME ,
44
57
VLLM_INFERENCE_RESTRICTED_PARAMS ,
45
58
)
46
59
from ads .aqua .data import AquaResourceIdentifier
47
- from ads .common .auth import default_signer
48
- from ads .common .decorator .threaded import threaded
60
+ from ads .common .auth import AuthState , default_signer
49
61
from ads .common .extended_enum import ExtendedEnumMeta
50
62
from ads .common .object_storage_details import ObjectStorageDetails
51
63
from ads .common .oci_resource import SEARCH_TYPE , OCIResource
@@ -213,7 +225,6 @@ def read_file(file_path: str, **kwargs) -> str:
213
225
return UNKNOWN
214
226
215
227
216
- @threaded ()
217
228
def load_config (file_path : str , config_file_name : str , ** kwargs ) -> dict :
218
229
artifact_path = f"{ file_path .rstrip ('/' )} /{ config_file_name } "
219
230
signer = default_signer () if artifact_path .startswith ("oci://" ) else {}
@@ -228,6 +239,32 @@ def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
228
239
return config
229
240
230
241
242
+ def list_os_files_with_extension (oss_path : str , extension : str ) -> [str ]:
243
+ """
244
+ List files in the specified directory with the given extension.
245
+
246
+ Parameters:
247
+ - oss_path: The path to the directory where files are located.
248
+ - extension: The file extension to filter by (e.g., 'txt' for text files).
249
+
250
+ Returns:
251
+ - A list of file paths matching the specified extension.
252
+ """
253
+
254
+ oss_client = ObjectStorageDetails .from_path (oss_path )
255
+
256
+ # Ensure the extension is prefixed with a dot if not already
257
+ if not extension .startswith ("." ):
258
+ extension = "." + extension
259
+ files : List [ObjectSummary ] = oss_client .list_objects ().objects
260
+
261
+ return [
262
+ file .name [len (oss_client .filepath ) :].lstrip ("/" )
263
+ for file in files
264
+ if file .name .endswith (extension )
265
+ ]
266
+
267
+
231
268
def is_valid_ocid (ocid : str ) -> bool :
232
269
"""Checks if the given ocid is valid.
233
270
@@ -503,6 +540,7 @@ def container_config_path():
503
540
return f"oci://{ AQUA_SERVICE_MODELS_BUCKET } @{ CONDA_BUCKET_NS } /service_models/config"
504
541
505
542
543
+ @cached (cache = TTLCache (maxsize = 1 , ttl = timedelta (hours = 5 ), timer = datetime .now ))
506
544
def get_container_config ():
507
545
config = load_config (
508
546
file_path = container_config_path (),
@@ -743,6 +781,33 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
743
781
return ocid [- key_len :] if ocid and len (ocid ) > key_len else ""
744
782
745
783
784
+ def upload_folder (os_path : str , local_dir : str , model_name : str ) -> str :
785
+ """Upload the local folder to the object storage
786
+
787
+ Args:
788
+ os_path (str): object storage URI with prefix. This is the path to upload
789
+ local_dir (str): Local directory where the object is downloaded
790
+ model_name (str): Name of the huggingface model
791
+ Retuns:
792
+ str: Object name inside the bucket
793
+ """
794
+ os_details : ObjectStorageDetails = ObjectStorageDetails .from_path (os_path )
795
+ if not os_details .is_bucket_versioned ():
796
+ raise ValueError (f"Version is not enabled at object storage location { os_path } " )
797
+ auth_state = AuthState ()
798
+ object_path = os_details .filepath .rstrip ("/" ) + "/" + model_name + "/"
799
+ command = f"oci os object bulk-upload --src-dir { local_dir } --prefix { object_path } -bn { os_details .bucket } -ns { os_details .namespace } --auth { auth_state .oci_iam_type } --profile { auth_state .oci_key_profile } --no-overwrite"
800
+ try :
801
+ logger .info (f"Running: { command } " )
802
+ subprocess .check_call (shlex .split (command ))
803
+ except subprocess .CalledProcessError as e :
804
+ logger .error (
805
+ f"Error uploading the object. Exit code: { e .returncode } with error { e .stdout } "
806
+ )
807
+
808
+ return f"oci://{ os_details .bucket } @{ os_details .namespace } " + "/" + object_path
809
+
810
+
746
811
def is_service_managed_container (container ):
747
812
return container and container .startswith (SERVICE_MANAGED_CONTAINER_URI_SCHEME )
748
813
@@ -881,6 +946,8 @@ def get_container_params_type(container_type_name: str) -> str:
881
946
return InferenceContainerParamType .PARAM_TYPE_VLLM
882
947
elif InferenceContainerType .CONTAINER_TYPE_TGI in container_type_name .lower ():
883
948
return InferenceContainerParamType .PARAM_TYPE_TGI
949
+ elif InferenceContainerType .CONTAINER_TYPE_LLAMA_CPP in container_type_name .lower ():
950
+ return InferenceContainerParamType .PARAM_TYPE_LLAMA_CPP
884
951
else :
885
952
return UNKNOWN
886
953
@@ -905,3 +972,93 @@ def get_restricted_params_by_container(container_type_name: str) -> set:
905
972
return TGI_INFERENCE_RESTRICTED_PARAMS
906
973
else :
907
974
return set ()
975
+
976
+
977
+ def get_huggingface_login_timeout () -> int :
978
+ """This helper function returns the huggingface login timeout, returns default if not set via
979
+ env var.
980
+ Returns
981
+ -------
982
+ timeout: int
983
+ huggingface login timeout.
984
+
985
+ """
986
+ timeout = HF_LOGIN_DEFAULT_TIMEOUT
987
+ try :
988
+ timeout = int (
989
+ os .environ .get ("HF_LOGIN_DEFAULT_TIMEOUT" , HF_LOGIN_DEFAULT_TIMEOUT )
990
+ )
991
+ except ValueError :
992
+ pass
993
+ return timeout
994
+
995
+
996
+ def format_hf_custom_error_message (error : HfHubHTTPError ):
997
+ """
998
+ Formats a custom error message based on the Hugging Face error response.
999
+
1000
+ Parameters
1001
+ ----------
1002
+ error (HfHubHTTPError): The caught exception.
1003
+
1004
+ Raises
1005
+ ------
1006
+ AquaRuntimeError: A user-friendly error message.
1007
+ """
1008
+ # Extract the repository URL from the error message if present
1009
+ match = re .search (r"(https://huggingface.co/[^\s]+)" , str (error ))
1010
+ url = match .group (1 ) if match else "the requested Hugging Face URL."
1011
+
1012
+ if isinstance (error , RepositoryNotFoundError ):
1013
+ raise AquaRuntimeError (
1014
+ reason = f"Failed to access `{ url } `. Please check if the provided repository name is correct. "
1015
+ "If the repo is private, make sure you are authenticated and have a valid HF token registered. "
1016
+ "To register your token, run this command in your terminal: `huggingface-cli login`" ,
1017
+ service_payload = {"error" : "RepositoryNotFoundError" },
1018
+ )
1019
+
1020
+ if isinstance (error , GatedRepoError ):
1021
+ raise AquaRuntimeError (
1022
+ reason = f"Access denied to `{ url } ` "
1023
+ "This repository is gated. Access is restricted to authorized users. "
1024
+ "Please request access or check with the repository administrator. "
1025
+ "If you are trying to access a gated repository, ensure you have a valid HF token registered. "
1026
+ "To register your token, run this command in your terminal: `huggingface-cli login`" ,
1027
+ service_payload = {"error" : "GatedRepoError" },
1028
+ )
1029
+
1030
+ if isinstance (error , RevisionNotFoundError ):
1031
+ raise AquaRuntimeError (
1032
+ reason = f"The specified revision could not be found at `{ url } ` "
1033
+ "Please check the revision identifier and try again." ,
1034
+ service_payload = {"error" : "RevisionNotFoundError" },
1035
+ )
1036
+
1037
+ raise AquaRuntimeError (
1038
+ reason = f"An error occurred while accessing `{ url } ` "
1039
+ "Please check your network connection and try again. "
1040
+ "If you are trying to access a gated repository, ensure you have a valid HF token registered. "
1041
+ "To register your token, run this command in your terminal: `huggingface-cli login`" ,
1042
+ service_payload = {"error" : "Error" },
1043
+ )
1044
+
1045
+
1046
+ @cached (cache = TTLCache (maxsize = 1 , ttl = timedelta (hours = 5 ), timer = datetime .now ))
1047
+ def get_hf_model_info (repo_id : str ) -> ModelInfo :
1048
+ """Gets the model information object for the given model repository name. For models that requires a token,
1049
+ this method assumes that the token validation is already done.
1050
+
1051
+ Parameters
1052
+ ----------
1053
+ repo_id: str
1054
+ hugging face model repository name
1055
+
1056
+ Returns
1057
+ -------
1058
+ instance of ModelInfo object
1059
+
1060
+ """
1061
+ try :
1062
+ return HfApi ().model_info (repo_id = repo_id )
1063
+ except HfHubHTTPError as err :
1064
+ raise format_hf_custom_error_message (err ) from err
0 commit comments