Skip to content

Commit 8d50f76

Browse files
authored
Merge branch 'main' into enhance_extended_enum
2 parents d49476d + c92d0f6 commit 8d50f76

File tree

17 files changed

+507
-54
lines changed

17 files changed

+507
-54
lines changed

ads/aqua/app.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import json
66
import os
7+
import traceback
78
from dataclasses import fields
89
from typing import Dict, Union
910

@@ -23,7 +24,7 @@
2324
from ads.aqua.constants import UNKNOWN
2425
from ads.common import oci_client as oc
2526
from ads.common.auth import default_signer
26-
from ads.common.utils import extract_region
27+
from ads.common.utils import extract_region, is_path_exists
2728
from ads.config import (
2829
AQUA_TELEMETRY_BUCKET,
2930
AQUA_TELEMETRY_BUCKET_NS,
@@ -296,33 +297,46 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
296297
raise AquaRuntimeError(f"Target model {oci_model.id} is not Aqua model.")
297298

298299
config = {}
299-
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
300+
# if the current model has a service model tag, then
301+
if Tags.AQUA_SERVICE_MODEL_TAG in oci_model.freeform_tags:
302+
base_model_ocid = oci_model.freeform_tags[Tags.AQUA_SERVICE_MODEL_TAG]
303+
logger.info(
304+
f"Base model found for the model: {oci_model.id}. "
305+
f"Loading {config_file_name} for base model {base_model_ocid}."
306+
)
307+
base_model = self.ds_client.get_model(base_model_ocid).data
308+
artifact_path = get_artifact_path(base_model.custom_metadata_list)
309+
else:
310+
logger.info(f"Loading {config_file_name} for model {oci_model.id}...")
311+
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
312+
300313
if not artifact_path:
301314
logger.debug(
302315
f"Failed to get artifact path from custom metadata for the model: {model_id}"
303316
)
304317
return config
305318

306-
try:
307-
config_path = f"{os.path.dirname(artifact_path)}/config/"
308-
config = load_config(
309-
config_path,
310-
config_file_name=config_file_name,
311-
)
312-
except Exception:
313-
# todo: temp fix for issue related to config load for byom models, update logic to choose the right path
319+
config_path = f"{os.path.dirname(artifact_path)}/config/"
320+
if not is_path_exists(config_path):
321+
config_path = f"{artifact_path.rstrip('/')}/config/"
322+
323+
config_file_path = f"{config_path}{config_file_name}"
324+
if is_path_exists(config_file_path):
314325
try:
315-
config_path = f"{artifact_path.rstrip('/')}/config/"
316326
config = load_config(
317327
config_path,
318328
config_file_name=config_file_name,
319329
)
320330
except Exception:
321-
pass
331+
logger.debug(
332+
f"Error loading the {config_file_name} at path {config_path}.\n"
333+
f"{traceback.format_exc()}"
334+
)
322335

323336
if not config:
324-
logger.error(
325-
f"{config_file_name} is not available for the model: {model_id}. Check if the custom metadata has the artifact path set."
337+
logger.debug(
338+
f"{config_file_name} is not available for the model: {model_id}. "
339+
f"Check if the custom metadata has the artifact path set."
326340
)
327341
return config
328342

ads/aqua/finetuning/entities.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ class CreateFineTuningDetails(Serializable):
122122
The log group id for fine tuning job infrastructure.
123123
log_id: (str, optional). Defaults to `None`.
124124
The log id for fine tuning job infrastructure.
125+
watch_logs: (bool, optional). Defaults to `False`.
126+
The flag to watch the job run logs when a fine-tuning job is created.
125127
force_overwrite: (bool, optional). Defaults to `False`.
126128
Whether to force overwrite the existing file in object storage.
127129
freeform_tags: (dict, optional)
@@ -148,6 +150,7 @@ class CreateFineTuningDetails(Serializable):
148150
subnet_id: Optional[str] = None
149151
log_id: Optional[str] = None
150152
log_group_id: Optional[str] = None
153+
watch_logs: Optional[bool] = False
151154
force_overwrite: Optional[bool] = False
152155
freeform_tags: Optional[dict] = None
153156
defined_tags: Optional[dict] = None

ads/aqua/finetuning/finetuning.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import json
66
import os
7+
import time
8+
import traceback
79
from typing import Dict
810

911
from oci.data_science.models import (
@@ -149,6 +151,15 @@ def create(
149151
f"Logging is required for fine tuning if replica is larger than {DEFAULT_FT_REPLICA}."
150152
)
151153

154+
if create_fine_tuning_details.watch_logs and not (
155+
create_fine_tuning_details.log_id
156+
and create_fine_tuning_details.log_group_id
157+
):
158+
raise AquaValueError(
159+
"Logging is required for fine tuning if watch_logs is set to True. "
160+
"Please provide log_id and log_group_id with the request parameters."
161+
)
162+
152163
ft_parameters = self._get_finetuning_params(
153164
create_fine_tuning_details.ft_parameters
154165
)
@@ -422,6 +433,20 @@ def create(
422433
value=source.display_name,
423434
)
424435

436+
if create_fine_tuning_details.watch_logs:
437+
logger.info(
438+
f"Watching fine-tuning job run logs for {ft_job_run.id}. Press Ctrl+C to stop watching logs.\n"
439+
)
440+
try:
441+
ft_job_run.watch()
442+
except KeyboardInterrupt:
443+
logger.info(f"\nStopped watching logs for {ft_job_run.id}.\n")
444+
time.sleep(1)
445+
except Exception:
446+
logger.debug(
447+
f"Something unexpected occurred while watching logs.\n{traceback.format_exc()}"
448+
)
449+
425450
return AquaFineTuningSummary(
426451
id=ft_model.id,
427452
name=ft_model.display_name,

ads/aqua/model/model.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
LifecycleStatus,
3030
_build_resource_identifier,
3131
cleanup_local_hf_model_artifact,
32-
copy_model_config,
3332
create_word_icon,
3433
generate_tei_cmd_var,
3534
get_artifact_path,
@@ -969,24 +968,6 @@ def _create_model_catalog_entry(
969968
)
970969
tags[Tags.LICENSE] = validation_result.tags.get(Tags.LICENSE, UNKNOWN)
971970

972-
try:
973-
# If verified model already has a artifact json, use that.
974-
artifact_path = metadata.get(MODEL_BY_REFERENCE_OSS_PATH_KEY).value
975-
logger.info(
976-
f"Found model artifact in the service bucket. "
977-
f"Using artifact from service bucket instead of {os_path}."
978-
)
979-
980-
# todo: implement generic copy_folder method
981-
# copy model config from artifact path to user bucket
982-
copy_model_config(
983-
artifact_path=artifact_path, os_path=os_path, auth=default_signer()
984-
)
985-
except Exception:
986-
logger.debug(
987-
f"Proceeding with model registration without copying model config files at {os_path}. "
988-
f"Default configuration will be used for deployment and fine-tuning."
989-
)
990971
# Set artifact location to user bucket, and replace existing key if present.
991972
metadata.add(
992973
key=MODEL_BY_REFERENCE_OSS_PATH_KEY,

ads/cli.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
#!/usr/bin/env python
2-
32
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

5+
import json
66
import logging
77
import sys
88
import traceback
99
import uuid
1010

1111
import fire
12+
from pydantic import BaseModel
1213

1314
from ads.common import logger
1415

@@ -84,7 +85,13 @@ def serialize(data):
8485
The string representation of each dataclass object.
8586
"""
8687
if isinstance(data, list):
87-
[print(str(item)) for item in data]
88+
for item in data:
89+
if isinstance(item, BaseModel):
90+
print(json.dumps(item.dict(), indent=4))
91+
else:
92+
print(str(item))
93+
elif isinstance(data, BaseModel):
94+
print(json.dumps(data.dict(), indent=4))
8895
else:
8996
print(str(data))
9097

ads/llm/__init__.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2025 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
try:
87
import langchain
9-
from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
10-
OCIModelDeploymentVLLM,
11-
OCIModelDeploymentTGI,
12-
)
8+
9+
from ads.llm.chat_template import ChatTemplates
1310
from ads.llm.langchain.plugins.chat_models.oci_data_science import (
1411
ChatOCIModelDeployment,
15-
ChatOCIModelDeploymentVLLM,
1612
ChatOCIModelDeploymentTGI,
13+
ChatOCIModelDeploymentVLLM,
14+
)
15+
from ads.llm.langchain.plugins.embeddings.oci_data_science_model_deployment_endpoint import (
16+
OCIDataScienceEmbedding,
17+
)
18+
from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
19+
OCIModelDeploymentTGI,
20+
OCIModelDeploymentVLLM,
1721
)
18-
from ads.llm.chat_template import ChatTemplates
1922
except ImportError as ex:
2023
if ex.name == "langchain":
2124
raise ImportError(
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2025 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

0 commit comments

Comments
 (0)