Skip to content

Commit 7757da3

Browse files
authoredJan 19, 2023
2.7.3 (#63)
2 parents d086c59 + 826213b commit 7757da3

37 files changed

+1540
-252
lines changed
 

‎ads/ads_version.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"version": "2.7.2"
2+
"version": "2.7.3"
33
}

‎ads/catalog/model.py

+14
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
)
5252
from ads.dataset.progress import DummyProgressBar, TqdmProgressBar
5353
from ads.feature_engineering.schema import Schema
54+
from ads.model.model_version_set import ModelVersionSet, _extract_model_version_set_id
5455
from ads.model.deployment.model_deployer import ModelDeployer
5556
from oci.data_science.data_science_client import DataScienceClient
5657
from oci.data_science.models import (
@@ -72,6 +73,8 @@
7273
"description",
7374
"freeform_tags",
7475
"defined_tags",
76+
"model_version_set_id",
77+
"version_label",
7578
]
7679
_MODEL_PROVENANCE_ATTRIBUTES = ModelProvenance().swagger_types.keys()
7780
_ETAG_KEY = "ETag"
@@ -1284,6 +1287,8 @@ def upload_model(
12841287
bucket_uri: Optional[str] = None,
12851288
remove_existing_artifact: Optional[bool] = True,
12861289
overwrite_existing_artifact: Optional[bool] = True,
1290+
model_version_set: Optional[Union[str, ModelVersionSet]] = None,
1291+
version_label: Optional[str] = None,
12871292
):
12881293
"""
12891294
Uploads the model artifact to cloud storage.
@@ -1315,6 +1320,10 @@ def upload_model(
13151320
Whether artifacts uploaded to object storage bucket need to be removed or not.
13161321
overwrite_existing_artifact: (bool, optional). Defaults to `True`.
13171322
Overwrite target bucket artifact if exists.
1323+
model_version_set: (Union[str, ModelVersionSet], optional). Defaults to None.
1324+
The Model version set OCID, or name, or `ModelVersionSet` instance.
1325+
version_label: (str, optional). Defaults to None.
1326+
The model version label.
13181327
13191328
Returns
13201329
-------
@@ -1340,6 +1349,9 @@ def upload_model(
13401349
)
13411350
copy_artifact_to_os = True
13421351

1352+
# extract model_version_set_id from model_version_set attribute or environment
1353+
# variables in case of saving model in context of model version set.
1354+
model_version_set_id = _extract_model_version_set_id(model_version_set)
13431355
# Set default display_name if not specified - randomly generated easy to remember name generated
13441356
display_name = display_name or utils.get_random_name_for_resource()
13451357

@@ -1373,6 +1385,8 @@ def upload_model(
13731385
else '{"schema": []}',
13741386
freeform_tags=freeform_tags,
13751387
defined_tags=defined_tags,
1388+
model_version_set_id=model_version_set_id,
1389+
version_label=version_label,
13761390
)
13771391
model = self.ds_client.create_model(create_model_details)
13781392

‎ads/catalog/summary.py

+4
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def __init__(self, entity_list, datetime_format=utils.date_format):
6868
self.df["compartment_id"] = "..." + self.df["compartment_id"].str[-6:]
6969
if "project_id" in ordered_columns:
7070
self.df["project_id"] = "..." + self.df["project_id"].str[-6:]
71+
if "model_version_set_id" in ordered_columns:
72+
self.df["model_version_set_id"] = (
73+
"..." + self.df["model_version_set_id"].str[-6:]
74+
)
7175
self.df["time_created"] = pd.to_datetime(
7276
self.df["time_created"]
7377
).dt.strftime(datetime_format)

‎ads/common/auth.py

+80-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33

4-
# Copyright (c) 2021, 2022 Oracle and/or its affiliates.
4+
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import copy
78
import os
89
from dataclasses import dataclass
910
from typing import Callable, Dict, Optional, Any
@@ -16,6 +17,7 @@
1617

1718
import ads.telemetry
1819
from ads.common import logger
20+
from ads.common.decorator.deprecate import deprecated
1921
from ads.common.extended_enum import ExtendedEnumMeta
2022

2123

@@ -255,7 +257,7 @@ def create_signer(
255257
256258
Parameters
257259
----------
258-
auth: Optional[str], default 'api_key'
260+
auth_type: Optional[str], default 'api_key'
259261
'api_key', 'resource_principal' or 'instance_principal'. Enable/disable resource principal identity,
260262
instance principal or keypair identity in a notebook session
261263
oci_config_location: Optional[str], default oci.config.DEFAULT_LOCATION, which is '~/.oci/config'
@@ -378,6 +380,10 @@ def default_signer(client_kwargs: Optional[Dict] = None) -> Dict:
378380
return signer_generator(signer_args).create_signer()
379381

380382

383+
@deprecated(
384+
"2.7.3",
385+
details="Deprecated, use: from ads.common.auth import create_signer. https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html#overriding-defaults.",
386+
)
381387
def get_signer(
382388
oci_config: Optional[str] = None, oci_profile: Optional[str] = None, **client_kwargs
383389
) -> Dict:
@@ -686,6 +692,10 @@ class OCIAuthContext:
686692
>>> df_run = DataFlowRun.from_ocid(run_id)
687693
"""
688694

