Skip to content

Commit 5c414a4

Browse files
committed
feat(NearestNeighbor): new class point
1 parent 516dae7 commit 5c414a4

File tree

5 files changed

+286
-324
lines changed

5 files changed

+286
-324
lines changed

examples/main.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,17 @@ int main(int argc, char const *argv[]) {
2121
{1.0, 2.0}, {2.0, 1.0}, {3.0, 2.0}, {7.0, 4.0}, {5.0, 9.0},
2222
{6.0, 1.0}, {0.0, 3.0}, {4.0, 7.0}, {8.0, 2.0}, {3.0, 5.0}};
2323

24-
auto tree = new KDTree();
24+
auto tree = new KDTree<double, 2>();
2525
tree->BuildTree(points);
2626
tree->PrintInorder();
2727

28-
std::vector<std::vector<int>> points_knn =
28+
std::vector<Point<double, 2>> points_knn =
2929
tree->KNearestNeighbor(Point<double, 2>({x, y}), k);
3030

3131
for (size_t i = 0; i < points_knn.size(); ++i) {
3232
std::cout << "Point " << i << ": ";
3333
for (size_t j = 0; j < points_knn[i].size(); ++j) {
34-
std::cout << points_knn[i][j] << " ";
34+
std::cout << points_knn[i].data().at(j) << " ";
3535
}
3636
std::cout << std::endl;
3737
}

include/mlcpppy/classifiers/neighbors/kdtree.h

Lines changed: 271 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,280 @@
1717
#ifndef KDTREE_H
1818
#define KDTREE_H
1919

20-
#include <queue>
2120
#include <vector>
22-
21+
#include <queue>
22+
#include <algorithm>
2323
#include "nearest_neighbor.h"
2424

