Skip to content

Commit

Permalink
update code to work with more recent versions of dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
marlonjan committed Sep 12, 2024
1 parent ff46590 commit d461782
Show file tree
Hide file tree
Showing 12 changed files with 67 additions and 67 deletions.
4 changes: 2 additions & 2 deletions notebooks/quickstart_classification_iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,10 @@
# This is to indicate which quantity should be used for model selection.

# %%
from bokeh.plotting import figure, Figure
from bokeh.plotting import figure


def make_figure(title: str) -> Figure:
def make_figure(title: str) -> figure:
p = figure(title=title)
p.line([0, 1, 2, 3], np.random.random(size=4), line_width=2)
return p
Expand Down
4 changes: 2 additions & 2 deletions src/metriculous/comparison_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

import pytest
from bokeh import plotting
from bokeh.plotting import Figure
from bokeh.plotting import figure

from metriculous import Comparison, Evaluation, Quantity


def make_a_bokeh_figure() -> Figure:
def make_a_bokeh_figure() -> figure:
p = plotting.figure()
p.line(x=[0, 1, 5, 6], y=[40, 60, 30, 50])
return p
Expand Down
6 changes: 3 additions & 3 deletions src/metriculous/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass, field, replace
from typing import Callable, Generic, Optional, Sequence, TypeVar, Union

from bokeh.plotting import Figure
from bokeh.plotting import figure


@dataclass(frozen=True)
Expand All @@ -21,7 +21,7 @@ class Quantity:
class Evaluation:
model_name: str
quantities: Sequence[Quantity] = field(default_factory=list)
lazy_figures: Sequence[Callable[[], Figure]] = field(default_factory=list)
lazy_figures: Sequence[Callable[[], figure]] = field(default_factory=list)
primary_metric: Optional[str] = None

def get_by_name(self, quantity_name: str) -> Quantity:
Expand All @@ -37,7 +37,7 @@ def get_primary(self) -> Optional[Quantity]:
return None
return self.get_by_name(self.primary_metric)

def figures(self) -> Sequence[Figure]:
def figures(self) -> Sequence[figure]:
return [f() for f in self.lazy_figures]

def filtered(
Expand Down
8 changes: 4 additions & 4 deletions src/metriculous/evaluators/bokeh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from bokeh.embed import file_html
from bokeh.models import Div, Title
from bokeh.plotting import Figure
from bokeh.plotting import figure
from bokeh.resources import CDN

TOOLS = "pan,box_zoom,wheel_zoom,reset"
Expand Down Expand Up @@ -37,15 +37,15 @@ def title_div(title_rows: Sequence[str]) -> Div:
)


def add_title_rows(p: Figure, title_rows: Sequence[str]) -> None:
def add_title_rows(p: figure, title_rows: Sequence[str]) -> None:
for title_row in reversed(title_rows):
p.add_layout(
Title(text=title_row, text_font_size=FONT_SIZE, align="center"),
place="above",
)


def apply_default_style(p: Figure) -> None:
def apply_default_style(p: figure) -> None:
p.background_fill_color = BACKGROUND_COLOR
p.grid.grid_line_color = "white"

