-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lib/model accessor #1
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,42 @@ | ||||||
# DKU Model Accessor | ||||||
|
||||||
## Description | ||||||
This lib provides tools to interact with dss saved models data (getting the original train/test set for example). | ||||||
|
||||||
It has a surrogate model and a doctor-like default preprocessor allowing to retrieve feature importance of any non-tree-based models. | ||||||
|
||||||
It uses an internal api, `dataiku.doctor.posttraining.model_information_handler.PredictionModelInformationHandler` (merci mamène Coni) so beware of future api break. | ||||||
|
||||||
|
||||||
## Examples | ||||||
|
||||||
|
||||||
```python | ||||||
from dku_model_accessor import get_model_handler, ModelAccessor | ||||||
|
||||||
model_id = 'XQyU0TO0' | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Maybe use a more explicit ID |
||||||
model = dataiku.Model(model_id) | ||||||
model_handler = get_model_handler(model) | ||||||
model_accessor = ModelAccessor(model_handler) | ||||||
|
||||||
original_test_set = model_accessor.get_original_test_df() | ||||||
feature_importance = model_accessor.get_feature_importance() # works for any models | ||||||
selected_features = model_accessor.get_selected_features() | ||||||
``` | ||||||
|
||||||
## Projects using the library | ||||||
|
||||||
Don't hesitate to check these plugins using the library for more examples : | ||||||
|
||||||
- [dss-plugin-model-drift](https://github.com/dataiku/dss-plugin-model-drift) | ||||||
- [dss-plugin-model-fairness-report](https://github.com/dataiku/dss-plugin-model-fairness-report) | ||||||
- [dss-plugin-model-error-analysis](https://github.com/dataiku/dss-plugin-model-error-analysis) | ||||||
|
||||||
## Version | ||||||
|
||||||
- Version: 0.1.0 | ||||||
- State: <span style="color:green">Supported</span> | ||||||
|
||||||
## Credit | ||||||
|
||||||
Library created by Du Phan. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from dku_model_accessor.model_accessor import ModelAccessor | ||
from dku_model_accessor.model_metadata import get_model_handler | ||
from dku_model_accessor.constants import DkuModelAccessorConstants |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,22 @@ | ||||||
# -*- coding: utf-8 -*- | ||||||
|
||||||
class DkuModelAccessorConstants(object): | ||||||
MODEL_ID = 'model_id' | ||||||
VERSION_ID = 'version_id' | ||||||
REGRRSSION_TYPE = 'REGRESSION' | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
typo |
||||||
DKU_MULTICLASS_CLASSIF = 'MULTICLASS' | ||||||
DKU_BINARY_CLASSIF = 'BINARY_CLASSIFICATION' | ||||||
CLASSIFICATION_TYPE = 'CLASSIFICATION' | ||||||
CLUSTERING_TYPE = 'CLUSTERING' | ||||||
MAX_NUM_ROW = 1000000 | ||||||
CUMULATIVE_PERCENTAGE_THRESHOLD = 90 | ||||||
SURROGATE_TARGET = "_dku_predicted_label_" | ||||||
FEAT_IMP_CUMULATIVE_PERCENTAGE_THRESHOLD = 95 | ||||||
CUMULATIVE_IMPORTANCE = 'cumulative_importance' | ||||||
FEATURE = 'feature' | ||||||
IMPORTANCE = 'importance' | ||||||
RANK = 'rank' | ||||||
CLASS = 'class' | ||||||
PERCENTAGE = 'percentage' | ||||||
DKU_XGBOOST_CLASSIF = 'XGBOOST_CLASSIFICATION' | ||||||
DKU_XGBOOST_REGRESSION = 'XGBOOST_REGRESSION' |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,136 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# -*- coding: utf-8 -*- | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import logging | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import pandas as pd | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from dku_model_accessor.constants import DkuModelAccessorConstants | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from dku_model_accessor.surrogate_model import SurrogateModel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
GradientBoostingRegressor, ExtraTreesClassifier, ExtraTreesRegressor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger = logging.getLogger(__name__) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ALGORITHMS_WITH_VARIABLE_IMPORTANCE = [RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
GradientBoostingRegressor, ExtraTreesClassifier, ExtraTreesRegressor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
DecisionTreeClassifier, DecisionTreeRegressor] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class ModelAccessor(object): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Wrapper for our internal object PredictionModelInformationHandler | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __init__(self, model_handler=None): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_handler: PredictionModelInformationHandler object | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.model_handler = model_handler | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_prediction_type(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Wrap the prediction type accessor of the model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.model_handler.get_prediction_type() in [DkuModelAccessorConstants.DKU_BINARY_CLASSIF, DkuModelAccessorConstants.DKU_MULTICLASS_CLASSIF]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return DkuModelAccessorConstants.CLASSIFICATION_TYPE | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif DkuModelAccessorConstants.REGRRSSION_TYPE == self.model_handler.get_prediction_type(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return DkuModelAccessorConstants.REGRRSSION_TYPE | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return DkuModelAccessorConstants.CLUSTERING_TYPE | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_target_variable(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Return the name of the target variable | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self.model_handler.get_target_variable() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_original_test_df(self, limit=DkuModelAccessorConstants.MAX_NUM_ROW): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
full_test_df = self.model_handler.get_test_df()[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
test_df = full_test_df[:limit] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.info('Loading {}/{} rows of the original test set'.format(len(test_df), len(full_test_df))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return test_df | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
except Exception as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.warning('Can not retrieve original test set: {}. The plugin will take the whole original dataset.'.format(e)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
full_test_df = self.model_handler.get_full_df()[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
test_df = full_test_df[:limit] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.info('Loading {}/{} rows of the whole original test set'.format(len(test_df), len(full_test_df))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return test_df | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+44
to
+55
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Code repeated, I would only try...catch on what can really raise an exception. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_train_df(self, limit=DkuModelAccessorConstants.MAX_NUM_ROW): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
full_train_df = self.model_handler.get_train_df()[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_df = full_train_df[:limit] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.info('Loading {}/{} rows of the original train set'.format(len(train_df), len(full_train_df))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return train_df | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_per_feature(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self.model_handler.get_per_feature() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_predictor(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self.model_handler.get_predictor() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_feature_importance(self,cumulative_percentage_threshold=DkuModelAccessorConstants.FEAT_IMP_CUMULATIVE_PERCENTAGE_THRESHOLD): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
:param cumulative_percentage_threshold: only return the top n features whose sum of importance reaches this threshold | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
:return: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self._algorithm_is_tree_based(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
predictor = self.get_predictor() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
clf = predictor._clf | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
feature_names = predictor.get_features() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
feature_importances = clf.feature_importances_ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: # use surrogate model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.info('Fitting surrogate model ...') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
surrogate_model = SurrogateModel(self.get_prediction_type()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
original_test_df = self.get_original_test_df() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
predictions_on_original_test_df = self.get_predictor().predict(original_test_df) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
surrogate_df = original_test_df[self.get_selected_features()] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
surrogate_df[DkuModelAccessorConstants.SURROGATE_TARGET] = predictions_on_original_test_df['prediction'] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
surrogate_model.fit(surrogate_df, DkuModelAccessorConstants.SURROGATE_TARGET) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
feature_names = surrogate_model.get_features() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
feature_importances = surrogate_model.clf.feature_importances_ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+81
to
+89
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it would be better to wrap into a new method fit_surrogate_model() |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
feature_importance = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for feature_name, feat_importance in zip(feature_names, feature_importances): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
feature_importance.append({ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
DkuModelAccessorConstants.FEATURE: feature_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
DkuModelAccessorConstants.IMPORTANCE: 100 * feat_importance / sum(feature_importances) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dfx = pd.DataFrame(feature_importance).sort_values(by=DkuModelAccessorConstants.IMPORTANCE, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ascending=False).reset_index(drop=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dfx[DkuModelAccessorConstants.CUMULATIVE_IMPORTANCE] = dfx[DkuModelAccessorConstants.IMPORTANCE].cumsum() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dfx_top = dfx.loc[dfx[DkuModelAccessorConstants.CUMULATIVE_IMPORTANCE] <= cumulative_percentage_threshold] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return dfx_top.rename_axis(DkuModelAccessorConstants.RANK).reset_index().set_index( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
DkuModelAccessorConstants.FEATURE) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_selected_features(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Return only features used in the model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
selected_features = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for feat, feat_info in self.get_per_feature().items(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if feat_info.get('role') == 'INPUT': | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
selected_features.append(feat) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return selected_features | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_selected_and_rejected_features(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Return all features in the input dataset except the target | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
selected_features = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for feat, feat_info in self.get_per_feature().items(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if feat_info.get('role') in ['INPUT', 'REJECT']: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
selected_features.append(feat) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return selected_features | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+105
to
+123
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
DRY |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def predict(self, df): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self.get_predictor().predict(df) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _algorithm_is_tree_based(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
predictor = self.get_predictor() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
algo = predictor._clf | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for algorithm in ALGORITHMS_WITH_VARIABLE_IMPORTANCE: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if isinstance(algo, algorithm): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif predictor.params.modeling_params.get('algorithm') in [DkuModelAccessorConstants.DKU_XGBOOST_CLASSIF, DkuModelAccessorConstants.DKU_XGBOOST_REGRESSION]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would place this array in a new var |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return False |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# -*- coding: utf-8 -*- | ||
import json | ||
import os | ||
import sys | ||
from dataiku.doctor.posttraining.model_information_handler import PredictionModelInformationHandler | ||
|
||
|
||
def get_model_handler(model, version_id=None): | ||
""" | ||
model: a dku saved model returned by dataiku.Model(model_id) | ||
version_id: if None, the active one is chosen | ||
""" | ||
saved_model_version_id = _get_saved_model_version_id(model, version_id) | ||
return _get_model_info_handler(saved_model_version_id) | ||
|
||
|
||
def _get_saved_model_version_id(model, version_id=None): | ||
model_def = model.get_definition() | ||
if version_id is None: | ||
version_id = model_def.get('activeVersion') | ||
saved_model_version_id = 'S-{0}-{1}-{2}'.format(model_def.get('projectKey'), model_def.get('id'), version_id) | ||
return saved_model_version_id | ||
|
||
|
||
def _get_model_info_handler(saved_model_version_id): | ||
infos = saved_model_version_id.split("-") | ||
if len(infos) != 4 or infos[0] != "S": | ||
raise ValueError("Invalid saved model id") | ||
pkey = infos[1] | ||
model_id = infos[2] | ||
version_id = infos[3] | ||
|
||
datadir_path = os.environ['DIP_HOME'] | ||
version_folder = os.path.join(datadir_path, "saved_models", pkey, model_id, "versions", version_id) | ||
|
||
# Loading and resolving paths in split_desc | ||
split_folder = os.path.join(version_folder, "split") | ||
with open(os.path.join(split_folder, "split.json")) as split_file: | ||
split_desc = json.load(split_file) | ||
|
||
path_field_names = ["trainPath", "testPath", "fullPath"] | ||
for field_name in path_field_names: | ||
if split_desc.get(field_name, None) is not None: | ||
split_desc[field_name] = os.path.join(split_folder, split_desc[field_name]) | ||
|
||
with open(os.path.join(version_folder, "core_params.json")) as core_params_file: | ||
core_params = json.load(core_params_file) | ||
|
||
try: | ||
return PredictionModelInformationHandler(split_desc, core_params, version_folder, version_folder) | ||
except Exception as e: | ||
from future.utils import raise_ | ||
if "ordinal not in range(128)" in str(e): | ||
raise_(Exception, "The plugin only supports python3, cannot load a python2 model. Original error: {}".format(e), sys.exc_info()[2]) | ||
elif str(e) == "non-string names in Numpy dtype unpickling": | ||
raise_(Exception, "The plugin is using a python2 code-env, cannot load a python3 model. Original error: {}".format(e), sys.exc_info()[2]) | ||
elif str(e) == "Using saved models in python recipes is limited to models trained using the python engine": | ||
raise_(Exception, "The plugin does not support Clustering model.", sys.exc_info()[2]) | ||
else: | ||
raise_(Exception, "Fail to load saved model: {}".format(e), sys.exc_info()[2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.