25-
template <typename T = double, size_t N = 3>
25+
template <typename T, int N>
2626
class KDTree : public NearestNeighbor<T, N> {
27-
private:
28-
using PointType = typename NearestNeighbor<T, N>::PointType;
29-
class Node;
30-
Node *root_;
31-
int K_;
32-
std::priority_queue<std::pair<double, Node *>> bests_;
33-
34-
Node *Build(std::vector<PointType> points, int depth);
35-
Node *NearestNeighbor(Node *root, const PointType &target, int depth);
36-
void KNearestNeighbor(Node *root, const PointType &target, int depth);
37-
38-
Node *Closest(Node *n0, Node *n1, const PointType &target);
39-
double DistSquared(const PointType &p0, const PointType &p1);
40-
static void Inorder(Node *root);
41-
42-
public:
43-
KDTree();
44-
void Insert(const PointType &point) override;
45-
void BuildTree(const std::vector<PointType> &points) override;
46-
std::vector<PointType> KNearestNeighbor(const PointType &target_points,
47-
int k) override;
48-
void PrintInorder();
49-
50-
~KDTree() override;
27+
private:
28+
class Node {
29+
public:
30+
Node* left_;
31+
Node* right_;
32+
Point<T, N> point_;
33+
int depth_;
34+
35+
explicit Node() {}
36+
explicit Node(Point<T, N> point, Node* left, Node* right) : point_(point), left_(left), right_(right) {}
37+
explicit Node(Point<T, N> point, Node* left, Node* right, int depth) : point_(point), left_(left), right_(right), depth_(depth) {}
38+
explicit Node(Point<T, N> point) : point_(point) {
39+
this->left_ = nullptr;
40+
this->right_ = nullptr;
41+
}
42+
explicit Node(Point<T, N> point, int depth) : point_(point), depth_(depth) {
43+
this->left_ = nullptr;
44+
this->right_ = nullptr;
45+
}
46+
47+
~Node() {
48+
delete left_;
49+
delete right_;
50+
}
51+
};
52+
53+
Node* root_;
54+
int K_;
55+
std::priority_queue<std::pair<double, Node*>> bests_;
56+
57+
Node* Build(std::vector<Point<T, N>> points, int depth) {
58+
if (points.empty()) {
59+
return nullptr;
60+
}
61+
int k = points.at(0).data().size();
62+
int axis = depth % k;
63+
64+
std::sort(points.begin(), points.end(), [axis](const Point<T, N>& a, const Point<T, N>& b) {
65+
return a.data()[axis] < b.data()[axis];
66+
});
67+
68+
int median = points.size() / 2;
69+
std::vector<Point<T, N>> points_left(points.begin(), points.begin() + median);
70+
std::vector<Point<T, N>> points_right(points.begin() + median + 1, points.end());
71+
72+
73+
return new Node(
74+
points.at(median),
75+
Build(points_left, depth + 1),
76+
Build(points_right, depth + 1),
77+
depth
78+
);
79+
}
80+
81+
82+
Node* NearestNeighbor(Node* root, Point<T, N>& target, int depth) {
83+
if (root == nullptr) return nullptr;
84+
85+
Node* next_branch;
86+
Node* other_branch;
87+
88+
int axis = depth % root->point_.size();
89+
90+
if (target.data().at(axis) < root->point_.data().at(axis))
91+
{
92+
next_branch = root->left_;
93+
other_branch = root->right_;
94+
} else {
95+
next_branch = root->right_;
96+
other_branch = root->left_;
97+
}
98+
99+
Node* temp = NearestNeighbor(next_branch, target, depth + 1);
100+
Node* best = Closest(temp, root, target);
101+
double radius_squared = DistSquared(target, best->point_);
102+
103+
double dist = target.data().at(axis) - root->point_.data().at(axis);
104+
105+
if (radius_squared >= dist * dist)
106+
{
107+
temp = NearestNeighbor(other_branch, target, depth + 1);
108+
best = Closest(temp, best, target);
109+
}
110+
111+
return best;
112+
113+
}
114+
115+
void KNearestNeighbor(Node* root, Point<T, N>& target, int depth) {
116+
if (root == nullptr) return;
117+
118+
Node* next_branch;
119+
Node* other_branch;
120+
121+
int axis = depth % root->point_.size();
122+
123+
if (target.data().at(axis) < root->point_.data().at(axis))
124+
{
125+
next_branch = root->left_;
126+
other_branch = root->right_;
127+
} else {
128+
next_branch = root->right_;
129+
other_branch = root->left_;
130+
}
131+
132+
KNearestNeighbor(next_branch, target, depth + 1);
133+
double dist = DistSquared(target, root->point_);
134+
135+
if (this->bests_.size() < this->K_)
136+
{
137+
this->bests_.push({dist, root});
138+
}
139+
else if (dist < this->bests_.top().first)
140+
{
141+
bests_.pop();
142+
this->bests_.push({dist, root});
143+
}
144+
145+
double diff = target.data().at(axis) - root->point_.data().at(axis);
146+
147+
if (this->bests_.size() < this->K_ || diff * diff < this->bests_.top().first) {
148+
KNearestNeighbor(other_branch, target, depth + 1);
149+
}
150+
}
151+
152+
153+
Node* Closest(Node* n0, Node* n1, Point<T, N>& target) {
154+
if (n0 == nullptr) return n1;
155+
if (n1 == nullptr) return n0;
156+
157+
long d1 = DistSquared(n0->point_, target);
158+
long d2 = DistSquared(n1->point_, target);
159+
160+
if (d1 < d2)
161+
return n0;
162+
else
163+
return n1;
164+
}
165+
166+
double DistSquared(const Point<T, N>& p0, const Point<T, N>& p1) {
167+
long total = 0;
168+
size_t numDims = p0.size();
169+
170+
for (size_t i = 0; i < numDims; ++i) {
171+
int diff = std::abs(p0.data()[i] - p1.data()[i]);
172+
total += static_cast<double>(diff) * diff; // mais eficiente que pow para int
173+
}
174+
175+
return total;
176+
}
177+
178+
static void Inorder(Node* root) {
179+
if (!root) return;
180+
181+
Inorder(root->left_);
182+
183+
if (root->left_) {
184+
std::cout << " \"";
185+
for (size_t i = 0; i < root->point_.size(); i++)
186+
std::cout << root->point_.data()[i] << (i + 1 == root->point_.size() ? "" : ",");
187+
std::cout << "\" -> \"";
188+
for (size_t i = 0; i < root->left_->point_.size(); i++)
189+
std::cout << root->left_->point_.data()[i] << (i + 1 == root->left_->point_.size() ? "" : ",");
190+
std::cout << "\" [label=\"esq\"];\n";
191+
}
192+
193+
if (root->right_) {
194+
std::cout << " \"";
195+
for (size_t i = 0; i < root->point_.size(); i++)
196+
std::cout << root->point_.data()[i] << (i + 1 == root->point_.size() ? "" : ",");
197+
std::cout << "\" -> \"";
198+
for (size_t i = 0; i < root->right_->point_.size(); i++)
199+
std::cout << root->right_->point_.data()[i] << (i + 1 == root->right_->point_.size() ? "" : ",");
200+
std::cout << "\" [label=\"dir\"];\n";
201+
}
202+
203+
Inorder(root->right_);
204+
}
205+
206+
public:
207+
KDTree() : root_(nullptr) {}
208+
void Insert(Point<T, N> point) override {
209+
Node* p = this->root_;
210+
Node* prev = nullptr;
211+
212+
int depth = 0;
213+
int n_dims = point.size();
214+
215+
while (p != nullptr)
216+
{
217+
prev = p;
218+
if (point.data().at(depth) < p->point_.data().at(depth))
219+
p = p->left_;
220+
else
221+
p= p->right_;
222+
depth = (depth + 1) % n_dims;
223+
}
224+
225+
if (this->root_ == nullptr)
226+
this->root_ = new Node(point);
227+
else if((point.data().at((depth - 1) % n_dims)) < (prev->point_.data().at((depth - 1) % n_dims)))
228+
prev->left_ = new Node(point, depth);
229+
else
230+
prev->right_ = new Node(point, depth);
231+
}
232+
233+
void BuildTree(std::vector<Point<T, N>> points) override {
234+
if (points.empty())
235+
{
236+
return;
237+
}
238+
239+
int initial_size = points.at(0).size();
240+
241+
for (auto &point : points)
242+
{
243+
if (point.data().empty())
244+
{
245+
return;
246+
}
247+
248+
if (point.size() != initial_size)
249+
{
250+
return;
251+
}
252+
253+
254+
}
255+
256+
this->root_ = Build(points, 0);
257+
}
258+
259+
260+
std::vector<Point<T, N>> KNearestNeighbor(Point<T, N> target_points, int k) override {
261+
if (k)
262+
this->K_ = k;
263+
264+
if (k == 1){
265+
Node* result = NearestNeighbor(this->root_, target_points, 0);
266+
std::vector<Point<T, N>> points;
267+
if (result)
268+
{
269+
points.push_back(result->point_);
270+
}
271+
return points;
272+
}
273+
274+
KNearestNeighbor(this->root_, target_points, 0);
275+
std::vector<Point<T, N>> tmp;
276+
while (!this->bests_.empty())
277+
{
278+
tmp.push_back(this->bests_.top().second->point_);
279+
this->bests_.pop();
280+
}
281+
282+
return tmp;
283+
}
284+
285+
void PrintInorder() {
286+
// TODO: A ideia é isso gerar um arquivo para o graphiz gerar o grafo
287+
std::cout << "digraph G {\n";
288+
Inorder(this->root_);
289+
std::cout << "}\n";
290+
}
291+
292+
~KDTree() override {
293+
delete root_;
294+
}
51295
};
52296
#endif // KDTREE_H

