From f20edaa1263772a72c98ffe48838f0df49653885 Mon Sep 17 00:00:00 2001 From: dthuerck Date: Wed, 8 Feb 2017 17:32:16 +0100 Subject: [PATCH] Added callback for external logging. --- mapmap/header/mapmap.h | 27 ++++++--- mapmap/source/mapmap.impl.h | 114 ++++++++++++++++++++++-------------- 2 files changed, 90 insertions(+), 51 deletions(-) diff --git a/mapmap/header/mapmap.h b/mapmap/header/mapmap.h index b0124c1..60b75a8 100644 --- a/mapmap/header/mapmap.h +++ b/mapmap/header/mapmap.h @@ -58,15 +58,15 @@ class mapMAP mapMAP(const luint_t num_nodes, const luint_t num_labels); ~mapMAP(); - /* set graph and label set */ + /* set graph and label set */ void set_graph(const Graph * graph) throw(); void set_label_set(const LabelSet * label_set) throw(); /* alternatively - construct graph and label set */ - void add_edge(const luint_t node_a, const luint_t node_b, + void add_edge(const luint_t node_a, const luint_t node_b, const _s_t weight = 1.0) throw(); - void set_node_label_set(const luint_t node_id, const + void set_node_label_set(const luint_t node_id, const std::vector<_iv_st>& label_set) throw(); /* set MRF cost functions */ @@ -74,11 +74,17 @@ class mapMAP void set_pairwise(const PAIRWISE * pairwise); /* configuration */ - void set_multilevel_criterion(MultilevelCriterion * + void set_multilevel_criterion(MultilevelCriterion * criterion); void set_termination_criterion(TerminationCriterion * criterion); + /** + * callback for external logging - outputs time in ms and energy after + */ + void set_logging_callback(const std::function)>& callback); + /* start optimization */ _s_t optimize(std::vector<_iv_st>& solution) throw(); @@ -121,15 +127,15 @@ class mapMAP const LabelSet * m_label_set; /* configuration */ - MultilevelCriterion * + MultilevelCriterion * m_multilevel_criterion; - TerminationCriterion * + TerminationCriterion * m_termination_criterion; luint_t m_num_roots = 64u; /* storage for functional modules */ - std::unique_ptr> + std::unique_ptr> m_multilevel; std::unique_ptr> m_storage_multilevel_criterion; @@ -141,7 +147,7 @@ class mapMAP std::chrono::system_clock::time_point m_time_start; /* current solution */ - std::vector<_iv_st> m_solution; + std::vector<_iv_st> m_solution; _s_t m_objective; /* solver history data */ @@ -152,6 +158,11 @@ class mapMAP luint_t m_hist_acyclic_iterations; luint_t m_hist_spanningtree_iterations; luint_t m_hist_multilevel_iterations; + + /* callback data */ + bool m_use_callback; + std::function)> + m_callback; }; NS_MAPMAP_END diff --git a/mapmap/source/mapmap.impl.h b/mapmap/source/mapmap.impl.h index 0555018..f01b164 100644 --- a/mapmap/source/mapmap.impl.h +++ b/mapmap/source/mapmap.impl.h @@ -39,7 +39,7 @@ mapMAP() m_hist_spanningtree_iterations(0), m_hist_multilevel_iterations(0) { - + } /* ************************************************************************** */ @@ -66,7 +66,7 @@ mapMAP( m_hist_spanningtree_iterations(0), m_hist_multilevel_iterations(0) { - + } /* ************************************************************************** */ @@ -76,7 +76,7 @@ FORCEINLINE mapMAP:: ~mapMAP() { - + } /* ************************************************************************** */ @@ -127,7 +127,7 @@ add_edge( const luint_t node_b, const _s_t weight) throw() -{ +{ if(!m_construct_graph) throw std::runtime_error("Adding edges is only allowed " "in construction mode."); @@ -209,6 +209,20 @@ set_termination_criterion( /* ************************************************************************** */ +template +FORCEINLINE +void +mapMAP:: +set_logging_callback( + const std::function)>& + callback) +{ + m_use_callback = true; + m_callback = callback; +} + +/* ************************************************************************** */ + template FORCEINLINE _s_t @@ -228,11 +242,12 @@ throw() "incomplete or not sane."); /* report on starting the optimization process */ - std::cout << "[mapMAP] " - << UNIX_COLOR_GREEN - << "Starting optimization..." - << UNIX_COLOR_RESET - << std::endl; + if(!m_use_callback) + std::cout << "[mapMAP] " + << UNIX_COLOR_GREEN + << "Starting optimization..." + << UNIX_COLOR_RESET + << std::endl; /* start timer */ m_time_start = std::chrono::system_clock::now(); @@ -250,7 +265,6 @@ throw() /* check for termination */ if(check_termination()) { - solution.clear(); solution.assign(m_solution.begin(), m_solution.end()); return m_objective; @@ -272,7 +286,6 @@ throw() /* check if algorithms needs to terminate */ if(check_termination()) { - solution.clear(); solution.assign(m_solution.begin(), m_solution.end()); return m_objective; @@ -307,9 +320,16 @@ throw() } /* output solution */ - solution.clear(); solution.assign(m_solution.begin(), m_solution.end()); + /* report on starting the optimization process */ + if(!m_use_callback) + std::cout << "[mapMAP] " + << UNIX_COLOR_GREEN + << "Finished optimization." + << UNIX_COLOR_RESET + << std::endl; + return m_objective; } @@ -345,7 +365,7 @@ create_std_modules() } /* create a multilevel module for the current graph */ - m_multilevel = std::unique_ptr>(new Multilevel( m_graph, m_label_set, m_unaries, m_pairwise, m_multilevel_criterion)); } @@ -376,7 +396,7 @@ check_data_complete() if(m_construct_graph && m_label_set_check.size() != m_num_nodes) return false; - if(m_construct_graph && (m_label_set->max_label() > + if(m_construct_graph && (m_label_set->max_label() > (_iv_st) m_num_labels)) return false; @@ -404,31 +424,39 @@ void mapMAP:: print_status() { - std::cout << "[mapMAP] " - << UNIX_COLOR_RED - << m_hist_time.back() << " ms" - << UNIX_COLOR_RESET - << ", " - << UNIX_COLOR_GREEN - << "Objective " - << m_objective - << UNIX_COLOR_RESET - << " (after " - << m_hist_multilevel_iterations - << " multilevel, " - << m_hist_spanningtree_iterations - << " spanning tree, " - << m_hist_acyclic_iterations - << " acyclic iterations)" - << std::endl; + if(!m_use_callback) + { + std::cout << "[mapMAP] " + << UNIX_COLOR_RED + << m_hist_time.back() << " ms" + << UNIX_COLOR_RESET + << ", " + << UNIX_COLOR_GREEN + << "Objective " + << m_objective + << UNIX_COLOR_RESET + << " (after " + << m_hist_multilevel_iterations + << " multilevel, " + << m_hist_spanningtree_iterations + << " spanning tree, " + << m_hist_acyclic_iterations + << " acyclic iterations)" + << std::endl; + } + else + { + /* hand off current time and objective to logging callback */ + m_callback(m_hist_time.back(), m_objective); + } } - + /* ************************************************************************** */ template FORCEINLINE -_s_t +_s_t mapMAP:: initial_labelling() { @@ -462,7 +490,7 @@ initial_labelling() template FORCEINLINE _s_t -mapMAP:: +mapMAP:: opt_step_spanning_tree() { /* sample a tree (forest) without dependencies */ @@ -498,7 +526,7 @@ opt_step_spanning_tree() template FORCEINLINE _s_t -mapMAP:: +mapMAP:: opt_step_multilevel() { std::vector<_iv_st> lvl_solution; @@ -531,14 +559,14 @@ opt_step_multilevel() const Graph * lvl_graph = m_multilevel->get_level_graph(); const LabelSet * lvl_label_set = m_multilevel-> get_level_label_set(); - const UnaryTable * lvl_unaries = + const UnaryTable * lvl_unaries = m_multilevel->get_level_unaries(); - const PairwiseTable * lvl_pairwise = + const PairwiseTable * lvl_pairwise = m_multilevel->get_level_pairwise(); /* create new optimizer for level graph */ - CombinatorialDynamicProgramming, PairwiseTable> + CombinatorialDynamicProgramming, PairwiseTable> lvl_opt; lvl_opt.set_graph(lvl_graph); lvl_opt.set_label_set(lvl_label_set); @@ -550,10 +578,10 @@ opt_step_multilevel() roots.clear(); sampler.select_random_roots(m_num_roots, roots); - std::unique_ptr> lvl_tree = + std::unique_ptr> lvl_tree = sampler.sample(roots, true); lvl_opt.set_tree(lvl_tree.get()); - + /* optimize for level solution */ lvl_opt.optimize(upper_solution); @@ -588,7 +616,7 @@ opt_step_multilevel() template FORCEINLINE -_s_t +_s_t mapMAP:: opt_step_acyclic() { @@ -618,7 +646,7 @@ opt_step_acyclic() ++m_hist_acyclic_iterations; m_hist_mode.push_back(SolverMode::SOLVER_ACYCLIC); - const _s_t ac_opt = opt.objective(ac_solution); + const _s_t ac_opt = opt.objective(ac_solution); if(ac_opt < m_objective) { m_objective = ac_opt;