Skip to content

Commit 09bbfb9

Browse files
committed
Fix bug
1 parent 956106f commit 09bbfb9

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

get_max_scores.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ def getMaxScoreMethod(label, arr):
1717
def removeRandomStateParam(model):
1818
if (model != None):
1919
d=model.get_params()
20-
d['random_state'] = None
20+
if ('random_state' in d):
21+
d['random_state'] = None
2122
model.set_params(**d)
2223
return model
2324

main.py

+5
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,13 @@ def check_models(models, model_type):
8484
raise ValueError("The {model_type}s should be a list of ({model_type} name, {model_type}) pairs or ({model_type} name, {model_type}, parameters grid) triplets.".format(model_type=model_type))
8585
return flat_models
8686

87+
#import matplotlib.pyplot as plt
88+
#from sklearn.metrics import RocCurveDisplay
89+
8790
def score_method(X_train, X_test, y_train, y_test, oversampler, classifier):
8891
y_predict = classifier[1].predict(X_test.values)
92+
#RocCurveDisplay.from_predictions(y_test.values, y_predict)
93+
#plt.show()
8994
return {
9095
'train_score': classifier[1].score(X_train.values, y_train.values),
9196
'test_score': classifier[1].score(X_test.values, y_test.values),

0 commit comments

Comments
 (0)