Expand Down Expand Up @@ -80,7 +80,7 @@ def scatter_plot_circle_size(
return max(smallest, biggest - slope * num_points)


def check_that_all_figures_can_be_rendered(figures: Sequence[Figure]) -> None:
def check_that_all_figures_can_be_rendered(figures: Sequence[figure]) -> None:
"""Generates HTML for each figure.
In some cases this reveals issues that might not be noticed if we just instantiated the figures
Expand Down
12 changes: 6 additions & 6 deletions src/metriculous/evaluators/classification/calibration_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from bokeh import plotting
from bokeh.models import ColumnDataSource, HoverTool
from bokeh.plotting import Figure
from bokeh.plotting import figure
from sklearn.utils import check_consistent_length, column_or_1d

from metriculous.evaluators.bokeh_utils import (
Expand All @@ -23,7 +23,7 @@ def _bokeh_probability_calibration_plot(
y_pred_score: np.ndarray,
title_rows: Sequence[str],
# TODO sample_weights: Optional[np.ndarray],
) -> Callable[[], Figure]:
) -> Callable[[], figure]:
"""Probability calibration plot.
Args:
Expand All @@ -39,7 +39,7 @@ def _bokeh_probability_calibration_plot(
"""

def figure() -> Figure:
def create_figure() -> figure:
assert y_true_binary.shape == y_pred_score.shape
assert set(y_true_binary).issubset({0, 1}) or set(y_true_binary).issubset(
{False, True}
Expand Down Expand Up @@ -71,8 +71,8 @@ def figure() -> Figure:
)

p = plotting.figure(
# plot_height=370,
# plot_width=350,
# height=370,
# width=350,
x_range=(-0.05, 1.05),
y_range=(-0.05, 1.05),
tools=TOOLS,
Expand Down Expand Up @@ -121,7 +121,7 @@ def figure() -> Figure:

return p

return figure
return create_figure


@dataclass(frozen=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import numpy.testing as npt
from assertpy import assert_that
from bokeh.plotting import Figure
from bokeh.plotting import figure
from scipy.stats import entropy
from sklearn import metrics as sklmetrics

Expand Down Expand Up @@ -203,9 +203,9 @@ def _lazy_figures(
data: ClassificationData,
maybe_sample_weights: Optional[np.ndarray],
class_names: Sequence[str],
) -> Sequence[Tuple[str, Callable[[], Figure]]]:
) -> Sequence[Tuple[str, Callable[[], figure]]]:

lazy_figures: List[Tuple[str, Callable[[], Figure]]] = []
lazy_figures: List[Tuple[str, Callable[[], figure]]] = []

y_true = data.target.argmaxes
y_true_one_hot = data.target.argmaxes_one_hot
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_ClassificationEvaluator_perfect_prediction(
),
Quantity(
name="Log Loss",
value=2.1094237467877998e-15,
value=2.220446e-16,
higher_is_better=False,
description=None,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from assertpy import assert_that
from bokeh import plotting
from bokeh.models import ColumnDataSource, HoverTool, LinearColorMapper
from bokeh.plotting import Figure
from bokeh.plotting import figure
from sklearn import metrics as sklmetrics

from metriculous.evaluators.bokeh_utils import (
Expand Down Expand Up @@ -33,7 +33,7 @@ def _bokeh_output_histogram(
title_rows: Sequence[str],
sample_weights: Optional[np.ndarray] = None,
x_label_rotation: Union[str, float] = "horizontal",
) -> Callable[[], Figure]:
) -> Callable[[], figure]:
"""Histogram of ground truth and prediction.
Args:
Expand All @@ -60,15 +60,15 @@ def _bokeh_output_histogram(
assert_that(np.shape(y_true)).is_equal_to(np.shape(weights))
normalize = not np.allclose(weights, 1.0)

def figure() -> Figure:
def create_figure() -> figure:
n = len(class_names)

bins = np.arange(0, n + 1, 1)

p = plotting.figure(
x_range=class_names,
plot_height=350,
plot_width=350,
height=350,
width=350,
tools=TOOLS,
toolbar_location=TOOLBAR_LOCATION,
)
Expand Down Expand Up @@ -110,7 +110,7 @@ def figure() -> Figure:
p.x_range.bounds = (-0.5, 0.5 + len(class_names))
return p

return figure
return create_figure


def _bokeh_confusion_matrix(
Expand All @@ -120,7 +120,7 @@ def _bokeh_confusion_matrix(
title_rows: Sequence[str],
x_label_rotation: Union[str, float] = "horizontal",
y_label_rotation: Union[str, float] = "vertical",
) -> Callable[[], Figure]:
) -> Callable[[], figure]:
"""Confusion matrix heatmap.
Args:
Expand All @@ -142,7 +142,7 @@ def _bokeh_confusion_matrix(
"""

def figure() -> Figure:
def create_figure() -> figure:
cm = sklmetrics.confusion_matrix(
y_true, y_pred, labels=list(range(len(class_names)))
)
Expand Down Expand Up @@ -241,7 +241,7 @@ def figure() -> Figure:

return p

return figure
return create_figure


def _bokeh_confusion_scatter(
Expand All @@ -251,7 +251,7 @@ def _bokeh_confusion_scatter(
title_rows: Sequence[str],
x_label_rotation: Union[str, float] = "horizontal",
y_label_rotation: Union[str, float] = "vertical",
) -> Callable[[], Figure]:
) -> Callable[[], figure]:
"""Scatter plot that contains the same information as a confusion matrix.
Args:
Expand All @@ -273,15 +273,15 @@ def _bokeh_confusion_scatter(
"""

def figure() -> Figure:
def create_figure() -> figure:
if len(y_true) != len(y_pred):
raise ValueError("y_true and y_pred must have the same length!")

p = plotting.figure(
x_range=(-0.5, -0.5 + len(class_names)),
y_range=(-0.5, -0.5 + len(class_names)),
plot_height=350,
plot_width=350,
height=350,
width=350,
tools=TOOLS,
toolbar_location=TOOLBAR_LOCATION,
match_aspect=True,
Expand Down Expand Up @@ -334,15 +334,15 @@ def noise() -> np.ndarray:

return p

return figure
return create_figure


def _bokeh_roc_curve(
y_true_binary: np.ndarray,
y_pred_score: np.ndarray,
title_rows: Sequence[str],
sample_weights: Optional[np.ndarray],
) -> Callable[[], Figure]:
) -> Callable[[], figure]:
"""Interactive receiver operator characteristic (ROC) curve.
Args:
Expand All @@ -360,7 +360,7 @@ def _bokeh_roc_curve(
"""

def figure() -> Figure:
def create_figure() -> figure:
assert y_true_binary.shape == y_pred_score.shape
assert set(y_true_binary).issubset({0, 1}) or set(y_true_binary).issubset(
{False, True}
Expand All @@ -381,8 +381,8 @@ def figure() -> Figure:
)

p = plotting.figure(
plot_height=400,
plot_width=350,
height=400,
width=350,
tools=TOOLS,
toolbar_location=TOOLBAR_LOCATION,
# toolbar_location=None, # hides entire toolbar
Expand Down Expand Up @@ -422,15 +422,15 @@ def figure() -> Figure:

return p

return figure
return create_figure


def _bokeh_precision_recall_curve(
y_true_binary: np.ndarray,
y_pred_score: np.ndarray,
title_rows: Sequence[str],
sample_weights: Optional[np.ndarray],
) -> Callable[[], Figure]:
) -> Callable[[], figure]:
"""Interactive precision recall curve.
Args:
Expand All @@ -448,7 +448,7 @@ def _bokeh_precision_recall_curve(
"""

def figure() -> Figure:
def create_figure() -> figure:
assert y_true_binary.shape == y_pred_score.shape
assert set(y_true_binary).issubset({0, 1}) or set(y_true_binary).issubset(
{False, True}
Expand All @@ -464,8 +464,8 @@ def figure() -> Figure:
recall = recall[:-1]

p = plotting.figure(
plot_height=400,
plot_width=350,
height=400,
width=350,
x_range=(-0.05, 1.05),
y_range=(-0.05, 1.05),
tools=TOOLS,
Expand Down Expand Up @@ -499,15 +499,15 @@ def figure() -> Figure:

return p

return figure
return create_figure


def _bokeh_automation_rate_analysis(
y_target_one_hot: np.ndarray,
y_pred_proba: np.ndarray,
title_rows: Sequence[str],
sample_weights: Optional[np.ndarray],
) -> Callable[[], Figure]:
) -> Callable[[], figure]:
"""
Plots various quantities over automation rate, where a single probability threshold
is used for all classes to decide if we are confident enough to automate the
Expand All @@ -528,7 +528,7 @@ def _bokeh_automation_rate_analysis(
"""

def figure() -> Figure:
def create_figure() -> figure:
# ----- Check input -----
assert y_target_one_hot.ndim == 2
assert y_pred_proba.ndim == 2
Expand Down Expand Up @@ -572,8 +572,8 @@ def figure() -> Figure:

# ----- Bokeh plot -----
p = plotting.figure(
plot_height=400,
plot_width=350,
height=400,
width=350,
x_range=(-0.05, 1.05),
y_range=(-0.05, 1.05),
tools=TOOLS,
Expand Down Expand Up @@ -634,7 +634,7 @@ def figure() -> Figure:

return p

return figure
return create_figure


def _faster_accuracy(
Expand Down
Loading

0 comments on commit d461782

Please sign in to comment.