Skip to content

Commit

Permalink
Fix various type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege committed Apr 28, 2024
1 parent a35a2d0 commit f1274fb
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 24 deletions.
17 changes: 8 additions & 9 deletions calibr8/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,17 @@
import json
import logging
import os
import typing
import warnings
from pathlib import Path
from typing import Callable, Optional, Sequence, Tuple, Union
from typing import Callable, DefaultDict, List, Optional, Sequence, Tuple, Union

import numpy
import scipy

from . import utils
from .utils import DistributionType, pm

__version__ = "7.1.1"
__version__ = "7.1.2"
_log = logging.getLogger("calibr8")


Expand Down Expand Up @@ -170,7 +169,7 @@ def _interval_prob(x_cdf: numpy.ndarray, cdf: numpy.ndarray, a: float, b: float)
return cdf[ib] - cdf[ia]


def _get_eti(x_cdf: numpy.ndarray, cdf: numpy.ndarray, ci_prob: float) -> typing.Tuple[float, float]:
def _get_eti(x_cdf: numpy.ndarray, cdf: numpy.ndarray, ci_prob: float) -> Tuple[float, float]:
"""Find the equal tailed interval (ETI) corresponding to a certain credible interval probability level.
Parameters
Expand Down Expand Up @@ -203,8 +202,8 @@ def _get_hdi(
guess_lower: float,
guess_upper: float,
*,
history: typing.Optional[typing.DefaultDict[str, typing.List]] = None,
) -> typing.Tuple[float, float]:
history: Optional[DefaultDict[str, List]] = None,
) -> Tuple[float, float]:
"""Find the highest density interval (HDI) corresponding to a certain credible interval probability level.
Parameters
Expand Down Expand Up @@ -600,7 +599,7 @@ def likelihood(self, *, y, x, theta=None, scan_x: bool = False):
return numpy.exp([self.loglikelihood(y=y, x=xi, theta=theta) for xi in x])
return numpy.exp(self.loglikelihood(y=y, x=x, theta=theta))

def objective(self, independent, dependent, minimize=True) -> typing.Callable:
def objective(self, independent, dependent, minimize=True) -> Callable:
"""Creates an objective function for fitting to data.
Parameters
Expand Down Expand Up @@ -628,7 +627,7 @@ def objective(x):

return objective

def save(self, filepath: Union[Path, os.PathLike]):
def save(self, filepath: Union[str, Path, os.PathLike]):
"""Save key properties of the calibration model to a JSON file.
Parameters
Expand All @@ -654,7 +653,7 @@ def save(self, filepath: Union[Path, os.PathLike]):
return

@classmethod
def load(cls, filepath: Union[Path, os.PathLike]):
def load(cls, filepath: Union[str, Path, os.PathLike]):
"""Instantiates a model from a JSON file of key properties.
Parameters
Expand Down
19 changes: 12 additions & 7 deletions calibr8/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
likelihood estimation of calibration model parameters.
"""
import logging
import typing
from typing import Any, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Literal, Mapping, Optional, Sequence, Tuple, Union

import numpy
import scipy.optimize
Expand All @@ -14,7 +13,11 @@
_log = logging.getLogger("calibr8.optimization")


