Skip to content

Commit

Permalink
Added callback for external logging.
Browse files Browse the repository at this point in the history
  • Loading branch information
dthuerck committed Feb 8, 2017
1 parent 79a4d87 commit f20edaa
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 51 deletions.
27 changes: 19 additions & 8 deletions mapmap/header/mapmap.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,27 +58,33 @@ class mapMAP
mapMAP(const luint_t num_nodes, const luint_t num_labels);
~mapMAP();

/* set graph and label set */
/* set graph and label set */
void set_graph(const Graph<COSTTYPE> * graph) throw();
void set_label_set(const LabelSet<COSTTYPE, SIMDWIDTH> * label_set)
throw();

/* alternatively - construct graph and label set */
void add_edge(const luint_t node_a, const luint_t node_b,
void add_edge(const luint_t node_a, const luint_t node_b,
const _s_t<COSTTYPE, SIMDWIDTH> weight = 1.0) throw();
void set_node_label_set(const luint_t node_id, const
void set_node_label_set(const luint_t node_id, const
std::vector<_iv_st<COSTTYPE, SIMDWIDTH>>& label_set) throw();

/* set MRF cost functions */
void set_unaries(const UNARY * unaries);
void set_pairwise(const PAIRWISE * pairwise);

/* configuration */
void set_multilevel_criterion(MultilevelCriterion<COSTTYPE, SIMDWIDTH> *
void set_multilevel_criterion(MultilevelCriterion<COSTTYPE, SIMDWIDTH> *
criterion);
void set_termination_criterion(TerminationCriterion<COSTTYPE,
SIMDWIDTH> * criterion);

/**
* callback for external logging - outputs time in ms and energy after
*/
void set_logging_callback(const std::function<void (const luint_t,
const _s_t<COSTTYPE, SIMDWIDTH>)>& callback);

/* start optimization */
_s_t<COSTTYPE, SIMDWIDTH> optimize(std::vector<_iv_st<COSTTYPE, SIMDWIDTH>>&
solution) throw();
Expand Down Expand Up @@ -121,15 +127,15 @@ class mapMAP
const LabelSet<COSTTYPE, SIMDWIDTH> * m_label_set;

/* configuration */
MultilevelCriterion<COSTTYPE, SIMDWIDTH> *
MultilevelCriterion<COSTTYPE, SIMDWIDTH> *
m_multilevel_criterion;
TerminationCriterion<COSTTYPE, SIMDWIDTH> *
TerminationCriterion<COSTTYPE, SIMDWIDTH> *
m_termination_criterion;

luint_t m_num_roots = 64u;

/* storage for functional modules */
std::unique_ptr<Multilevel<COSTTYPE, SIMDWIDTH, UNARY, PAIRWISE>>
std::unique_ptr<Multilevel<COSTTYPE, SIMDWIDTH, UNARY, PAIRWISE>>
m_multilevel;
std::unique_ptr<MultilevelCriterion<COSTTYPE, SIMDWIDTH>>
m_storage_multilevel_criterion;
Expand All @@ -141,7 +147,7 @@ class mapMAP
std::chrono::system_clock::time_point m_time_start;

/* current solution */
std::vector<_iv_st<COSTTYPE, SIMDWIDTH>> m_solution;
std::vector<_iv_st<COSTTYPE, SIMDWIDTH>> m_solution;
_s_t<COSTTYPE, SIMDWIDTH> m_objective;

/* solver history data */
Expand All @@ -152,6 +158,11 @@ class mapMAP
luint_t m_hist_acyclic_iterations;
luint_t m_hist_spanningtree_iterations;
luint_t m_hist_multilevel_iterations;

/* callback data */
bool m_use_callback;
std::function<void (const luint_t, const _s_t<COSTTYPE, SIMDWIDTH>)>
m_callback;
};

NS_MAPMAP_END
Expand Down
114 changes: 71 additions & 43 deletions mapmap/source/mapmap.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ mapMAP()
m_hist_spanningtree_iterations(0),
m_hist_multilevel_iterations(0)
{

}

/* ************************************************************************** */
Expand All @@ -66,7 +66,7 @@ mapMAP(
m_hist_spanningtree_iterations(0),
m_hist_multilevel_iterations(0)
{

}

/* ************************************************************************** */
Expand All @@ -76,7 +76,7 @@ FORCEINLINE
mapMAP<COSTTYPE, SIMDWIDTH, UNARY, PAIRWISE>::
~mapMAP()
{

}

/* ************************************************************************** */
Expand Down Expand Up @@ -127,7 +127,7 @@ add_edge(
const luint_t node_b,
const _s_t<COSTTYPE, SIMDWIDTH> weight)
throw()
{
{
if(!m_construct_graph)
throw std::runtime_error("Adding edges is only allowed "
"in construction mode.");
Expand Down Expand Up @@ -209,6 +209,20 @@ set_termination_criterion(

/* ************************************************************************** */

template<typename COSTTYPE, uint_t SIMDWIDTH, typename UNARY, typename PAIRWISE>
FORCEINLINE
void
mapMAP<COSTTYPE, SIMDWIDTH, UNARY, PAIRWISE>::
set_logging_callback(
const std::function<void (const luint_t, const _s_t<COSTTYPE, SIMDWIDTH>)>&
callback)
{
m_use_callback = true;
m_callback = callback;
}

/* ************************************************************************** */

template<typename COSTTYPE, uint_t SIMDWIDTH, typename UNARY, typename PAIRWISE>
FORCEINLINE
_s_t<COSTTYPE, SIMDWIDTH>
Expand All @@ -228,11 +242,12 @@ throw()
"incomplete or not sane.");

/* report on starting the optimization process */
std::cout << "[mapMAP] "
<< UNIX_COLOR_GREEN
<< "Starting optimization..."
<< UNIX_COLOR_RESET
<< std::endl;
if(!m_use_callback)
std::cout << "[mapMAP] "
<< UNIX_COLOR_GREEN
<< "Starting optimization..."
<< UNIX_COLOR_RESET
<< std::endl;

/* start timer */
m_time_start = std::chrono::system_clock::now();
Expand All @@ -250,7 +265,6 @@ throw()
/* check for termination */
if(check_termination())
{
solution.clear();
solution.assign(m_solution.begin(), m_solution.end());

return m_objective;
Expand All @@ -272,7 +286,6 @@ throw()
/* check if algorithms needs to terminate */
if(check_termination())
{
solution.clear();
solution.assign(m_solution.begin(), m_solution.end());

return m_objective;
Expand Down Expand Up @@ -307,9 +320,16 @@ throw()
}

/* output solution */
solution.clear();
solution.assign(m_solution.begin(), m_solution.end());

/* report on starting the optimization process */
if(!m_use_callback)
std::cout << "[mapMAP] "
<< UNIX_COLOR_GREEN
<< "Finished optimization."
<< UNIX_COLOR_RESET
<< std::endl;

return m_objective;
}

Expand Down Expand Up @@ -345,7 +365,7 @@ create_std_modules()
}

/* create a multilevel module for the current graph */
m_multilevel = std::unique_ptr<Multilevel<COSTTYPE, SIMDWIDTH,
m_multilevel = std::unique_ptr<Multilevel<COSTTYPE, SIMDWIDTH,
UNARY, PAIRWISE>>(new Multilevel<COSTTYPE, SIMDWIDTH, UNARY, PAIRWISE>(
m_graph, m_label_set, m_unaries, m_pairwise, m_multilevel_criterion));
}
Expand Down Expand Up @@ -376,7 +396,7 @@ check_data_complete()
if(m_construct_graph && m_label_set_check.size() != m_num_nodes)
return false;

if(m_construct_graph && (m_label_set->max_label() >
if(m_construct_graph && (m_label_set->max_label() >
(_iv_st<COSTTYPE, SIMDWIDTH>) m_num_labels))
return false;

Expand Down Expand Up @@ -404,31 +424,39 @@ void
mapMAP<COSTTYPE, SIMDWIDTH, UNARY, PAIRWISE>::
print_status()
{
std::cout << "[mapMAP] "
<< UNIX_COLOR_RED
<< m_hist_time.back() << " ms"
<< UNIX_COLOR_RESET
<< ", "
<< UNIX_COLOR_GREEN
<< "Objective "
<< m_objective
<< UNIX_COLOR_RESET
<< " (after "
<< m_hist_multilevel_iterations
<< " multilevel, "
<< m_hist_spanningtree_iterations
<< " spanning tree, "
<< m_hist_acyclic_iterations
<< " acyclic iterations)"
<< std::endl;
if(!m_use_callback)
{
std::cout << "[mapMAP] "
<< UNIX_COLOR_RED
<< m_hist_time.back() << " ms"
<< UNIX_COLOR_RESET
<< ", "
<< UNIX_COLOR_GREEN
<< "Objective "
<< m_objective
<< UNIX_COLOR_RESET
<< " (after "
<< m_hist_multilevel_iterations
<< " multilevel, "
<< m_hist_spanningtree_iterations
<< " spanning tree, "
<< m_hist_acyclic_iterations
<< " acyclic iterations)"
<< std::endl;
}
else
{
/* hand off current time and objective to logging callback */
m_callback(m_hist_time.back(), m_objective);
}
}


/* ************************************************************************** */

template<typename COSTTYPE, uint_t SIMDWIDTH, typename UNARY, typename PAIRWISE>
FORCEINLINE
_s_t<COSTTYPE, SIMDWIDTH>
_s_t<COSTTYPE, SIMDWIDTH>
mapMAP<COSTTYPE, SIMDWIDTH, UNARY, PAIRWISE>::
initial_labelling()
{
Expand Down Expand Up @@ -462,7 +490,7 @@ initial_labelling()
template<typename COSTTYPE, uint_t SIMDWIDTH, typename UNARY, typename PAIRWISE>
FORCEINLINE
_s_t<COSTTYPE, SIMDWIDTH>
mapMAP<COSTTYPE, SIMDWIDTH, UNARY, PAIRWISE>::
mapMAP<COSTTYPE, SIMDWIDTH, UNARY, PAIRWISE>::
opt_step_spanning_tree()
{
/* sample a tree (forest) without dependencies */
Expand Down Expand Up @@ -498,7 +526,7 @@ opt_step_spanning_tree()
template<typename COSTTYPE, uint_t SIMDWIDTH, typename UNARY, typename PAIRWISE>
FORCEINLINE
_s_t<COSTTYPE, SIMDWIDTH>
mapMAP<COSTTYPE, SIMDWIDTH, UNARY, PAIRWISE>::
mapMAP<COSTTYPE, SIMDWIDTH, UNARY, PAIRWISE>::
opt_step_multilevel()
{
std::vector<_iv_st<COSTTYPE, SIMDWIDTH>> lvl_solution;
Expand Down Expand Up @@ -531,14 +559,14 @@ opt_step_multilevel()
const Graph<COSTTYPE> * lvl_graph = m_multilevel->get_level_graph();
const LabelSet<COSTTYPE, SIMDWIDTH> * lvl_label_set = m_multilevel->
get_level_label_set();
const UnaryTable<COSTTYPE, SIMDWIDTH> * lvl_unaries =
const UnaryTable<COSTTYPE, SIMDWIDTH> * lvl_unaries =
m_multilevel->get_level_unaries();
const PairwiseTable<COSTTYPE, SIMDWIDTH> * lvl_pairwise =
const PairwiseTable<COSTTYPE, SIMDWIDTH> * lvl_pairwise =
m_multilevel->get_level_pairwise();

/* create new optimizer for level graph */
CombinatorialDynamicProgramming<COSTTYPE, SIMDWIDTH,
UnaryTable<COSTTYPE, SIMDWIDTH>, PairwiseTable<COSTTYPE, SIMDWIDTH>>
CombinatorialDynamicProgramming<COSTTYPE, SIMDWIDTH,
UnaryTable<COSTTYPE, SIMDWIDTH>, PairwiseTable<COSTTYPE, SIMDWIDTH>>
lvl_opt;
lvl_opt.set_graph(lvl_graph);
lvl_opt.set_label_set(lvl_label_set);
Expand All @@ -550,10 +578,10 @@ opt_step_multilevel()

roots.clear();
sampler.select_random_roots(m_num_roots, roots);
std::unique_ptr<Tree<COSTTYPE>> lvl_tree =
std::unique_ptr<Tree<COSTTYPE>> lvl_tree =
sampler.sample(roots, true);
lvl_opt.set_tree(lvl_tree.get());

/* optimize for level solution */
lvl_opt.optimize(upper_solution);

Expand Down Expand Up @@ -588,7 +616,7 @@ opt_step_multilevel()

template<typename COSTTYPE, uint_t SIMDWIDTH, typename UNARY, typename PAIRWISE>
FORCEINLINE
_s_t<COSTTYPE, SIMDWIDTH>
_s_t<COSTTYPE, SIMDWIDTH>
mapMAP<COSTTYPE, SIMDWIDTH, UNARY, PAIRWISE>::
opt_step_acyclic()
{
Expand Down Expand Up @@ -618,7 +646,7 @@ opt_step_acyclic()
++m_hist_acyclic_iterations;
m_hist_mode.push_back(SolverMode::SOLVER_ACYCLIC);

const _s_t<COSTTYPE, SIMDWIDTH> ac_opt = opt.objective(ac_solution);
const _s_t<COSTTYPE, SIMDWIDTH> ac_opt = opt.objective(ac_solution);
if(ac_opt < m_objective)
{
m_objective = ac_opt;
Expand Down

0 comments on commit f20edaa

Please sign in to comment.