Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_circles
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import FixedThresholdClassifier
from sklearn.metrics import precision_score
from sklearn.inspection import DecisionBoundaryDisplay
Expand All @@ -21,17 +21,71 @@
RANDOM_STATE = 1

##############################################################################
# Let us first load the dataset and fit an SVC on the training data.
# Fist, load the dataset and then split it into training, calibration
# (for conformalization), and test sets.

# Generate toy dataset
X, y = make_circles(n_samples=3000, noise=0.3, factor=0.3, random_state=RANDOM_STATE)
(X_train, X_calib, X_test,
y_train, y_calib, y_test) = train_conformalize_test_split(
X, y,
train_size=0.8, conformalize_size=0.1, test_size=0.1,
random_state=RANDOM_STATE
)

clf = SVC(probability=True, random_state=RANDOM_STATE)
(X_train, X_calib, X_test, y_train, y_calib, y_test) = train_conformalize_test_split(
X,
y,
train_size=0.8,
conformalize_size=0.1,
test_size=0.1,
random_state=RANDOM_STATE,
)

# Plot the three datasets to visualize the distribution of the two classes.
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
titles = ["Training Data", "Calibration Data", "Test Data"]
datasets = [(X_train, y_train), (X_calib, y_calib), (X_test, y_test)]

for i, (ax, (X_data, y_data), title) in enumerate(zip(axes, datasets, titles)):
ax.scatter(
X_data[y_data == 0, 0],
X_data[y_data == 0, 1],
edgecolors="k",
c="tab:blue",
alpha=0.5,
label='"negative" class',
)
ax.scatter(
X_data[y_data == 1, 0],
X_data[y_data == 1, 1],
edgecolors="k",
c="tab:red",
alpha=0.5,
label='"positive" class',
)
ax.set_title(title, fontsize=18)
ax.set_xlabel("Feature 1", fontsize=16)
ax.tick_params(labelsize=14)

if i == 0:
ax.set_ylabel("Feature 2", fontsize=16)
else:
ax.set_ylabel("")
ax.set_yticks([])

handles, labels = axes[0].get_legend_handles_labels()
fig.legend(
handles,
labels,
loc="lower center",
bbox_to_anchor=(0.5, -0.01),
ncol=2,
fontsize=16,
)

plt.suptitle("Visualization of Train, Calibration, and Test Sets", fontsize=22)
plt.tight_layout(rect=[0, 0.05, 1, 0.95])
plt.show()

##############################################################################
# Second, fit a KNeighborsClassifier on the training data.

# Fit KNeighborsClassifier on training data
clf = KNeighborsClassifier(n_neighbors=10)
clf.fit(X_train, y_train)

##############################################################################
Expand All @@ -45,15 +99,18 @@
confidence_level = 0.9
bcc = BinaryClassificationController(
clf.predict_proba,
precision, target_level=target_precision,
confidence_level=confidence_level
)
precision,
target_level=target_precision,
confidence_level=confidence_level,
)
bcc.calibrate(X_calib, y_calib)

print(f'{len(bcc.valid_predict_params)} thresholds found that guarantee a precision of '
f'at least {target_precision} with a confidence of {confidence_level}.\n'
'Among those, the one that maximizes the secondary objective (recall here) is: '
f'{bcc.best_predict_param:.3f}.')
print(
f"{len(bcc.valid_predict_params)} thresholds found that guarantee a precision of "
f"at least {target_precision} with a confidence of {confidence_level}.\n"
"Among those, the one that maximizes the secondary objective (recall here) is: "
f"{bcc.best_predict_param:.3f}."
)


##############################################################################
Expand All @@ -69,29 +126,42 @@
precisions[i] = precision_score(y_calib, y_pred)

valid_thresholds_indices = np.array(
[t in bcc.valid_predict_params for t in tested_thresholds])
best_threshold_index = np.where(
tested_thresholds == bcc.best_predict_param)[0][0]
[t in bcc.valid_predict_params for t in tested_thresholds]
)
best_threshold_index = np.where(tested_thresholds == bcc.best_predict_param)[0][0]

