Skip to content

Commit 844c901

Browse files
committed
Added flatten and unflatten utils
1 parent 0cda5b2 commit 844c901

File tree

4 files changed

+42
-0
lines changed

4 files changed

+42
-0
lines changed

src/finitediff.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,4 +224,27 @@ bool compare_hessian(
224224
return compare_jacobian(x, y, test_eps, msg);
225225
}
226226

227+
// Flatten the matrix rowwise
228+
Eigen::VectorXd flatten(const Eigen::MatrixXd& X)
229+
{
230+
Eigen::VectorXd x(X.size());
231+
for (int i = 0; i < X.rows(); i++) {
232+
for (int j = 0; j < X.cols(); j++) {
233+
x(i * X.cols() + j) = X(i, j);
234+
}
235+
}
236+
return x;
237+
}
238+
239+
// Unflatten rowwise
240+
Eigen::MatrixXd unflatten(const Eigen::VectorXd& x, int dim)
241+
{
242+
assert(x.size() % dim == 0);
243+
Eigen::MatrixXd X(x.size() / dim, dim);
244+
for (int i = 0; i < x.size(); i++) {
245+
X(i / dim, i % dim) = x(i);
246+
}
247+
return X;
248+
}
249+
227250
} // namespace fd

src/finitediff.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,10 @@ bool compare_hessian(
118118
const double test_eps = 1e-4,
119119
const std::string& msg = "compare_hessian ");
120120

121+
/// @brief Flatten the matrix rowwise
122+
Eigen::VectorXd flatten(const Eigen::MatrixXd& X);
123+
124+
/// @brief Unflatten rowwise
125+
Eigen::MatrixXd unflatten(const Eigen::VectorXd& x, int dim);
126+
121127
} // namespace fd

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_executable(finitediff_tests
88
test_gradient.cpp
99
test_jacobian.cpp
1010
test_hessian.cpp
11+
test_flatten.cpp
1112
)
1213

1314
################################################################################

tests/test_flatten.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#include <catch2/catch.hpp>
2+
3+
#include <finitediff.hpp>
4+
5+
using namespace fd;
6+
7+
TEST_CASE("Flatten and unflatten", "[utils]")
8+
{
9+
Eigen::MatrixXd X = Eigen::MatrixXd::Random(1000, 3);
10+
Eigen::MatrixXd R = unflatten(flatten(X), X.cols());
11+
CHECK(X == R);
12+
}

0 commit comments

Comments
 (0)