def _mask_and_warn_inf_or_nan(x: numpy.ndarray, y: numpy.ndarray, on: typing.Optional[str] = None):
def _mask_and_warn_inf_or_nan(
x: Union[Sequence[float], numpy.ndarray],
y: Union[Sequence[float], numpy.ndarray],
on: Optional[Literal["x", "y"]] = None,
) -> Tuple[numpy.ndarray, numpy.ndarray]:
"""Filters `x` and `y` such that only finite elements remain.
Parameters
Expand All @@ -31,6 +34,8 @@ def _mask_and_warn_inf_or_nan(x: numpy.ndarray, y: numpy.ndarray, on: typing.Opt
x : array
y : array
"""
x = numpy.asarray(x)
y = numpy.asarray(y)
xdims = numpy.ndim(x)
if xdims == 1:
mask_x = numpy.isfinite(x)
Expand Down Expand Up @@ -82,8 +87,8 @@ def _warn_hit_bounds(theta, bounds, theta_names) -> bool:
def fit_scipy(
model: core.CalibrationModel,
*,
independent: numpy.ndarray,
dependent: numpy.ndarray,
independent: Union[Sequence[float], numpy.ndarray],
dependent: Union[Sequence[float], numpy.ndarray],
theta_guess: Union[Sequence[float], numpy.ndarray],
theta_bounds: Sequence[Tuple[float, float]],
minimize_kwargs: Optional[Mapping[str, Any]] = None,
Expand Down Expand Up @@ -154,8 +159,8 @@ def fit_scipy(
def fit_scipy_global(
model: core.CalibrationModel,
*,
independent: numpy.ndarray,
dependent: numpy.ndarray,
independent: Union[Sequence[float], numpy.ndarray],
dependent: Union[Sequence[float], numpy.ndarray],
theta_bounds: list,
method: Optional[str] = None,
maxiter: int = 5000,
Expand Down
21 changes: 13 additions & 8 deletions calibr8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
imports, timestamp parsing and plotting.
"""
import datetime
import typing
import warnings
from collections.abc import Iterable
from typing import Literal, Optional, Sequence, Tuple
from typing import List, Literal, Optional, Sequence, Tuple

import matplotlib
import numpy
Expand Down Expand Up @@ -52,7 +51,7 @@ def __getattr__(self, attr):
pm = ImportWarner("pymc")


def parse_datetime(s: typing.Optional[str]) -> typing.Optional[datetime.datetime]:
def parse_datetime(s: Optional[str]) -> Optional[datetime.datetime]:
"""Parses a timezone-aware datetime formatted like 2020-08-05T13:37:00Z.
Returns
Expand All @@ -65,7 +64,7 @@ def parse_datetime(s: typing.Optional[str]) -> typing.Optional[datetime.datetime
return datetime.datetime.strptime(s.replace("Z", "+0000"), "%Y-%m-%dT%H:%M:%S%z")


def format_datetime(dt: typing.Optional[datetime.datetime]) -> typing.Optional[str]:
def format_datetime(dt: Optional[datetime.datetime]) -> Optional[str]:
"""Formats a datetime like 2020-08-05T13:37:00Z.
Returns
Expand Down Expand Up @@ -176,7 +175,9 @@ def plot_norm_band(ax, independent, mu, scale):
return artists


def plot_t_band(ax, independent, mu, scale, df, *, residual_type: typing.Optional[str] = None):
def plot_t_band(
ax, independent, mu, scale, df, *, residual_type: Optional[Literal["absolute", "relative"]] = None
):
"""Helper function for plotting the 68, 90 and 95 % likelihood-bands of a t-distribution.
Parameters
Expand Down Expand Up @@ -241,7 +242,9 @@ def plot_t_band(ax, independent, mu, scale, df, *, residual_type: typing.Optiona
return artists


def plot_continuous_band(ax, independent, model, residual_type: typing.Optional[str] = None):
def plot_continuous_band(
ax, independent, model, residual_type: Optional[Literal["absolute", "relative"]] = None
):
"""Helper function for plotting the 68, 90 and 95 % likelihood-bands of a univariate distribution.
Parameters
Expand Down Expand Up @@ -364,9 +367,9 @@ def plot_model(
*,
fig: Optional[matplotlib.figure.Figure] = None,
axs: Optional[Sequence[matplotlib.axes.Axes]] = None,
residual_type="absolute",
residual_type: Literal["absolute", "relative"] = "absolute",
band_xlim: Tuple[Optional[float], Optional[float]] = (None, None),
):
) -> Tuple[matplotlib.figure.Figure, List[matplotlib.axes.Axes]]:
"""Makes a plot of the model with its data.
Parameters
Expand Down Expand Up @@ -416,6 +419,8 @@ def plot_model(
axs.append(fig.add_subplot(gs1[0, 1], sharey=axs[0]))
pyplot.setp(axs[1].get_yticklabels(), visible=False)
axs.append(fig.add_subplot(gs2[0, 2]))
else:
axs = list(axs)

# ======= Left =======
# Untransformed, outer range
Expand Down

0 comments on commit f1274fb

Please sign in to comment.