Skip to content

Commit

Permalink
Code
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewB330 committed Jan 26, 2019
1 parent 1ebc33b commit d8f95a2
Show file tree
Hide file tree
Showing 9 changed files with 604 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -328,3 +328,6 @@ ASALocalRun/

# MFractors (Xamarin productivity tool) working folder
.mfractor/

*.vcxproj
*.vcxproj.filters
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,12 @@
# EuclideanMST
Implementations of different algorithms for building Euclidean minimum spanning tree in k-dimensional space.
### Algorithms:
- EMST using Kd-tree O(NlogN)
Implementation of algorithm described in "Fast Euclidean Minimum Spanning Tree: Algorithm, Analysis, and Applications. William B. March, Parikshit Ram, Alexander G. Gray"
- Prim's algorithm O(N^2)
Straightforward MST on fully connected Eclidean graph

### TODO:
- Implement EMST using Cover-tree
- Benchmarks
- \dots
62 changes: 62 additions & 0 deletions dsu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#pragma once
#include <vector>

/**
* Disjoin-set-union data structure
*/
class DSU {

public:

DSU(size_t n = 0);

size_t get_set(size_t x) const;
bool is_in_same_set(size_t x, size_t y) const;

void reset(size_t n);
bool unite(size_t x, size_t y);

private:

mutable std::vector<size_t> p;
std::vector<size_t> rank;
};

// Implementation

DSU::DSU(size_t n) {
reset(n);
}


size_t DSU::get_set(size_t x) const {
return p[x] == x ? x : p[x] = get_set(p[x]);
}

bool DSU::is_in_same_set(size_t x, size_t y) const {
return get_set(x) == get_set(y);
}

void DSU::reset(size_t n) {
p.resize(n);
rank.assign(n, 0);
for (size_t i = 0; i < n; i++) {
p[i] = i;
}
}

bool DSU::unite(size_t x, size_t y) {
size_t set_a = get_set(x);
size_t set_b = get_set(y);
if (set_a == set_b) {
return false;
}
if (rank[set_a] > rank[set_b]) {
std::swap(set_a, set_b);
}
if (rank[set_a] == rank[set_b]) {
rank[set_b]++;
}
p[set_a] = set_b;
return true;
}
179 changes: 179 additions & 0 deletions emst.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#pragma once
#include "model.h"
#include "tree.h"
#include "dsu.h"

typedef std::pair<size_t, size_t> Edge;

template<size_t DIM>
class EmstSolver {
public:
EmstSolver() {}

const std::vector<Edge> & get_solution() const { return solution; }
const double & get_total_length() const { return total_length; }

protected:
std::vector<Edge> solution;
double total_length = 0.0;
};

