7
7
import traceback
8
8
from dataclasses import fields
9
9
from datetime import datetime , timedelta
10
- from typing import Any , Dict , Optional , Union
10
+ from itertools import chain
11
+ from typing import Any , Dict , List , Optional , Union
11
12
12
13
import oci
13
14
from cachetools import TTLCache , cached
14
- from oci .data_science .models import UpdateModelDetails , UpdateModelProvenanceDetails
15
+ from oci .data_science .models import (
16
+ ContainerSummary ,
17
+ UpdateModelDetails ,
18
+ UpdateModelProvenanceDetails ,
19
+ )
15
20
16
21
from ads import set_auth
17
22
from ads .aqua import logger
24
29
is_valid_ocid ,
25
30
load_config ,
26
31
)
32
+ from ads .aqua .config .container_config import (
33
+ AquaContainerConfig ,
34
+ AquaContainerConfigItem ,
35
+ )
36
+ from ads .aqua .constants import SERVICE_MANAGED_CONTAINER_URI_SCHEME
27
37
from ads .common import oci_client as oc
28
38
from ads .common .auth import default_signer
29
39
from ads .common .utils import UNKNOWN , extract_region , is_path_exists
@@ -240,7 +250,9 @@ def create_model_catalog(
240
250
.with_custom_metadata_list (model_custom_metadata )
241
251
.with_defined_metadata_list (model_taxonomy_metadata )
242
252
.with_provenance_metadata (ModelProvenanceMetadata (training_id = UNKNOWN ))
243
- .with_defined_tags (** (defined_tags or {})) # Create defined tags when a model is created.
253
+ .with_defined_tags (
254
+ ** (defined_tags or {})
255
+ ) # Create defined tags when a model is created.
244
256
.create (
245
257
** kwargs ,
246
258
)
@@ -271,6 +283,43 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
271
283
logger .info (f"Artifact not found in model { model_id } ." )
272
284
return False
273
285
286
+ def get_config_from_metadata (
287
+ self , model_id : str , metadata_key : str
288
+ ) -> ModelConfigResult :
289
+ """Gets the config for the given Aqua model from model catalog metadata content.
290
+
291
+ Parameters
292
+ ----------
293
+ model_id: str
294
+ The OCID of the Aqua model.
295
+ metadata_key: str
296
+ The metadata key name where artifact content is stored
297
+ Returns
298
+ -------
299
+ ModelConfigResult
300
+ A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
301
+ """
302
+ config = {}
303
+ oci_model = self .ds_client .get_model (model_id ).data
304
+ try :
305
+ config = self .ds_client .get_model_defined_metadatum_artifact_content (
306
+ model_id , metadata_key
307
+ ).data .content .decode ("utf-8" )
308
+ return ModelConfigResult (config = json .loads (config ), model_details = oci_model )
309
+ except UnicodeDecodeError as ex :
310
+ logger .error (
311
+ f"Failed to decode content for '{ metadata_key } ' in defined metadata for model '{ model_id } ' : { ex } "
312
+ )
313
+ except json .JSONDecodeError as ex :
314
+ logger .error (
315
+ f"Invalid JSON format for '{ metadata_key } ' in defined metadata for model '{ model_id } ' : { ex } "
316
+ )
317
+ except Exception as ex :
318
+ logger .error (
319
+ f"Failed to retrieve defined metadata key '{ metadata_key } ' for model '{ model_id } ': { ex } "
320
+ )
321
+ return ModelConfigResult (config = config , model_details = oci_model )
322
+
274
323
@cached (cache = TTLCache (maxsize = 1 , ttl = timedelta (minutes = 1 ), timer = datetime .now ))
275
324
def get_config (
276
325
self ,
@@ -310,22 +359,7 @@ def get_config(
310
359
raise AquaRuntimeError (f"Target model { oci_model .id } is not an Aqua model." )
311
360
312
361
config : Dict [str , Any ] = {}
313
-
314
- # if the current model has a service model tag, then
315
- if Tags .AQUA_SERVICE_MODEL_TAG in oci_model .freeform_tags :
316
- base_model_ocid = oci_model .freeform_tags [Tags .AQUA_SERVICE_MODEL_TAG ]
317
- logger .info (
318
- f"Base model found for the model: { oci_model .id } . "
319
- f"Loading { config_file_name } for base model { base_model_ocid } ."
320
- )
321
- if config_folder == ConfigFolder .ARTIFACT :
322
- artifact_path = get_artifact_path (oci_model .custom_metadata_list )
323
- else :
324
- base_model = self .ds_client .get_model (base_model_ocid ).data
325
- artifact_path = get_artifact_path (base_model .custom_metadata_list )
326
- else :
327
- logger .info (f"Loading { config_file_name } for model { oci_model .id } ..." )
328
- artifact_path = get_artifact_path (oci_model .custom_metadata_list )
362
+ artifact_path = get_artifact_path (oci_model .custom_metadata_list )
329
363
if not artifact_path :
330
364
logger .debug (
331
365
f"Failed to get artifact path from custom metadata for the model: { model_id } "
@@ -340,7 +374,7 @@ def get_config(
340
374
config_file_path = os .path .join (config_path , config_file_name )
341
375
if is_path_exists (config_file_path ):
342
376
try :
343
- logger .debug (
377
+ logger .info (
344
378
f"Loading config: `{ config_file_name } ` from `{ config_path } `"
345
379
)
346
380
config = load_config (
@@ -361,6 +395,85 @@ def get_config(
361
395
362
396
return ModelConfigResult (config = config , model_details = oci_model )
363
397
398
+ def get_container_image (self , container_type : str = None ) -> str :
399
+ """
400
+ Gets the latest smc container complete image name from the given container type.
401
+
402
+ Parameters
403
+ ----------
404
+ container_type: str
405
+ type of container, can be either odsc-vllm-serving, odsc-llm-fine-tuning, odsc-llm-evaluate
406
+
407
+ Returns
408
+ -------
409
+ str:
410
+ A complete container name along with version. ex: dsmc://odsc-vllm-serving:0.7.4.1
411
+ """
412
+
413
+ containers = self .list_service_containers ()
414
+ container = next (
415
+ (c for c in containers if c .is_latest and c .family_name == container_type ),
416
+ None ,
417
+ )
418
+ if not container :
419
+ raise AquaValueError (f"Invalid container type : { container_type } " )
420
+ container_image = (
421
+ SERVICE_MANAGED_CONTAINER_URI_SCHEME
422
+ + container .container_name
423
+ + ":"
424
+ + container .tag
425
+ )
426
+ return container_image
427
+
428
+ @cached (cache = TTLCache (maxsize = 20 , ttl = timedelta (minutes = 30 ), timer = datetime .now ))
429
+ def list_service_containers (self ) -> List [ContainerSummary ]:
430
+ """
431
+ List containers from containers.conf in OCI Datascience control plane
432
+ """
433
+ containers = self .ds_client .list_containers ().data
434
+ return containers
435
+
436
+ def get_container_config (self ) -> AquaContainerConfig :
437
+ """
438
+ Fetches latest containers from containers.conf in OCI Datascience control plane
439
+
440
+ Returns
441
+ -------
442
+ AquaContainerConfig
443
+ An Object that contains latest container info for the given container family
444
+
445
+ """
446
+ return AquaContainerConfig .from_service_config (
447
+ service_containers = self .list_service_containers ()
448
+ )
449
+
450
+ def get_container_config_item (
451
+ self , container_family : str
452
+ ) -> AquaContainerConfigItem :
453
+ """
454
+ Fetches latest container for given container_family_name from containers.conf in OCI Datascience control plane
455
+
456
+ Returns
457
+ -------
458
+ AquaContainerConfigItem
459
+ An Object that contains latest container info for the given container family
460
+
461
+ """
462
+
463
+ aqua_container_config = self .get_container_config ()
464
+ inference_config = aqua_container_config .inference .values ()
465
+ ft_config = aqua_container_config .finetune .values ()
466
+ eval_config = aqua_container_config .evaluate .values ()
467
+ container = next (
468
+ (
469
+ container
470
+ for container in chain (inference_config , ft_config , eval_config )
471
+ if container .family .lower () == container_family .lower ()
472
+ ),
473
+ None ,
474
+ )
475
+ return container
476
+
364
477
@property
365
478
def telemetry (self ):
366
479
if not self ._telemetry :
0 commit comments