Skip to content

Commit ecef8ed

Browse files
committed
Refactor scores
* rename scoring_rules module to scores * mean and median wrappers for normed difference scores * argument handling for quantile score * scoring rule tests
1 parent 1a85fc8 commit ecef8ed

File tree

7 files changed

+298
-78
lines changed

7 files changed

+298
-78
lines changed

bayesflow/scores/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .scores import (
2+
ScoringRule,
3+
ParametricDistributionRule,
4+
NormedDifferenceScore,
5+
MedianScore,
6+
MeanScore,
7+
QuantileScore,
8+
MultivariateNormalScore,
9+
)

bayesflow/scores/scores.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
from collections.abc import Callable, Sequence
2+
3+
from bayesflow.types import Shape, Tensor
4+
5+
from bayesflow.links import OrderedQuantiles, PositiveSemiDefinite
6+
7+
from bayesflow.utils import logging
8+
9+
import keras
10+
11+
import math
12+
13+
PI = keras.ops.convert_to_tensor(math.pi)
14+
15+
16+
class ScoringRule:
17+
def get_link(self):
18+
return keras.layers.Activation("linear")
19+
20+
def build(self, reference_shape: Shape):
21+
pass
22+
23+
def score(self, reference, target):
24+
raise NotImplementedError
25+
26+
27+
class NormedDifferenceScore(ScoringRule):
28+
def __init__(
29+
self,
30+
k: int, # results in an estimator for the mean
31+
):
32+
self.k = k
33+
self.target_shape = (1,)
34+
35+
def score(self, reference: Tensor, target: Tensor) -> Tensor:
36+
pointwise_differance = target - reference[:, None, :]
37+
score = keras.ops.absolute(pointwise_differance) ** self.k
38+
score = keras.ops.mean(score)
39+
return score
40+
41+
42+
class MedianScore(NormedDifferenceScore):
43+
def __init__(self):
44+
super().__init__(k=1)
45+
46+
47+
class MeanScore(NormedDifferenceScore):
48+
def __init__(self):
49+
super().__init__(k=2)
50+
51+
52+
class WeightedNormedDifferenceScore(ScoringRule):
53+
def __init__(
54+
self,
55+
weighting_function: Callable,
56+
k: int = 2,
57+
):
58+
if weighting_function:
59+
self.weighting_function = weighting_function
60+
else:
61+
self.weighting_function = lambda input: 1
62+
self.k = k
63+
self.target_shape = (1,)
64+
65+
def score(self, reference: Tensor, target: Tensor) -> Tensor:
66+
pointwise_differance = target - reference[:, None, :]
67+
score = self.weighting_function(reference) * keras.ops.absolute(pointwise_differance) ** self.k
68+
score = keras.ops.mean(score)
69+
return score
70+
71+
72+
class QuantileScore(ScoringRule):
73+
def __init__(
74+
self,
75+
q: Sequence[float] = None,
76+
):
77+
if q is None:
78+
q = [0.1, 0.5, 0.9]
79+
logging.info(f"QuantileScore was not provided with argument `q`. Using the default quantile levels: {q}.")
80+
81+
self.q = keras.ops.convert_to_tensor(q)
82+
self.target_shape = (len(self.q),)
83+
84+
def get_link(self):
85+
if self.q is None:
86+
raise AssertionError("Needs q to construct link")
87+
else:
88+
print(self.q)
89+
return OrderedQuantiles(self.q)
90+
91+
def score(self, reference: Tensor, target: Tensor) -> Tensor:
92+
pointwise_differance = target - reference[:, None, :]
93+
94+
score = pointwise_differance * (keras.ops.cast(pointwise_differance > 0, float) - self.q[None, :, None])
95+
score = keras.ops.mean(score)
96+
return score
97+
98+
99+
class ParametricDistributionRule(ScoringRule):
100+
"""
101+
TODO
102+
"""
103+
104+
def __init__(self, target_mappings: dict[str, str] = None):
105+
self.target_mappings = target_mappings
106+
self.build(
107+
(
108+
None,
109+
0,
110+
)
111+
) # TODO: make less confusing: this currently needs to be called initially AND another time when the
112+
# scoring rules attributes are updated instead of just once in the end. This is probably not the Keras way.
113+
114+
def build(self, reference_shape: Shape):
115+
if self.target_mappings is None:
116+
self.target_mappings = {key: key for key in self.compute_target_shape().keys()}
117+
118+
self.target_shape = self.map_target_keys(self.compute_target_shape(), inverse=True)
119+
120+
def map_target_keys(self, target_dict: dict[str, str], inverse=False):
121+
if inverse:
122+
map = {v: k for k, v in self.target_mappings.items()}
123+
else:
124+
map = self.target_mappings
125+
return {map[key]: value for key, value in target_dict.items()}
126+
127+
def compute_target_shape(self):
128+
raise NotImplementedError
129+
130+
def log_prob(self, x, **kwargs):
131+
raise NotImplementedError
132+
133+
def score(self, reference: Tensor, target: dict[str, Tensor]) -> Tensor:
134+
score = -self.log_prob(x=reference, **self.map_target_keys(target))
135+
score = keras.ops.mean(score)
136+
return score * 0.01
137+
138+
139+
class MultivariateNormalScore(ParametricDistributionRule):
140+
def __init__(self, D: int = None, **kwargs):
141+
super().__init__(**kwargs)
142+
self.D = D
143+
144+
def build(self, reference_shape: Shape):
145+
if reference_shape is None:
146+
raise AssertionError("Cannot build before setting D.")
147+
elif isinstance(reference_shape, tuple) and len(reference_shape) == 2:
148+
self.D = reference_shape[1]
149+
else:
150+
raise AssertionError(f"Cannot build from reference_shape {reference_shape}")
151+
super().build(reference_shape)
152+
153+
def compute_target_shape(self) -> dict[str, Shape]:
154+
return dict(mean=(1,), covariance=(self.D,))
155+
156+
def get_link(self):
157+
return self.map_target_keys(
158+
dict(mean=keras.layers.Activation("linear"), covariance=PositiveSemiDefinite()), inverse=True
159+
)
160+
161+
def log_prob(self, x: Tensor, mean: Tensor, covariance: Tensor) -> Tensor:
162+
diff = x[:, None, :] - mean
163+
inv_covariance = keras.ops.inv(covariance)
164+
log_det_covariance = keras.ops.slogdet(covariance)[1] # Only take the log of the determinant part
165+
166+
# Compute the quadratic term in the exponential of the multivariate Gaussian
167+
quadratic_term = keras.ops.einsum("...i,...ij,...j->...", diff, inv_covariance, diff)
168+
169+
# Compute the log probability density
170+
log_prob = -0.5 * (self.D * keras.ops.log(2 * PI) + log_det_covariance + quadratic_term)
171+
172+
return log_prob
173+
174+
def sample(self, sample_size, mean, covariance):
175+
batch_size, D = mean.shape
176+
# Ensure covariance is (batch_size, D, D)
177+
assert covariance.shape == (batch_size, D, D)
178+
179+
# Use Cholesky decomposition to generate samples
180+
chol = keras.ops.cholesky(covariance)
181+
normal_samples = keras.random.normal((batch_size, D, sample_size))
182+
samples = mean[:, :, None] + keras.ops.einsum("ijk,ikl->ijl", chol, normal_samples)
183+
184+
return samples

