Skip to content

Commit

Permalink
address some of the mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
marlonjan committed Sep 12, 2024
1 parent d461782 commit a1f3673
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 18 deletions.
3 changes: 2 additions & 1 deletion src/metriculous/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from IPython.display import HTML, display

from metriculous.evaluation import Evaluation, Evaluator
from numpy import floating


@dataclass(frozen=True)
Expand Down Expand Up @@ -217,7 +218,7 @@ def _model_evaluations_to_data_frame(
# create one row per quantity
data = []
for i_q, quantity_name in enumerate(quantity_names):
row: List[Union[str, float]] = [quantity_name]
row: List[Union[str, float, floating]] = [quantity_name]
for evaluation in model_evaluations:
quantity = evaluation.quantities[i_q]
assert_that(quantity.name).is_equal_to(quantity_name)
Expand Down
8 changes: 5 additions & 3 deletions src/metriculous/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
from dataclasses import dataclass, field, replace
from typing import Callable, Generic, Optional, Sequence, TypeVar, Union

from bokeh.models import LayoutDOM
from bokeh.plotting import figure
from numpy import floating


@dataclass(frozen=True)
class Quantity:
name: str
value: Union[float, str]
value: Union[float, str, floating]
higher_is_better: Optional[bool] = None
description: Optional[str] = None

Expand All @@ -21,7 +23,7 @@ class Quantity:
class Evaluation:
model_name: str
quantities: Sequence[Quantity] = field(default_factory=list)
lazy_figures: Sequence[Callable[[], figure]] = field(default_factory=list)
lazy_figures: Sequence[Callable[[], figure | LayoutDOM]] = field(default_factory=list)
primary_metric: Optional[str] = None

def get_by_name(self, quantity_name: str) -> Quantity:
Expand All @@ -37,7 +39,7 @@ def get_primary(self) -> Optional[Quantity]:
return None
return self.get_by_name(self.primary_metric)

def figures(self) -> Sequence[figure]:
def figures(self) -> Sequence[LayoutDOM]:
return [f() for f in self.lazy_figures]

def filtered(
Expand Down
4 changes: 2 additions & 2 deletions src/metriculous/evaluators/bokeh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
from bokeh.embed import file_html
from bokeh.models import Div, Title
from bokeh.models import Div, Title, LayoutDOM
from bokeh.plotting import figure
from bokeh.resources import CDN

Expand Down Expand Up @@ -80,7 +80,7 @@ def scatter_plot_circle_size(
return max(smallest, biggest - slope * num_points)


def check_that_all_figures_can_be_rendered(figures: Sequence[figure]) -> None:
def check_that_all_figures_can_be_rendered(figures: Sequence[LayoutDOM]) -> None:
"""Generates HTML for each figure.
In some cases this reveals issues that might not be noticed if we just instantiated the figures
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy.testing as npt
from assertpy import assert_that
from bokeh.plotting import figure
from numpy import ndarray, floating
from scipy.stats import entropy
from sklearn import metrics as sklmetrics

Expand Down Expand Up @@ -119,7 +120,7 @@ def evaluate(
ground_truth: ClassificationGroundTruth,
model_prediction: ClassificationPrediction,
model_name: str,
sample_weights: Optional[Sequence[float]] = None,
sample_weights: Optional[Sequence[float] | ndarray] = None,
) -> Evaluation:
"""
Computes Quantities and generates Figures that are useful for most
Expand Down Expand Up @@ -590,7 +591,7 @@ def check_input(


def _sample_weights(
sample_weights: Optional[Sequence[float]],
sample_weights: Optional[Sequence[float] | ndarray],
simulated_class_distribution: Optional[Sequence[float]],
y_true: Integers,
) -> Optional[np.ndarray]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -604,12 +604,12 @@ def create_figure() -> figure:

# Make sure something is visible if lines consist of just a single point
p.scatter(
x=source.data["automation_rate"][[0, -1]],
y=source.data["accuracy"][[0, -1]],
x=np.array(source.data["automation_rate"])[[0, -1]],
y=np.array(source.data["accuracy"])[[0, -1]],
)
p.scatter(
x=source.data["automation_rate"][[0, -1]],
y=source.data["threshold"][[0, -1]],
x=np.array(source.data["automation_rate"])[[0, -1]],
y=np.array(source.data["threshold"])[[0, -1]],
color="grey",
)

Expand Down
6 changes: 4 additions & 2 deletions src/metriculous/evaluators/regression/regression_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Callable, Optional, Sequence, Tuple

import numpy as np
from bokeh.models import LayoutDOM
from bokeh.plotting import figure
from numpy import ndarray
from sklearn import metrics as sklmetrics

from metriculous.evaluation import Evaluation, Evaluator, Quantity
Expand Down Expand Up @@ -61,7 +63,7 @@ def evaluate(
ground_truth: Floats,
model_prediction: Floats,
model_name: str,
sample_weights: Optional[Sequence[float]] = None,
sample_weights: Optional[Sequence[float] | ndarray] = None,
) -> Evaluation:
"""
Computes Quantities and generates figures that are useful for most
Expand Down Expand Up @@ -130,7 +132,7 @@ def _lazy_figures(
model_name: str,
maybe_sample_weights: Optional[Floats],
n_histogram_bins: int,
) -> Sequence[Tuple[str, Callable[[], figure]]]:
) -> Sequence[Tuple[str, Callable[[], LayoutDOM]]]:
F = FigureNames

if maybe_sample_weights is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def assert_all_close(
) -> None:
assert len(a) == len(b)
for qa, qb in zip(a, b):
if isinstance(qa.value, float):
if isinstance(qa.value, float) and isinstance(qb.value, float):
npt.assert_allclose(qa.value, qb.value, atol=atol, rtol=rtol)
assert replace(qa, value="any") == replace(qb, value="any")
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Callable, Iterable, Optional, Sequence, Tuple

import numpy as np
from bokeh.models import LayoutDOM
from bokeh.plotting import figure
from sklearn import metrics as sklmetrics

Expand Down Expand Up @@ -173,9 +174,9 @@ def evaluate(

def _lazy_figures(
self, model_name: str, y_pred: np.ndarray, y_true: np.ndarray
) -> Sequence[Tuple[str, Callable[[], figure]]]:
) -> Sequence[Tuple[str, Callable[[], LayoutDOM]]]:

lazy_figures = []
lazy_figures: list[Tuple[str, Callable[[], LayoutDOM]]] = []

class_distribution_figure_name = "Class Distribution"

Expand Down
3 changes: 2 additions & 1 deletion src/metriculous/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
from assertpy import assert_that
from numpy import floating
from sklearn.metrics import roc_auc_score, roc_curve


Expand All @@ -17,7 +18,7 @@ def normalized(matrix: np.ndarray) -> np.ndarray:


def cross_entropy(
target_probas: np.ndarray, pred_probas: np.ndarray, epsilon: float = 1e-15
target_probas: np.ndarray, pred_probas: np.ndarray, epsilon: float | floating = 1e-15
) -> float:
"""Returns the cross-entropy for probabilistic ground truth labels.
Expand Down
2 changes: 2 additions & 0 deletions src/metriculous/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def test_a_vs_b_auroc_symmetry() -> None:
a2b1 = metrics.a_vs_b_auroc(
target_ints=target_ints, predicted_probas=probas, class_a=2, class_b=1
)
assert isinstance(a1b2, float)
assert isinstance(a2b1, float)
np.testing.assert_allclose(a1b2, a2b1, atol=1e-15)


Expand Down

0 comments on commit a1f3673

Please sign in to comment.