|
| 1 | +from collections.abc import Callable, Sequence |
| 2 | + |
| 3 | +from bayesflow.types import Shape, Tensor |
| 4 | + |
| 5 | +from bayesflow.links import OrderedQuantiles, PositiveSemiDefinite |
| 6 | + |
| 7 | +from bayesflow.utils import logging |
| 8 | + |
| 9 | +import keras |
| 10 | + |
| 11 | +import math |
| 12 | + |
| 13 | +PI = keras.ops.convert_to_tensor(math.pi) |
| 14 | + |
| 15 | + |
| 16 | +class ScoringRule: |
| 17 | + def get_link(self): |
| 18 | + return keras.layers.Activation("linear") |
| 19 | + |
| 20 | + def build(self, reference_shape: Shape): |
| 21 | + pass |
| 22 | + |
| 23 | + def score(self, reference, target): |
| 24 | + raise NotImplementedError |
| 25 | + |
| 26 | + |
| 27 | +class NormedDifferenceScore(ScoringRule): |
| 28 | + def __init__( |
| 29 | + self, |
| 30 | + k: int, # results in an estimator for the mean |
| 31 | + ): |
| 32 | + self.k = k |
| 33 | + self.target_shape = (1,) |
| 34 | + |
| 35 | + def score(self, reference: Tensor, target: Tensor) -> Tensor: |
| 36 | + pointwise_differance = target - reference[:, None, :] |
| 37 | + score = keras.ops.absolute(pointwise_differance) ** self.k |
| 38 | + score = keras.ops.mean(score) |
| 39 | + return score |
| 40 | + |
| 41 | + |
| 42 | +class MedianScore(NormedDifferenceScore): |
| 43 | + def __init__(self): |
| 44 | + super().__init__(k=1) |
| 45 | + |
| 46 | + |
| 47 | +class MeanScore(NormedDifferenceScore): |
| 48 | + def __init__(self): |
| 49 | + super().__init__(k=2) |
| 50 | + |
| 51 | + |
| 52 | +class WeightedNormedDifferenceScore(ScoringRule): |
| 53 | + def __init__( |
| 54 | + self, |
| 55 | + weighting_function: Callable, |
| 56 | + k: int = 2, |
| 57 | + ): |
| 58 | + if weighting_function: |
| 59 | + self.weighting_function = weighting_function |
| 60 | + else: |
| 61 | + self.weighting_function = lambda input: 1 |
| 62 | + self.k = k |
| 63 | + self.target_shape = (1,) |
| 64 | + |
| 65 | + def score(self, reference: Tensor, target: Tensor) -> Tensor: |
| 66 | + pointwise_differance = target - reference[:, None, :] |
| 67 | + score = self.weighting_function(reference) * keras.ops.absolute(pointwise_differance) ** self.k |
| 68 | + score = keras.ops.mean(score) |
| 69 | + return score |
| 70 | + |
| 71 | + |
| 72 | +class QuantileScore(ScoringRule): |
| 73 | + def __init__( |
| 74 | + self, |
| 75 | + q: Sequence[float] = None, |
| 76 | + ): |
| 77 | + if q is None: |
| 78 | + q = [0.1, 0.5, 0.9] |
| 79 | + logging.info(f"QuantileScore was not provided with argument `q`. Using the default quantile levels: {q}.") |
| 80 | + |
| 81 | + self.q = keras.ops.convert_to_tensor(q) |
| 82 | + self.target_shape = (len(self.q),) |
| 83 | + |
| 84 | + def get_link(self): |
| 85 | + if self.q is None: |
| 86 | + raise AssertionError("Needs q to construct link") |
| 87 | + else: |
| 88 | + print(self.q) |
| 89 | + return OrderedQuantiles(self.q) |
| 90 | + |
| 91 | + def score(self, reference: Tensor, target: Tensor) -> Tensor: |
| 92 | + pointwise_differance = target - reference[:, None, :] |
| 93 | + |
| 94 | + score = pointwise_differance * (keras.ops.cast(pointwise_differance > 0, float) - self.q[None, :, None]) |
| 95 | + score = keras.ops.mean(score) |
| 96 | + return score |
| 97 | + |
| 98 | + |
| 99 | +class ParametricDistributionRule(ScoringRule): |
| 100 | + """ |
| 101 | + TODO |
| 102 | + """ |
| 103 | + |
| 104 | + def __init__(self, target_mappings: dict[str, str] = None): |
| 105 | + self.target_mappings = target_mappings |
| 106 | + self.build( |
| 107 | + ( |
| 108 | + None, |
| 109 | + 0, |
| 110 | + ) |
| 111 | + ) # TODO: make less confusing: this currently needs to be called initially AND another time when the |
| 112 | + # scoring rules attributes are updated instead of just once in the end. This is probably not the Keras way. |
| 113 | + |
| 114 | + def build(self, reference_shape: Shape): |
| 115 | + if self.target_mappings is None: |
| 116 | + self.target_mappings = {key: key for key in self.compute_target_shape().keys()} |
| 117 | + |
| 118 | + self.target_shape = self.map_target_keys(self.compute_target_shape(), inverse=True) |
| 119 | + |
| 120 | + def map_target_keys(self, target_dict: dict[str, str], inverse=False): |
| 121 | + if inverse: |
| 122 | + map = {v: k for k, v in self.target_mappings.items()} |
| 123 | + else: |
| 124 | + map = self.target_mappings |
| 125 | + return {map[key]: value for key, value in target_dict.items()} |
| 126 | + |
| 127 | + def compute_target_shape(self): |
| 128 | + raise NotImplementedError |
| 129 | + |
| 130 | + def log_prob(self, x, **kwargs): |
| 131 | + raise NotImplementedError |
| 132 | + |
| 133 | + def score(self, reference: Tensor, target: dict[str, Tensor]) -> Tensor: |
| 134 | + score = -self.log_prob(x=reference, **self.map_target_keys(target)) |
| 135 | + score = keras.ops.mean(score) |
| 136 | + return score * 0.01 |
| 137 | + |
| 138 | + |
| 139 | +class MultivariateNormalScore(ParametricDistributionRule): |
| 140 | + def __init__(self, D: int = None, **kwargs): |
| 141 | + super().__init__(**kwargs) |
| 142 | + self.D = D |
| 143 | + |
| 144 | + def build(self, reference_shape: Shape): |
| 145 | + if reference_shape is None: |
| 146 | + raise AssertionError("Cannot build before setting D.") |
| 147 | + elif isinstance(reference_shape, tuple) and len(reference_shape) == 2: |
| 148 | + self.D = reference_shape[1] |
| 149 | + else: |
| 150 | + raise AssertionError(f"Cannot build from reference_shape {reference_shape}") |
| 151 | + super().build(reference_shape) |
| 152 | + |
| 153 | + def compute_target_shape(self) -> dict[str, Shape]: |
| 154 | + return dict(mean=(1,), covariance=(self.D,)) |
| 155 | + |
| 156 | + def get_link(self): |
| 157 | + return self.map_target_keys( |
| 158 | + dict(mean=keras.layers.Activation("linear"), covariance=PositiveSemiDefinite()), inverse=True |
| 159 | + ) |
| 160 | + |
| 161 | + def log_prob(self, x: Tensor, mean: Tensor, covariance: Tensor) -> Tensor: |
| 162 | + diff = x[:, None, :] - mean |
| 163 | + inv_covariance = keras.ops.inv(covariance) |
| 164 | + log_det_covariance = keras.ops.slogdet(covariance)[1] # Only take the log of the determinant part |
| 165 | + |
| 166 | + # Compute the quadratic term in the exponential of the multivariate Gaussian |
| 167 | + quadratic_term = keras.ops.einsum("...i,...ij,...j->...", diff, inv_covariance, diff) |
| 168 | + |
| 169 | + # Compute the log probability density |
| 170 | + log_prob = -0.5 * (self.D * keras.ops.log(2 * PI) + log_det_covariance + quadratic_term) |
| 171 | + |
| 172 | + return log_prob |
| 173 | + |
| 174 | + def sample(self, sample_size, mean, covariance): |
| 175 | + batch_size, D = mean.shape |
| 176 | + # Ensure covariance is (batch_size, D, D) |
| 177 | + assert covariance.shape == (batch_size, D, D) |
| 178 | + |
| 179 | + # Use Cholesky decomposition to generate samples |
| 180 | + chol = keras.ops.cholesky(covariance) |
| 181 | + normal_samples = keras.random.normal((batch_size, D, sample_size)) |
| 182 | + samples = mean[:, :, None] + keras.ops.einsum("ijk,ikl->ijl", chol, normal_samples) |
| 183 | + |
| 184 | + return samples |
0 commit comments