Skip to content

Commit

Permalink
Simplify argument handling
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed May 19, 2024
1 parent 8fadbb8 commit 337c255
Showing 1 changed file with 41 additions and 34 deletions.
75 changes: 41 additions & 34 deletions src/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,21 @@ static const char* DEBUG_KEYWORD = "debug";
static const char* DEBUG_DESCRIPTION = "Output debugging information such as intermediary steps, parameters, and other internals";

static const char* METHOD_KEYWORD = "method";
static const std::string METHOD_DESCRIPTION = "Dimension reduction method. One of the following: " +
comma_separated_keys(DIMENSION_REDUCTION_METHODS.begin(), DIMENSION_REDUCTION_METHODS.end());

static const char* NEIGHBORS_METHOD_KEYWORD = "neighbors-method";
static const std::string NEIGHBORS_METHOD_DESCRIPTION = "Neighbors search method. One of the following: " +
comma_separated_keys(NEIGHBORS_METHODS.begin(), NEIGHBORS_METHODS.end());

static const char* EIGEN_METHOD_KEYWORD = "eigen-method";
static const std::string EIGEN_METHOD_DESCRIPTION = "Eigendecomposition method. One of the following: " +
comma_separated_keys(EIGEN_METHODS.begin(), EIGEN_METHODS.end());

static const char* COMPUTATION_STRATEGY_KEYWORD = "computation-strategy";
static const std::string COMPUTATION_STRATEGY_DESCRIPTION = "Computation strategy. One of the following: " +
comma_separated_keys(COMPUTATION_STRATEGIES.begin(), COMPUTATION_STRATEGIES.end());

static const char* TARGET_DIMENSION_KEYWORD = "target-dimension";
static const char* NUM_NEIGHBORS_KEYWORD = "num-neighbors";
static const char* GAUSSIAN_WIDTH_KEYWORD = "gaussian-width";
Expand Down Expand Up @@ -161,14 +173,12 @@ int run(int argc, const char **argv)
)
(
either("m", METHOD_KEYWORD),
"Dimension reduction method. One of the following: " +
comma_separated_keys(DIMENSION_REDUCTION_METHODS.begin(), DIMENSION_REDUCTION_METHODS.end()),
METHOD_DESCRIPTION,
with_default("locally_linear_embedding"s)
)
(
either("nm", NEIGHBORS_METHOD_KEYWORD),
"Neighbors search method. One of the following: " +
comma_separated_keys(NEIGHBORS_METHODS.begin(), NEIGHBORS_METHODS.end()),
NEIGHBORS_METHOD_DESCRIPTION,
#ifdef TAPKEE_USE_LGPL_COVERTREE
with_default("covertree"s)
#else
Expand All @@ -177,8 +187,7 @@ int run(int argc, const char **argv)
)
(
either("em", EIGEN_METHOD_KEYWORD),
"Eigendecomposition method. One of the following: " +
comma_separated_keys(EIGEN_METHODS.begin(), EIGEN_METHODS.end()),
EIGEN_METHOD_DESCRIPTION,
#ifdef TAPKEE_WITH_ARPACK
with_default("arpack"s)
#else
Expand All @@ -187,8 +196,7 @@ int run(int argc, const char **argv)
)
(
either("cs", COMPUTATION_STRATEGY_KEYWORD),
"Computation strategy. One of the following: " +
comma_separated_keys(COMPUTATION_STRATEGIES.begin(), COMPUTATION_STRATEGIES.end()),
COMPUTATION_STRATEGY_DESCRIPTION,
with_default("cpu"s)
)
(
Expand Down Expand Up @@ -381,17 +389,6 @@ int run(int argc, const char **argv)
tapkee::Logging::instance().message_error("Number of timesteps is negative.");
return 1;
}
double eigenshift = opt[EIGENSHIFT_KEYWORD].as<double>();
double landmark_rt = opt[LANDMARK_RATIO_KEYWORD].as<double>();
bool spe_global = opt.count(SPE_LOCAL_KEYWORD);
double spe_tol = opt[SPE_TOLERANCE_KEYWORD].as<double>();
int spe_num_upd = opt[SPE_NUM_UPDATES_KEYWORD].as<int>();
int max_iters = opt[MAX_ITERS_KEYWORD].as<int>();
double fa_eps = opt[FA_EPSILON_KEYWORD].as<double>();
double perplexity = opt[SNE_PERPLEXITY_KEYWORD].as<double>();
double theta = opt[SNE_THETA_KEYWORD].as<double>();
double squishing = opt[MS_SQUISHING_RATE_KEYWORD].as<double>();

