Skip to content

Commit a5b8512

Browse files
committed
Add a nearest neighbor model.
1 parent 5cf319f commit a5b8512

File tree

10 files changed

+262
-32
lines changed

10 files changed

+262
-32
lines changed

include/albatross/NearestNeighbor

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/*
2+
* Copyright (C) 2019 Swift Navigation Inc.
3+
* Contact: Swift Navigation <[email protected]>
4+
*
5+
* This source is subject to the license found in the file 'LICENSE' which must
6+
* be distributed together with this source. All other rights reserved.
7+
*
8+
* THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND,
9+
* EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED
10+
* WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE.
11+
*/
12+
13+
#ifndef ALBATROSS_NEAREST_NEIGHBOR_MODEL_H
14+
#define ALBATROSS_NEAREST_NEIGHBOR_MODEL_H
15+
16+
#include "Core"
17+
18+
#include <albatross/src/models/nearest_neighbor.hpp>
19+
20+
#endif
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/*
2+
* Copyright (C) 2019 Swift Navigation Inc.
3+
* Contact: Swift Navigation <[email protected]>
4+
*
5+
* This source is subject to the license found in the file 'LICENSE' which must
6+
* be distributed together with this source. All other rights reserved.
7+
*
8+
* THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND,
9+
* EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED
10+
* WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE.
11+
*/
12+
13+
#ifndef ALBATROSS_SERIALIZE_NEAREST_NEIGHBOR_H
14+
#define ALBATROSS_SERIALIZE_NEAREST_NEIGHBOR_H
15+
16+
#include "Core"
17+
18+
#include "../src/cereal/nearest_neighbor.hpp"
19+
20+
#endif

include/albatross/src/cereal/dataset.hpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,16 @@ serialize(Archive &archive, RegressionDataset<FeatureType> &dataset,
2929
archive(cereal::make_nvp("metadata", dataset.metadata));
3030
}
3131

32-
template <class Archive, class FeatureType>
33-
typename std::enable_if<!valid_in_out_serializer<FeatureType, Archive>::value,
34-
void>::type
35-
serialize(Archive &archive, RegressionDataset<FeatureType> &dataset,
36-
const std::uint32_t) {
37-
static_assert(delay_static_assert<Archive>::value,
38-
"In order to serialize a RegressionDataset the corresponding "
39-
"FeatureType must be serializable.");
40-
}
32+
// template <class Archive, class FeatureType>
33+
// typename std::enable_if<!valid_in_out_serializer<FeatureType,
34+
// Archive>::value,
35+
// void>::type
36+
// serialize(Archive &archive, RegressionDataset<FeatureType> &dataset,
37+
// const std::uint32_t) {
38+
// static_assert(delay_static_assert<Archive>::value,
39+
// "In order to serialize a RegressionDataset the corresponding "
40+
// "FeatureType must be serializable.");
41+
//}
4142

4243
} // namespace cereal
4344

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright (C) 2019 Swift Navigation Inc.
3+
* Contact: Swift Navigation <[email protected]>
4+
*
5+
* This source is subject to the license found in the file 'LICENSE' which must
6+
* be distributed together with this source. All other rights reserved.
7+
*
8+
* THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND,
9+
* EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED
10+
* WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE.
11+
*/
12+
13+
#ifndef ALBATROSS_CEREAL_NEAREST_NEIGHBOR_HPP_
14+
#define ALBATROSS_CEREAL_NEAREST_NEIGHBOR_HPP_
15+
16+
namespace albatross {
17+
18+
template <typename FeatureType> struct NearestNeighborFit;
19+
20+
template <typename DistanceMetric> struct NearestNeighborModel;
21+
22+
} // namespace albatross
23+
24+
namespace cereal {
25+
26+
template <typename Archive, typename FeatureType>
27+
inline void
28+
save(Archive &archive,
29+
const albatross::Fit<albatross::NearestNeighborFit<FeatureType>> &fit,
30+
const std::uint32_t) {
31+
archive(cereal::make_nvp("training_features", fit.training_data.features));
32+
archive(cereal::make_nvp("training_targets", fit.training_data.targets));
33+
}
34+
35+
template <typename Archive, typename FeatureType>
36+
inline void
37+
load(Archive &archive,
38+
albatross::Fit<albatross::NearestNeighborFit<FeatureType>> &fit,
39+
const std::uint32_t) {
40+
std::vector<FeatureType> features;
41+
archive(cereal::make_nvp("training_features", features));
42+
albatross::MarginalDistribution targets;
43+
archive(cereal::make_nvp("training_targets", targets));
44+
fit.training_data = RegressionDataset<FeatureType>(features, targets);
45+
}
46+
47+
} // namespace cereal
48+
49+
#endif /* ALBATROSS_CEREAL_NEAREST_NEIGHBOR_HPP_ */

