Skip to content

Commit d50c68e

Browse files
authored
Fix Recursion Issue in OCIDatascienceModel Due to is_model_by_reference Conflict in New OCI SDK (#1073)
2 parents 3ca41c6 + e49bf66 commit d50c68e

File tree

6 files changed

+20
-23
lines changed

6 files changed

+20
-23
lines changed

ads/model/artifact_downloader.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
32

43
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -12,9 +11,9 @@
1211
from zipfile import ZipFile
1312

1413
from ads.common import utils
14+
from ads.common.object_storage_details import ObjectStorageDetails
1515
from ads.common.utils import extract_region
1616
from ads.model.service.oci_datascience_model import OCIDataScienceModel
17-
from ads.common.object_storage_details import ObjectStorageDetails
1817

1918

2019
class ArtifactDownloader(ABC):
@@ -169,9 +168,9 @@ def __init__(
169168

170169
def _download(self):
171170
"""Downloads model artifacts."""
172-
self.progress.update(f"Importing model artifacts from catalog")
171+
self.progress.update("Importing model artifacts from catalog")
173172

174-
if self.dsc_model.is_model_by_reference() and self.model_file_description:
173+
if self.dsc_model._is_model_by_reference() and self.model_file_description:
175174
self.download_from_model_file_description()
176175
self.progress.update()
177176
return

ads/model/datascience_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1778,7 +1778,7 @@ def _update_from_oci_dsc_model(
17781778
artifact_info = self.dsc_model.get_artifact_info()
17791779
_, file_name_info = cgi.parse_header(artifact_info["Content-Disposition"])
17801780

1781-
if self.dsc_model.is_model_by_reference():
1781+
if self.dsc_model._is_model_by_reference():
17821782
_, file_extension = os.path.splitext(file_name_info["filename"])
17831783
if file_extension.lower() == ".json":
17841784
bucket_uri, _ = self._download_file_description_artifact()

ads/model/service/oci_datascience_model.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,32 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
32

43
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
import logging
8-
import time
97
from functools import wraps
108
from io import BytesIO
119
from typing import Callable, Dict, List, Optional
1210

1311
import oci.data_science
14-
from ads.common import utils
15-
from ads.common.object_storage_details import ObjectStorageDetails
16-
from ads.common.oci_datascience import OCIDataScienceMixin
17-
from ads.common.oci_mixin import OCIWorkRequestMixin
18-
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
19-
from ads.common.utils import extract_region
20-
from ads.common.work_request import DataScienceWorkRequest
21-
from ads.model.deployment import ModelDeployment
2212
from oci.data_science.models import (
2313
ArtifactExportDetailsObjectStorage,
2414
ArtifactImportDetailsObjectStorage,
2515
CreateModelDetails,
2616
ExportModelArtifactDetails,
2717
ImportModelArtifactDetails,
2818
UpdateModelDetails,
29-
WorkRequest,
3019
)
3120
from oci.exceptions import ServiceError
3221

22+
from ads.common.object_storage_details import ObjectStorageDetails
23+
from ads.common.oci_datascience import OCIDataScienceMixin
24+
from ads.common.oci_mixin import OCIWorkRequestMixin
25+
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
26+
from ads.common.utils import extract_region
27+
from ads.common.work_request import DataScienceWorkRequest
28+
from ads.model.deployment import ModelDeployment
29+
3330
logger = logging.getLogger(__name__)
3431

3532
_REQUEST_INTERVAL_IN_SEC = 3
@@ -282,7 +279,7 @@ def get_artifact_info(self) -> Dict:
282279
msg="Model needs to be restored before the archived artifact content can be accessed."
283280
)
284281
def restore_archived_model_artifact(
285-
self, restore_model_for_hours_specified: Optional[int] = None
282+
self, restore_model_for_hours_specified: Optional[int] = None
286283
) -> None:
287284
"""Restores the archived model artifact.
288285
@@ -304,7 +301,8 @@ def restore_archived_model_artifact(
304301
"""
305302
return self.client.restore_archived_model_artifact(
306303
model_id=self.id,
307-
restore_model_for_hours_specified=restore_model_for_hours_specified).headers["opc-work-request-id"]
304+
restore_model_for_hours_specified=restore_model_for_hours_specified,
305+
).headers["opc-work-request-id"]
308306

309307
@check_for_model_id(
310308
msg="Model needs to be saved to the Model Catalog before the artifact content can be read."
@@ -581,7 +579,7 @@ def from_id(cls, ocid: str) -> "OCIDataScienceModel":
581579
raise ValueError("Model OCID not provided.")
582580
return super().from_ocid(ocid)
583581

584-
def is_model_by_reference(self):
582+
def _is_model_by_reference(self):
585583
"""Checks if model is created by reference
586584
Returns
587585
-------

tests/unitary/default_setup/model/test_artifact_downloader.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def test_download_large_artifact_from_model_file_description(self, mock_download
189189
"""Tests whether model_file_description is loaded within downloader and is parsed, and also if
190190
# download_from_model_file_description is appropriately called."""
191191

192-
self.mock_dsc_model.is_model_by_reference.return_value = True
192+
self.mock_dsc_model._is_model_by_reference.return_value = True
193193
self.mock_artifact_file_path = os.path.join(
194194
self.curr_dir, "test_files/model_description.json"
195195
)

tests/unitary/default_setup/model/test_datascience_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ def test__to_oci_dsc_model(self):
613613
True,
614614
],
615615
)
616-
@patch.object(OCIDataScienceModel, "is_model_by_reference")
616+
@patch.object(OCIDataScienceModel, "_is_model_by_reference")
617617
@patch.object(OCIDataScienceModel, "get_artifact_info")
618618
@patch.object(OCIDataScienceModel, "get_model_provenance")
619619
@patch.object(DataScienceModel, "_download_file_description_artifact")

tests/unitary/default_setup/model/test_oci_datascience_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def test_is_model_by_reference(self):
473473
category="Other",
474474
)
475475
self.mock_model.custom_metadata_list = [metadata_item]
476-
assert not self.mock_model.is_model_by_reference()
476+
assert not self.mock_model._is_model_by_reference()
477477

478478
metadata_item = ModelCustomMetadataItem(
479479
key="modelDescription",
@@ -483,4 +483,4 @@ def test_is_model_by_reference(self):
483483
)
484484
self.mock_model.custom_metadata_list = [metadata_item]
485485

486-
assert self.mock_model.is_model_by_reference()
486+
assert self.mock_model._is_model_by_reference()

0 commit comments

Comments
 (0)