695+
@deprecated(
696+
"2.7.3",
697+
details="Deprecated, use: from ads.common.auth import AuthContext",
698+
)
689699
def __init__(self, profile: str = None):
690700
"""
691701
Initialize class OCIAuthContext and saves global state of authentication type and configuration profile.
@@ -700,6 +710,10 @@ def __init__(self, profile: str = None):
700710
self.prev_profile = AuthState().oci_key_profile
701711
self.oci_cli_auth = AuthState().oci_cli_auth
702712

713+
@deprecated(
714+
"2.7.3",
715+
details="Deprecated, use: from ads.common.auth import AuthContext",
716+
)
703717
def __enter__(self):
704718
"""
705719
When called by the 'with' statement and if 'profile' provided - 'api_key' authentication with 'profile' used.
@@ -718,3 +732,67 @@ def __exit__(self, exc_type, exc_val, exc_tb):
718732
When called by the 'with' statement restores initial state of authentication type and profile value.
719733
"""
720734
ads.set_auth(auth=self.prev_mode, profile=self.prev_profile)
735+
736+
737+
class AuthContext:
738+
"""
739+
AuthContext used in 'with' statement for properly managing global authentication type, signer, config
740+
and global configuration parameters.
741+
742+
Examples
743+
--------
744+
>>> from ads import set_auth
745+
>>> from ads.jobs import DataFlowRun
746+
>>> with AuthContext(auth='resource_principal'):
747+
>>> df_run = DataFlowRun.from_ocid(run_id)
748+
749+
>>> from ads.model.framework.sklearn_model import SklearnModel
750+
>>> model = SklearnModel.from_model_artifact(uri="model_artifact_path", artifact_dir="model_artifact_path")
751+
>>> set_auth(auth='api_key', oci_config_location="~/.oci/config")
752+
>>> with AuthContext(auth='api_key', oci_config_location="~/another_config_location/config"):
753+
>>> # upload model to Object Storage using config from another_config_location/config
754+
>>> model.upload_artifact(uri="oci://bucket@namespace/prefix/")
755+
>>> # upload model to Object Storage using config from ~/.oci/config, which was set before 'with AuthContext():'
756+
>>> model.upload_artifact(uri="oci://bucket@namespace/prefix/")
757+
"""
758+
759+
def __init__(self, **kwargs):
760+
"""
761+
Initialize class AuthContext and saves global state of authentication type, signer, config
762+
and global configuration parameters.
763+
764+
Parameters
765+
----------
766+
**kwargs: optional, list of parameters passed to ads.set_auth() method, which can be:
767+
auth: Optional[str], default 'api_key'
768+
'api_key', 'resource_principal' or 'instance_principal'. Enable/disable resource principal
769+
identity, instance principal or keypair identity
770+
oci_config_location: Optional[str], default oci.config.DEFAULT_LOCATION, which is '~/.oci/config'
771+
config file location
772+
profile: Optional[str], default is DEFAULT_PROFILE, which is 'DEFAULT'
773+
profile name for api keys config file
774+
config: Optional[Dict], default {}
775+
created config dictionary
776+
signer: Optional[Any], default None
777+
created signer, can be resource principals signer, instance principal signer or other
778+
signer_callable: Optional[Callable], default None
779+
a callable object that returns signer
780+
signer_kwargs: Optional[Dict], default None
781+
parameters accepted by the signer
782+
"""
783+
self.kwargs = kwargs
784+
785+
def __enter__(self):
786+
"""
787+
When called by the 'with' statement current state of authentication type, signer, config
788+
and configuration parameters saved.
789+
"""
790+
self.previous_state = copy.deepcopy(AuthState())
791+
set_auth(**self.kwargs)
792+
793+
def __exit__(self, exc_type, exc_val, exc_tb):
794+
"""
795+
When called by the 'with' statement initial state of authentication type, signer, config
796+
and configuration parameters restored.
797+
"""
798+
AuthState().__dict__.update(self.previous_state.__dict__)

‎ads/common/ipython.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,22 @@
44
# Copyright (c) 2022 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import logging
78
import sys
8-
from ads.common import logger
9+
10+
# TODO - Revisit this as part of ADS logging changes https://jira.oci.oraclecorp.com/browse/ODSC-36245
11+
# Use a unique logger that we can individually configure without impacting other log statements.
12+
# We don't want the logger name to mention "ads", since this logger will report any exception that happens in a
13+
# notebook cell, and we don't want customers incorrectly assuming that ADS is somehow responsible for every error.
14+
logger = logging.getLogger("ipython.traceback")
15+
# Set propagate to False so logs aren't passed back up to the root logger handlers. There are some places in ADS
16+
# where logging.basicConfig() is called. This changes root logger configurations. The user could import/use code that
17+
# invokes the logging.basicConfig() function at any time, making the behavior of the root logger unpredictable.
18+
logger.propagate = False
19+
logger.handlers.clear()
20+
traceback_handler = logging.StreamHandler()
21+
traceback_handler.setFormatter(logging.Formatter("%(levelname)s - %(message)s"))
22+
logger.addHandler(traceback_handler)
923

1024

1125
def _log_traceback(self, exc_tuple=None, **kwargs):
@@ -15,7 +29,8 @@ def _log_traceback(self, exc_tuple=None, **kwargs):
1529
print("No traceback available to show.", file=sys.stderr)
1630
return
1731
msg = etype.__name__, str(value)
18-
logger.error("ADS Exception", exc_info=(etype, value, tb))
32+
# User a generic message that makes no mention of ADS.
33+
logger.error("Exception", exc_info=(etype, value, tb))
1934
sys.stderr.write("{0}: {1}".format(*msg))
2035

2136

‎ads/common/model_artifact.py

+9
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
from ads.common.decorator.deprecate import deprecated
8282
from ads.feature_engineering.schema import DataSizeTooWide, Schema, SchemaSizeTooLarge
8383
from ads.model.extractor.model_info_extractor_factory import ModelInfoExtractorFactory
84+
from ads.model.model_version_set import ModelVersionSet
8485
from ads.model.common.utils import fetch_manifest_from_conda_location
8586
from git import InvalidGitRepositoryError, Repo
8687

@@ -714,6 +715,8 @@ def save(
714715
defined_tags=None,
715716
bucket_uri: Optional[str] = None,
716717
remove_existing_artifact: Optional[bool] = True,
718+
model_version_set: Optional[Union[str, ModelVersionSet]] = None,
719+
version_label: Optional[str] = None,
717720
):
718721
"""
719722
Saves the model artifact in the model catalog.
@@ -757,6 +760,10 @@ def save(
757760
size is greater than 2GB. Example: `oci://<bucket_name>@<namespace>/prefix/`
758761
remove_existing_artifact: (bool, optional). Defaults to `True`.
759762
Whether artifacts uploaded to object storage bucket need to be removed or not.
763+
model_version_set: (Union[str, ModelVersionSet], optional). Defaults to None.
764+
The Model version set OCID, or name, or `ModelVersionSet` instance.
765+
version_label: (str, optional). Defaults to None.
766+
The model version label.
760767
761768
Examples
762769
________
@@ -894,6 +901,8 @@ def save(
894901
defined_tags=defined_tags,
895902
bucket_uri=bucket_uri,
896903
remove_existing_artifact=remove_existing_artifact,
904+
model_version_set=model_version_set,
905+
version_label=version_label,
897906
)
898907
except oci.exceptions.RequestException as e:
899908
if "The write operation timed out" in str(e):

‎ads/common/utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,9 @@ def default(self, obj):
737737
),
738738
):
739739
return int(obj)
740-
elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
740+
elif isinstance(
741+
obj, (np.float_, np.float16, np.float32, np.float64, np.double)
742+
):
741743
return float(obj)
742744
elif isinstance(obj, (np.ndarray,)):
743745
return obj.tolist()

