Skip to content

Commit 2bce5b6

Browse files
committed
update multiply for single to double loop
1 parent 6d3c6d1 commit 2bce5b6

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

stan/math/rev/fun/multiply.hpp

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -135,31 +135,37 @@ template <typename T1, typename T2, require_not_matrix_t<T1>* = nullptr,
135135
require_not_row_and_col_vector_t<T1, T2>* = nullptr>
136136
inline auto multiply(const T1& A, const T2& B) {
137137
if (!is_constant<T2>::value && !is_constant<T1>::value) {
138-
arena_t<promote_scalar_t<var, T2>> arena_B = to_ref(B);
139-
arena_t<promote_scalar_t<double, T2>> arena_B_val = value_of(arena_B);
138+
arena_t<promote_scalar_t<var, T1>> arena_A = A;
139+
arena_t<promote_scalar_t<var, T2>> arena_B = B;
140140
using return_t = return_var_matrix_t<T2, T1, T2>;
141-
arena_t<return_t> res = value_of(A) * arena_B_val.array();
142-
reverse_pass_callback([A, arena_B, arena_B_val, res]() mutable {
143-
auto res_adj = res.adj().eval();
144-
forward_as<var>(A).adj() += (res_adj.array() * arena_B_val.array()).sum();
145-
arena_B.adj().array() += value_of(A) * res_adj.array();
141+
arena_t<return_t> res = arena_A.val() * arena_B.val().array();
142+
reverse_pass_callback([arena_A, arena_B, res]() mutable {
143+
const auto a_val = arena_A.val();
144+
for (Eigen::Index j = 0; j < res.cols(); ++j) {
145+
for (Eigen::Index i = 0; i < res.rows(); ++i) {
146+
const auto res_adj = res.adj().coeffRef(i, j);
147+
arena_A.adj() += res_adj * arena_B.val().coeff(i, j);
148+
arena_B.adj().coeffRef(i, j) += a_val * res_adj;
149+
}
150+
}
146151
});
147152
return return_t(res);
148153
} else if (!is_constant<T2>::value) {
149-
arena_t<promote_scalar_t<var, T2>> arena_B = to_ref(B);
154+
arena_t<promote_scalar_t<double, T1>> arena_A = value_of(A);
155+
arena_t<promote_scalar_t<var, T2>> arena_B = B;
150156
using return_t = return_var_matrix_t<T2, T1, T2>;
151-
arena_t<return_t> res = value_of(A) * value_of(arena_B).array();
152-
reverse_pass_callback([A, arena_B, res]() mutable {
153-
arena_B.adj().array() += value_of(A) * res.adj().array();
157+
arena_t<return_t> res = arena_A * arena_B.val().array();
158+
reverse_pass_callback([arena_A, arena_B, res]() mutable {
159+
arena_B.adj().array() += arena_A * res.adj().array();
154160
});
155161
return return_t(res);
156162
} else {
157-
arena_t<promote_scalar_t<double, T2>> arena_B_val = value_of(B);
163+
arena_t<promote_scalar_t<var, T1>> arena_A = A;
164+
arena_t<promote_scalar_t<double, T2>> arena_B = value_of(B);
158165
using return_t = return_var_matrix_t<T2, T1, T2>;
159-
arena_t<return_t> res = value_of(A) * arena_B_val.array();
160-
reverse_pass_callback([A, arena_B_val, res]() mutable {
161-
forward_as<var>(A).adj()
162-
+= (res.adj().array() * arena_B_val.array()).sum();
166+
arena_t<return_t> res = arena_A.val() * arena_B.array();
167+
reverse_pass_callback([arena_A, arena_B, res]() mutable {
168+
arena_A.adj() += (res.adj().array() * arena_B.array()).sum();
163169
});
164170
return return_t(res);
165171
}

0 commit comments

Comments
 (0)