diff --git a/src/cli/main.cpp b/src/cli/main.cpp index 78115ed..be07c26 100644 --- a/src/cli/main.cpp +++ b/src/cli/main.cpp @@ -282,9 +282,9 @@ int run(int argc, const char **argv) string method = opt[METHOD_KEYWORD].as(); try { - tapkee_method = parse_reduction_method(method.c_str()); + tapkee_method = parse_multiple(DIMENSION_REDUCTION_METHODS, method); } - catch (const std::exception &) + catch (const std::exception & ex) { tapkee::Logging::instance().message_error(string("Unknown method ") + method); return 1; @@ -296,7 +296,7 @@ int run(int argc, const char **argv) string method = opt[NEIGHBORS_METHOD_KEYWORD].as(); try { - tapkee_neighbors_method = parse_neighbors_method(method.c_str()); + tapkee_neighbors_method = parse_multiple(NEIGHBORS_METHODS, method); } catch (const std::exception &) { @@ -309,7 +309,7 @@ int run(int argc, const char **argv) string method = opt[EIGEN_METHOD_KEYWORD].as(); try { - tapkee_eigen_method = parse_eigen_method(method.c_str()); + tapkee_eigen_method = parse_multiple(EIGEN_METHODS, method); } catch (const std::exception &) { @@ -322,7 +322,7 @@ int run(int argc, const char **argv) string method = opt[COMPUTATION_STRATEGY_KEYWORD].as(); try { - tapkee_computation_strategy = parse_computation_strategy(method.c_str()); + tapkee_computation_strategy = parse_multiple(COMPUTATION_STRATEGIES, method); } catch (const std::exception &) { diff --git a/src/cli/util.hpp b/src/cli/util.hpp index 489d8ba..69b5c52 100644 --- a/src/cli/util.hpp +++ b/src/cli/util.hpp @@ -10,6 +10,7 @@ #include #include +#include using namespace std; @@ -22,6 +23,38 @@ inline bool is_wrong_char(char c) return false; } +int levenshtein_distance(const std::string& s1, const std::string& s2) +{ + const auto len1 = s1.size(); + const auto len2 = s2.size(); + + std::vector> d(len1 + 1, std::vector(len2 + 1)); + + d[0][0] = 0; + for (unsigned int i = 1; i <= len1; ++i) + { + d[i][0] = i; + } + for (unsigned int j = 1; j <= len2; ++j) + { + d[0][j] = j; + } + + for (unsigned int i = 1; i <= len1; ++i) + { + for (unsigned int j = 1; j <= len2; ++j) + { + d[i][j] = std::min({ + d[i - 1][j] + 1, + d[i][j - 1] + 1, + d[i - 1][j - 1] + (s1[i - 1] == s2[j - 1] ? 0 : 1) + }); + } + } + + return d[len1][len2]; +} + template std::string comma_separated_keys(Iterator begin, Iterator end) { std::ostringstream oss; @@ -109,7 +142,7 @@ void write_vector(tapkee::DenseVector* matrix, ofstream& of) } } -static const std::map DIMENSION_REDUCTION_METHODS = { +static const std::map DIMENSION_REDUCTION_METHODS = { {"local_tangent_space_alignment", tapkee::KernelLocalTangentSpaceAlignment}, {"ltsa", tapkee::KernelLocalTangentSpaceAlignment}, {"locally_linear_embedding", tapkee::KernelLocallyLinearEmbedding}, @@ -148,12 +181,7 @@ static const std::map DIMENSION_R {"manifold_sculpting", tapkee::ManifoldSculpting}, }; -tapkee::DimensionReductionMethod parse_reduction_method(const char* str) -{ - return DIMENSION_REDUCTION_METHODS.at(str); -} - -static const std::map NEIGHBORS_METHODS = { +static const std::map NEIGHBORS_METHODS = { {"brute", tapkee::Brute}, {"vptree", tapkee::VpTree}, #ifdef TAPKEE_USE_LGPL_COVERTREE @@ -161,12 +189,7 @@ static const std::map NEIGHBORS_METHODS = #endif }; -tapkee::NeighborsMethod parse_neighbors_method(const char* str) -{ - return NEIGHBORS_METHODS.at(str); -} - -static const std::map EIGEN_METHODS = { +static const std::map EIGEN_METHODS = { {"dense", tapkee::Dense}, {"randomized", tapkee::Randomized}, #ifdef TAPKEE_WITH_ARPACK @@ -174,21 +197,32 @@ static const std::map EIGEN_METHODS = { #endif }; -tapkee::EigenMethod parse_eigen_method(const char* str) -{ - return EIGEN_METHODS.at(str); -} - -static const std::map COMPUTATION_STRATEGIES = { +static const std::map COMPUTATION_STRATEGIES = { {"cpu", tapkee::HomogeneousCPUStrategy}, #ifdef TAPKEE_WITH_VIENNACL {"opencl", tapkee::HeterogeneousOpenCLStrategy}, #endif }; -tapkee::ComputationStrategy parse_computation_strategy(const char* str) +template +typename Mapping::mapped_type parse_multiple(Mapping mapping, const std::string& str) { - return COMPUTATION_STRATEGIES.at(str); + auto it = mapping.find(str); + if (it != mapping.end()) + { + return it->second; + } + + auto closest = std::min_element(mapping.begin(), mapping.end(), + [&str] (const auto &a, const auto &b) { + return levenshtein_distance(str, a.first) < levenshtein_distance(str, b.first); + }); + if (closest != mapping.end()) + { + tapkee::Logging::instance().message_info(fmt::format("Unknown parameter value `{}`. Did you mean `{}`?", str, closest->first)); + } + + throw std::logic_error(str); } template