‎ads/dataset/classification_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def mapper(df, column_name, arg):
229229
df[column_name] = df[column_name].map(arg)
230230
return df
231231

232-
df = df.map_partitions(mapper, target, update_arg)
232+
df = mapper(df, target, update_arg)
233233
sampled_df = mapper(sampled_df, target, update_arg)
234234
ClassificationDataset.__init__(
235235
self, df, sampled_df, target, target_type, shape, **kwargs

‎ads/dataset/dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1034,7 +1034,7 @@ def _convert_dtypes_to_avro_types(self):
10341034
avro_dtype = "double"
10351035
elif "float" in str(dtype):
10361036
avro_dtype = "float"
1037-
elif dtype == np.bool:
1037+
elif dtype == np.bool_:
10381038
avro_dtype = "boolean"
10391039
else:
10401040
avro_dtype = "string"

‎ads/dataset/factory.py

-3
Original file line numberDiff line numberDiff line change
@@ -828,9 +828,6 @@ def read_log(path, **kwargs):
828828
},
829829
**kwargs,
830830
)
831-
df["time"] = df["time"].map_partitions(
832-
pd.to_datetime, utc=True, meta="datetime64[ns]"
833-
)
834831
return df
835832

836833
@staticmethod

‎ads/dbmixin/db_pandas_accessor.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*--
33