/**
* Implementation of EMST algorithm using K-d tree
* "Fast Euclidean Minimum Spanning Tree: Algorithm, Analysis, and Applications. William B. March, Parikshit Ram, Alexander G. Gray"
* Time complexity: O(cNlogN), extra constant c depends on the distribution of points
*/
template<size_t DIM>
class KdTreeSolver : public EmstSolver<DIM> {
public:
KdTreeSolver(std::vector<Point<DIM>> & points) :num_points(points.size()) {
dsu.reset(num_points);
tree = KdTree<DIM>(points, floor(log2(num_points)) - 2);
is_fully_connected.assign(tree.get_maximal_id() + 1, false);
solve();
// todo: clear containers
}

private:
void solve() {
auto & solution = EmstSolver<DIM>::solution;
auto & total_length = EmstSolver<DIM>::total_length;

while (solution.size() + 1 < num_points) {
node_approximation.assign(tree.get_maximal_id() + 1, std::numeric_limits<double>::max());
nearest_set.assign(num_points, { std::numeric_limits<double>::max(), Edge(0,0) });

check_fully_connected(tree.get_root_id());

find_component_neighbors(tree.get_root_id(), tree.get_root_id());

for (size_t i = 0; i < num_points; i++) {
if (i == dsu.get_set(i)) {
Edge e = nearest_set[i].second;
if (dsu.unite(e.first, e.second)) {
solution.push_back(e);
total_length += nearest_set[i].first;
}
}
}
}
}

void find_component_neighbors(size_t q, size_t r, size_t depth = 0) {
if (is_fully_connected[q] && is_fully_connected[r] &&
dsu.is_in_same_set(tree.points_begin(q)->get_id(), tree.points_begin(r)->get_id())) {
return;
}
if (distance(tree.get_bounding_box(q), tree.get_bounding_box(r)) > node_approximation[q]) {
return;
}
if (tree.is_leaf(q) && tree.is_leaf(r)) {
node_approximation[q] = 0.0;
for (auto i = tree.points_begin(q); i != tree.points_end(q); i++) {
for (auto j = tree.points_begin(r); j != tree.points_end(r); j++) {
if (!dsu.is_in_same_set(i->get_id(), j->get_id())) {
double dist = distance(*i, *j);
if (dist < nearest_set[dsu.get_set(i->get_id())].first) {
nearest_set[dsu.get_set(i->get_id())] = { dist, { i->get_id(), j->get_id() } };
}
}
}
node_approximation[q] = std::max(node_approximation[q], nearest_set[dsu.get_set(i->get_id())].first);
}
} else {
size_t qleft = tree.get_left_child_id(q);
size_t qright = tree.get_right_child_id(q);
size_t rleft = tree.get_left_child_id(r);
size_t rright = tree.get_right_child_id(r);
if (tree.is_leaf(q)) {
find_component_neighbors(q, rleft, depth);
find_component_neighbors(q, rright, depth);
return;
}
if (tree.is_leaf(r)) {
find_component_neighbors(qleft, r, depth);
find_component_neighbors(qright, r, depth);
node_approximation[q] = std::max(node_approximation[qleft], node_approximation[qright]);
return;
}
find_component_neighbors(qleft, rleft, depth + 1);
find_component_neighbors(qleft, rright, depth + 1);
find_component_neighbors(qright, rright, depth + 1);
find_component_neighbors(qright, rleft, depth + 1);
node_approximation[q] = std::max(node_approximation[qleft], node_approximation[qright]);
}
}

void check_fully_connected(size_t node_id) {
if (is_fully_connected[node_id]) {
return;
}
if (tree.is_leaf(node_id)) {
bool fully_connected = true;
for (auto iter = tree.points_begin(node_id); iter + 1 != tree.points_end(node_id); ++iter) {
fully_connected &= dsu.is_in_same_set(iter->get_id(), (iter + 1)->get_id());
}
is_fully_connected[node_id] = fully_connected;
return;
}
size_t left = tree.get_left_child_id(node_id);
size_t right = tree.get_right_child_id(node_id);
check_fully_connected(left);
check_fully_connected(right);
if (is_fully_connected[left] && is_fully_connected[right] &&
dsu.is_in_same_set(tree.points_begin(left)->get_id(), tree.points_begin(right)->get_id())) {
is_fully_connected[node_id] = true;
}
}

size_t num_points;
DSU dsu;
KdTree<DIM> tree;
std::vector<bool> is_fully_connected;
std::vector<double> node_approximation;
std::vector<std::pair<double, Edge>> nearest_set;
};

/**
* Prim's algorithm
* Time complexity: O(N^2)
*/
template<size_t DIM>
class PrimSolver : public EmstSolver<DIM> {
public:
PrimSolver(std::vector<Point<DIM>> & points) {
solve(points);
}

private:
void solve(const std::vector<Point<DIM>> & points) {
auto & solution = EmstSolver<DIM>::solution;
auto & total_length = EmstSolver<DIM>::total_length;
size_t num_points = points.size();

std::vector<std::pair<double, size_t>> distance_to_tree(num_points, { std::numeric_limits<double>::max(), 0 });
std::vector<bool> used(num_points, false);
used[0] = true;
for (size_t i = 1; i < num_points; i++) {
distance_to_tree[i] = { distance(points[0], points[i]), 0 };
}

for (size_t iteration = 1; iteration < num_points; iteration++) {
size_t nearest_id = 0;
for (size_t i = 1; i < num_points; i++) {
if (!used[i] && distance_to_tree[i] < distance_to_tree[nearest_id]) {
nearest_id = i;
}
}
solution.push_back({ nearest_id, distance_to_tree[nearest_id].second });
total_length += distance_to_tree[nearest_id].first;
used[nearest_id] = true;
for (size_t i = 1; i < num_points; i++) {
if (!used[i] && distance(points[i], points[nearest_id]) < distance_to_tree[i].first) {
distance_to_tree[i] = { distance(points[i], points[nearest_id]) , nearest_id };
}
}
}
}

};
45 changes: 45 additions & 0 deletions main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#include <iomanip>
#include <iostream>
#include <fstream>
#include <algorithm>
#include <vector>
#include <limits>
#include <memory>
#include <string>
#include "emst.h"

using namespace std;

template<size_t DIM>
void example() {

fstream fin("test_data/dim" + to_string(DIM) + ".txt");
size_t n; fin >> n;
vector<Point<DIM>> points(n);
for (size_t i = 0; i < n; i++) {
for (size_t k = 0; k < DIM; k++) {
fin >> points[i][k];
}
}
string s; fin >> s;
double answer; fin >> answer;

KdTreeSolver<DIM> solver_fast(points);
PrimSolver<DIM> solver_slow(points);

cout << DIM << "-dimensional space:" << endl;
cout << "Answer: " << answer << endl;
cout << "Fast solver answer: " << solver_fast.get_total_length() << endl;
cout << "Slow solver answer: " << solver_slow.get_total_length() << endl;
cout << endl;
}

int main() {
cout.setf(ios::fixed);
cout.precision(6);

example<2>();
example<10>();

return 0;
}
Loading

0 comments on commit d8f95a2

Please sign in to comment.