diff --git a/src/sample_hyper.cpp b/src/sample_hyper.cpp index 3464a2c..b349147 100644 --- a/src/sample_hyper.cpp +++ b/src/sample_hyper.cpp @@ -110,25 +110,25 @@ double log_ml( double log_ml = 0; mat inv_Omega = diagmat(1 / Omega.diag()); - + try { mat XX = X.t() * X + inv_Omega; mat Bhat = solve(XX, X.t() * Y + inv_Omega * b, solve_opts::likely_sympd); mat ehat = Y - X * Bhat; - - log_ml += - N * T / 2.0 * log(M_PI); - log_ml += log_mvgamma(N, (T + d) / 2.0); - log_ml += -log_mvgamma(N, d / 2.0); - log_ml += - N / 2.0 * log_det_sympd(Omega); - log_ml += d / 2.0 * log_det_sympd(Psi); - log_ml += - N / 2.0 * log_det_sympd(XX); - mat A = Psi + ehat.t() * ehat + (Bhat - b).t() * inv_Omega * (Bhat - b); - log_ml += - (T + d) / 2.0 * log_det_sympd(A); - - } catch(std::runtime_error) { + + log_ml += - N * T / 2.0 * log(M_PI); + log_ml += log_mvgamma(N, (T + d) / 2.0); + log_ml += -log_mvgamma(N, d / 2.0); + log_ml += - N / 2.0 * log_det_sympd(Omega); + log_ml += d / 2.0 * log_det_sympd(Psi); + log_ml += - N / 2.0 * log_det_sympd(XX); + mat A = Psi + ehat.t() * ehat + (Bhat - b).t() * inv_Omega * (Bhat - b); + log_ml += - (T + d) / 2.0 * log_det_sympd(A); + + } catch(const std::runtime_error& e) { log_ml = -1e+10; } - + return log_ml; }