4-
# Copyright (c) 2021, 2022 Oracle and/or its affiliates.
4+
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
from ads.bds.big_data_service import ADSHiveConnection
@@ -33,7 +33,7 @@ def get(cls, engine="oracle"):
3333

3434
if not Connection:
3535
if engine == "mysql":
36-
print("Requires mysql-connection-python package to use mysql engine")
36+
print("Requires mysql-connector-python package to use mysql engine")
3737
elif engine == "oracle":
3838
print(
3939
f"The `oracledb` or `cx_Oracle` module was not found. Please run "
@@ -102,6 +102,7 @@ def to_sql(
102102
if_exists: str = "fail",
103103
batch_size=100000,
104104
engine="oracle",
105+
encoding="utf-8",
105106
):
106107
"""To save the dataframe df to database.
107108
@@ -120,6 +121,8 @@ def to_sql(
120121
Inserting in batches improves insertion performance. Choose this value based on available memore and network bandwidth.
121122
engine: {'oracle', 'mysql'}, default 'oracle'
122123
Select the database type - MySQL or Oracle to store the data
124+
encoding: str, default is "utf-8"
125+
Encoding provided will be used for ecoding all columns, when inserting into table
123126
124127
125128
Returns
@@ -146,5 +149,5 @@ def to_sql(
146149

147150
Connection = ConnectionFactory.get(engine)
148151
return Connection(**connection_parameters).insert(
149-
table_name, self._obj, if_exists, batch_size
152+
table_name, self._obj, if_exists, batch_size, encoding
150153
)

‎ads/jobs/builders/infrastructure/dataflow.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@
3535
CONDA_PACK_SUFFIX = "#conda"
3636

3737

38+
def conda_pack_name_to_dataflow_config(conda_uri):
39+
return {
40+
"spark.archives": conda_uri + CONDA_PACK_SUFFIX, # .replace(" ", "%20")
41+
"dataflow.auth": "resource_principal",
42+
}
43+
44+
3845
class DataFlowApp(OCIModelMixin, oci.data_flow.models.Application):
3946
@classmethod
4047
def init_client(cls, **kwargs) -> oci.data_flow.data_flow_client.DataFlowClient:
@@ -778,12 +785,7 @@ def create(self, runtime: DataFlowRuntime, **kwargs) -> "DataFlow":
778785
else:
779786
raise ValueError(f"Conda built type not understood: {conda_type}.")
780787
runtime_config = runtime.configuration or dict()
781-
runtime_config.update(
782-
{
783-
"spark.archives": conda_uri.replace(" ", "%20") + CONDA_PACK_SUFFIX,
784-
"dataflow.auth": "resource_principal",
785-
}
786-
)
788+
runtime_config.update(conda_pack_name_to_dataflow_config(conda_uri))
787789
runtime.with_configuration(runtime_config)
788790
payload.update(
789791
{

0 commit comments

Comments
 (0)
Please sign in to comment.