plt.figure()
plt.scatter(
tested_thresholds[valid_thresholds_indices], precisions[valid_thresholds_indices],
c='tab:green', label='Valid thresholds'
)
tested_thresholds[valid_thresholds_indices],
precisions[valid_thresholds_indices],
c="tab:green",
label="Valid thresholds",
)
plt.scatter(
tested_thresholds[~valid_thresholds_indices], precisions[~valid_thresholds_indices],
c='tab:red', label='Invalid thresholds'
)
tested_thresholds[~valid_thresholds_indices],
precisions[~valid_thresholds_indices],
c="tab:red",
label="Invalid thresholds",
)
plt.scatter(
tested_thresholds[best_threshold_index], precisions[best_threshold_index],
c='tab:green', label='Best threshold', marker='*', edgecolors='k', s=300
)
plt.axhline(target_precision, color='tab:gray', linestyle='--')
tested_thresholds[best_threshold_index],
precisions[best_threshold_index],
c="tab:green",
label="Best threshold",
marker="*",
edgecolors="k",
s=300,
)
plt.axhline(target_precision, color="tab:gray", linestyle="--")
plt.text(
0.7, target_precision+0.02, 'Target precision', color='tab:gray', fontstyle='italic'
0.7,
target_precision + 0.02,
"Target precision",
color="tab:gray",
fontstyle="italic",
)
plt.xlabel('Threshold')
plt.ylabel('Precision')
plt.xlabel("Threshold")
plt.ylabel("Precision")
plt.legend()
plt.show()

Expand Down Expand Up @@ -126,16 +196,24 @@

disp = DecisionBoundaryDisplay.from_estimator(
clf_threshold, X_test, response_method="predict", cmap=plt.cm.coolwarm
)
)

plt.scatter(
X_test[y_test == 0, 0], X_test[y_test == 0, 1],
edgecolors='k', c='tab:blue', alpha=0.5, label='"negative" class'
)
X_test[y_test == 0, 0],
X_test[y_test == 0, 1],
edgecolors="k",
c="tab:blue",
alpha=0.5,
label='"negative" class',
)
plt.scatter(
X_test[y_test == 1, 0], X_test[y_test == 1, 1],
edgecolors='k', c='tab:red', alpha=0.5, label='"positive" class'
)
X_test[y_test == 1, 0],
X_test[y_test == 1, 1],
edgecolors="k",
c="tab:red",
alpha=0.5,
label='"positive" class',
)
plt.title("Decision Boundary of FixedThresholdClassifier")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
Expand Down
8 changes: 5 additions & 3 deletions mapie/risk_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ class BinaryClassificationController:
... target_level=0.6
... )

>>> controller.calibrate(X_calib, y_calib)
>>> controller = controller.calibrate(X_calib, y_calib)
>>> predictions = controller.predict(X_test)

References
Expand Down Expand Up @@ -1004,7 +1004,7 @@ def calibrate( # pragma: no cover
self,
X_calibrate: ArrayLike,
y_calibrate: ArrayLike
) -> None:
) -> "BinaryClassificationController":
"""
Calibrate the BinaryClassificationController.
Sets attributes valid_predict_params and best_predict_param (if the risk
Expand All @@ -1020,7 +1020,8 @@ def calibrate( # pragma: no cover

Returns
-------
None
BinaryClassificationController
The fitted controller instance (for chaining).
"""
y_calibrate_ = np.asarray(y_calibrate, dtype=int)

Expand Down Expand Up @@ -1056,6 +1057,7 @@ def calibrate( # pragma: no cover
predictions_per_param,
valid_params_index,
)
return self

def predict(self, X_test: ArrayLike) -> NDArray:
"""
Expand Down
71 changes: 23 additions & 48 deletions mapie/tests/risk_control/test_control_risk.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Testing for control_risk module.
Testing for now risks for multilabel classification
"""

from typing import List, Union

import numpy as np
Expand All @@ -10,40 +11,28 @@
from numpy.typing import NDArray
from mapie.control_risk.ltt import find_lambda_control_star, ltt_procedure
from mapie.control_risk.p_values import compute_hoeffding_bentkus_p_value
from mapie.control_risk.risks import (compute_risk_precision,
compute_risk_recall)
from mapie.control_risk.risks import compute_risk_precision, compute_risk_recall

lambdas = np.array([0.5, 0.9])

y_toy = np.stack([
[1, 0, 1],
[0, 1, 0],
[1, 1, 0],
[1, 1, 1],
])
y_toy = np.stack(
[
[1, 0, 1],
[0, 1, 0],
[1, 1, 0],
[1, 1, 1],
]
)

