Skip to content

Commit

Permalink
Rename target and reference to estimates and targets; add optional we…
Browse files Browse the repository at this point in the history
…ights to all scores
  • Loading branch information
han-ol committed Feb 12, 2025
1 parent 63d448b commit b69df85
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 48 deletions.
6 changes: 3 additions & 3 deletions bayesflow/networks/point_inference_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
self.heads_flat = dict() # see comment regarding heads_flat below

for score_key, score in self.scores.items():
score.set_target_shapes(xz_shape)
score.set_head_shapes_from_target_shape(xz_shape)

self.heads[score_key] = {}

for head_key in score.target_shapes.keys():
for head_key in score.head_shapes.keys():
head = score.get_head(head_key)
head.build(body_output_shape)
# If head is not tracked explicitly, self.variables does not include them.
Expand Down Expand Up @@ -120,7 +120,7 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr
# calculate negative score as mean over all scores
neg_score = 0
for score_key, score in self.scores.items():
score_value = score.score(x, output[score_key])
score_value = score.score(output[score_key], x)
neg_score += score_value
metrics |= {score_key: score_value}
neg_score /= len(self.scores)
Expand Down
71 changes: 39 additions & 32 deletions bayesflow/scores/scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def from_config(cls, config):

return cls(**config)

def get_target_shapes(self, reference_shape):
def get_head_shapes_from_target_shape(self, target_shape):
raise NotImplementedError

def set_target_shapes(self, reference_shape):
self.target_shapes = self.get_target_shapes(reference_shape)
def set_head_shapes_from_target_shape(self, target_shape):
self.head_shapes = self.get_head_shapes_from_target_shape(target_shape)

def get_subnet(self, key: str):
if key not in self.subnets.keys():
Expand All @@ -69,15 +69,22 @@ def get_link(self, key: str):

def get_head(self, key: str):
subnet = self.get_subnet(key)
target_shape = self.target_shapes[key]
dense = keras.layers.Dense(units=math.prod(target_shape))
reshape = keras.layers.Reshape(target_shape=target_shape)
head_shape = self.head_shapes[key]
dense = keras.layers.Dense(units=math.prod(head_shape))
reshape = keras.layers.Reshape(target_shape=head_shape)
link = self.get_link(key)
return keras.Sequential([subnet, dense, reshape, link])

def score(self, reference: Tensor, target: dict[str, Tensor]) -> Tensor:
def score(self, estimates: dict[str, Tensor], target: Tensor, weights: Tensor) -> Tensor:
raise NotImplementedError

def aggregate(self, scores: Tensor, weights: Tensor = None):
if weights is not None:
weighted = scores * weights
else:
weighted = scores
return keras.ops.mean(weighted)


