Skip to content

Commit

Permalink
Merge pull request #117 from elseml/Development
Browse files Browse the repository at this point in the history
Minor improvements
  • Loading branch information
stefanradev93 authored Dec 14, 2023
2 parents 28db526 + 219aeec commit 1db5d58
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
21 changes: 13 additions & 8 deletions bayesflow/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def plot_recovery(
n_row=None,
xlabel="Ground truth",
ylabel="Estimated",
**kwargs
**kwargs,
):
"""Creates and plots publication-ready recovery plot with true vs. point estimate + uncertainty.
The point estimate can be controlled with the ``point_agg`` argument, and the uncertainty estimate
Expand Down Expand Up @@ -110,7 +110,7 @@ def plot_recovery(
**kwargs : optional
Additional keyword arguments passed to ax.errorbar or ax.scatter.
Example: `rasterized=True` to reduce PDF file size with many dots
Returns
-------
f : plt.Figure - the figure instance for optional saving
Expand Down Expand Up @@ -240,7 +240,7 @@ def plot_z_score_contraction(
tick_fontsize=12,
color="#8f2727",
n_col=None,
n_row=None
n_row=None,
):
"""Implements a graphical check for global model sensitivity by plotting the posterior
z-score over the posterior contraction for each set of posterior samples in ``post_samples``
Expand Down Expand Up @@ -567,7 +567,7 @@ def plot_sbc_histograms(
tick_fontsize=12,
hist_color="#a34f4f",
n_row=None,
n_col=None
n_col=None,
):
"""Creates and plots publication-ready histograms of rank statistics for simulation-based calibration
(SBC) checks according to [1].
Expand Down Expand Up @@ -929,7 +929,7 @@ def plot_losses(
)
# Schmuck
ax.set_xlabel("Training step #", fontsize=label_fontsize)
ax.set_ylabel("Loss value", fontsize=label_fontsize)
ax.set_ylabel("Value", fontsize=label_fontsize)
sns.despine(ax=ax)
ax.grid(alpha=grid_alpha)
ax.set_title(train_losses.columns[i], fontsize=title_fontsize)
Expand Down Expand Up @@ -1061,7 +1061,7 @@ def plot_calibration_curves(
fig_size=None,
color="#8f2727",
n_row=None,
n_col=None
n_col=None,
):
"""Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities
for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin.
Expand Down Expand Up @@ -1114,7 +1114,6 @@ def plot_calibration_curves(
elif n_row is not None and n_col is None:
n_col = int(np.ceil(num_models / n_row))


# Compute calibration
cal_errs, probs_true, probs_pred = expected_calibration_error(true_models, pred_models, num_bins)

Expand Down Expand Up @@ -1273,7 +1272,13 @@ def plot_confusion_matrix(
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(
j, i, format(cm[i, j], fmt), fontsize=value_fontsize, ha="center", va="center", color="white" if cm[i, j] > thresh else "black"
j,
i,
format(cm[i, j], fmt),
fontsize=value_fontsize,
ha="center",
va="center",
color="white" if cm[i, j] > thresh else "black",
)
if title:
ax.set_title("Confusion Matrix", fontsize=title_fontsize)
Expand Down
5 changes: 3 additions & 2 deletions bayesflow/summary_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
# Construct final attention layer, which will perform cross-attention
# between the outputs ot the self-attention layers and the dynamic template
if bidirectional:
final_input_dim = template_dim*2
final_input_dim = template_dim * 2
else:
final_input_dim = template_dim
self.output_attention = MultiHeadAttentionBlock(
Expand Down Expand Up @@ -184,7 +184,8 @@ def call(self, x, **kwargs):

class SetTransformer(tf.keras.Model):
"""Implements the set transformer architecture from [1] which ultimately represents
a learnable permutation-invariant function.
a learnable permutation-invariant function. Designed to naturally model interactions in
the input set, which may be hard to capture with the simpler ``DeepSet`` architecture.
[1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019).
Set transformer: A framework for attention-based permutation-invariant neural networks.
Expand Down

0 comments on commit 1db5d58

Please sign in to comment.