Skip to content

Commit

Permalink
check sympd instead of catch #43
Browse files Browse the repository at this point in the history
  • Loading branch information
adamwang15 committed Aug 10, 2024
1 parent cc6db2d commit 55a62a5
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions src/sample_hyper.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

#define ARMA_WARN_LEVEL 1
#include <RcppArmadillo.h>

#include "utils.h"
Expand Down Expand Up @@ -108,25 +109,26 @@ double log_ml(
int N = Y.n_cols;

double log_ml = 0;

mat inv_Omega = diagmat(1 / Omega.diag());
mat XX = X.t() * X + inv_Omega;

try {
mat Bhat = inv_sympd(X.t() * X + inv_Omega) * (X.t() * Y + inv_Omega * b);
mat ehat = Y - X * Bhat;

log_ml += - N * T / 2.0 * log(M_PI);
log_ml += log_mvgamma(N, (T + d) / 2.0);
log_ml += -log_mvgamma(N, d / 2.0);
log_ml += - N / 2.0 * log_det_sympd(Omega);
log_ml += d / 2.0 * log_det_sympd(Psi);
log_ml += - N / 2.0 * log_det_sympd(X.t() * X + inv_Omega);
mat A = Psi + ehat.t() * ehat + (Bhat - b).t() * inv_Omega * (Bhat - b);
log_ml += - (T + d) / 2.0 * log_det_sympd(A);

} catch(...) {
log_ml = -1e+10;
if (!Omega.is_sympd() or !Psi.is_sympd() or !XX.is_sympd()) {
return -1e10;
}

mat Bhat = solve(XX, X.t() * Y + inv_Omega * b, solve_opts::likely_sympd);
mat ehat = Y - X * Bhat;

log_ml += - N * T / 2.0 * log(M_PI);
log_ml += log_mvgamma(N, (T + d) / 2.0);
log_ml += -log_mvgamma(N, d / 2.0);
log_ml += - N / 2.0 * log_det_sympd(Omega);
log_ml += d / 2.0 * log_det_sympd(Psi);
log_ml += - N / 2.0 * log_det_sympd(XX);
mat A = Psi + ehat.t() * ehat + (Bhat - b).t() * inv_Omega * (Bhat - b);
log_ml += - (T + d) / 2.0 * log_det_sympd(A);

return log_ml;
}

Expand Down

0 comments on commit 55a62a5

Please sign in to comment.