From aedfc4fdca0f1f23ca8e99e8321b92c4e8993cf4 Mon Sep 17 00:00:00 2001 From: ntorresd Date: Wed, 31 Jul 2024 18:47:19 -0500 Subject: [PATCH] feat: add 'plot_rhats' function --- R/plot_seromodel.R | 61 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/R/plot_seromodel.R b/R/plot_seromodel.R index 7834cc36..469ca8db 100644 --- a/R/plot_seromodel.R +++ b/R/plot_seromodel.R @@ -254,3 +254,64 @@ plot_foi_estimates <- function( return(foi_plot) } + +#' Plot r-hats convergence criteria for the specified model +#' +#' @inheritParams extract_central_estimates +#' @inheritParams plot_serosurvey +#' @return ggplot object showing the r-hats of the model to be compared with the +#' convergence criteria (horizontal dashed line) +plot_rhats <- function( + seromodel, + serosurvey, + par_name = "foi_expanded", + size_text = 11 +) { + checkmate::assert_class(seromodel, "stanfit", null.ok = TRUE) + + rhats <- bayesplot::rhat(seromodel, par_name) + + if (startsWith(seromodel@model_name, "age")) { + xlab <- "Age" + ages <- 1:max(serosurvey$age_max) + rhats_df <- data.frame( + age = ages, + rhat = rhats + ) + + rhats_plot <- ggplot2::ggplot( + data = rhats_df, ggplot2::aes(x = age) + ) + } else if (startsWith(seromodel@model_name, "time")) { + xlab <- "Year" + ages <- rev(1:max(serosurvey$age_max)) + years <- unique(serosurvey$tsur) - ages + rhats_df <- data.frame( + year = years, + rhat = rhats + ) + + rhats_plot <- ggplot2::ggplot( + data = rhats_df, ggplot2::aes(x = year) + ) + } + + rhats_plot <- rhats_plot + + ggplot2::geom_hline( + yintercept = 1.01, + linetype = 'dashed' + ) + + ggplot2::geom_line(ggplot2::aes(y = rhat)) + + ggplot2::geom_point(ggplot2::aes(y = rhat)) + + ggplot2::coord_cartesian( + ylim = c( + min(1.0, min(rhats_df$rhat)), + max(1.02, max(rhats_df$rhat)) + ) + ) + + ggplot2::theme_bw(size_text) + + ggplot2::xlab(xlab) + + ggplot2::ylab("Convergence (r-hats)") + + return(rhats_plot) +}