Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions rocketpy/plots/monte_carlo_plots.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

from ..tools import generate_monte_carlo_ellipses, import_optional_dependency
from .plot_helpers import show_or_save_plot


class _MonteCarloPlots:
Expand Down Expand Up @@ -147,7 +150,7 @@ def ellipses(
else:
plt.show()

def all(self, keys=None):
def all(self, keys=None, *, filename=None):
"""
Plot the histograms of the Monte Carlo simulation results.

Expand All @@ -156,6 +159,13 @@ def all(self, keys=None):
keys : str, list or tuple, optional
The keys of the results to be plotted. If None, all results will be
plotted. Default is None.
filename : str | None, optional
The path the plot should be saved to, by default None. If provided,
the plot will be saved instead of displayed. When multiple plots are
generated (one per key), the key name will be appended to the filename.
Supported file endings are: eps, jpg, jpeg, pdf, pgf, png, ps, raw,
rgba, svg, svgz, tif, tiff and webp (these are the formats supported
by matplotlib).

Returns
-------
Expand All @@ -173,6 +183,7 @@ def all(self, keys=None):
)
else:
raise ValueError("The 'keys' argument must be a string, list, or tuple.")

for key in keys:
# Create figure with GridSpec
fig = plt.figure(figsize=(8, 8))
Expand All @@ -194,7 +205,16 @@ def all(self, keys=None):
ax1.set_xticks([])

plt.tight_layout()
plt.show()

# Handle filename for multiple plots
if filename is not None:
# For multiple keys, append the key name to the filename
filepath = Path(filename)
# Use the full key name to avoid collisions between x_impact and y_impact
key_filename = filepath.parent / f"{filepath.stem}_{key}{filepath.suffix}"
show_or_save_plot(str(key_filename))
else:
show_or_save_plot(filename)

def plot_comparison(self, other_monte_carlo):
"""
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/test_monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,33 @@ def test_stochastic_environment_create_object_with_wind_x(stochastic_environment
# TODO: add a new test for the special case of ensemble member


def test_monte_carlo_plots_all_with_filename(monte_carlo_calisto_pre_loaded, tmp_path):
"""Tests the all method of the MonteCarlo plots with filename parameter.

Parameters
----------
monte_carlo_calisto_pre_loaded : MonteCarlo
A MonteCarlo object with pre-loaded results, this is a pytest fixture.
tmp_path : Path
Temporary directory path for saving test files.
"""
# Test without filename (should work as before)
result = monte_carlo_calisto_pre_loaded.plots.all()
assert result is None

# Test with filename - save to temporary directory
filename = tmp_path / "test_monte_carlo_plot.png"
result = monte_carlo_calisto_pre_loaded.plots.all(filename=str(filename))
assert result is None

# Test with specific keys and filename
filename_apogee = tmp_path / "test_apogee_plot.png"
result = monte_carlo_calisto_pre_loaded.plots.all(
keys="apogee", filename=str(filename_apogee)
)
assert result is None


def test_stochastic_solid_motor_create_object_with_impulse(stochastic_solid_motor):
"""Tests the stochastic solid motor object by checking if the total impulse
can be generated properly. The goal is to check if the create_object()
Expand Down
Loading