class NormedDifferenceScore(ScoringRule):
def __init__(
Expand All @@ -92,17 +99,17 @@ def __init__(
"k": k,
}

def get_target_shapes(self, reference_shape):
# keras.saving.load_model sometimes passes reference_shape as a list.
def get_head_shapes_from_target_shape(self, target_shape):
# keras.saving.load_model sometimes passes target_shape as a list.
# This is why I force a conversion to tuple here.
reference_shape = tuple(reference_shape)
return dict(value=reference_shape[1:])

def score(self, reference: Tensor, target: dict[str, Tensor]) -> Tensor:
target = target["value"]
pointwise_differance = target - reference
score = keras.ops.absolute(pointwise_differance) ** self.k
score = keras.ops.mean(score)
target_shape = tuple(target_shape)
return dict(value=target_shape[1:])

def score(self, estimates: dict[str, Tensor], target: Tensor, weights: Tensor = None) -> Tensor:
estimates = estimates["value"]
pointwise_differance = estimates - target
scores = keras.ops.absolute(pointwise_differance) ** self.k
score = self.aggregate(scores, weights)
return score

def get_config(self):
Expand Down Expand Up @@ -141,18 +148,18 @@ def get_config(self):
base_config = super().get_config()
return base_config | self.config

def get_target_shapes(self, reference_shape):
# keras.saving.load_model sometimes passes reference_shape as a list.
def get_head_shapes_from_target_shape(self, target_shape):
# keras.saving.load_model sometimes passes target_shape as a list.
# This is why I force a conversion to tuple here.
reference_shape = tuple(reference_shape)
return dict(value=(len(self.q),) + reference_shape[1:])
target_shape = tuple(target_shape)
return dict(value=(len(self.q),) + target_shape[1:])

def score(self, reference: Tensor, target: dict[str, Tensor]) -> Tensor:
target = target["value"]
pointwise_differance = target - reference[:, None, :]
def score(self, estimates: dict[str, Tensor], target: Tensor, weights: Tensor = None) -> Tensor:
estimates = estimates["value"]
pointwise_differance = estimates - target[:, None, :]

score = pointwise_differance * (keras.ops.cast(pointwise_differance > 0, float) - self._q[None, :, None])
score = keras.ops.mean(score)
scores = pointwise_differance * (keras.ops.cast(pointwise_differance > 0, float) - self._q[None, :, None])
score = self.aggregate(scores, weights)
return score


Expand All @@ -161,7 +168,7 @@ class ParametricDistributionRule(ScoringRule):
TODO
"""

def __init__(self, **kwargs): # , target_mappings: dict[str, str] = None):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def log_prob(self, x, **kwargs):
Expand All @@ -170,9 +177,9 @@ def log_prob(self, x, **kwargs):
def sample(self, batch_shape, **kwargs):
raise NotImplementedError

def score(self, reference: Tensor, target: dict[str, Tensor]) -> Tensor:
score = -self.log_prob(x=reference, **target)
score = keras.ops.mean(score)
def score(self, estimates: dict[str, Tensor], target: Tensor, weights: Tensor = None) -> Tensor:
scores = -self.log_prob(x=target, **estimates)
score = self.aggregate(scores, weights)
# multipy to mitigate instability due to relatively high values of parametric score
return score * 0.01

Expand All @@ -191,8 +198,8 @@ def get_config(self):
base_config = super().get_config()
return base_config | self.config

def get_target_shapes(self, reference_shape) -> dict[str, Shape]:
self.D = reference_shape[-1]
def get_head_shapes_from_target_shape(self, target_shape) -> dict[str, Shape]:
self.D = target_shape[-1]
return dict(
mean=(self.D,),
covariance=(self.D, self.D),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ def test_output_structure(point_inference_network, random_samples, random_condit

assert isinstance(output, dict)
for score_key, score in point_inference_network.scores.items():
assert isinstance(score.target_shapes, dict)
assert isinstance(score.head_shapes, dict)

for head_key, target_shape in score.target_shapes.items():
for head_key, head_shape in score.head_shapes.items():
head_output = output[score_key][head_key]
assert keras.ops.is_tensor(head_output)
assert head_output.shape[1:] == target_shape
assert head_output.shape[1:] == head_shape


def test_serialize_deserialize(point_inference_network, random_samples, random_conditions):
Expand Down
20 changes: 10 additions & 10 deletions tests/test_scores/test_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ def test_score_output(scoring_rule, random_conditions):
if random_conditions is None:
random_conditions = keras.ops.convert_to_tensor([[1.0]])

scoring_rule.set_target_shapes(random_conditions.shape)
scoring_rule.set_head_shapes_from_target_shape(random_conditions.shape)
print(scoring_rule.get_config())
target = {
k: scoring_rule.get_link(k)(keras.random.normal((random_conditions.shape[0],) + target_shape))
for k, target_shape in scoring_rule.target_shapes.items()
estimates = {
k: scoring_rule.get_link(k)(keras.random.normal((random_conditions.shape[0],) + head_shape))
for k, head_shape in scoring_rule.head_shapes.items()
}
score = scoring_rule.score(random_conditions, target)
score = scoring_rule.score(estimates, random_conditions)

assert score.ndim == 0

Expand All @@ -30,13 +30,13 @@ def test_mean_score_optimality(mean_score, random_conditions):
if random_conditions is None:
random_conditions = keras.ops.convert_to_tensor([[1.0]])

mean_score.set_target_shapes(random_conditions.shape)
mean_score.set_head_shapes_from_target_shape(random_conditions.shape)
key = "value"
suboptimal_target = {key: keras.random.uniform(random_conditions.shape)}
optimal_target = {key: random_conditions}
suboptimal_estimates = {key: keras.random.uniform(random_conditions.shape)}
optimal_estimates = {key: random_conditions}

suboptimal_score = mean_score.score(random_conditions, suboptimal_target)
optimal_score = mean_score.score(random_conditions, optimal_target)
suboptimal_score = mean_score.score(suboptimal_estimates, random_conditions)
optimal_score = mean_score.score(optimal_estimates, random_conditions)

assert suboptimal_score > optimal_score
assert keras.ops.isclose(optimal_score, 0)

0 comments on commit b69df85

Please sign in to comment.