From 8210610de9e141a487a4226e6478cf15e9b9a5af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20Olischl=C3=A4ger?= <106988117+han-ol@users.noreply.github.com> Date: Tue, 11 Feb 2025 16:44:07 +0100 Subject: [PATCH] Fix single parameter plots (#303) * Fix single parameter plotting * Remove superfluous atleast_1d --- bayesflow/diagnostics/plots/loss.py | 5 ++--- bayesflow/diagnostics/plots/recovery.py | 2 +- bayesflow/utils/plot_utils.py | 1 + 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bayesflow/diagnostics/plots/loss.py b/bayesflow/diagnostics/plots/loss.py index 500d522ca..d30983c00 100644 --- a/bayesflow/diagnostics/plots/loss.py +++ b/bayesflow/diagnostics/plots/loss.py @@ -96,8 +96,7 @@ def loss( val_step_index = val_step_index[: val_losses.shape[0]] # Loop through loss entries and populate plot - looper = [axes] if num_row == 1 else axes.flat - for i, ax in enumerate(looper): + for i, ax in enumerate(axes.flat): # Plot train curve ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training") if moving_average and train_losses.columns[i] == "Loss": @@ -127,7 +126,7 @@ def loss( # Add labels, titles, and set font sizes add_titles_and_labels( - axes=np.atleast_1d(axes), + axes=axes, num_row=num_row, num_col=1, title=["Loss Trajectory"], diff --git a/bayesflow/diagnostics/plots/recovery.py b/bayesflow/diagnostics/plots/recovery.py index 403aefa03..c7d157a27 100644 --- a/bayesflow/diagnostics/plots/recovery.py +++ b/bayesflow/diagnostics/plots/recovery.py @@ -77,7 +77,7 @@ def recovery( if uncertainty_agg is not None: u = uncertainty_agg(targets, axis=1) - for i, ax in enumerate(np.atleast_1d(plot_data["axes"].flat)): + for i, ax in enumerate(plot_data["axes"].flat): if i >= plot_data["num_variables"]: break diff --git a/bayesflow/utils/plot_utils.py b/bayesflow/utils/plot_utils.py index 11c79995a..57764d20a 100644 --- a/bayesflow/utils/plot_utils.py +++ b/bayesflow/utils/plot_utils.py @@ -130,6 +130,7 @@ def make_figure(num_row: int = None, num_col: int = None, figsize: tuple = None) figsize = (int(5 * num_col), int(5 * num_row)) f, axes = plt.subplots(num_row, num_col, figsize=figsize) + axes = np.atleast_1d(axes) return f, axes