Skip to content

Commit a0645cb

Browse files
Merge branch 'main' into ODSC-66822/allow_cache_delete
2 parents 583aaea + 410dbe0 commit a0645cb

File tree

29 files changed

+3556
-346
lines changed

29 files changed

+3556
-346
lines changed

.github/workflows/run-forecast-unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,6 @@ jobs:
5656
$CONDA/bin/conda init
5757
source /home/runner/.bashrc
5858
pip install -r test-requirements-operators.txt
59-
pip install "oracle-automlx[forecasting]>=24.4.0"
59+
pip install "oracle-automlx[forecasting]>=24.4.1"
6060
pip install pandas>=2.2.0
6161
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast

ads/aqua/common/enums.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5252
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
5353
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
5454
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
55+
56+
57+
class CustomInferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5558
AQUA_TEI_CONTAINER_FAMILY = "odsc-tei-serving"
5659

5760

ads/aqua/extension/model_handler.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from tornado.web import HTTPError
99

1010
from ads.aqua.common.decorator import handle_exceptions
11+
from ads.aqua.common.enums import (
12+
CustomInferenceContainerTypeFamily,
13+
)
1114
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1215
from ads.aqua.common.utils import (
1316
get_hf_model_info,
@@ -169,7 +172,9 @@ def put(self, id):
169172
raise HTTPError(400, Errors.NO_INPUT_DATA)
170173

171174
inference_container = input_data.get("inference_container")
175+
inference_container_uri = input_data.get("inference_container_uri")
172176
inference_containers = AquaModelApp.list_valid_inference_containers()
177+
inference_containers.extend(CustomInferenceContainerTypeFamily.values())
173178
if (
174179
inference_container is not None
175180
and inference_container not in inference_containers
@@ -182,7 +187,13 @@ def put(self, id):
182187
task = input_data.get("task")
183188
app = AquaModelApp()
184189
self.finish(
185-
app.edit_registered_model(id, inference_container, enable_finetuning, task)
190+
app.edit_registered_model(
191+
id,
192+
inference_container,
193+
inference_container_uri,
194+
enable_finetuning,
195+
task,
196+
)
186197
)
187198
app.clear_model_details_cache(model_id=id)
188199

ads/aqua/model/model.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
1616
from ads.aqua.app import AquaApp
1717
from ads.aqua.common.enums import (
18+
CustomInferenceContainerTypeFamily,
1819
FineTuningContainerTypeFamily,
1920
InferenceContainerTypeFamily,
2021
Tags,
@@ -377,8 +378,10 @@ def delete_model(self, model_id):
377378
f"Failed to delete model:{model_id}. Only registered models or finetuned model can be deleted."
378379
)
379380

380-
@telemetry(entry_point="plugin=model&action=delete", name="aqua")
381-
def edit_registered_model(self, id, inference_container, enable_finetuning, task):
381+
@telemetry(entry_point="plugin=model&action=edit", name="aqua")
382+
def edit_registered_model(
383+
self, id, inference_container, inference_container_uri, enable_finetuning, task
384+
):
382385
"""Edits the default config of unverified registered model.
383386
384387
Parameters
@@ -387,6 +390,8 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
387390
The model OCID.
388391
inference_container: str.
389392
The inference container family name
393+
inference_container_uri: str
394+
The inference container uri for embedding models
390395
enable_finetuning: str
391396
Flag to enable or disable finetuning over the model. Defaults to None
392397
task:
@@ -402,19 +407,44 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
402407
if ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None):
403408
if ds_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None):
404409
raise AquaRuntimeError(
405-
f"Failed to edit model:{id}. Only registered unverified models can be edited."
410+
"Only registered unverified models can be edited."
406411
)
407412
else:
408413
custom_metadata_list = ds_model.custom_metadata_list
409414
freeform_tags = ds_model.freeform_tags
410415
if inference_container:
411-
custom_metadata_list.add(
412-
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
413-
value=inference_container,
414-
category=MetadataCustomCategory.OTHER,
415-
description="Deployment container mapping for SMC",
416-
replace=True,
417-
)
416+
if (
417+
inference_container in CustomInferenceContainerTypeFamily
418+
and inference_container_uri is None
419+
):
420+
raise AquaRuntimeError(
421+
"Inference container URI must be provided."
422+
)
423+
else:
424+
custom_metadata_list.add(
425+
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
426+
value=inference_container,
427+
category=MetadataCustomCategory.OTHER,
428+
description="Deployment container mapping for SMC",
429+
replace=True,
430+
)
431+
if inference_container_uri:
432+
if (
433+
inference_container in CustomInferenceContainerTypeFamily
434+
or inference_container is None
435+
):
436+
custom_metadata_list.add(
437+
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER_URI,
438+
value=inference_container_uri,
439+
category=MetadataCustomCategory.OTHER,
440+
description=f"Inference container URI for {ds_model.display_name}",
441+
replace=True,
442+
)
443+
else:
444+
raise AquaRuntimeError(
445+
f"Inference container URI can be edited only with container values: {CustomInferenceContainerTypeFamily.values()}"
446+
)
447+
418448
if enable_finetuning is not None:
419449
if enable_finetuning.lower() == "true":
420450
custom_metadata_list.add(
@@ -449,9 +479,7 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
449479
)
450480
AquaApp().update_model(id, update_model_details)
451481
else:
452-
raise AquaRuntimeError(
453-
f"Failed to edit model:{id}. Only registered unverified models can be edited."
454-
)
482+
raise AquaRuntimeError("Only registered unverified models can be edited.")
455483

456484
def _fetch_metric_from_metadata(
457485
self,
@@ -870,8 +898,7 @@ def _create_model_catalog_entry(
870898
# only add cmd vars if inference container is not an SMC
871899
if (
872900
inference_container not in smc_container_set
873-
and inference_container
874-
== InferenceContainerTypeFamily.AQUA_TEI_CONTAINER_FAMILY
901+
and inference_container in CustomInferenceContainerTypeFamily.values()
875902
):
876903
cmd_vars = generate_tei_cmd_var(os_path)
877904
metadata.add(

ads/model/__init__.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,26 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

7-
from ads.model.generic_model import GenericModel, ModelState
86
from ads.model.datascience_model import DataScienceModel
9-
from ads.model.model_properties import ModelProperties
7+
from ads.model.deployment.model_deployer import ModelDeployer
8+
from ads.model.deployment.model_deployment import ModelDeployment
9+
from ads.model.deployment.model_deployment_properties import ModelDeploymentProperties
1010
from ads.model.framework.automl_model import AutoMLModel
11+
from ads.model.framework.embedding_onnx_model import EmbeddingONNXModel
12+
from ads.model.framework.huggingface_model import HuggingFacePipelineModel
1113
from ads.model.framework.lightgbm_model import LightGBMModel
1214
from ads.model.framework.pytorch_model import PyTorchModel
1315
from ads.model.framework.sklearn_model import SklearnModel
16+
from ads.model.framework.spark_model import SparkPipelineModel
1417
from ads.model.framework.tensorflow_model import TensorFlowModel
1518
from ads.model.framework.xgboost_model import XGBoostModel
16-
from ads.model.framework.spark_model import SparkPipelineModel
17-
from ads.model.framework.huggingface_model import HuggingFacePipelineModel
18-
19-
from ads.model.deployment.model_deployer import ModelDeployer
20-
from ads.model.deployment.model_deployment import ModelDeployment
21-
from ads.model.deployment.model_deployment_properties import ModelDeploymentProperties
22-
19+
from ads.model.generic_model import GenericModel, ModelState
20+
from ads.model.model_properties import ModelProperties
21+
from ads.model.model_version_set import ModelVersionSet, experiment
2322
from ads.model.serde.common import SERDE
2423
from ads.model.serde.model_input import ModelInputSerializer
25-
26-
from ads.model.model_version_set import ModelVersionSet, experiment
2724
from ads.model.service.oci_datascience_model_version_set import (
2825
ModelVersionSetNotExists,
2926
ModelVersionSetNotSaved,
@@ -42,6 +39,7 @@
4239
"XGBoostModel",
4340
"SparkPipelineModel",
4441
"HuggingFacePipelineModel",
42+
"EmbeddingONNXModel",
4543
"ModelDeployer",
4644
"ModelDeployment",
4745
"ModelDeploymentProperties",

ads/model/artifact.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2022, 2025 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
import fnmatch
87
import importlib
98
import os
10-
import sys
119
import shutil
10+
import sys
1211
import tempfile
1312
import uuid
14-
import fsspec
13+
from datetime import datetime
1514
from typing import Dict, Optional, Tuple
15+
16+
import fsspec
17+
from jinja2 import Environment, PackageLoader
18+
19+
from ads import __version__
1620
from ads.common import auth as authutil
1721
from ads.common import logger, utils
1822
from ads.common.object_storage_details import ObjectStorageDetails
1923
from ads.config import CONDA_BUCKET_NAME, CONDA_BUCKET_NS
2024
from ads.model.runtime.env_info import EnvInfo, InferenceEnvInfo, TrainingEnvInfo
2125
from ads.model.runtime.runtime_info import RuntimeInfo
22-
from jinja2 import Environment, PackageLoader
23-
import warnings
24-
from ads import __version__
25-
from datetime import datetime
2626

2727
MODEL_ARTIFACT_VERSION = "3.0"
2828
REQUIRED_ARTIFACT_FILES = ("runtime.yaml", "score.py")
@@ -378,6 +378,45 @@ def prepare_score_py(
378378
) as f:
379379
f.write(scorefn_template.render(context))
380380

381+
def prepare_schema(self, schema_name: str):
382+
"""Copies schema to artifact directory.
383+
384+
Parameters
385+
----------
386+
schema_name: str
387+
The schema name
388+
389+
Returns
390+
-------
391+
None
392+
393+
Raises
394+
------
395+
FileExistsError
396+
If `schema_name` doesn't exist.
397+
"""
398+
uri_src = os.path.join(
399+
os.path.abspath(os.path.join(os.path.dirname(__file__), "..")),
400+
"templates",
401+
"schemas",
402+
f"{schema_name}",
403+
)
404+
405+
if not os.path.exists(uri_src):
406+
raise FileExistsError(
407+
f"{schema_name} does not exists. "
408+
"Ensure the schema name is valid or specify a different one."
409+
)
410+
411+
uri_dst = os.path.join(self.artifact_dir, os.path.basename(uri_src))
412+
413+
utils.copy_file(
414+
uri_src=uri_src,
415+
uri_dst=uri_dst,
416+
force_overwrite=True,
417+
auth=self.auth,
418+
)
419+
381420
def reload(self):
382421
"""Syncs the `score.py` to reload the model and predict function.
383422
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2025 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
from ads.common.decorator.runtime_dependency import (
7+
OptionalDependency,
8+
runtime_dependency,
9+
)
10+
from ads.model.extractor.model_info_extractor import ModelInfoExtractor
11+
from ads.model.model_metadata import Framework
12+
13+
14+
class EmbeddingONNXExtractor(ModelInfoExtractor):
15+
"""Class that extract model metadata from EmbeddingONNXModel models.
16+
17+
Attributes
18+
----------
19+
model: object
20+
The model to extract metadata from.
21+
22+
Methods
23+
-------
24+
framework(self) -> str
25+
Returns the framework of the model.
26+
algorithm(self) -> object
27+
Returns the algorithm of the model.
28+
version(self) -> str
29+
Returns the version of framework of the model.
30+
hyperparameter(self) -> dict
31+
Returns the hyperparameter of the model.
32+
"""
33+
34+
def __init__(self, model=None):
35+
self.model = model
36+
37+
@property
38+
def framework(self):
39+
"""Extracts the framework of the model.
40+
41+
Returns
42+
----------
43+
str:
44+
The framework of the model.
45+
"""
46+
return Framework.EMBEDDING_ONNX
47+
48+
@property
49+
def algorithm(self):
50+
"""Extracts the algorithm of the model.
51+
52+
Returns
53+
----------
54+
object:
55+
The algorithm of the model.
56+
"""
57+
return "Embedding_ONNX"
58+
59+
@property
60+
@runtime_dependency(module="onnxruntime", install_from=OptionalDependency.ONNX)
61+
def version(self):
62+
"""Extracts the framework version of the model.
63+
64+
Returns
65+
----------
66+
str:
67+
The framework version of the model.
68+
"""
69+
return onnxruntime.__version__
70+
71+
@property
72+
def hyperparameter(self):
73+
"""Extracts the hyperparameters of the model.
74+
75+
Returns
76+
----------
77+
dict:
78+
The hyperparameters of the model.
79+
"""
80+
return None

0 commit comments

Comments
 (0)