include/albatross/src/evaluation/cross_validation_utils.hpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,25 @@ inline MarginalDistribution concatenate_marginal_predictions(
9393
Eigen::VectorXd variance(n);
9494
Eigen::Index number_filled = 0;
9595
// Put all the predicted means back in order.
96+
bool has_covariance = false;
9697
for (const auto &pair : indexer) {
9798
assert(preds.at(pair.first).size() == pair.second.size());
9899
set_subset(preds.at(pair.first).mean, pair.second, &mean);
99-
set_subset(preds.at(pair.first).covariance.diagonal(), pair.second,
100-
&variance);
100+
if (preds.at(pair.first).has_covariance()) {
101+
has_covariance = true;
102+
set_subset(preds.at(pair.first).covariance.diagonal(), pair.second,
103+
&variance);
104+
} else {
105+
assert(!has_covariance);
106+
}
101107
number_filled += static_cast<Eigen::Index>(pair.second.size());
102108
}
103109
assert(number_filled == n);
104-
return MarginalDistribution(mean, variance.asDiagonal());
110+
if (has_covariance) {
111+
return MarginalDistribution(mean, variance.asDiagonal());
112+
} else {
113+
return MarginalDistribution(mean);
114+
}
105115
}
106116

107117
template <typename PredictionMetricType, typename FeatureType,
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright (C) 2019 Swift Navigation Inc.
3+
* Contact: Swift Navigation <[email protected]>
4+
*
5+
* This source is subject to the license found in the file 'LICENSE' which must
6+
* be distributed together with this source. All other rights reserved.
7+
*
8+
* THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND,
9+
* EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED
10+
* WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE.
11+
*/
12+
13+
#ifndef ALBATROSS_SRC_MODELS_NEAREST_NEIGHBOR_MODEL_HPP_
14+
#define ALBATROSS_SRC_MODELS_NEAREST_NEIGHBOR_MODEL_HPP_
15+
16+
namespace albatross {
17+
18+
template <typename DistanceMetric> class NearestNeighborModel;
19+
20+
template <typename FeatureType> struct NearestNeighborFit;
21+
22+
template <typename FeatureType> struct Fit<NearestNeighborFit<FeatureType>> {
23+
24+
Fit() : training_data(){};
25+
26+
Fit(const RegressionDataset<FeatureType> &dataset) : training_data(dataset){};
27+
28+
bool operator==(const Fit<NearestNeighborFit<FeatureType>> &other) const {
29+
return training_data == other.training_data;
30+
}
31+
32+
RegressionDataset<FeatureType> training_data;
33+
};
34+
35+
template <typename DistanceMetric>
36+
class NearestNeighborModel
37+
: public ModelBase<NearestNeighborModel<DistanceMetric>> {
38+
39+
public:
40+
NearestNeighborModel() : distance_metric(){};
41+
42+
std::string get_name() const { return "nearest_neighbor_model"; };
43+
44+
template <typename FeatureType>
45+
Fit<NearestNeighborFit<FeatureType>>
46+
_fit_impl(const std::vector<FeatureType> &features,
47+
const MarginalDistribution &targets) const {
48+
return Fit<NearestNeighborFit<FeatureType>>(
49+
RegressionDataset<FeatureType>(features, targets));
50+
}
51+
52+
template <typename FeatureType>
53+
auto fit_from_prediction(const std::vector<FeatureType> &features,
54+
const JointDistribution &prediction) const {
55+
const NearestNeighborModel<DistanceMetric> m(*this);
56+
MarginalDistribution marginal_pred(
57+
prediction.mean, prediction.covariance.diagonal().asDiagonal());
58+
Fit<NearestNeighborFit<FeatureType>> fit = {
59+
RegressionDataset<FeatureType>(features, marginal_pred)};
60+
FitModel<NearestNeighborModel, Fit<NearestNeighborFit<FeatureType>>>
61+
fit_model(m, fit);
62+
return fit_model;
63+
}
64+
65+
template <typename FeatureType>
66+
MarginalDistribution
67+
_predict_impl(const std::vector<FeatureType> &features,
68+
const Fit<NearestNeighborFit<FeatureType>> &fit,
69+
PredictTypeIdentity<MarginalDistribution> &&) const {
70+
const Eigen::Index n = static_cast<Eigen::Index>(features.size());
71+
Eigen::VectorXd mean = Eigen::VectorXd::Zero(n);
72+
mean.fill(NAN);
73+
Eigen::VectorXd variance = Eigen::VectorXd::Zero(n);
74+
variance.fill(NAN);
75+
76+
for (std::size_t i = 0; i < features.size(); ++i) {
77+
const auto min_index =
78+
index_with_min_distance(features[i], fit.training_data.features);
79+
mean[i] = fit.training_data.targets.mean[min_index];
80+
variance[i] = fit.training_data.targets.get_diagonal(min_index);
81+
}
82+
83+
if (fit.training_data.targets.has_covariance()) {
84+
return MarginalDistribution(mean, variance.asDiagonal());
85+
} else {
86+
return MarginalDistribution(mean);
87+
}
88+
}
89+
90+
private:
91+
template <typename FeatureType>
92+
std::size_t
93+
index_with_min_distance(const FeatureType &ref,
94+
const std::vector<FeatureType> &features) const {
95+
assert(features.size() > 0);
96+
97+
std::size_t min_index = 0;
98+
double min_distance = distance_metric(ref, features[0]);
99+
100+
for (std::size_t i = 1; i < features.size(); ++i) {
101+
const double dist = distance_metric(ref, features[i]);
102+
if (dist < min_distance) {
103+
min_index = i;
104+
min_distance = dist;
105+
}
106+
}
107+
return min_index;
108+
}
109+
110+
DistanceMetric distance_metric;
111+
};
112+
113+
} // namespace albatross
114+
115+
#endif // ALBATROSS_SRC_MODELS_NEAREST_NEIGHBOR_MODEL_HPP_