bayesflow/scoring_rules/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

bayesflow/scoring_rules/scoring_rules.py

Lines changed: 0 additions & 77 deletions
This file was deleted.

tests/test_scores/__init__.py

Whitespace-only changes.

tests/test_scores/conftest.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import keras
2+
import pytest
3+
4+
5+
@pytest.fixture()
6+
def batch_size():
7+
return 16
8+
9+
10+
@pytest.fixture()
11+
def num_variables():
12+
return 4
13+
14+
15+
@pytest.fixture()
16+
def reference(batch_size, num_variables):
17+
return keras.random.uniform((batch_size, num_variables))
18+
19+
20+
@pytest.fixture()
21+
def median_score():
22+
from bayesflow.scores import MedianScore
23+
24+
return MedianScore()
25+
26+
27+
@pytest.fixture()
28+
def mean_score():
29+
from bayesflow.scores import MeanScore
30+
31+
return MeanScore()
32+
33+
34+
@pytest.fixture()
35+
def normed_diff_score():
36+
from bayesflow.scores import NormedDifferenceScore
37+
38+
return NormedDifferenceScore(k=3)
39+
40+
41+
@pytest.fixture()
42+
def quantile_score():
43+
from bayesflow.scores import QuantileScore
44+
45+
return QuantileScore()
46+
47+
48+
@pytest.fixture(params=["median_score", "mean_score", "normed_diff_score", "quantile_score"], scope="function")
49+
def basic_scoring_rule(request):
50+
return request.getfixturevalue(request.param)
51+
52+
53+
@pytest.fixture()
54+
def mvn_target(batch_size, num_variables):
55+
mean_target = keras.ops.zeros((batch_size, 1, num_variables))
56+
inputs = keras.random.normal((batch_size, num_variables, num_variables))
57+
print(inputs.shape)
58+
covariance_target = keras.ops.einsum("...ij,...kj->...ik", inputs, inputs)
59+
return dict(
60+
mean=mean_target,
61+
covariance=covariance_target,
62+
)
63+
64+
65+
@pytest.fixture()
66+
def multivariate_normal_score(num_variables):
67+
from bayesflow.scores import MultivariateNormalScore
68+
69+
return MultivariateNormalScore(D=num_variables)

tests/test_scores/test_scores.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import keras
2+
import pytest
3+
4+
5+
def test_require_argument_k():
6+
from bayesflow.scores import NormedDifferenceScore
7+
8+
with pytest.raises(TypeError) as excinfo:
9+
NormedDifferenceScore()
10+
11+
assert "missing 1 required positional argument: 'k'" in str(excinfo)
12+
13+
14+
def test_score_output(basic_scoring_rule, reference):
15+
target_shape = (reference.shape[0], *basic_scoring_rule.target_shape, reference.shape[-1])
16+
target = keras.ops.zeros(target_shape)
17+
score = basic_scoring_rule.score(reference, target)
18+
19+
assert score.ndim == 0
20+
21+
22+
def test_mean_score_optimality(mean_score, reference):
23+
suboptimal_target = keras.ops.expand_dims(keras.random.uniform(reference.shape), axis=1)
24+
optimal_target = keras.ops.expand_dims(reference, axis=1)
25+
26+
suboptimal_score = mean_score.score(reference, suboptimal_target)
27+
optimal_score = mean_score.score(reference, optimal_target)
28+
29+
assert suboptimal_score > optimal_score
30+
assert keras.ops.isclose(optimal_score, 0)
31+
32+
33+
def test_multivariate_normal_score_output(multivariate_normal_score, reference, mvn_target):
34+
score = multivariate_normal_score.score(reference, mvn_target)
35+
36+
assert score.ndim == 0

0 commit comments

Comments
 (0)