Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
RasmusSkytte committed Oct 11, 2023
1 parent d0f39c0 commit 7eb02a7
Show file tree
Hide file tree
Showing 15 changed files with 1,131 additions and 1 deletion.
84 changes: 84 additions & 0 deletions R/DiseasyModelBRM.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#' @title Meta module for the BRM class of models
#'
#' @description TODO
DiseasyModelBRM <- R6::R6Class( # nolint: object_name_linter
classname = "DiseasyModelBRM",
inherit = DiseasyModelRegression,

private = list(

.parameters = list(seed = 0,
n_warmup = 1000,
n_iter = 500,
n_chains = 4),

fit_regression = function(data, formula) {
coll <- checkmate::makeAssertCollection()
checkmate::assert_tibble(data, add = coll)
checkmate::assert_formula(formula, add = coll)
checkmate::assert_class(self$family, "family", add = coll)
checkmate::assert_number(self$parameters$n_warmup, add = coll)
checkmate::assert_number(self$parameters$n_iter, add = coll)
checkmate::assert_number(self$parameters$n_chains, add = coll)
checkmate::assert_number(self$parameters$seed, add = coll)
checkmate::reportAssertions(coll)
rm(coll)

# Look for hashed results
hash <- private$get_hash()
if (!private$is_cached(hash)) {

# Run the model with the provided params
private$report_regression_fit("brm", formula, self$family, hash)
brms_fit <- brms::brm(formula = formula,
data = data,
family = self$family,
warmup = self$parameters$n_warmup,
iter = self$parameters$n_warmup + self$parameters$n_iter,
chains = self$parameters$n_chains,
seed = self$parameters$seed)

# Store in cache
private$cache(hash, brms_fit)
}

# Write to the log
private$lg$info("Using fitted brm (hash: {hash})")

return(private$cache(hash))

},

get_prediction = function(regression_fit, new_data, quantiles) {
coll <- checkmate::makeAssertCollection()
checkmate::assert_class(regression_fit, "brmsfit", add = coll)
checkmate::assert_tibble(new_data, add = coll)
checkmate::assert_true(nrow(new_data) > 0, add = coll)
checkmate::assert_number(quantiles, null.ok = TRUE, add = coll)
checkmate::assert_number(self$parameters$n_iter, add = coll)
checkmate::assert_number(self$parameters$n_chains, add = coll)
checkmate::reportAssertions(coll)

brms_predict <- stats::predict(regression_fit, newdata = new_data, summary = FALSE)

# Draw samples if quantiles is not given
if (is.null(quantiles)) {

# First we expand the combinations of draws
combinations <- tidyr::expand_grid(chain = seq(self$parameters$n_chains), iter = seq(self$parameters$n_iter)) |>
dplyr::mutate(index = dplyr::row_number())

# Construct output tibble
brm_samples <- purrr::pmap(combinations,
~ dplyr::mutate(new_data,
observable = brms_predict[..3, ],
realization_id = paste(..1, ..2, sep = "_"))) |>
purrr::reduce(dplyr::union_all)

return(brm_samples)
} else {
private$not_implemented_error("quantile support not yet implemented")
}
}
),
)
104 changes: 104 additions & 0 deletions R/DiseasyModelB_.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#' @title Meta module for the simple, b* reference models models
#'
#' @description TODO
DiseasyModelB_ <- R6::R6Class( # nolint: object_name_linter
inherit = DiseasyModelBRM,
public = list(

#' @description
#' Creates a new instance of the `DiseasyModelB_` [R6][R6::R6Class] class.
#' This module is typically not constructed directly but rather through `DiseasyModelB*` classes,
#' such as [DiseasyModelB0] and [DiseasyModelB1].
#' @param ...
#' parameters sent to `DiseasyModelBRM` [R6][R6::R6Class] constructor.
#' @details
#' Helper class for the the `DiseasyModelB*` [R6][R6::R6Class] classes.
#' @seealso [stats::family], [stats::as.formula]
#' @return
#' A new instance of the `DiseasyModelB_` [R6][R6::R6Class] class.
initialize = function(...) {
super$initialize(formula = self$formula,
family = stats::poisson(),
...)
}
),

private = list(
update_formula = function(formula, aggregation) {

# When aggregation is given, we treat each group as having their own rates and intercepts
if (!is.null(aggregation)) {

# stats::update.formula does not update formulas with only intercept term as expected
# with the `*` operator so we need to manually detect if the formula initially is only
# intercept and use the `+` operator for the first reduction.
initial_operator <- ifelse(rlang::is_empty(labels(terms(formula))), "+", "*")

# Now we can reduce with the operators set
purrr::pmap(tibble::lst(label = names(aggregation),
aggregation = aggregation,
operator = c(initial_operator, rep("*", length(label) - 1))),
\(label, aggregation, operator) {
glue::glue("~ . {operator} {ifelse(label != '', label, dplyr::as_label(aggregation))}") |>
stats::as.formula()
}) |>
purrr::reduce(stats::update.formula, .init = formula)

} else { # Do nothing
return(formula)
}
}
)
)


