Skip to content

Commit

Permalink
feat: add 'plot_rhats' function
Browse files Browse the repository at this point in the history
  • Loading branch information
ntorresd committed Aug 1, 2024
1 parent 962df9e commit aedfc4f
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions R/plot_seromodel.R
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,64 @@ plot_foi_estimates <- function(

return(foi_plot)
}

#' Plot r-hats convergence criteria for the specified model
#'
#' @inheritParams extract_central_estimates
#' @inheritParams plot_serosurvey
#' @return ggplot object showing the r-hats of the model to be compared with the
#' convergence criteria (horizontal dashed line)
plot_rhats <- function(
seromodel,
serosurvey,
par_name = "foi_expanded",
size_text = 11
) {
checkmate::assert_class(seromodel, "stanfit", null.ok = TRUE)

rhats <- bayesplot::rhat(seromodel, par_name)

if (startsWith(seromodel@model_name, "age")) {
xlab <- "Age"
ages <- 1:max(serosurvey$age_max)
rhats_df <- data.frame(
age = ages,
rhat = rhats
)

rhats_plot <- ggplot2::ggplot(
data = rhats_df, ggplot2::aes(x = age)
)
} else if (startsWith(seromodel@model_name, "time")) {
xlab <- "Year"
ages <- rev(1:max(serosurvey$age_max))
years <- unique(serosurvey$tsur) - ages
rhats_df <- data.frame(
year = years,
rhat = rhats
)

rhats_plot <- ggplot2::ggplot(
data = rhats_df, ggplot2::aes(x = year)
)
}

rhats_plot <- rhats_plot +
ggplot2::geom_hline(
yintercept = 1.01,
linetype = 'dashed'
) +
ggplot2::geom_line(ggplot2::aes(y = rhat)) +
ggplot2::geom_point(ggplot2::aes(y = rhat)) +
ggplot2::coord_cartesian(
ylim = c(
min(1.0, min(rhats_df$rhat)),
max(1.02, max(rhats_df$rhat))
)
) +
ggplot2::theme_bw(size_text) +
ggplot2::xlab(xlab) +
ggplot2::ylab("Convergence (r-hats)")

return(rhats_plot)
}

0 comments on commit aedfc4f

Please sign in to comment.