Skip to content

Commit 410dbe0

Browse files
Adding inference container URI for edit model handler (#1030)
2 parents 4d323cc + 1ba166a commit 410dbe0

File tree

3 files changed

+57
-16
lines changed

3 files changed

+57
-16
lines changed

ads/aqua/common/enums.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5252
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
5353
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
5454
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
55+
56+
57+
class CustomInferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5558
AQUA_TEI_CONTAINER_FAMILY = "odsc-tei-serving"
5659

5760

ads/aqua/extension/model_handler.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from tornado.web import HTTPError
99

1010
from ads.aqua.common.decorator import handle_exceptions
11+
from ads.aqua.common.enums import (
12+
CustomInferenceContainerTypeFamily,
13+
)
1114
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1215
from ads.aqua.common.utils import (
1316
get_hf_model_info,
@@ -163,7 +166,9 @@ def put(self, id):
163166
raise HTTPError(400, Errors.NO_INPUT_DATA)
164167

165168
inference_container = input_data.get("inference_container")
169+
inference_container_uri = input_data.get("inference_container_uri")
166170
inference_containers = AquaModelApp.list_valid_inference_containers()
171+
inference_containers.extend(CustomInferenceContainerTypeFamily.values())
167172
if (
168173
inference_container is not None
169174
and inference_container not in inference_containers
@@ -176,7 +181,13 @@ def put(self, id):
176181
task = input_data.get("task")
177182
app = AquaModelApp()
178183
self.finish(
179-
app.edit_registered_model(id, inference_container, enable_finetuning, task)
184+
app.edit_registered_model(
185+
id,
186+
inference_container,
187+
inference_container_uri,
188+
enable_finetuning,
189+
task,
190+
)
180191
)
181192
app.clear_model_details_cache(model_id=id)
182193

ads/aqua/model/model.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
1616
from ads.aqua.app import AquaApp
1717
from ads.aqua.common.enums import (
18+
CustomInferenceContainerTypeFamily,
1819
FineTuningContainerTypeFamily,
1920
InferenceContainerTypeFamily,
2021
Tags,
@@ -376,8 +377,10 @@ def delete_model(self, model_id):
376377
f"Failed to delete model:{model_id}. Only registered models or finetuned model can be deleted."
377378
)
378379

379-
@telemetry(entry_point="plugin=model&action=delete", name="aqua")
380-
def edit_registered_model(self, id, inference_container, enable_finetuning, task):
380+
@telemetry(entry_point="plugin=model&action=edit", name="aqua")
381+
def edit_registered_model(
382+
self, id, inference_container, inference_container_uri, enable_finetuning, task
383+
):
381384
"""Edits the default config of unverified registered model.
382385
383386
Parameters
@@ -386,6 +389,8 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
386389
The model OCID.
387390
inference_container: str.
388391
The inference container family name
392+
inference_container_uri: str
393+
The inference container uri for embedding models
389394
enable_finetuning: str
390395
Flag to enable or disable finetuning over the model. Defaults to None
391396
task:
@@ -401,19 +406,44 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
401406
if ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None):
402407
if ds_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None):
403408
raise AquaRuntimeError(
404-
f"Failed to edit model:{id}. Only registered unverified models can be edited."
409+
"Only registered unverified models can be edited."
405410
)
406411
else:
407412
custom_metadata_list = ds_model.custom_metadata_list
408413
freeform_tags = ds_model.freeform_tags
409414
if inference_container:
410-
custom_metadata_list.add(
411-
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
412-
value=inference_container,
413-
category=MetadataCustomCategory.OTHER,
414-
description="Deployment container mapping for SMC",
415-
replace=True,
416-
)
415+
if (
416+
inference_container in CustomInferenceContainerTypeFamily
417+
and inference_container_uri is None
418+
):
419+
raise AquaRuntimeError(
420+
"Inference container URI must be provided."
421+
)
422+
else:
423+
custom_metadata_list.add(
424+
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
425+
value=inference_container,
426+
category=MetadataCustomCategory.OTHER,
427+
description="Deployment container mapping for SMC",
428+
replace=True,
429+
)
430+
if inference_container_uri:
431+
if (
432+
inference_container in CustomInferenceContainerTypeFamily
433+
or inference_container is None
434+
):
435+
custom_metadata_list.add(
436+
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER_URI,
437+
value=inference_container_uri,
438+
category=MetadataCustomCategory.OTHER,
439+
description=f"Inference container URI for {ds_model.display_name}",
440+
replace=True,
441+
)
442+
else:
443+
raise AquaRuntimeError(
444+
f"Inference container URI can be edited only with container values: {CustomInferenceContainerTypeFamily.values()}"
445+
)
446+
417447
if enable_finetuning is not None:
418448
if enable_finetuning.lower() == "true":
419449
custom_metadata_list.add(
@@ -448,9 +478,7 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
448478
)
449479
AquaApp().update_model(id, update_model_details)
450480
else:
451-
raise AquaRuntimeError(
452-
f"Failed to edit model:{id}. Only registered unverified models can be edited."
453-
)
481+
raise AquaRuntimeError("Only registered unverified models can be edited.")
454482

455483
def _fetch_metric_from_metadata(
456484
self,
@@ -869,8 +897,7 @@ def _create_model_catalog_entry(
869897
# only add cmd vars if inference container is not an SMC
870898
if (
871899
inference_container not in smc_container_set
872-
and inference_container
873-
== InferenceContainerTypeFamily.AQUA_TEI_CONTAINER_FAMILY
900+
and inference_container in CustomInferenceContainerTypeFamily.values()
874901
):
875902
cmd_vars = generate_tei_cmd_var(os_path)
876903
metadata.add(

0 commit comments

Comments
 (0)