#' @title Model module for the b0 reference model
#'
#' @description TODO
#' @export
DiseasyModelB0 <- R6::R6Class( # nolint: object_name_linter
classname = "DiseasyModelB0",
inherit = DiseasyModelB_,
public = list(
#' @description
#' Creates a new instance of the `DiseasyModelB0` [R6][R6::R6Class] class.
#' @param training_length `r rd_training_length()`
#' @param ...
#' parameters sent to `DiseasyModelB_` [R6][R6::R6Class] constructor
#' @return
#' A new instance of the `DiseasyModelB0` [R6][R6::R6Class] class.
initialize = function(...) {
super$initialize(label = "b0", training_length = 7, ...)
}
),

private = list(
.formula = "{observable} ~ 1" # "{observable}" will be replaced by the observable at runtime
)
)


#' @title Model module for the b1 reference model
#'
#' @description TODO
#' @export
DiseasyModelB1 <- R6::R6Class( # nolint: object_name_linter
classname = "DiseasyModelB1",
inherit = DiseasyModelB_,
public = list(
#' @description
#' Creates a new instance of the `DiseasyModelB1` [R6][R6::R6Class] class.
#' @param training_length `r rd_training_length()`
#' @param ...
#' parameters sent to `DiseasyModelB_` [R6][R6::R6Class] constructor
#' @return
#' A new instance of the `DiseasyModelB1` [R6][R6::R6Class] class.
initialize = function(...) {
super$initialize(label = "b1", training_length = 21, ...)
}
),

private = list(
.formula = "{observable} ~ t" # "{observable}" will be replaced by the observable at runtime
)
)
70 changes: 70 additions & 0 deletions R/DiseasyModelGLM.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#' @title Meta module for the GLM class of models
#'
#' @description TODO
DiseasyModelGLM <- R6::R6Class( # nolint: object_name_linter
classname = "DiseasyModelGLM",
inherit = DiseasyModelRegression,

private = list(

.parameters = list(seed = 0,
n_realizations = 100),

fit_regression = function(data, formula) {
coll <- checkmate::makeAssertCollection()
checkmate::assert_tibble(data, add = coll)
checkmate::assert_formula(formula, add = coll)
checkmate::assert_class(self$family, "family", add = coll)
checkmate::reportAssertions(coll)
rm(coll)

# Look for hashed results
hash <- private$get_hash()
if (!private$is_cached(hash)) {

# Run the model with the provided params
private$report_regression_fit("glm", formula, self$family, hash)
glm_fit <- stats::glm(formula, data = data, family = self$family)

# Store in cache
private$cache(hash, glm_fit)
}

# Write to the log
private$lg$info("Using fitted glm (hash: {hash})")

return(private$cache(hash))

},

get_prediction = function(regression_fit, new_data, quantiles) {
coll <- checkmate::makeAssertCollection()
checkmate::assert_class(regression_fit, "glm", add = coll)
checkmate::assert_tibble(new_data, add = coll)
checkmate::assert_true(nrow(new_data) > 0, add = coll)
checkmate::assert_number(quantiles, null.ok = TRUE, add = coll)
checkmate::assert_number(self$parameters$n_realizations, add = coll)
checkmate::reportAssertions(coll)

glm_predict <- stats::predict(regression_fit, newdata = new_data, type = "link", se.fit = TRUE)

# Draw samples if quantiles is not given
if (is.null(quantiles)) {
set.seed(seed = self$parameters$seed)
glm_samples <- seq(self$parameters$n_realizations) |>
purrr::map(~ dplyr::mutate(new_data,
observable = self$family$linkinv(
stats::rnorm(nrow(new_data),
glm_predict$fit,
glm_predict$se.fit)
),
realization_id = as.character(.x))) |>
purrr::reduce(dplyr::union_all)

return(glm_samples)
} else {
private$not_implemented_error("quantile support not yet implemented")
}
}
),
)
104 changes: 104 additions & 0 deletions R/DiseasyModelG_.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#' @title Meta module for the simple, g* reference models models
#'
#' @description TODO
DiseasyModelG_ <- R6::R6Class( # nolint: object_name_linter
inherit = DiseasyModelGLM,
public = list(

#' @description
#' Creates a new instance of the `DiseasyModelG_` [R6][R6::R6Class] class.
#' This module is typically not constructed directly but rather through `DiseasyModelG*` classes,
#' such as [DiseasyModelG0] and [DiseasyModelG1].
#' @param ...
#' parameters sent to `DiseasyModelGLM` [R6][R6::R6Class] constructor.
#' @details
#' Helper class for the the `DiseasyModelG*` [R6][R6::R6Class] classes.
#' @seealso [stats::family], [stats::as.formula]
#' @return
#' A new instance of the `DiseasyModelG_` [R6][R6::R6Class] class.
initialize = function(...) {
super$initialize(formula = self$formula,
family = stats::quasipoisson(),
...)
}
),

private = list(
update_formula = function(formula, aggregation) {

# When aggregation is given, we treat each group as having their own rates and intercepts
if (!is.null(aggregation)) {

# stats::update.formula does not update formulas with only intercept term as expected
# when using the `*` operator so we need to manually detect if the formula initially is only
# intercept and use the `+` operator for the first reduction.
initial_operator <- ifelse(rlang::is_empty(labels(terms(formula))), "+", "*")

# Now we can reduce with the operators set
purrr::pmap(tibble::lst(label = names(aggregation),
aggregation = aggregation,
operator = c(initial_operator, rep("*", length(label) - 1))),
\(label, aggregation, operator) {
glue::glue("~ . {operator} {ifelse(label != '', label, dplyr::as_label(aggregation))}") |>
stats::as.formula()
}) |>
purrr::reduce(stats::update.formula, .init = formula)

} else { # Do nothing
return(formula)
}
}
)
)


