diff --git a/src/restrictions_zero.cpp b/src/restrictions_zero.cpp index 527a69d..6090c6c 100644 --- a/src/restrictions_zero.cpp +++ b/src/restrictions_zero.cpp @@ -34,13 +34,15 @@ arma::colvec zero_restrictions( const arma::colvec vec_structural ) { int N = Z(0).n_cols; - mat A0 = reshape(vec_structural.rows(0, N*N-1), N, N); + + mat A0 = reshape(vec_structural.rows(0, N * N - 1), N, N); arma::field ZF = ZIRF(Z, inv(A0.t())); - colvec z; + vec z; for (int j=0; j 0) z = join_vert(z, ZF_j.col(j)); } return z; diff --git a/src/sample_hyper.cpp b/src/sample_hyper.cpp index 4e5ec8a..28fb2c5 100644 --- a/src/sample_hyper.cpp +++ b/src/sample_hyper.cpp @@ -1,4 +1,5 @@ +#define ARMA_WARN_LEVEL 1 #include #include "utils.h" @@ -108,25 +109,26 @@ double log_ml( int N = Y.n_cols; double log_ml = 0; + mat inv_Omega = diagmat(1 / Omega.diag()); + mat XX = X.t() * X + inv_Omega; - try { - mat Bhat = inv_sympd(X.t() * X + inv_Omega) * (X.t() * Y + inv_Omega * b); - mat ehat = Y - X * Bhat; - - log_ml += - N * T / 2.0 * log(M_PI); - log_ml += log_mvgamma(N, (T + d) / 2.0); - log_ml += -log_mvgamma(N, d / 2.0); - log_ml += - N / 2.0 * log_det_sympd(Omega); - log_ml += d / 2.0 * log_det_sympd(Psi); - log_ml += - N / 2.0 * log_det_sympd(X.t() * X + inv_Omega); - mat A = Psi + ehat.t() * ehat + (Bhat - b).t() * inv_Omega * (Bhat - b); - log_ml += - (T + d) / 2.0 * log_det_sympd(A); - - } catch(...) { - log_ml = -1e+10; + if (!Omega.is_sympd() or !Psi.is_sympd() or !XX.is_sympd()) { + return -1e10; } + mat Bhat = solve(XX, X.t() * Y + inv_Omega * b, solve_opts::likely_sympd); + mat ehat = Y - X * Bhat; + + log_ml += - N * T / 2.0 * log(M_PI); + log_ml += log_mvgamma(N, (T + d) / 2.0); + log_ml += -log_mvgamma(N, d / 2.0); + log_ml += - N / 2.0 * log_det_sympd(Omega); + log_ml += d / 2.0 * log_det_sympd(Psi); + log_ml += - N / 2.0 * log_det_sympd(XX); + mat A = Psi + ehat.t() * ehat + (Bhat - b).t() * inv_Omega * (Bhat - b); + log_ml += - (T + d) / 2.0 * log_det_sympd(A); + return log_ml; }