// Load data
string input_filename = opt[INPUT_FILE_KEYWORD].as<std::string>();
string output_filename = opt[OUTPUT_FILE_KEYWORD].as<std::string>();
Expand All @@ -418,23 +415,33 @@ int run(int argc, const char **argv)
input_data.transposeInPlace();
}

std::stringstream ss;
ss << "Data contains " << input_data.cols() << " feature vectors with dimension of " << input_data.rows();
tapkee::Logging::instance().message_info(ss.str());
tapkee::Logging::instance().message_info(fmt::format("Data contains {} feature vectors with dimension of {}", input_data.cols(), input_data.rows()));

tapkee::TapkeeOutput output;

tapkee::ParametersSet parameters =
tapkee::kwargs[(tapkee::method = tapkee_method, tapkee::computation_strategy = tapkee_computation_strategy,
tapkee::eigen_method = tapkee_eigen_method, tapkee::neighbors_method = tapkee_neighbors_method,
tapkee::num_neighbors = k, tapkee::target_dimension = target_dim,
tapkee::diffusion_map_timesteps = timesteps, tapkee::gaussian_kernel_width = width,
tapkee::max_iteration = max_iters, tapkee::spe_global_strategy = spe_global,
tapkee::spe_num_updates = spe_num_upd, tapkee::spe_tolerance = spe_tol,
tapkee::landmark_ratio = landmark_rt, tapkee::nullspace_shift = eigenshift,
tapkee::check_connectivity = true, tapkee::fa_epsilon = fa_eps,
tapkee::sne_perplexity = perplexity, tapkee::sne_theta = theta,
tapkee::squishing_rate = squishing)];
tapkee::kwargs[(
tapkee::method = tapkee_method,
tapkee::computation_strategy = tapkee_computation_strategy,
tapkee::eigen_method = tapkee_eigen_method,
tapkee::neighbors_method = tapkee_neighbors_method,
tapkee::num_neighbors = k,
tapkee::target_dimension = target_dim,
tapkee::diffusion_map_timesteps = timesteps,
tapkee::gaussian_kernel_width = width,
tapkee::max_iteration = opt[MAX_ITERS_KEYWORD].as<int>(),
tapkee::spe_global_strategy = opt.count(SPE_LOCAL_KEYWORD),
tapkee::spe_num_updates = opt[SPE_NUM_UPDATES_KEYWORD].as<int>(),
tapkee::spe_tolerance = opt[SPE_TOLERANCE_KEYWORD].as<double>(),
tapkee::landmark_ratio = opt[LANDMARK_RATIO_KEYWORD].as<double>(),
tapkee::nullspace_shift = opt[EIGENSHIFT_KEYWORD].as<double>(),
tapkee::check_connectivity = true,
tapkee::fa_epsilon = opt[FA_EPSILON_KEYWORD].as<double>(),
tapkee::sne_perplexity = opt[SNE_PERPLEXITY_KEYWORD].as<double>(),
tapkee::sne_theta = opt[SNE_THETA_KEYWORD].as<double>(),
tapkee::squishing_rate = opt[MS_SQUISHING_RATE_KEYWORD].as<double>()
)];


if (opt.count(PRECOMPUTE_KEYWORD))
{
Expand All @@ -449,13 +456,13 @@ int run(int argc, const char **argv)
{
tapkee::tapkee_internal::timed_context context("[+] Distance matrix computation");
distance_matrix = matrix_from_callback(static_cast<tapkee::IndexType>(input_data.cols()),
tapkee::eigen_distance_callback(input_data));
tapkee::eigen_distance_callback(input_data));
}
if (tapkee_method.needs_kernel)
{
tapkee::tapkee_internal::timed_context context("[+] Kernel matrix computation");
kernel_matrix = matrix_from_callback(static_cast<tapkee::IndexType>(input_data.cols()),
tapkee::eigen_kernel_callback(input_data));
tapkee::eigen_kernel_callback(input_data));
}
}
tapkee::precomputed_distance_callback dcb(distance_matrix);
Expand Down

0 comments on commit 337c255

Please sign in to comment.