1717#ifndef KDTREE_H
1818#define KDTREE_H
1919
20- #include < queue>
2120#include < vector>
22-
21+ #include < queue>
22+ #include < algorithm>
2323#include " nearest_neighbor.h"
2424
25- template <typename T = double , size_t N = 3 >
25+ template <typename T, int N >
2626class KDTree : public NearestNeighbor <T, N> {
27- private:
28- using PointType = typename NearestNeighbor<T, N>::PointType;
29- class Node ;
30- Node *root_;
31- int K_;
32- std::priority_queue<std::pair<double , Node *>> bests_;
33-
34- Node *Build (std::vector<PointType> points, int depth);
35- Node *NearestNeighbor (Node *root, const PointType &target, int depth);
36- void KNearestNeighbor (Node *root, const PointType &target, int depth);
37-
38- Node *Closest (Node *n0, Node *n1, const PointType &target);
39- double DistSquared (const PointType &p0, const PointType &p1);
40- static void Inorder (Node *root);
41-
42- public:
43- KDTree ();
44- void Insert (const PointType &point) override ;
45- void BuildTree (const std::vector<PointType> &points) override ;
46- std::vector<PointType> KNearestNeighbor (const PointType &target_points,
47- int k) override ;
48- void PrintInorder ();
49-
50- ~KDTree () override ;
27+ private:
28+ class Node {
29+ public:
30+ Node* left_;
31+ Node* right_;
32+ Point<T, N> point_;
33+ int depth_;
34+
35+ explicit Node () {}
36+ explicit Node (Point<T, N> point, Node* left, Node* right) : point_(point), left_(left), right_(right) {}
37+ explicit Node (Point<T, N> point, Node* left, Node* right, int depth) : point_(point), left_(left), right_(right), depth_(depth) {}
38+ explicit Node (Point<T, N> point) : point_(point) {
39+ this ->left_ = nullptr ;
40+ this ->right_ = nullptr ;
41+ }
42+ explicit Node (Point<T, N> point, int depth) : point_(point), depth_(depth) {
43+ this ->left_ = nullptr ;
44+ this ->right_ = nullptr ;
45+ }
46+
47+ ~Node () {
48+ delete left_;
49+ delete right_;
50+ }
51+ };
52+
53+ Node* root_;
54+ int K_;
55+ std::priority_queue<std::pair<double , Node*>> bests_;
56+
57+ Node* Build (std::vector<Point<T, N>> points, int depth) {
58+ if (points.empty ()) {
59+ return nullptr ;
60+ }
61+ int k = points.at (0 ).data ().size ();
62+ int axis = depth % k;
63+
64+ std::sort (points.begin (), points.end (), [axis](const Point<T, N>& a, const Point<T, N>& b) {
65+ return a.data ()[axis] < b.data ()[axis];
66+ });
67+
68+ int median = points.size () / 2 ;
69+ std::vector<Point<T, N>> points_left (points.begin (), points.begin () + median);
70+ std::vector<Point<T, N>> points_right (points.begin () + median + 1 , points.end ());
71+
72+
73+ return new Node (
74+ points.at (median),
75+ Build (points_left, depth + 1 ),
76+ Build (points_right, depth + 1 ),
77+ depth
78+ );
79+ }
80+
81+
82+ Node* NearestNeighbor (Node* root, Point<T, N>& target, int depth) {
83+ if (root == nullptr ) return nullptr ;
84+
85+ Node* next_branch;
86+ Node* other_branch;
87+
88+ int axis = depth % root->point_ .size ();
89+
90+ if (target.data ().at (axis) < root->point_ .data ().at (axis))
91+ {
92+ next_branch = root->left_ ;
93+ other_branch = root->right_ ;
94+ } else {
95+ next_branch = root->right_ ;
96+ other_branch = root->left_ ;
97+ }
98+
99+ Node* temp = NearestNeighbor (next_branch, target, depth + 1 );
100+ Node* best = Closest (temp, root, target);
101+ double radius_squared = DistSquared (target, best->point_ );
102+
103+ double dist = target.data ().at (axis) - root->point_ .data ().at (axis);
104+
105+ if (radius_squared >= dist * dist)
106+ {
107+ temp = NearestNeighbor (other_branch, target, depth + 1 );
108+ best = Closest (temp, best, target);
109+ }
110+
111+ return best;
112+
113+ }
114+
115+ void KNearestNeighbor (Node* root, Point<T, N>& target, int depth) {
116+ if (root == nullptr ) return ;
117+
118+ Node* next_branch;
119+ Node* other_branch;
120+
121+ int axis = depth % root->point_ .size ();
122+
123+ if (target.data ().at (axis) < root->point_ .data ().at (axis))
124+ {
125+ next_branch = root->left_ ;
126+ other_branch = root->right_ ;
127+ } else {
128+ next_branch = root->right_ ;
129+ other_branch = root->left_ ;
130+ }
131+
132+ KNearestNeighbor (next_branch, target, depth + 1 );
133+ double dist = DistSquared (target, root->point_ );
134+
135+ if (this ->bests_ .size () < this ->K_ )
136+ {
137+ this ->bests_ .push ({dist, root});
138+ }
139+ else if (dist < this ->bests_ .top ().first )
140+ {
141+ bests_.pop ();
142+ this ->bests_ .push ({dist, root});
143+ }
144+
145+ double diff = target.data ().at (axis) - root->point_ .data ().at (axis);
146+
147+ if (this ->bests_ .size () < this ->K_ || diff * diff < this ->bests_ .top ().first ) {
148+ KNearestNeighbor (other_branch, target, depth + 1 );
149+ }
150+ }
151+
152+
153+ Node* Closest (Node* n0, Node* n1, Point<T, N>& target) {
154+ if (n0 == nullptr ) return n1;
155+ if (n1 == nullptr ) return n0;
156+
157+ long d1 = DistSquared (n0->point_ , target);
158+ long d2 = DistSquared (n1->point_ , target);
159+
160+ if (d1 < d2)
161+ return n0;
162+ else
163+ return n1;
164+ }
165+
166+ double DistSquared (const Point<T, N>& p0, const Point<T, N>& p1) {
167+ long total = 0 ;
168+ size_t numDims = p0.size ();
169+
170+ for (size_t i = 0 ; i < numDims; ++i) {
171+ int diff = std::abs (p0.data ()[i] - p1.data ()[i]);
172+ total += static_cast <double >(diff) * diff; // mais eficiente que pow para int
173+ }
174+
175+ return total;
176+ }
177+
178+ static void Inorder (Node* root) {
179+ if (!root) return ;
180+
181+ Inorder (root->left_ );
182+
183+ if (root->left_ ) {
184+ std::cout << " \" " ;
185+ for (size_t i = 0 ; i < root->point_ .size (); i++)
186+ std::cout << root->point_ .data ()[i] << (i + 1 == root->point_ .size () ? " " : " ," );
187+ std::cout << " \" -> \" " ;
188+ for (size_t i = 0 ; i < root->left_ ->point_ .size (); i++)
189+ std::cout << root->left_ ->point_ .data ()[i] << (i + 1 == root->left_ ->point_ .size () ? " " : " ," );
190+ std::cout << " \" [label=\" esq\" ];\n " ;
191+ }
192+
193+ if (root->right_ ) {
194+ std::cout << " \" " ;
195+ for (size_t i = 0 ; i < root->point_ .size (); i++)
196+ std::cout << root->point_ .data ()[i] << (i + 1 == root->point_ .size () ? " " : " ," );
197+ std::cout << " \" -> \" " ;
198+ for (size_t i = 0 ; i < root->right_ ->point_ .size (); i++)
199+ std::cout << root->right_ ->point_ .data ()[i] << (i + 1 == root->right_ ->point_ .size () ? " " : " ," );
200+ std::cout << " \" [label=\" dir\" ];\n " ;
201+ }
202+
203+ Inorder (root->right_ );
204+ }
205+
206+ public:
207+ KDTree () : root_(nullptr ) {}
208+ void Insert (Point<T, N> point) override {
209+ Node* p = this ->root_ ;
210+ Node* prev = nullptr ;
211+
212+ int depth = 0 ;
213+ int n_dims = point.size ();
214+
215+ while (p != nullptr )
216+ {
217+ prev = p;
218+ if (point.data ().at (depth) < p->point_ .data ().at (depth))
219+ p = p->left_ ;
220+ else
221+ p= p->right_ ;
222+ depth = (depth + 1 ) % n_dims;
223+ }
224+
225+ if (this ->root_ == nullptr )
226+ this ->root_ = new Node (point);
227+ else if ((point.data ().at ((depth - 1 ) % n_dims)) < (prev->point_ .data ().at ((depth - 1 ) % n_dims)))
228+ prev->left_ = new Node (point, depth);
229+ else
230+ prev->right_ = new Node (point, depth);
231+ }
232+
233+ void BuildTree (std::vector<Point<T, N>> points) override {
234+ if (points.empty ())
235+ {
236+ return ;
237+ }
238+
239+ int initial_size = points.at (0 ).size ();
240+
241+ for (auto &point : points)
242+ {
243+ if (point.data ().empty ())
244+ {
245+ return ;
246+ }
247+
248+ if (point.size () != initial_size)
249+ {
250+ return ;
251+ }
252+
253+
254+ }
255+
256+ this ->root_ = Build (points, 0 );
257+ }
258+
259+
260+ std::vector<Point<T, N>> KNearestNeighbor (Point<T, N> target_points, int k) override {
261+ if (k)
262+ this ->K_ = k;
263+
264+ if (k == 1 ){
265+ Node* result = NearestNeighbor (this ->root_ , target_points, 0 );
266+ std::vector<Point<T, N>> points;
267+ if (result)
268+ {
269+ points.push_back (result->point_ );
270+ }
271+ return points;
272+ }
273+
274+ KNearestNeighbor (this ->root_ , target_points, 0 );
275+ std::vector<Point<T, N>> tmp;
276+ while (!this ->bests_ .empty ())
277+ {
278+ tmp.push_back (this ->bests_ .top ().second ->point_ );
279+ this ->bests_ .pop ();
280+ }
281+
282+ return tmp;
283+ }
284+
285+ void PrintInorder () {
286+ // TODO: A ideia é isso gerar um arquivo para o graphiz gerar o grafo
287+ std::cout << " digraph G {\n " ;
288+ Inorder (this ->root_ );
289+ std::cout << " }\n " ;
290+ }
291+
292+ ~KDTree () override {
293+ delete root_;
294+ }
51295};
52296#endif // KDTREE_H
0 commit comments