Skip to content

Commit 9fa4e13

Browse files
author
Bob Carpenter
authored
Merge pull request #1718 from stan-dev/cleanup/1689-check-multiplicable-size-zero
Remove size zero checks from check_multiplicable
2 parents 6bd1ced + 84bbd9a commit 9fa4e13

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+383
-128
lines changed

stan/math/fwd/fun/mdivide_left.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_left(
1919
const Eigen::Matrix<fvar<T>, R2, C2> &b) {
2020
check_square("mdivide_left", "A", A);
2121
check_multiplicable("mdivide_left", "A", A, "b", b);
22+
if (A.size() == 0) {
23+
return {0, b.cols()};
24+
}
2225

2326
Eigen::Matrix<T, R1, C2> inv_A_mult_b(A.rows(), b.cols());
2427
Eigen::Matrix<T, R1, C2> inv_A_mult_deriv_b(A.rows(), b.cols());
@@ -58,6 +61,9 @@ inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_left(
5861
const Eigen::Matrix<fvar<T>, R2, C2> &b) {
5962
check_square("mdivide_left", "A", A);
6063
check_multiplicable("mdivide_left", "A", A, "b", b);
64+
if (A.size() == 0) {
65+
return {0, b.cols()};
66+
}
6167

6268
Eigen::Matrix<T, R2, C2> val_b(b.rows(), b.cols());
6369
Eigen::Matrix<T, R2, C2> deriv_b(b.rows(), b.cols());
@@ -78,6 +84,9 @@ inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_left(
7884
const Eigen::Matrix<double, R2, C2> &b) {
7985
check_square("mdivide_left", "A", A);
8086
check_multiplicable("mdivide_left", "A", A, "b", b);
87+
if (A.size() == 0) {
88+
return {0, b.cols()};
89+
}
8190

8291
Eigen::Matrix<T, R1, C2> inv_A_mult_b(A.rows(), b.cols());
8392
Eigen::Matrix<T, R1, C1> inv_A_mult_deriv_A(A.rows(), A.cols());

stan/math/fwd/fun/mdivide_left_tri_low.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_left_tri_low(
1818
const Eigen::Matrix<fvar<T>, R2, C2>& b) {
1919
check_square("mdivide_left_tri_low", "A", A);
2020
check_multiplicable("mdivide_left_tri_low", "A", A, "b", b);
21+
if (A.size() == 0) {
22+
return {0, b.cols()};
23+
}
2124

2225
Eigen::Matrix<T, R1, C2> inv_A_mult_b(A.rows(), b.cols());
2326
Eigen::Matrix<T, R1, C2> inv_A_mult_deriv_b(A.rows(), b.cols());
@@ -59,6 +62,9 @@ inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_left_tri_low(
5962
const Eigen::Matrix<fvar<T>, R2, C2>& b) {
6063
check_square("mdivide_left_tri_low", "A", A);
6164
check_multiplicable("mdivide_left_tri_low", "A", A, "b", b);
65+
if (A.size() == 0) {
66+
return {0, b.cols()};
67+
}
6268

6369
Eigen::Matrix<T, R1, C2> inv_A_mult_b(A.rows(), b.cols());
6470
Eigen::Matrix<T, R1, C2> inv_A_mult_deriv_b(A.rows(), b.cols());
@@ -95,6 +101,9 @@ inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_left_tri_low(
95101
const Eigen::Matrix<double, R2, C2>& b) {
96102
check_square("mdivide_left_tri_low", "A", A);
97103
check_multiplicable("mdivide_left_tri_low", "A", A, "b", b);
104+
if (A.size() == 0) {
105+
return {0, b.cols()};
106+
}
98107

99108
Eigen::Matrix<T, R1, C2> inv_A_mult_b(A.rows(), b.cols());
100109
Eigen::Matrix<T, R1, C1> inv_A_mult_deriv_A(A.rows(), A.cols());

stan/math/fwd/fun/mdivide_right.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_right(
1919
const Eigen::Matrix<fvar<T>, R2, C2> &b) {
2020
check_square("mdivide_right", "b", b);
2121
check_multiplicable("mdivide_right", "A", A, "b", b);
22+
if (b.size() == 0) {
23+
return {A.rows(), 0};
24+
}
2225

2326
Eigen::Matrix<T, R1, C2> A_mult_inv_b(A.rows(), b.cols());
2427
Eigen::Matrix<T, R1, C2> deriv_A_mult_inv_b(A.rows(), b.cols());
@@ -58,6 +61,9 @@ inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_right(
5861
const Eigen::Matrix<double, R2, C2> &b) {
5962
check_square("mdivide_right", "b", b);
6063
check_multiplicable("mdivide_right", "A", A, "b", b);
64+
if (b.size() == 0) {
65+
return {A.rows(), 0};
66+
}
6167

6268
Eigen::Matrix<T, R2, C2> deriv_b_mult_inv_b(b.rows(), b.cols());
6369
Eigen::Matrix<T, R1, C1> val_A(A.rows(), A.cols());
@@ -79,6 +85,10 @@ inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_right(
7985
const Eigen::Matrix<fvar<T>, R2, C2> &b) {
8086
check_square("mdivide_right", "b", b);
8187
check_multiplicable("mdivide_right", "A", A, "b", b);
88+
if (b.size() == 0) {
89+
return {A.rows(), 0};
90+
}
91+
8292
Eigen::Matrix<T, R1, C2> A_mult_inv_b(A.rows(), b.cols());
8393
Eigen::Matrix<T, R2, C2> deriv_b_mult_inv_b(b.rows(), b.cols());
8494
Eigen::Matrix<T, R2, C2> val_b(b.rows(), b.cols());

stan/math/fwd/fun/mdivide_right_tri_low.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ inline Eigen::Matrix<fvar<T>, R1, C1> mdivide_right_tri_low(
1717
const Eigen::Matrix<fvar<T>, R2, C2> &b) {
1818
check_square("mdivide_right_tri_low", "b", b);
1919
check_multiplicable("mdivide_right_tri_low", "A", A, "b", b);
20+
if (b.size() == 0) {
21+
return {A.rows(), 0};
22+
}
2023

2124
Eigen::Matrix<T, R1, C2> A_mult_inv_b(A.rows(), b.cols());
2225
Eigen::Matrix<T, R1, C2> deriv_A_mult_inv_b(A.rows(), b.cols());
@@ -58,6 +61,9 @@ inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_right_tri_low(
5861
const Eigen::Matrix<double, R2, C2> &b) {
5962
check_square("mdivide_right_tri_low", "b", b);
6063
check_multiplicable("mdivide_right_tri_low", "A", A, "b", b);
64+
if (b.size() == 0) {
65+
return {A.rows(), 0};
66+
}
6167

6268
Eigen::Matrix<T, R2, C2> deriv_b_mult_inv_b(b.rows(), b.cols());
6369
Eigen::Matrix<T, R1, C1> val_A(A.rows(), A.cols());
@@ -87,6 +93,9 @@ inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_right_tri_low(
8793
const Eigen::Matrix<fvar<T>, R2, C2> &b) {
8894
check_square("mdivide_right_tri_low", "b", b);
8995
check_multiplicable("mdivide_right_tri_low", "A", A, "b", b);
96+
if (b.size() == 0) {
97+
return {A.rows(), 0};
98+
}
9099

91100
Eigen::Matrix<T, R1, C2> A_mult_inv_b(A.rows(), b.cols());
92101
Eigen::Matrix<T, R2, C2> deriv_b_mult_inv_b(b.rows(), b.cols());

stan/math/fwd/fun/multiply.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ template <typename Mat1, typename Mat2,
1616
require_same_vt<Mat1, Mat2>* = nullptr,
1717
require_not_eigen_row_and_col_t<Mat1, Mat2>* = nullptr>
1818
inline auto multiply(const Mat1& m1, const Mat2& m2) {
19-
check_size_match("multiply", "Columns of m1", m1.cols(), "Rows of m2",
20-
m2.rows());
19+
check_multiplicable("multiply", "m1", m1, "m2", m2);
2120
return m1 * m2;
2221
}
2322

@@ -26,8 +25,7 @@ template <typename Mat1, typename Mat2,
2625
require_eigen_vt<std::is_floating_point, Mat2>* = nullptr,
2726
require_not_eigen_row_and_col_t<Mat1, Mat2>* = nullptr>
2827
inline auto multiply(const Mat1& m1, const Mat2& m2) {
29-
check_size_match("multiply", "Columns of m1", m1.cols(), "Rows of m2",
30-
m2.rows());
28+
check_multiplicable("multiply", "m1", m1, "m2", m2);
3129
Eigen::Matrix<value_type_t<Mat1>, Mat1::RowsAtCompileTime,
3230
Mat2::ColsAtCompileTime>
3331
result(m1.rows(), m2.cols());
@@ -46,8 +44,7 @@ template <typename Mat1, typename Mat2,
4644
require_eigen_vt<is_fvar, Mat2>* = nullptr,
4745
require_not_eigen_row_and_col_t<Mat1, Mat2>* = nullptr>
4846
inline auto multiply(const Mat1& m1, const Mat2& m2) {
49-
check_size_match("multiply", "Columns of m1", m1.cols(), "Rows of m2",
50-
m2.rows());
47+
check_multiplicable("multiply", "m1", m1, "m2", m2);
5148
Eigen::Matrix<value_type_t<Mat2>, Mat1::RowsAtCompileTime,
5249
Mat2::ColsAtCompileTime>
5350
result(m1.rows(), m2.cols());

stan/math/prim/err/check_multiplicable.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,8 @@ namespace math {
2929
template <typename T1, typename T2>
3030
inline void check_multiplicable(const char* function, const char* name1,
3131
const T1& y1, const char* name2, const T2& y2) {
32-
check_positive(function, name1, "rows()", y1.rows());
33-
check_positive(function, name2, "cols()", y2.cols());
3432
check_size_match(function, "Columns of ", name1, y1.cols(), "Rows of ", name2,
3533
y2.rows());
36-
check_positive(function, name1, "cols()", y1.cols());
3734
}
3835
} // namespace math
3936
} // namespace stan

stan/math/prim/fun/matrix_exp_multiply.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@ template <int Cb>
2020
inline Eigen::Matrix<double, -1, Cb> matrix_exp_multiply(
2121
const Eigen::MatrixXd& A, const Eigen::Matrix<double, -1, Cb>& B) {
2222
check_square("matrix_exp_multiply", "input matrix", A);
23-
if (A.size() == 0 && B.rows() == 0) {
24-
return Eigen::Matrix<double, -1, Cb>(0, B.cols());
25-
}
26-
2723
check_multiplicable("matrix_exp_multiply", "A", A, "B", B);
24+
if (A.size() == 0) {
25+
return {0, B.cols()};
26+
}
2827

2928
return matrix_exp_action_handler().action(A, B);
3029
}

stan/math/prim/fun/mdivide_left.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ inline Eigen::Matrix<return_type_t<T1, T2>, R1, C2> mdivide_left(
3131
const Eigen::Matrix<T1, R1, C1> &A, const Eigen::Matrix<T2, R2, C2> &b) {
3232
check_square("mdivide_left", "A", A);
3333
check_multiplicable("mdivide_left", "A", A, "b", b);
34+
if (A.size() == 0) {
35+
return {0, b.cols()};
36+
}
3437

3538
return Eigen::Matrix<return_type_t<T1, T2>, R1, C1>(A).lu().solve(
3639
Eigen::Matrix<return_type_t<T1, T2>, R2, C2>(b));

stan/math/prim/fun/mdivide_left_ldlt.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,10 @@ namespace math {
3030
template <int R1, int C1, int R2, int C2, typename T1, typename T2>
3131
inline Eigen::Matrix<return_type_t<T1, T2>, R1, C2> mdivide_left_ldlt(
3232
const LDLT_factor<T1, R1, C1> &A, const Eigen::Matrix<T2, R2, C2> &b) {
33-
if (A.cols() == 0 && b.rows() == 0) {
34-
return Eigen::Matrix<return_type_t<T1, T2>, R1, C2>(0, b.cols());
35-
}
36-
3733
check_multiplicable("mdivide_left_ldlt", "A", A, "b", b);
34+
if (A.cols() == 0) {
35+
return {0, b.cols()};
36+
}
3837

3938
return A.solve(Eigen::Matrix<return_type_t<T1, T2>, R2, C2>(b));
4039
}

stan/math/prim/fun/mdivide_left_spd.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@ inline Eigen::Matrix<return_type_t<T1, T2>, R1, C2> mdivide_left_spd(
3232
const Eigen::Matrix<T1, R1, C1> &A, const Eigen::Matrix<T2, R2, C2> &b) {
3333
static const char *function = "mdivide_left_spd";
3434
check_multiplicable(function, "A", A, "b", b);
35-
check_positive(function, "rows", A.rows());
3635
check_symmetric(function, "A", A);
3736
check_not_nan(function, "A", A);
37+
if (A.size() == 0) {
38+
return {0, b.cols()};
39+
}
3840

3941
auto llt = Eigen::Matrix<return_type_t<T1, T2>, R1, C1>(A).llt();
4042
check_pos_definite(function, "A", llt);

stan/math/prim/fun/mdivide_left_tri.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ inline Eigen::Matrix<return_type_t<T1, T2>, R1, C2> mdivide_left_tri(
3636
const Eigen::Matrix<T1, R1, C1> &A, const Eigen::Matrix<T2, R2, C2> &b) {
3737
check_square("mdivide_left_tri", "A", A);
3838
check_multiplicable("mdivide_left_tri", "A", A, "b", b);
39+
if (A.rows() == 0) {
40+
return {0, b.cols()};
41+
}
3942

4043
return Eigen::Matrix<return_type_t<T1, T2>, R1, C1>(A)
4144
.template triangularView<TriView>()
@@ -57,6 +60,10 @@ template <int TriView, typename T, int R1, int C1>
5760
inline Eigen::Matrix<T, R1, C1> mdivide_left_tri(
5861
const Eigen::Matrix<T, R1, C1> &A) {
5962
check_square("mdivide_left_tri", "A", A);
63+
if (A.rows() == 0) {
64+
return {};
65+
}
66+
6067
int n = A.rows();
6168
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> b;
6269
b.setIdentity(n, n);
@@ -89,6 +96,10 @@ inline Eigen::Matrix<double, R1, C2> mdivide_left_tri(
8996
const Eigen::Matrix<double, R2, C2> &b) {
9097
check_square("mdivide_left_tri", "A", A);
9198
check_multiplicable("mdivide_left_tri", "A", A, "b", b);
99+
if (A.rows() == 0) {
100+
return {0, b.cols()};
101+
}
102+
92103
#ifdef STAN_OPENCL
93104
if (A.rows()
94105
>= opencl_context.tuning_opts().tri_inverse_size_worth_transfer) {
@@ -122,6 +133,10 @@ template <Eigen::UpLoType TriView, int R1, int C1>
122133
inline Eigen::Matrix<double, R1, C1> mdivide_left_tri(
123134
const Eigen::Matrix<double, R1, C1> &A) {
124135
check_square("mdivide_left_tri", "A", A);
136+
if (A.rows() == 0) {
137+
return {};
138+
}
139+
125140
const int n = A.rows();
126141
#ifdef STAN_OPENCL
127142
if (A.rows()

stan/math/prim/fun/mdivide_left_tri_low.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,21 @@ inline Eigen::Matrix<return_type_t<T1, T2>, R1, C2> mdivide_left_tri_low(
3535
const Eigen::Matrix<T1, R1, C1> &A, const Eigen::Matrix<T2, R2, C2> &b) {
3636
check_square("mdivide_left_tri_low", "A", A);
3737
check_multiplicable("mdivide_left_tri_low", "A", A, "b", b);
38+
if (A.rows() == 0) {
39+
return {0, b.cols()};
40+
}
41+
3842
return mdivide_left_tri<Eigen::Lower>(A, b);
3943
}
4044

4145
template <typename T, int R1, int C1>
4246
inline Eigen::Matrix<T, R1, C1> mdivide_left_tri_low(
4347
const Eigen::Matrix<T, R1, C1> &A) {
4448
check_square("mdivide_left_tri_low", "A", A);
49+
if (A.rows() == 0) {
50+
return {};
51+
}
52+
4553
return mdivide_left_tri<Eigen::Lower>(A);
4654
}
4755

stan/math/prim/fun/mdivide_right.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ inline Eigen::Matrix<return_type_t<T1, T2>, R1, C2> mdivide_right(
3131
const Eigen::Matrix<T1, R1, C1> &b, const Eigen::Matrix<T2, R2, C2> &A) {
3232
check_square("mdivide_right", "A", A);
3333
check_multiplicable("mdivide_right", "b", b, "A", A);
34+
if (A.size() == 0) {
35+
return {b.rows(), 0};
36+
}
3437

3538
return Eigen::Matrix<return_type_t<T1, T2>, R2, C2>(A)
3639
.transpose()

stan/math/prim/fun/mdivide_right_ldlt.hpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,10 @@ namespace math {
3030
template <typename T1, typename T2, int R1, int C1, int R2, int C2>
3131
inline Eigen::Matrix<return_type_t<T1, T2>, R1, C2> mdivide_right_ldlt(
3232
const Eigen::Matrix<T1, R1, C1> &b, const LDLT_factor<T2, R2, C2> &A) {
33-
if (b.cols() == 0 && A.rows() == 0) {
34-
return Eigen::Matrix<return_type_t<T1, T2>, R1, C2>(b.rows(), 0);
35-
}
36-
3733
check_multiplicable("mdivide_right_ldlt", "b", b, "A", A);
34+
if (A.rows() == 0) {
35+
return {b.rows(), 0};
36+
}
3837

3938
return transpose(mdivide_left_ldlt(A, transpose(b)));
4039
}
@@ -43,11 +42,10 @@ template <int R1, int C1, int R2, int C2>
4342
inline Eigen::Matrix<double, R1, C2> mdivide_right_ldlt(
4443
const Eigen::Matrix<double, R1, C1> &b,
4544
const LDLT_factor<double, R2, C2> &A) {
46-
if (b.cols() == 0 && A.rows() == 0) {
47-
return Eigen::Matrix<double, R1, C2>(b.rows(), 0);
48-
}
49-
5045
check_multiplicable("mdivide_right_ldlt", "b", b, "A", A);
46+
if (A.rows() == 0) {
47+
return {b.rows(), 0};
48+
}
5149

5250
return A.solveRight(b);
5351
}

stan/math/prim/fun/mdivide_right_spd.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,12 @@ inline Eigen::Matrix<return_type_t<T1, T2>, R1, C2> mdivide_right_spd(
3333
const Eigen::Matrix<T1, R1, C1> &b, const Eigen::Matrix<T2, R2, C2> &A) {
3434
static const char *function = "mdivide_right_spd";
3535
check_multiplicable(function, "b", b, "A", A);
36-
check_positive(function, "rows", A.rows());
3736
check_symmetric(function, "A", A);
3837
check_not_nan(function, "A", A);
38+
if (A.size() == 0) {
39+
return {b.rows(), 0};
40+
}
41+
3942
// FIXME: After allowing for general MatrixBase in mdivide_left_spd,
4043
// change to b.transpose()
4144
return mdivide_left_spd(A, transpose(b)).transpose();

stan/math/prim/fun/mdivide_right_tri.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ inline Eigen::Matrix<return_type_t<T1, T2>, R1, C2> mdivide_right_tri(
4343
"triangular view must be Eigen::Lower or Eigen::Upper",
4444
"", "");
4545
}
46+
if (A.rows() == 0) {
47+
return {b.rows(), 0};
48+
}
4649

4750
return Eigen::Matrix<return_type_t<T1, T2>, R2, C2>(A)
4851
.template triangularView<TriView>()
@@ -77,6 +80,10 @@ inline Eigen::Matrix<double, R1, C2> mdivide_right_tri(
7780
const Eigen::Matrix<double, R2, C2> &A) {
7881
check_square("mdivide_right_tri", "A", A);
7982
check_multiplicable("mdivide_right_tri", "b", b, "A", A);
83+
if (A.rows() == 0) {
84+
return {b.rows(), 0};
85+
}
86+
8087
#ifdef STAN_OPENCL
8188
if (A.rows()
8289
>= opencl_context.tuning_opts().tri_inverse_size_worth_transfer) {

stan/math/prim/fun/multiply.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ template <typename Mat1, typename Mat2,
9292
require_not_eigen_row_and_col_t<Mat1, Mat2>* = nullptr>
9393
inline auto multiply(const Mat1& m1, const Mat2& m2)
9494
-> decltype((m1 * m2).eval()) {
95-
check_size_match("multiply", "Columns of m1", m1.cols(), "Rows of m2",
96-
m2.rows());
95+
check_multiplicable("multiply", "m1", m1, "m2", m2);
96+
9797
#ifdef STAN_OPENCL
9898
if (m1.rows() * m1.cols() * m2.cols()
9999
> opencl_context.tuning_opts().multiply_dim_prod_worth_transfer) {
@@ -128,7 +128,7 @@ template <typename RowVec, typename ColVec,
128128
scalar_type_t<ColVec>>* = nullptr,
129129
require_eigen_row_and_col_t<RowVec, ColVec>* = nullptr>
130130
inline auto multiply(const RowVec& rv, const ColVec& v) {
131-
check_matching_sizes("multiply", "rv", rv, "v", v);
131+
check_multiplicable("multiply", "rv", rv, "v", v);
132132
return dot_product(rv, v);
133133
}
134134

0 commit comments

Comments
 (0)