Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
9cc430b
add ignore validation flag while registering model
VipulMascarenhas Dec 12, 2024
bd4205e
update logging
VipulMascarenhas Dec 12, 2024
9881dcb
Merge branch 'main' into ODSC-65657/ignore_config_validation
VipulMascarenhas Dec 12, 2024
6aaa3ef
improve error message logging
VipulMascarenhas Dec 14, 2024
6b2e4aa
update handler tests
VipulMascarenhas Dec 14, 2024
71269a7
update logging level
VipulMascarenhas Dec 16, 2024
bff4771
Merge branch 'main' into ODSC-65657/ignore_config_validation
VipulMascarenhas Dec 17, 2024
59a042b
log request ids
VipulMascarenhas Jan 3, 2025
a8321a1
update logging for model operations
VipulMascarenhas Jan 3, 2025
fe1eada
Merge branch 'main' into ODSC-65657/ignore_config_validation
VipulMascarenhas Jan 3, 2025
ee8dbf8
update logging for deployment operations
VipulMascarenhas Jan 3, 2025
68f325a
update logging for deployment operations
VipulMascarenhas Jan 4, 2025
86ef7ad
update logging for finetuning operations
VipulMascarenhas Jan 4, 2025
adace88
update logging for finetuning operations
VipulMascarenhas Jan 4, 2025
38e046c
Merge branch 'main' into ODSC-65657/ignore_config_validation
VipulMascarenhas Jan 6, 2025
e3cb8d5
update evaluation validation for create API
VipulMascarenhas Jan 6, 2025
8005294
merge ODSC-65657/ignore_config_validation changes
VipulMascarenhas Jan 6, 2025
b0827e5
update evaluation logging
VipulMascarenhas Jan 6, 2025
ed9504d
update evaluation logging
VipulMascarenhas Jan 6, 2025
fc0caa1
Merge branch 'main' into ODSC-65657/ignore_config_validation
VipulMascarenhas Jan 6, 2025
b83e1e9
Merge branch 'ODSC-65657/ignore_config_validation' of github.com:orac…
VipulMascarenhas Jan 6, 2025
8b603b7
Merge branch 'ODSC-65657/ignore_config_validation' into ODSC-65743/ad…
VipulMascarenhas Jan 6, 2025
ca02f03
update tests
VipulMascarenhas Jan 6, 2025
44d41ce
Merge branch 'main' into ODSC-65657/ignore_config_validation
VipulMascarenhas Jan 6, 2025
d39a1ff
Merge branch 'ODSC-65657/ignore_config_validation' of github.com:orac…
VipulMascarenhas Jan 6, 2025
f4240ce
Merge branch 'ODSC-65657/ignore_config_validation' into ODSC-65743/ad…
VipulMascarenhas Jan 6, 2025
d7ff57d
Merge branch 'main' into ODSC-65657/ignore_config_validation
VipulMascarenhas Jan 8, 2025
fde6c46
Merge branch 'main' into ODSC-65657/ignore_config_validation
VipulMascarenhas Jan 9, 2025
ed4098d
add missing request id
VipulMascarenhas Jan 10, 2025
8db09df
Additional logging statements for AI Quick Actions operations (#1034)
VipulMascarenhas Jan 10, 2025
6b70096
Merge branch 'main' into ODSC-65657/ignore_config_validation
VipulMascarenhas Jan 10, 2025
5621a06
Merge branch 'main' into ODSC-65657/ignore_config_validation
VipulMascarenhas Jan 13, 2025
4cbd7e0
revert to previous validation
VipulMascarenhas Jan 14, 2025
3bc31e5
Resolve merge conflicts
VipulMascarenhas Jan 31, 2025
7616ff2
fix tests after merge
VipulMascarenhas Jan 31, 2025
85a825b
Merge branch 'main' into ODSC-65657/ignore_config_validation
mayoor Feb 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ads/aqua/extension/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def post(self, *args, **kwargs): # noqa: ARG002
ignore_patterns = input_data.get("ignore_patterns")
freeform_tags = input_data.get("freeform_tags")
defined_tags = input_data.get("defined_tags")
ignore_model_artifact_check = (
str(input_data.get("ignore_model_artifact_check", "false")).lower()
== "true"
)

return self.finish(
AquaModelApp().register(
Expand All @@ -149,6 +153,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
ignore_patterns=ignore_patterns,
freeform_tags=freeform_tags,
defined_tags=defined_tags,
ignore_model_artifact_check=ignore_model_artifact_check,
)
)

Expand Down
1 change: 1 addition & 0 deletions ads/aqua/model/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ class ImportModelDetails(CLIBuilderMixin):
ignore_patterns: Optional[List[str]] = None
freeform_tags: Optional[dict] = None
defined_tags: Optional[dict] = None
ignore_model_artifact_check: Optional[bool] = None

def __post_init__(self):
self._command = "model register"
116 changes: 78 additions & 38 deletions ads/aqua/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
InferenceContainerTypeFamily,
Tags,
)
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
from ads.aqua.common.errors import (
AquaFileNotFoundError,
AquaRuntimeError,
AquaValueError,
)
from ads.aqua.common.utils import (
LifecycleStatus,
_build_resource_identifier,
Expand Down Expand Up @@ -972,13 +976,23 @@ def get_model_files(os_path: str, model_format: ModelFormat) -> List[str]:
# todo: revisit this logic to account for .bin files. In the current state, .bin and .safetensor models
# are grouped in one category and validation checks for config.json files only.
if model_format == ModelFormat.SAFETENSORS:
model_files.extend(
list_os_files_with_extension(oss_path=os_path, extension=".safetensors")
)
try:
load_config(
file_path=os_path,
config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
)
except Exception:
pass
except Exception as ex:
message = (
f"The model path {os_path} does not contain the file config.json. "
f"Please check if the path is correct or the model artifacts are available at this location."
)
logger.warning(
f"{message}\n"
f"Details: {ex.reason if isinstance(ex, AquaFileNotFoundError) else str(ex)}\n"
)
else:
model_files.append(AQUA_MODEL_ARTIFACT_CONFIG)

Expand Down Expand Up @@ -1022,10 +1036,12 @@ def get_hf_model_files(model_name: str, model_format: ModelFormat) -> List[str]:

for model_sibling in model_siblings:
extension = pathlib.Path(model_sibling.rfilename).suffix[1:].upper()
if model_format == ModelFormat.SAFETENSORS:
if model_sibling.rfilename == AQUA_MODEL_ARTIFACT_CONFIG:
model_files.append(model_sibling.rfilename)
elif extension == model_format.value:
if (
model_format == ModelFormat.SAFETENSORS
and model_sibling.rfilename == AQUA_MODEL_ARTIFACT_CONFIG
):
model_files.append(model_sibling.rfilename)
if extension == model_format.value:
model_files.append(model_sibling.rfilename)

return model_files
Expand Down Expand Up @@ -1061,7 +1077,10 @@ def _validate_model(
safetensors_model_files = self.get_hf_model_files(
model_name, ModelFormat.SAFETENSORS
)
if safetensors_model_files:
if (
safetensors_model_files
and AQUA_MODEL_ARTIFACT_CONFIG in safetensors_model_files
):
hf_download_config_present = True
gguf_model_files = self.get_hf_model_files(model_name, ModelFormat.GGUF)
else:
Expand Down Expand Up @@ -1117,8 +1136,11 @@ def _validate_model(
Tags.LICENSE: license_value,
}
validation_result.tags = hf_tags
except Exception:
pass
except Exception as ex:
logger.debug(
f"An error occurred while getting tag information for model {model_name}. "
f"Error: {str(ex)}"
)

validation_result.model_formats = model_formats

Expand Down Expand Up @@ -1173,40 +1195,55 @@ def _validate_safetensor_format(
model_name: str = None,
):
if import_model_details.download_from_hf:
# validates config.json exists for safetensors model from hugginface
if not hf_download_config_present:
# validates config.json exists for safetensors model from huggingface
if not (
hf_download_config_present
or import_model_details.ignore_model_artifact_check
):
raise AquaRuntimeError(
f"The model {model_name} does not contain {AQUA_MODEL_ARTIFACT_CONFIG} file as required "
f"by {ModelFormat.SAFETENSORS.value} format model."
f" Please check if the model name is correct in Hugging Face repository."
)
validation_result.telemetry_model_name = model_name
else:
# validate if config.json is available from object storage, and get model name for telemetry
model_config = None
try:
model_config = load_config(
file_path=import_model_details.os_path,
config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
)
except Exception as ex:
logger.error(
f"Exception occurred while loading config file from {import_model_details.os_path}"
f"Exception message: {ex}"
)
raise AquaRuntimeError(
message = (
f"The model path {import_model_details.os_path} does not contain the file config.json. "
f"Please check if the path is correct or the model artifacts are available at this location."
) from ex
else:
)
if not import_model_details.ignore_model_artifact_check:
logger.error(
f"{message}\n"
f"Details: {ex.reason if isinstance(ex, AquaFileNotFoundError) else str(ex)}"
)
raise AquaRuntimeError(message) from ex
else:
logger.warning(
f"{message}\n"
f"Proceeding with model registration as ignore_model_artifact_check field is set."
)

if verified_model:
# model_type validation, log message if metadata field doesn't match.
try:
metadata_model_type = verified_model.custom_metadata_list.get(
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
).value
if metadata_model_type:
if metadata_model_type and model_config is not None:
if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config:
if (
model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]
!= metadata_model_type
):
raise AquaRuntimeError(
logger.debug(
f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}"
f" at {import_model_details.os_path} is invalid, expected {metadata_model_type} for "
f"the model {model_name}. Please check if the path is correct or "
Expand All @@ -1218,22 +1255,26 @@ def _validate_safetensor_format(
f"Could not find {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in "
f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration."
)
except Exception:
pass
if verified_model:
validation_result.telemetry_model_name = verified_model.display_name
elif (
model_config is not None
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
):
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
elif (
model_config is not None
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config
):
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
else:
validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
except Exception as ex:
# todo: raise exception if model_type doesn't match. Currently log message and pass since service
# models do not have this metadata.
logger.debug(
f"Error occurred while processing metadata for model {model_name}. "
f"Exception: {str(ex)}"
)
validation_result.telemetry_model_name = verified_model.display_name
elif (
model_config is not None
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
):
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
elif (
model_config is not None
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config
):
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
else:
validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM

@staticmethod
def _validate_gguf_format(
Expand Down Expand Up @@ -1416,7 +1457,6 @@ def register(
).rstrip("/")
else:
artifact_path = import_model_details.os_path.rstrip("/")

# Create Model catalog entry with pass by reference
ds_model = self._create_model_catalog_entry(
os_path=artifact_path,
Expand Down
85 changes: 61 additions & 24 deletions tests/unitary/with_extras/aqua/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,10 +920,18 @@ def test_import_model_with_project_compartment_override(
assert model.project_id == project_override

@pytest.mark.parametrize(
"download_from_hf",
[True, False],
("ignore_artifact_check", "download_from_hf"),
[
(True, True),
(True, False),
(False, True),
(False, False),
(None, False),
(None, True),
],
)
@patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create")
@patch("ads.model.datascience_model.DataScienceModel.sync")
@patch("ads.model.datascience_model.DataScienceModel.upload_artifact")
@patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects")
@patch("ads.aqua.common.utils.load_config", side_effect=AquaFileNotFoundError)
Expand All @@ -936,45 +944,65 @@ def test_import_model_with_missing_config(
mock_load_config,
mock_list_objects,
mock_upload_artifact,
mock_sync,
mock_ocidsc_create,
mock_get_container_config,
ignore_artifact_check,
download_from_hf,
mock_get_hf_model_info,
mock_init_client,
):
"""Test for validating if error is returned when model artifacts are incomplete or not available."""

os_path = "oci://aqua-bkt@aqua-ns/prefix/path"
model_name = "oracle/aqua-1t-mega-model"
my_model = "oracle/aqua-1t-mega-model"
ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True)
mock_list_objects.return_value = MagicMock(objects=[])
reload(ads.aqua.model.model)
app = AquaModelApp()
app.list = MagicMock(return_value=[])
# set object list from OSS without config.json
os_path = "oci://aqua-bkt@aqua-ns/prefix/path"

# set object list from HF without config.json
if download_from_hf:
with pytest.raises(AquaValueError):
mock_get_hf_model_info.return_value.siblings = []
with tempfile.TemporaryDirectory() as tmpdir:
model: AquaModel = app.register(
model=model_name,
os_path=os_path,
local_dir=str(tmpdir),
download_from_hf=True,
)
mock_get_hf_model_info.return_value.siblings = [
MagicMock(rfilename="model.safetensors")
]
else:
with pytest.raises(AquaRuntimeError):
obj1 = MagicMock(etag="12345-1234-1234-1234-123456789", size=150)
obj1.name = f"prefix/path/model.safetensors"
objects = [obj1]
mock_list_objects.return_value = MagicMock(objects=objects)

reload(ads.aqua.model.model)
app = AquaModelApp()
with patch.object(AquaModelApp, "list") as aqua_model_mock_list:
aqua_model_mock_list.return_value = [
AquaModelSummary(
id="test_id1",
name="organization1/name1",
organization="organization1",
)
]

if ignore_artifact_check:
model: AquaModel = app.register(
model=model_name,
model=my_model,
os_path=os_path,
download_from_hf=False,
inference_container="odsc-vllm-or-tgi-container",
finetuning_container="odsc-llm-fine-tuning",
download_from_hf=download_from_hf,
ignore_model_artifact_check=ignore_artifact_check,
)
assert model.ready_to_deploy is True
else:
with pytest.raises(AquaRuntimeError):
model: AquaModel = app.register(
model=my_model,
os_path=os_path,
inference_container="odsc-vllm-or-tgi-container",
finetuning_container="odsc-llm-fine-tuning",
download_from_hf=download_from_hf,
ignore_model_artifact_check=ignore_artifact_check,
)

@patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create")
@patch("ads.model.datascience_model.DataScienceModel.sync")
@patch("ads.model.datascience_model.DataScienceModel.upload_artifact")
@patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects")
@patch.object(HfApi, "model_info")
@patch("ads.aqua.common.utils.load_config", return_value={})
def test_import_any_model_smc_container(
self,
Expand Down Expand Up @@ -1230,6 +1258,15 @@ def test_import_model_with_input_tags(
"--download_from_hf True --inference_container odsc-vllm-serving --freeform_tags "
'{"ftag1": "fvalue1", "ftag2": "fvalue2"} --defined_tags {"dtag1": "dvalue1", "dtag2": "dvalue2"}',
),
(
{
"os_path": "oci://aqua-bkt@aqua-ns/path",
"model": "oracle/oracle-1it",
"inference_container": "odsc-vllm-serving",
"ignore_model_artifact_check": True,
},
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --inference_container odsc-vllm-serving --ignore_model_artifact_check True",
),
],
)
def test_import_cli(self, data, expected_output):
Expand Down
Loading
Loading