-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d0f39c0
commit 7eb02a7
Showing
15 changed files
with
1,131 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
#' @title Meta module for the BRM class of models | ||
#' | ||
#' @description TODO | ||
DiseasyModelBRM <- R6::R6Class( # nolint: object_name_linter | ||
classname = "DiseasyModelBRM", | ||
inherit = DiseasyModelRegression, | ||
|
||
private = list( | ||
|
||
.parameters = list(seed = 0, | ||
n_warmup = 1000, | ||
n_iter = 500, | ||
n_chains = 4), | ||
|
||
fit_regression = function(data, formula) { | ||
coll <- checkmate::makeAssertCollection() | ||
checkmate::assert_tibble(data, add = coll) | ||
checkmate::assert_formula(formula, add = coll) | ||
checkmate::assert_class(self$family, "family", add = coll) | ||
checkmate::assert_number(self$parameters$n_warmup, add = coll) | ||
checkmate::assert_number(self$parameters$n_iter, add = coll) | ||
checkmate::assert_number(self$parameters$n_chains, add = coll) | ||
checkmate::assert_number(self$parameters$seed, add = coll) | ||
checkmate::reportAssertions(coll) | ||
rm(coll) | ||
|
||
# Look for hashed results | ||
hash <- private$get_hash() | ||
if (!private$is_cached(hash)) { | ||
|
||
# Run the model with the provided params | ||
private$report_regression_fit("brm", formula, self$family, hash) | ||
brms_fit <- brms::brm(formula = formula, | ||
data = data, | ||
family = self$family, | ||
warmup = self$parameters$n_warmup, | ||
iter = self$parameters$n_warmup + self$parameters$n_iter, | ||
chains = self$parameters$n_chains, | ||
seed = self$parameters$seed) | ||
|
||
# Store in cache | ||
private$cache(hash, brms_fit) | ||
} | ||
|
||
# Write to the log | ||
private$lg$info("Using fitted brm (hash: {hash})") | ||
|
||
return(private$cache(hash)) | ||
|
||
}, | ||
|
||
get_prediction = function(regression_fit, new_data, quantiles) { | ||
coll <- checkmate::makeAssertCollection() | ||
checkmate::assert_class(regression_fit, "brmsfit", add = coll) | ||
checkmate::assert_tibble(new_data, add = coll) | ||
checkmate::assert_true(nrow(new_data) > 0, add = coll) | ||
checkmate::assert_number(quantiles, null.ok = TRUE, add = coll) | ||
checkmate::assert_number(self$parameters$n_iter, add = coll) | ||
checkmate::assert_number(self$parameters$n_chains, add = coll) | ||
checkmate::reportAssertions(coll) | ||
|
||
brms_predict <- stats::predict(regression_fit, newdata = new_data, summary = FALSE) | ||
|
||
# Draw samples if quantiles is not given | ||
if (is.null(quantiles)) { | ||
|
||
# First we expand the combinations of draws | ||
combinations <- tidyr::expand_grid(chain = seq(self$parameters$n_chains), iter = seq(self$parameters$n_iter)) |> | ||
dplyr::mutate(index = dplyr::row_number()) | ||
|
||
# Construct output tibble | ||
brm_samples <- purrr::pmap(combinations, | ||
~ dplyr::mutate(new_data, | ||
observable = brms_predict[..3, ], | ||
realization_id = paste(..1, ..2, sep = "_"))) |> | ||
purrr::reduce(dplyr::union_all) | ||
|
||
return(brm_samples) | ||
} else { | ||
private$not_implemented_error("quantile support not yet implemented") | ||
} | ||
} | ||
), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
#' @title Meta module for the simple, b* reference models models | ||
#' | ||
#' @description TODO | ||
DiseasyModelB_ <- R6::R6Class( # nolint: object_name_linter | ||
inherit = DiseasyModelBRM, | ||
public = list( | ||
|
||
#' @description | ||
#' Creates a new instance of the `DiseasyModelB_` [R6][R6::R6Class] class. | ||
#' This module is typically not constructed directly but rather through `DiseasyModelB*` classes, | ||
#' such as [DiseasyModelB0] and [DiseasyModelB1]. | ||
#' @param ... | ||
#' parameters sent to `DiseasyModelBRM` [R6][R6::R6Class] constructor. | ||
#' @details | ||
#' Helper class for the the `DiseasyModelB*` [R6][R6::R6Class] classes. | ||
#' @seealso [stats::family], [stats::as.formula] | ||
#' @return | ||
#' A new instance of the `DiseasyModelB_` [R6][R6::R6Class] class. | ||
initialize = function(...) { | ||
super$initialize(formula = self$formula, | ||
family = stats::poisson(), | ||
...) | ||
} | ||
), | ||
|
||
private = list( | ||
update_formula = function(formula, aggregation) { | ||
|
||
# When aggregation is given, we treat each group as having their own rates and intercepts | ||
if (!is.null(aggregation)) { | ||
|
||
# stats::update.formula does not update formulas with only intercept term as expected | ||
# with the `*` operator so we need to manually detect if the formula initially is only | ||
# intercept and use the `+` operator for the first reduction. | ||
initial_operator <- ifelse(rlang::is_empty(labels(terms(formula))), "+", "*") | ||
|
||
# Now we can reduce with the operators set | ||
purrr::pmap(tibble::lst(label = names(aggregation), | ||
aggregation = aggregation, | ||
operator = c(initial_operator, rep("*", length(label) - 1))), | ||
\(label, aggregation, operator) { | ||
glue::glue("~ . {operator} {ifelse(label != '', label, dplyr::as_label(aggregation))}") |> | ||
stats::as.formula() | ||
}) |> | ||
purrr::reduce(stats::update.formula, .init = formula) | ||
|
||
} else { # Do nothing | ||
return(formula) | ||
} | ||
} | ||
) | ||
) | ||
|
||
|
||
#' @title Model module for the b0 reference model | ||
#' | ||
#' @description TODO | ||
#' @export | ||
DiseasyModelB0 <- R6::R6Class( # nolint: object_name_linter | ||
classname = "DiseasyModelB0", | ||
inherit = DiseasyModelB_, | ||
public = list( | ||
#' @description | ||
#' Creates a new instance of the `DiseasyModelB0` [R6][R6::R6Class] class. | ||
#' @param training_length `r rd_training_length()` | ||
#' @param ... | ||
#' parameters sent to `DiseasyModelB_` [R6][R6::R6Class] constructor | ||
#' @return | ||
#' A new instance of the `DiseasyModelB0` [R6][R6::R6Class] class. | ||
initialize = function(...) { | ||
super$initialize(label = "b0", training_length = 7, ...) | ||
} | ||
), | ||
|
||
private = list( | ||
.formula = "{observable} ~ 1" # "{observable}" will be replaced by the observable at runtime | ||
) | ||
) | ||
|
||
|
||
#' @title Model module for the b1 reference model | ||
#' | ||
#' @description TODO | ||
#' @export | ||
DiseasyModelB1 <- R6::R6Class( # nolint: object_name_linter | ||
classname = "DiseasyModelB1", | ||
inherit = DiseasyModelB_, | ||
public = list( | ||
#' @description | ||
#' Creates a new instance of the `DiseasyModelB1` [R6][R6::R6Class] class. | ||
#' @param training_length `r rd_training_length()` | ||
#' @param ... | ||
#' parameters sent to `DiseasyModelB_` [R6][R6::R6Class] constructor | ||
#' @return | ||
#' A new instance of the `DiseasyModelB1` [R6][R6::R6Class] class. | ||
initialize = function(...) { | ||
super$initialize(label = "b1", training_length = 21, ...) | ||
} | ||
), | ||
|
||
private = list( | ||
.formula = "{observable} ~ t" # "{observable}" will be replaced by the observable at runtime | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
#' @title Meta module for the GLM class of models | ||
#' | ||
#' @description TODO | ||
DiseasyModelGLM <- R6::R6Class( # nolint: object_name_linter | ||
classname = "DiseasyModelGLM", | ||
inherit = DiseasyModelRegression, | ||
|
||
private = list( | ||
|
||
.parameters = list(seed = 0, | ||
n_realizations = 100), | ||
|
||
fit_regression = function(data, formula) { | ||
coll <- checkmate::makeAssertCollection() | ||
checkmate::assert_tibble(data, add = coll) | ||
checkmate::assert_formula(formula, add = coll) | ||
checkmate::assert_class(self$family, "family", add = coll) | ||
checkmate::reportAssertions(coll) | ||
rm(coll) | ||
|
||
# Look for hashed results | ||
hash <- private$get_hash() | ||
if (!private$is_cached(hash)) { | ||
|
||
# Run the model with the provided params | ||
private$report_regression_fit("glm", formula, self$family, hash) | ||
glm_fit <- stats::glm(formula, data = data, family = self$family) | ||
|
||
# Store in cache | ||
private$cache(hash, glm_fit) | ||
} | ||
|
||
# Write to the log | ||
private$lg$info("Using fitted glm (hash: {hash})") | ||
|
||
return(private$cache(hash)) | ||
|
||
}, | ||
|
||
get_prediction = function(regression_fit, new_data, quantiles) { | ||
coll <- checkmate::makeAssertCollection() | ||
checkmate::assert_class(regression_fit, "glm", add = coll) | ||
checkmate::assert_tibble(new_data, add = coll) | ||
checkmate::assert_true(nrow(new_data) > 0, add = coll) | ||
checkmate::assert_number(quantiles, null.ok = TRUE, add = coll) | ||
checkmate::assert_number(self$parameters$n_realizations, add = coll) | ||
checkmate::reportAssertions(coll) | ||
|
||
glm_predict <- stats::predict(regression_fit, newdata = new_data, type = "link", se.fit = TRUE) | ||
|
||
# Draw samples if quantiles is not given | ||
if (is.null(quantiles)) { | ||
set.seed(seed = self$parameters$seed) | ||
glm_samples <- seq(self$parameters$n_realizations) |> | ||
purrr::map(~ dplyr::mutate(new_data, | ||
observable = self$family$linkinv( | ||
stats::rnorm(nrow(new_data), | ||
glm_predict$fit, | ||
glm_predict$se.fit) | ||
), | ||
realization_id = as.character(.x))) |> | ||
purrr::reduce(dplyr::union_all) | ||
|
||
return(glm_samples) | ||
} else { | ||
private$not_implemented_error("quantile support not yet implemented") | ||
} | ||
} | ||
), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
#' @title Meta module for the simple, g* reference models models | ||
#' | ||
#' @description TODO | ||
DiseasyModelG_ <- R6::R6Class( # nolint: object_name_linter | ||
inherit = DiseasyModelGLM, | ||
public = list( | ||
|
||
#' @description | ||
#' Creates a new instance of the `DiseasyModelG_` [R6][R6::R6Class] class. | ||
#' This module is typically not constructed directly but rather through `DiseasyModelG*` classes, | ||
#' such as [DiseasyModelG0] and [DiseasyModelG1]. | ||
#' @param ... | ||
#' parameters sent to `DiseasyModelGLM` [R6][R6::R6Class] constructor. | ||
#' @details | ||
#' Helper class for the the `DiseasyModelG*` [R6][R6::R6Class] classes. | ||
#' @seealso [stats::family], [stats::as.formula] | ||
#' @return | ||
#' A new instance of the `DiseasyModelG_` [R6][R6::R6Class] class. | ||
initialize = function(...) { | ||
super$initialize(formula = self$formula, | ||
family = stats::quasipoisson(), | ||
...) | ||
} | ||
), | ||
|
||
private = list( | ||
update_formula = function(formula, aggregation) { | ||
|
||
# When aggregation is given, we treat each group as having their own rates and intercepts | ||
if (!is.null(aggregation)) { | ||
|
||
# stats::update.formula does not update formulas with only intercept term as expected | ||
# when using the `*` operator so we need to manually detect if the formula initially is only | ||
# intercept and use the `+` operator for the first reduction. | ||
initial_operator <- ifelse(rlang::is_empty(labels(terms(formula))), "+", "*") | ||
|
||
# Now we can reduce with the operators set | ||
purrr::pmap(tibble::lst(label = names(aggregation), | ||
aggregation = aggregation, | ||
operator = c(initial_operator, rep("*", length(label) - 1))), | ||
\(label, aggregation, operator) { | ||
glue::glue("~ . {operator} {ifelse(label != '', label, dplyr::as_label(aggregation))}") |> | ||
stats::as.formula() | ||
}) |> | ||
purrr::reduce(stats::update.formula, .init = formula) | ||
|
||
} else { # Do nothing | ||
return(formula) | ||
} | ||
} | ||
) | ||
) | ||
|
||
|
||
#' @title Model module for the g0 reference model | ||
#' | ||
#' @description TODO | ||
#' @export | ||
DiseasyModelG0 <- R6::R6Class( # nolint: object_name_linter | ||
classname = "DiseasyModelG0", | ||
inherit = DiseasyModelG_, | ||
public = list( | ||
#' @description | ||
#' Creates a new instance of the `DiseasyModelG0` [R6][R6::R6Class] class. | ||
#' @param training_length `r rd_training_length()` | ||
#' @param ... | ||
#' parameters sent to `DiseasyModelG_` [R6][R6::R6Class] constructor | ||
#' @return | ||
#' A new instance of the `DiseasyModelG1` [R6][R6::R6Class] class. | ||
initialize = function(...) { | ||
super$initialize(label = "g0", training_length = 7, ...) | ||
} | ||
), | ||
|
||
private = list( | ||
.formula = "{observable} ~ 1" # "{observable}" will be replaced by the observable at runtime | ||
) | ||
) | ||
|
||
|
||
#' @title Model module for the g1 reference model | ||
#' | ||
#' @description TODO | ||
#' @export | ||
DiseasyModelG1 <- R6::R6Class( # nolint: object_name_linter | ||
classname = "DiseasyModelG1", | ||
inherit = DiseasyModelG_, | ||
public = list( | ||
#' @description | ||
#' Creates a new instance of the `DiseasyModelG1` [R6][R6::R6Class] class. | ||
#' @param training_length `r rd_training_length()` | ||
#' @param ... | ||
#' parameters sent to `DiseasyModelG_` [R6][R6::R6Class] constructor | ||
#' @return | ||
#' A new instance of the `DiseasyModelG1` [R6][R6::R6Class] class. | ||
initialize = function(...) { | ||
super$initialize(label = "g1", training_length = 21, ...) | ||
} | ||
), | ||
|
||
private = list( | ||
.formula = "{observable} ~ t" # "{observable}" will be replaced by the observable at runtime | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.