@@ -135,31 +135,37 @@ template <typename T1, typename T2, require_not_matrix_t<T1>* = nullptr,
135
135
require_not_row_and_col_vector_t <T1, T2>* = nullptr >
136
136
inline auto multiply (const T1& A, const T2& B) {
137
137
if (!is_constant<T2>::value && !is_constant<T1>::value) {
138
- arena_t <promote_scalar_t <var, T2 >> arena_B = to_ref (B) ;
139
- arena_t <promote_scalar_t <double , T2>> arena_B_val = value_of (arena_B) ;
138
+ arena_t <promote_scalar_t <var, T1 >> arena_A = A ;
139
+ arena_t <promote_scalar_t <var , T2>> arena_B = B ;
140
140
using return_t = return_var_matrix_t <T2, T1, T2>;
141
- arena_t <return_t > res = value_of (A) * arena_B_val.array ();
142
- reverse_pass_callback ([A, arena_B, arena_B_val, res]() mutable {
143
- auto res_adj = res.adj ().eval ();
144
- forward_as<var>(A).adj () += (res_adj.array () * arena_B_val.array ()).sum ();
145
- arena_B.adj ().array () += value_of (A) * res_adj.array ();
141
+ arena_t <return_t > res = arena_A.val () * arena_B.val ().array ();
142
+ reverse_pass_callback ([arena_A, arena_B, res]() mutable {
143
+ const auto a_val = arena_A.val ();
144
+ for (Eigen::Index j = 0 ; j < res.cols (); ++j) {
145
+ for (Eigen::Index i = 0 ; i < res.rows (); ++i) {
146
+ const auto res_adj = res.adj ().coeffRef (i, j);
147
+ arena_A.adj () += res_adj * arena_B.val ().coeff (i, j);
148
+ arena_B.adj ().coeffRef (i, j) += a_val * res_adj;
149
+ }
150
+ }
146
151
});
147
152
return return_t (res);
148
153
} else if (!is_constant<T2>::value) {
149
- arena_t <promote_scalar_t <var, T2>> arena_B = to_ref (B);
154
+ arena_t <promote_scalar_t <double , T1>> arena_A = value_of (A);
155
+ arena_t <promote_scalar_t <var, T2>> arena_B = B;
150
156
using return_t = return_var_matrix_t <T2, T1, T2>;
151
- arena_t <return_t > res = value_of (A) * value_of ( arena_B).array ();
152
- reverse_pass_callback ([A , arena_B, res]() mutable {
153
- arena_B.adj ().array () += value_of (A) * res.adj ().array ();
157
+ arena_t <return_t > res = arena_A * arena_B. val ( ).array ();
158
+ reverse_pass_callback ([arena_A , arena_B, res]() mutable {
159
+ arena_B.adj ().array () += arena_A * res.adj ().array ();
154
160
});
155
161
return return_t (res);
156
162
} else {
157
- arena_t <promote_scalar_t <double , T2>> arena_B_val = value_of (B);
163
+ arena_t <promote_scalar_t <var, T1>> arena_A = A;
164
+ arena_t <promote_scalar_t <double , T2>> arena_B = value_of (B);
158
165
using return_t = return_var_matrix_t <T2, T1, T2>;
159
- arena_t <return_t > res = value_of (A) * arena_B_val.array ();
160
- reverse_pass_callback ([A, arena_B_val, res]() mutable {
161
- forward_as<var>(A).adj ()
162
- += (res.adj ().array () * arena_B_val.array ()).sum ();
166
+ arena_t <return_t > res = arena_A.val () * arena_B.array ();
167
+ reverse_pass_callback ([arena_A, arena_B, res]() mutable {
168
+ arena_A.adj () += (res.adj ().array () * arena_B.array ()).sum ();
163
169
});
164
170
return return_t (res);
165
171
}
0 commit comments