include/albatross/src/models/null_model.hpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,13 @@ class NullModel : public ModelBase<NullModel> {
3232

3333
std::string get_name() const { return "null_model"; };
3434

35-
/*
36-
* The Gaussian Process Regression model derives its parameters from
37-
* the covariance functions.
38-
*/
3935
ParameterStore get_params() const override { return params_; }
4036

4137
void unchecked_set_param(const std::string &name,
4238
const Parameter &param) override {
4339
params_[name] = param;
4440
}
4541

46-
// If the implementing class doesn't have a fit method for this
47-
// FeatureType but the CovarianceFunction does.
4842
template <typename FeatureType>
4943
Fit<NullModel> _fit_impl(const std::vector<FeatureType> &features,
5044
const MarginalDistribution &targets) const {
@@ -87,5 +81,4 @@ class NullModel : public ModelBase<NullModel> {
8781

8882
} // namespace albatross
8983

90-
#endif /* THIRD_PARTY_ALBATROSS_INCLUDE_ALBATROSS_SRC_MODELS_NULL_MODEL_HPP_ \
91-
*/
84+
#endif // ALBATROSS_SRC_MODELS_NULL_MODEL_HPP_

tests/test_cross_validation.cc

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,22 @@ TYPED_TEST_P(RegressionModelTester, test_logo_predict_variants) {
5353
auto dataset = this->test_case.get_dataset();
5454
auto model = this->test_case.get_model();
5555

56-
// Here we assume that the test case is linear, then split
57-
// it using a group function which will not preserve order
58-
// and make sure that cross validation properly reassembles
59-
// the predictions
60-
LeaveOneGroupOut<typename decltype(dataset)::Feature> logo(group_by_interval);
61-
const auto prediction = model.cross_validate().predict(dataset, logo);
62-
63-
EXPECT_TRUE(is_monotonic_increasing(prediction.mean()));
64-
65-
expect_predict_variants_consistent(prediction);
56+
// The nearest neighbor approach is not capable of modelling linear
57+
// trends and in turn fails this test.
58+
if (!std::is_same<decltype(model),
59+
NearestNeighborModel<EuclideanDistance>>::value) {
60+
// Here we assume that the test case is linear, then split
61+
// it using a group function which will not preserve order
62+
// and make sure that cross validation properly reassembles
63+
// the predictions
64+
LeaveOneGroupOut<typename decltype(dataset)::Feature> logo(
65+
group_by_interval);
66+
const auto prediction = model.cross_validate().predict(dataset, logo);
67+
68+
EXPECT_TRUE(is_monotonic_increasing(prediction.mean()));
69+
70+
expect_predict_variants_consistent(prediction);
71+
}
6672
}
6773

6874
TYPED_TEST_P(RegressionModelTester, test_loo_predict_variants) {
@@ -110,7 +116,9 @@ TYPED_TEST_P(RegressionModelTester, test_score_variants) {
110116
// Here we make sure the cross validated mean absolute error is reasonable.
111117
// Note that because we are running leave one out cross validation, the
112118
// RMSE for each fold is just the absolute value of the error.
113-
if (!std::is_same<decltype(model), NullModel>::value) {
119+
if (!std::is_same<decltype(model), NullModel>::value &&
120+
!std::is_same<decltype(model),
121+
NearestNeighborModel<EuclideanDistance>>::value) {
114122
EXPECT_LE(cv_scores.mean(), 0.1);
115123
}
116124
}

tests/test_models.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <albatross/GP>
1414
#include <albatross/LeastSquares>
15+
#include <albatross/NearestNeighbor>
1516
#include <albatross/NullModel>
1617
#include <albatross/Ransac>
1718
#include <gtest/gtest.h>
@@ -181,6 +182,17 @@ class MakeNullModel {
181182
}
182183
};
183184

185+
class MakeNearestNeighborModel {
186+
public:
187+
NearestNeighborModel<EuclideanDistance> get_model() const {
188+
return NearestNeighborModel<EuclideanDistance>();
189+
}
190+
191+
RegressionDataset<double> get_dataset() const {
192+
return make_toy_linear_data();
193+
}
194+
};
195+
184196
template <typename ModelTestCase>
185197
class RegressionModelTester : public ::testing::Test {
186198
public:
@@ -189,7 +201,8 @@ class RegressionModelTester : public ::testing::Test {
189201

190202
typedef ::testing::Types<MakeLinearRegression, MakeGaussianProcess,
191203
MakeAdaptedGaussianProcess, MakeRansacGaussianProcess,
192-
MakeRansacAdaptedGaussianProcess, MakeNullModel>
204+
MakeRansacAdaptedGaussianProcess, MakeNullModel,
205+
MakeNearestNeighborModel>
193206
ExampleModels;
194207

195208
TYPED_TEST_CASE_P(RegressionModelTester);

tests/test_serialize.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <albatross/serialize/Common>
1717
#include <albatross/serialize/GP>
1818
#include <albatross/serialize/LeastSquares>
19+
#include <albatross/serialize/NearestNeighbor>
1920
#include <albatross/serialize/Ransac>
2021

2122
#include <gtest/gtest.h>

0 commit comments

Comments
 (0)