From 19229b2250e7895a313e09c0eaebe32b8127dfeb Mon Sep 17 00:00:00 2001 From: alexbrillant Date: Thu, 10 Sep 2020 21:57:51 -0400 Subject: [PATCH] Fix Trial Metrics Plotting Observer --- neuraxle/plotting.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/neuraxle/plotting.py b/neuraxle/plotting.py index ab0f46ac..8dd746e4 100644 --- a/neuraxle/plotting.py +++ b/neuraxle/plotting.py @@ -134,7 +134,7 @@ def on_next(self, value: Tuple[HyperparamsRepository, Trial]): def _plot_all_trial_main_and_validation_metric_results(self, repo, trial): trial_hash = repo._get_trial_hash(trial) - for split in trial.validation_splits: + for split_number, split in enumerate(trial.validation_splits): for metric_name in split.get_metric_names(): train_results = split.get_metric_train_results(metric_name=metric_name) validation_results = split.get_metric_validation_results(metric_name=metric_name) @@ -144,7 +144,12 @@ def _plot_all_trial_main_and_validation_metric_results(self, repo, trial): plt.xlabel('epoch') plt.legend(['train', 'validation'], loc='upper left') plt.title(metric_name) - plotting_file = os.path.join(repo.cache_folder, trial_hash, str(split), '{}.png'.format(metric_name)) + + plotting_folder = os.path.join(repo.cache_folder, trial_hash, str(split_number)) + if not os.path.exists(plotting_folder): + os.makedirs(plotting_folder) + plotting_file = os.path.join(plotting_folder, '{}.png'.format(metric_name)) + self._show_or_save_plot(plotting_file) def on_complete(self, value: Tuple[HyperparamsRepository, Trial]):