#' @title Model module for the g0 reference model
#'
#' @description TODO
#' @export
DiseasyModelG0 <- R6::R6Class( # nolint: object_name_linter
classname = "DiseasyModelG0",
inherit = DiseasyModelG_,
public = list(
#' @description
#' Creates a new instance of the `DiseasyModelG0` [R6][R6::R6Class] class.
#' @param training_length `r rd_training_length()`
#' @param ...
#' parameters sent to `DiseasyModelG_` [R6][R6::R6Class] constructor
#' @return
#' A new instance of the `DiseasyModelG1` [R6][R6::R6Class] class.
initialize = function(...) {
super$initialize(label = "g0", training_length = 7, ...)
}
),

private = list(
.formula = "{observable} ~ 1" # "{observable}" will be replaced by the observable at runtime
)
)


#' @title Model module for the g1 reference model
#'
#' @description TODO
#' @export
DiseasyModelG1 <- R6::R6Class( # nolint: object_name_linter
classname = "DiseasyModelG1",
inherit = DiseasyModelG_,
public = list(
#' @description
#' Creates a new instance of the `DiseasyModelG1` [R6][R6::R6Class] class.
#' @param training_length `r rd_training_length()`
#' @param ...
#' parameters sent to `DiseasyModelG_` [R6][R6::R6Class] constructor
#' @return
#' A new instance of the `DiseasyModelG1` [R6][R6::R6Class] class.
initialize = function(...) {
super$initialize(label = "g1", training_length = 21, ...)
}
),

private = list(
.formula = "{observable} ~ t" # "{observable}" will be replaced by the observable at runtime
)
)
2 changes: 1 addition & 1 deletion R/DiseasyObservables.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ DiseasyObservables <- R6::R6Class(
ds_case_definition <- diseasystore:::diseasystore_case_definition(case_definition)
private$.ds <- get(ds_case_definition)$new(slice_ts = self %.% slice_ts,
verbose = !testthat::is_testing(),
target_conn = parse_conn(self %.% conn))
target_conn = self %.% conn)

private$.case_definition <- private$.ds %.% case_definition # Use the human readable from the diseasystore

Expand Down
Loading

0 comments on commit 7eb02a7

Please sign in to comment.