Skip to content

Commit

Permalink
Fix single parameter plots (#303)
Browse files Browse the repository at this point in the history
* Fix single parameter plotting

* Remove superfluous atleast_1d
  • Loading branch information
han-ol authored Feb 11, 2025
1 parent 5da3759 commit 8210610
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
5 changes: 2 additions & 3 deletions bayesflow/diagnostics/plots/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def loss(
val_step_index = val_step_index[: val_losses.shape[0]]

# Loop through loss entries and populate plot
looper = [axes] if num_row == 1 else axes.flat
for i, ax in enumerate(looper):
for i, ax in enumerate(axes.flat):
# Plot train curve
ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training")
if moving_average and train_losses.columns[i] == "Loss":
Expand Down Expand Up @@ -127,7 +126,7 @@ def loss(

# Add labels, titles, and set font sizes
add_titles_and_labels(
axes=np.atleast_1d(axes),
axes=axes,
num_row=num_row,
num_col=1,
title=["Loss Trajectory"],
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/diagnostics/plots/recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def recovery(
if uncertainty_agg is not None:
u = uncertainty_agg(targets, axis=1)

for i, ax in enumerate(np.atleast_1d(plot_data["axes"].flat)):
for i, ax in enumerate(plot_data["axes"].flat):
if i >= plot_data["num_variables"]:
break

Expand Down
1 change: 1 addition & 0 deletions bayesflow/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def make_figure(num_row: int = None, num_col: int = None, figsize: tuple = None)
figsize = (int(5 * num_col), int(5 * num_row))

f, axes = plt.subplots(num_row, num_col, figsize=figsize)
axes = np.atleast_1d(axes)

return f, axes

Expand Down

0 comments on commit 8210610

Please sign in to comment.