Skip to content

Commit 3d03131

Browse files
authored
Merge pull request #2230 from stan-dev/feature/varmat-mdivide-left
Added `var<mat>` implementation of `mdivide_left` (Issue #2101)
2 parents 72552dc + ac60eb1 commit 3d03131

File tree

6 files changed

+199
-18
lines changed

6 files changed

+199
-18
lines changed

stan/math/rev/core.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <stan/math/rev/core/autodiffstackstorage.hpp>
88
#include <stan/math/rev/core/build_vari_array.hpp>
99
#include <stan/math/rev/core/chainable_alloc.hpp>
10+
#include <stan/math/rev/core/chainable_object.hpp>
1011
#include <stan/math/rev/core/chainablestack.hpp>
1112
#include <stan/math/rev/core/count_vars.hpp>
1213
#include <stan/math/rev/core/callback_vari.hpp>
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#ifndef STAN_MATH_REV_CORE_CHAINABLE_OBJECT_HPP
2+
#define STAN_MATH_REV_CORE_CHAINABLE_OBJECT_HPP
3+
4+
#include <stan/math/rev/meta.hpp>
5+
#include <stan/math/rev/core/chainable_alloc.hpp>
6+
#include <stan/math/rev/core/typedefs.hpp>
7+
#include <stan/math/prim/fun/Eigen.hpp>
8+
#include <stan/math/prim/fun/typedefs.hpp>
9+
#include <vector>
10+
11+
namespace stan {
12+
namespace math {
13+
14+
/**
15+
* `chainable_object` hold another object is useful for connecting
16+
* the lifetime of a specific object to the chainable stack
17+
*
18+
* `chainable_object` objects should only be allocated with `new`.
19+
* `chainable_object` objects allocated on the stack will result
20+
* in a double free (`obj_` will get destructed once when the
21+
* chainable_object leaves scope and once when the chainable
22+
* stack memory is recovered).
23+
*
24+
* @tparam T type of object to hold
25+
*/
26+
template <typename T>
27+
class chainable_object : public chainable_alloc {
28+
private:
29+
plain_type_t<T> obj_;
30+
31+
public:
32+
/**
33+
* Construct chainable object from another object
34+
*
35+
* @tparam S type of object to hold (must have the same plain type as `T`)
36+
*/
37+
template <typename S,
38+
require_same_t<plain_type_t<T>, plain_type_t<S>>* = nullptr>
39+
explicit chainable_object(S&& obj) : obj_(std::forward<S>(obj)) {}
40+
41+
/**
42+
* Return a reference to the underlying object
43+
*
44+
* @return reference to underlying object
45+
*/
46+
inline auto& get() noexcept { return obj_; }
47+
inline const auto& get() const noexcept { return obj_; }
48+
};
49+
50+
/**
51+
* Store the given object in a `chainable_object` so it is destructed
52+
* only when the chainable stack memory is recovered and return
53+
* a pointer to the underlying object
54+
*
55+
* @tparam T type of object to hold
56+
* @param obj object to hold
57+
* @return pointer to object held in `chainable_object`
58+
*/
59+
template <typename T>
60+
auto make_chainable_ptr(T&& obj) {
61+
auto ptr = new chainable_object<T>(std::forward<T>(obj));
62+
return &ptr->get();
63+
}
64+
65+
} // namespace math
66+
} // namespace stan
67+
#endif

stan/math/rev/fun/mdivide_left.hpp

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
21
#ifndef STAN_MATH_REV_FUN_MDIVIDE_LEFT_HPP
32
#define STAN_MATH_REV_FUN_MDIVIDE_LEFT_HPP
43

54
#include <stan/math/rev/meta.hpp>
65
#include <stan/math/rev/core.hpp>
76
#include <stan/math/rev/core/typedefs.hpp>
7+
#include <stan/math/rev/core/chainable_object.hpp>
88
#include <stan/math/prim/err.hpp>
99
#include <stan/math/prim/fun/Eigen.hpp>
1010
#include <stan/math/prim/fun/typedefs.hpp>
@@ -13,45 +13,74 @@
1313
namespace stan {
1414
namespace math {
1515

16-
template <typename T1, typename T2, require_all_eigen_t<T1, T2>* = nullptr,
17-
require_any_vt_var<T1, T2>* = nullptr>
16+
/**
17+
* Return the solution `X` of `AX = B`.
18+
*
19+
* A must be a square matrix, but B can be a matrix or a vector
20+
*
21+
* @tparam T1 type of first matrix
22+
* @tparam T2 type of second matrix
23+
*
24+
* @param[in] A square matrix
25+
* @param[in] B right hand side
26+
* @return solution of AX = B
27+
*/
28+
template <typename T1, typename T2, require_all_matrix_t<T1, T2>* = nullptr,
29+
require_any_st_var<T1, T2>* = nullptr>
1830
inline auto mdivide_left(const T1& A, const T2& B) {
19-
using ret_type = plain_type_t<decltype(A * B)>;
31+
using ret_val_type = plain_type_t<decltype(value_of(A) * value_of(B))>;
32+
using ret_type = promote_var_matrix_t<ret_val_type, T1, T2>;
2033

2134
check_square("mdivide_left", "A", A);
2235
check_multiplicable("mdivide_left", "A", A, "B", B);
2336

2437
if (A.size() == 0) {
25-
return ret_type(0, B.cols());
38+
return ret_type(ret_val_type(0, B.cols()));
2639
}
40+
2741
if (!is_constant<T1>::value && !is_constant<T2>::value) {
2842
arena_t<promote_scalar_t<var, T1>> arena_A = A;
2943
arena_t<promote_scalar_t<var, T2>> arena_B = B;
30-
arena_t<promote_scalar_t<double, T1>> arena_A_val = arena_A.val();
31-
arena_t<ret_type> res = arena_A_val.householderQr().solve(arena_B.val());
32-
reverse_pass_callback([arena_A, arena_B, arena_A_val, res]() mutable {
44+
45+
auto hqr_A_ptr = make_chainable_ptr(arena_A.val().householderQr());
46+
arena_t<ret_type> res = hqr_A_ptr->solve(arena_B.val());
47+
reverse_pass_callback([arena_A, arena_B, hqr_A_ptr, res]() mutable {
3348
promote_scalar_t<double, T2> adjB
34-
= arena_A_val.transpose().householderQr().solve(res.adj());
49+
= hqr_A_ptr->householderQ()
50+
* hqr_A_ptr->matrixQR()
51+
.template triangularView<Eigen::Upper>()
52+
.transpose()
53+
.solve(res.adj());
3554
arena_A.adj() -= adjB * res.val_op().transpose();
3655
arena_B.adj() += adjB;
3756
});
3857

3958
return ret_type(res);
4059
} else if (!is_constant<T2>::value) {
4160
arena_t<promote_scalar_t<var, T2>> arena_B = B;
42-
arena_t<promote_scalar_t<double, T1>> arena_A_val = value_of(A);
43-
arena_t<ret_type> res = arena_A_val.householderQr().solve(arena_B.val());
44-
reverse_pass_callback([arena_B, arena_A_val, res]() mutable {
45-
arena_B.adj() += arena_A_val.transpose().householderQr().solve(res.adj());
61+
62+
auto hqr_A_ptr = make_chainable_ptr(value_of(A).householderQr());
63+
arena_t<ret_type> res = hqr_A_ptr->solve(arena_B.val());
64+
reverse_pass_callback([arena_B, hqr_A_ptr, res]() mutable {
65+
arena_B.adj() += hqr_A_ptr->householderQ()
66+
* hqr_A_ptr->matrixQR()
67+
.template triangularView<Eigen::Upper>()
68+
.transpose()
69+
.solve(res.adj());
4670
});
4771
return ret_type(res);
4872
} else {
4973
arena_t<promote_scalar_t<var, T1>> arena_A = A;
50-
arena_t<ret_type> res = arena_A.val().householderQr().solve(value_of(B));
51-
reverse_pass_callback([arena_A, res]() mutable {
52-
arena_A.adj()
53-
-= arena_A.val().transpose().householderQr().solve(res.adj())
54-
* res.val_op().transpose();
74+
75+
auto hqr_A_ptr = make_chainable_ptr(arena_A.val().householderQr());
76+
arena_t<ret_type> res = hqr_A_ptr->solve(value_of(B));
77+
reverse_pass_callback([arena_A, hqr_A_ptr, res]() mutable {
78+
arena_A.adj() -= hqr_A_ptr->householderQ()
79+
* hqr_A_ptr->matrixQR()
80+
.template triangularView<Eigen::Upper>()
81+
.transpose()
82+
.solve(res.adj())
83+
* res.val_op().transpose();
5584
});
5685
return ret_type(res);
5786
}

test/unit/math/mix/fun/mdivide_left_test.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,24 @@ TEST(MathMixMatFun, mdivideLeft) {
1313
stan::test::expect_ad(f, m00, m00);
1414
stan::test::expect_ad(f, m00, m02);
1515
stan::test::expect_ad(f, m00, v0);
16+
stan::test::expect_ad_matvar(f, m00, m00);
17+
stan::test::expect_ad_matvar(f, m00, m02);
18+
stan::test::expect_ad_matvar(f, m00, v0);
1619

1720
Eigen::MatrixXd aa(1, 1);
1821
aa << 1;
1922
Eigen::MatrixXd bb(1, 1);
2023
bb << 2;
2124
stan::test::expect_ad(f, aa, bb);
25+
stan::test::expect_ad_matvar(f, aa, bb);
2226
Eigen::MatrixXd b0(1, 0);
2327
stan::test::expect_ad(f, aa, b0);
28+
stan::test::expect_ad_matvar(f, aa, b0);
2429

2530
Eigen::VectorXd cc(1);
2631
cc << 3;
2732
stan::test::expect_ad(f, aa, cc);
33+
stan::test::expect_ad_matvar(f, aa, cc);
2834

2935
Eigen::MatrixXd a(2, 2);
3036
a << 2, 3, 3, 7;
@@ -47,12 +53,14 @@ TEST(MathMixMatFun, mdivideLeft) {
4753
for (const auto& m1 : std::vector<Eigen::MatrixXd>{a, b, c, d}) {
4854
for (const auto& m2 : std::vector<Eigen::MatrixXd>{a, b, c, d, e}) {
4955
stan::test::expect_ad(f, m1, m2);
56+
stan::test::expect_ad_matvar(f, m1, m2);
5057
}
5158
}
5259

5360
// matrix, vector
5461
for (const auto& m : std::vector<Eigen::MatrixXd>{a, b, c, d}) {
5562
stan::test::expect_ad(f, m, g);
63+
stan::test::expect_ad_matvar(f, m, g);
5664
}
5765

5866
Eigen::MatrixXd v(5, 5);
@@ -61,6 +69,7 @@ TEST(MathMixMatFun, mdivideLeft) {
6169
Eigen::VectorXd u(5);
6270
u << 62, 84, 84, 76, 108;
6371
stan::test::expect_ad(f, v, u);
72+
stan::test::expect_ad_matvar(f, v, u);
6473

6574
Eigen::MatrixXd m33 = Eigen::MatrixXd::Zero(3, 3);
6675
Eigen::MatrixXd m44 = Eigen::MatrixXd::Zero(4, 4);
@@ -72,7 +81,10 @@ TEST(MathMixMatFun, mdivideLeft) {
7281
// exceptions: wrong sizes
7382
stan::test::expect_ad(f, m33, m44);
7483
stan::test::expect_ad(f, m33, v4);
84+
stan::test::expect_ad_matvar(f, m33, m44);
85+
stan::test::expect_ad_matvar(f, m33, v4);
7586

7687
// exceptions: wrong types
7788
stan::test::expect_ad(f, m33, rv3);
89+
stan::test::expect_ad_matvar(f, m33, rv3);
7890
}

test/unit/math/mix/prob/multi_gp_cholesky_test.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ TEST(ProbDistributionsMultiGPCholesky, fvar_var) {
4444
stan::math::multi_gp_cholesky_log(y, L, w).val_.val());
4545
EXPECT_FLOAT_EQ(-74.572952,
4646
stan::math::multi_gp_cholesky_log(y, L, w).d_.val());
47+
48+
stan::math::recover_memory();
4749
}
4850

4951
TEST(ProbDistributionsMultiGPCholesky, fvar_fvar_var) {
@@ -89,4 +91,6 @@ TEST(ProbDistributionsMultiGPCholesky, fvar_fvar_var) {
8991
stan::math::multi_gp_cholesky_log(y, L, w).val_.val_.val());
9092
EXPECT_FLOAT_EQ(-74.572952,
9193
stan::math::multi_gp_cholesky_log(y, L, w).d_.val_.val());
94+
95+
stan::math::recover_memory();
9296
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#include <stan/math.hpp>
2+
#include <gtest/gtest.h>
3+
#include <test/unit/util.hpp>
4+
5+
class ChainableObjectTest {
6+
public:
7+
static int counter;
8+
9+
~ChainableObjectTest() { counter++; }
10+
};
11+
12+
int ChainableObjectTest::counter = 0;
13+
14+
TEST(AgradRev, chainable_object_test) {
15+
{
16+
auto ptr = new stan::math::chainable_object<ChainableObjectTest>(
17+
ChainableObjectTest());
18+
ChainableObjectTest::counter = 0;
19+
}
20+
21+
EXPECT_EQ((ChainableObjectTest::counter), 0);
22+
stan::math::recover_memory();
23+
EXPECT_EQ((ChainableObjectTest::counter), 1);
24+
}
25+
26+
TEST(AgradRev, chainable_object_nested_test) {
27+
stan::math::start_nested();
28+
29+
{
30+
auto ptr = new stan::math::chainable_object<ChainableObjectTest>(
31+
ChainableObjectTest());
32+
ChainableObjectTest::counter = 0;
33+
}
34+
35+
EXPECT_EQ((ChainableObjectTest::counter), 0);
36+
37+
stan::math::recover_memory_nested();
38+
39+
EXPECT_EQ((ChainableObjectTest::counter), 1);
40+
}
41+
42+
TEST(AgradRev, make_chainable_ptr_test) {
43+
{
44+
ChainableObjectTest* ptr
45+
= stan::math::make_chainable_ptr(ChainableObjectTest());
46+
ChainableObjectTest::counter = 0;
47+
}
48+
49+
EXPECT_EQ((ChainableObjectTest::counter), 0);
50+
stan::math::recover_memory();
51+
EXPECT_EQ((ChainableObjectTest::counter), 1);
52+
}
53+
54+
TEST(AgradRev, make_chainable_ptr_nested_test) {
55+
stan::math::start_nested();
56+
57+
{
58+
ChainableObjectTest* ptr
59+
= stan::math::make_chainable_ptr(ChainableObjectTest());
60+
ChainableObjectTest::counter = 0;
61+
}
62+
63+
EXPECT_EQ((ChainableObjectTest::counter), 0);
64+
65+
stan::math::recover_memory_nested();
66+
67+
EXPECT_EQ((ChainableObjectTest::counter), 1);
68+
}

0 commit comments

Comments
 (0)