Skip to content

Commit a77c5bf

Browse files
nhuurreyashiknostan-buildbotrok-cesnovar
authored
more general check_matching_dims (#1936)
* more general check_matching_dims * use dims() * [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.04.1 (tags/RELEASE_600/final) * fix overload issue in opencl/dims * [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2 (tags/RELEASE_600/final) Co-authored-by: Jenkins <nobody@nowhere> Co-authored-by: Stan Jenkins <[email protected]> Co-authored-by: rok-cesnovar <[email protected]>
1 parent 1b300c5 commit a77c5bf

File tree

3 files changed

+104
-8
lines changed

3 files changed

+104
-8
lines changed

stan/math/opencl/prim/dims.hpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,26 @@
22
#define STAN_MATH_OPENCL_PRIM_DIMS_HPP
33
#ifdef STAN_OPENCL
44

5+
#include <stan/math/prim/fun/dims.hpp>
56
#include <stan/math/opencl/matrix_cl.hpp>
67
#include <stan/math/opencl/rev/matrix_cl.hpp>
78
#include <vector>
89

910
namespace stan {
1011
namespace math {
1112
/** \ingroup opencl
12-
* Returns a vector of matrix_cl dimensions.
13+
* matrix_cl overload of the dims helper function in prim/fun/dims.hpp.
14+
* Pushes the rows and columns to the result vector argument.
1315
*
1416
* @tparam T_x type of input kernel generator expression a
15-
* @param x the input matrix_cl
16-
*
17-
* @return std::vector of the dimensions of the input kernel generato expression
17+
* @param[in] x the input matrix_cl
18+
* @param[out] result the output vector of dimensions
1819
*/
1920
template <typename T_x,
20-
typename = require_all_kernel_expressions_and_none_scalar_t<T_x>>
21-
inline std::vector<int> dims(const T_x& x) {
22-
return {x.rows(), x.cols()};
21+
require_all_kernel_expressions_and_none_scalar_t<T_x>* = nullptr>
22+
inline void dims(const T_x& x, std::vector<int>& result) {
23+
result.push_back(x.rows());
24+
result.push_back(x.cols());
2325
}
2426
} // namespace math
2527
} // namespace stan

stan/math/prim/err/check_matching_dims.hpp

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/fun/Eigen.hpp>
6-
#include <stan/math/prim/err/check_size_match.hpp>
6+
#include <stan/math/prim/fun/dims.hpp>
77
#include <stan/math/prim/err/invalid_argument.hpp>
88
#include <sstream>
99
#include <string>
@@ -70,6 +70,56 @@ inline void check_matching_dims(const char* function, const char* name1,
7070
check_matching_dims(function, name1, y1, name2, y2);
7171
}
7272

73+
/**
74+
* Check if the two containers have the same dimensions.
75+
* @tparam T1 type of the first container
76+
* @tparam T2 type of the second container
77+
* @param function name of function (for error messages)
78+
* @param name1 variable name for the first container (for error messages)
79+
* @param y1 first container to test
80+
* @param name2 variable name for the second container (for error messages)
81+
* @param y2 second container to test
82+
* @throw <code>std::invalid_argument</code> if the dimensions of the
83+
* containers do not match
84+
*/
85+
template <typename T1, typename T2, require_all_std_vector_t<T1, T2>* = nullptr>
86+
inline void check_matching_dims(const char* function, const char* name1,
87+
const T1& y1, const char* name2, const T2& y2) {
88+
std::vector<int> y1_d = dims(y1);
89+
std::vector<int> y2_d = dims(y2);
90+
bool error = false;
91+
if (y1_d.size() != y2_d.size()) {
92+
error = true;
93+
} else {
94+
for (int i = 0; i < y1_d.size(); i++) {
95+
if (y1_d[i] != y2_d[i]) {
96+
error = true;
97+
break;
98+
}
99+
}
100+
}
101+
if (error) {
102+
std::ostringstream y1s;
103+
if (y1_d.size() > 0) {
104+
y1s << y1_d[0];
105+
for (int i = 1; i < y1_d.size(); i++) {
106+
y1s << ", " << y1_d[i];
107+
}
108+
}
109+
std::ostringstream msg;
110+
msg << ") and " << name2 << " (";
111+
if (y2_d.size() > 0) {
112+
msg << y2_d[0];
113+
for (int i = 1; i < y2_d.size(); i++) {
114+
msg << ", " << y2_d[i];
115+
}
116+
}
117+
msg << ") must match in size";
118+
std::string msg_str(msg.str());
119+
invalid_argument(function, name1, y1s.str(), "(", msg_str.c_str());
120+
}
121+
}
122+
73123
} // namespace math
74124
} // namespace stan
75125
#endif

test/unit/math/prim/err/check_matching_dims_test.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,50 @@ TEST(ErrorHandlingMatrix, checkMatchingDimsMatrix) {
2727
std::invalid_argument);
2828
}
2929

30+
TEST(ErrorHandlingMatrix, checkMatchingDimsArray) {
31+
std::vector<double> y(3);
32+
std::vector<double> x(3);
33+
34+
EXPECT_NO_THROW(
35+
stan::math::check_matching_dims("checkMatchingDims", "x", x, "y", y));
36+
x.resize(0);
37+
y.resize(0);
38+
EXPECT_NO_THROW(
39+
stan::math::check_matching_dims("checkMatchingDims", "x", x, "y", y));
40+
41+
y.resize(1);
42+
EXPECT_THROW(
43+
stan::math::check_matching_dims("checkMatchingDims", "x", x, "y", y),
44+
std::invalid_argument);
45+
46+
x.resize(2);
47+
EXPECT_THROW(
48+
stan::math::check_matching_dims("checkMatchingDims", "x", x, "y", y),
49+
std::invalid_argument);
50+
}
51+
52+
TEST(ErrorHandlingMatrix, checkMatchingDimsVectorArray) {
53+
std::vector<Eigen::VectorXd> y(3, Eigen::VectorXd(3));
54+
std::vector<Eigen::VectorXd> x(3, Eigen::VectorXd(3));
55+
56+
EXPECT_NO_THROW(
57+
stan::math::check_matching_dims("checkMatchingDims", "x", x, "y", y));
58+
x.resize(0);
59+
y.resize(0);
60+
EXPECT_NO_THROW(
61+
stan::math::check_matching_dims("checkMatchingDims", "x", x, "y", y));
62+
63+
y.resize(1, Eigen::VectorXd(3));
64+
EXPECT_THROW(
65+
stan::math::check_matching_dims("checkMatchingDims", "x", x, "y", y),
66+
std::invalid_argument);
67+
68+
x.resize(1, Eigen::VectorXd(2));
69+
EXPECT_THROW(
70+
stan::math::check_matching_dims("checkMatchingDims", "x", x, "y", y),
71+
std::invalid_argument);
72+
}
73+
3074
TEST(ErrorHandlingMatrix, checkMatchingDimsMatrix_nan) {
3175
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> y;
3276
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> x;

0 commit comments

Comments
 (0)