y_preds_proba = np.stack([
[0.2, 0.6, 0.9],
[0.8, 0.2, 0.6],
[0.4, 0.8, 0.1],
[0.6, 0.8, 0.7]
])
y_preds_proba = np.stack(
[[0.2, 0.6, 0.9], [0.8, 0.2, 0.6], [0.4, 0.8, 0.1], [0.6, 0.8, 0.7]]
)

y_preds_proba = np.expand_dims(y_preds_proba, axis=2)

test_recall = np.array([
[1/2, 1.],
[1., 1.],
[1/2, 1.],
[0., 1.]
])
test_recall = np.array([[1 / 2, 1.0], [1.0, 1.0], [1 / 2, 1.0], [0.0, 1.0]])

test_precision = np.array([
[1/2, 1.],
[1., 1.],
[0., 1.],
[0., 1.]
])
test_precision = np.array([[1 / 2, 1.0], [1.0, 1.0], [0.0, 1.0], [0.0, 1.0]])

r_hat = np.array([0.5, 0.8])

Expand All @@ -55,10 +44,7 @@

wrong_alpha = 0

wrong_alpha_shape = np.array([
[0.1, 0.2],
[0.3, 0.4]
])
wrong_alpha_shape = np.array([[0.1, 0.2], [0.3, 0.4]])

random_state = 42
prng = np.random.RandomState(random_state)
Expand Down Expand Up @@ -115,13 +101,11 @@ def test_compute_recall_with_wrong_shape() -> None:


def test_compute_precision_with_wrong_shape() -> None:
"""Test shape when using _compute_precision"""
"""Test error when wrong shape in _compute_precision"""
with pytest.raises(ValueError, match=r".*y_pred_proba should be a 3d*"):
compute_risk_precision(lambdas, y_preds_proba.squeeze(), y_toy)
with pytest.raises(ValueError, match=r".*y should be a 2d*"):
compute_risk_precision(
lambdas, y_preds_proba, np.expand_dims(y_toy, 2)
)
compute_risk_precision(lambdas, y_preds_proba, np.expand_dims(y_toy, 2))
with pytest.raises(ValueError, match=r".*could not be broadcast*"):
compute_risk_precision(lambdas, y_preds_proba, y_toy[:-1])

Expand All @@ -146,10 +130,7 @@ def test_find_lambda_control_star() -> None:

@pytest.mark.parametrize("delta", [0.1, 0.8])
@pytest.mark.parametrize("alpha", [[0.5], [0.6, 0.8]])
def test_ltt_type_output_alpha_delta(
alpha: NDArray,
delta: float
) -> None:
def test_ltt_type_output_alpha_delta(alpha: NDArray, delta: float) -> None:
"""Test type output _ltt_procedure"""
valid_index = ltt_procedure(r_hat, alpha, delta, n)
assert isinstance(valid_index, list)
Expand All @@ -164,9 +145,7 @@ def test_find_lambda_control_star_output(valid_index: List[List[int]]) -> None:
def test_warning_valid_index_empty() -> None:
"""Test warning sent when empty list"""
valid_index = [[]] # type: List[List[int]]
with pytest.warns(
UserWarning, match=r".*At least one sequence is empty*"
):
with pytest.warns(UserWarning, match=r".*At least one sequence is empty*"):
find_lambda_control_star(r_hat, valid_index, lambdas)


Expand All @@ -189,14 +168,10 @@ def test_hb_p_values_n_obs_int_vs_array() -> None:
alpha = np.array([0.6, 0.7])

pval_0 = compute_hoeffding_bentkus_p_value(
np.array([r_hat[0]]),
int(n_obs[0]),
alpha
np.array([r_hat[0]]), int(n_obs[0]), alpha
)
pval_1 = compute_hoeffding_bentkus_p_value(
np.array([r_hat[1]]),
int(n_obs[1]),
alpha
np.array([r_hat[1]]), int(n_obs[1]), alpha
)
pval_manual = np.vstack([pval_0, pval_1])

Expand All @@ -211,7 +186,7 @@ def test_ltt_procedure_n_obs_negative() -> None:
This happens when the risk, defined as the conditional expectation of
a loss, is undefined because the condition is never met.
This should return an invalid lambda.
"""
"""
r_hat = np.array([0.5])
n_obs = np.array([-1])
alpha_np = np.array([0.6])
Expand Down
Loading
Loading