Skip to content

Commit

Permalink
Merge pull request #394 from alexbrillant/master
Browse files Browse the repository at this point in the history
Fix Trial Metrics Plotting Observer
  • Loading branch information
alexbrillant authored Sep 11, 2020
2 parents fac9bbc + 5514eef commit 9b81587
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions neuraxle/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def on_next(self, value: Tuple[HyperparamsRepository, Trial]):

def _plot_all_trial_main_and_validation_metric_results(self, repo, trial):
trial_hash = repo._get_trial_hash(trial)
for split in trial.validation_splits:
for split_number, split in enumerate(trial.validation_splits):
for metric_name in split.get_metric_names():
train_results = split.get_metric_train_results(metric_name=metric_name)
validation_results = split.get_metric_validation_results(metric_name=metric_name)
Expand All @@ -144,7 +144,12 @@ def _plot_all_trial_main_and_validation_metric_results(self, repo, trial):
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.title(metric_name)
plotting_file = os.path.join(repo.cache_folder, trial_hash, str(split), '{}.png'.format(metric_name))

plotting_folder = os.path.join(repo.cache_folder, trial_hash, str(split_number))
if not os.path.exists(plotting_folder):
os.makedirs(plotting_folder)
plotting_file = os.path.join(plotting_folder, '{}.png'.format(metric_name))

self._show_or_save_plot(plotting_file)

def on_complete(self, value: Tuple[HyperparamsRepository, Trial]):
Expand Down

0 comments on commit 9b81587

Please sign in to comment.