include/mlcpppy/classifiers/neighbors/nearest_neighbor.h

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,15 @@
1717
#ifndef NEAREST_NEIGHBOR_H
1818
#define NEAREST_NEIGHBOR_H
1919
#include <vector>
20-
2120
#include "point.h"
2221

23-
template <typename T = double, size_t N = 3>
22+
template<typename T, int N>
2423
class NearestNeighbor {
25-
public:
26-
using PointType = Point<T, N>;
27-
virtual ~NearestNeighbor() = default;
28-
virtual std::vector<PointType> KNearestNeighbor(const PointType &,
29-
int) = 0;
30-
virtual void Insert(const PointType &) = 0;
31-
virtual void BuildTree(const std::vector<PointType> &) = 0;
32-
virtual void Delete(const PointType &) {};
24+
public:
25+
virtual ~NearestNeighbor() = default;
26+
virtual std::vector<Point<T, N>> KNearestNeighbor(Point<T, N>, int) = 0;
27+
virtual void Insert(Point<T, N>) = 0;
28+
virtual void BuildTree(std::vector<Point<T, N>>) = 0;
29+
virtual void Delete(Point<T, N>){};
3330
};
3431
#endif // NEAREST_NEIGHBOR_H

include/mlcpppy/classifiers/neighbors/point.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#include <stdexcept>
2323
#include <type_traits>
2424

25-
template <typename T = double, size_t N = 3,
25+
template <typename T = double, int N = 3,
2626
typename =
2727
typename std::enable_if<std::is_floating_point<T>::value>::type>
2828
class Point {
@@ -40,6 +40,9 @@ class Point {
4040
std::copy(list.begin(), list.end(), data_.begin());
4141
}
4242
const std::array<T, N> &data() const { return data_; }
43+
const size_t size() const {
44+
return data_.size();
45+
}
4346
};
4447

45-
#endif
48+
#endif // POINT_H

0 commit comments

Comments
 (0)