Skip to content

Commit 573fc9b

Browse files
committed
WIP epi_slide refactor
1 parent d962101 commit 573fc9b

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed

R/slide-refactor.R

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
2+
#'
3+
#' @examples
4+
#' time_slide_to_simple_hopper(as_time_slide_computation(~ .x[1L,]),
5+
#' .before_n_steps = 2L, .after_n_steps = 0L
6+
#' )(
7+
#' tibble(time_value = 1:5, value = 1:5),
8+
#' tibble(geo_value = 1),
9+
#' 3:4
10+
#' )
11+
#'
12+
time_slide_to_simple_hopper <- function(.slide_comp, ..., .before_n_steps, .after_n_steps) {
13+
function(grp_data, grp_key,ref_inds) {
14+
available_ref_time_values <- vec_slice(grp_data$time_value, ref_inds)
15+
i <<- 0L
16+
wrapped_slide_comp <- function(.x, .group_key, ...) {
17+
i <<- i + 1L
18+
# XXX could also use .after_n_steps to figure out...
19+
20+
# FIXME wrong dots here?
21+
.slide_comp(.x, .group_key, available_ref_time_values[[i]], ...)
22+
}
23+
if (.before_n_steps == Inf) {
24+
starts <- 1L
25+
} else {
26+
starts <- ref_inds - .before_n_steps
27+
}
28+
stops <- ref_inds + .after_n_steps
29+
# Compute the slide values. slider::hop_index will return a list of f outputs
30+
# e.g. list(f(.slide_group_1, .group_key, .ref_time_value_1),
31+
# f(.slide_group_1, .group_key, .ref_time_value_2), ...)
32+
slide_values_list <- slider::hop(
33+
.x = grp_data,
34+
.i = grp_data$time_value,
35+
.starts = starts,
36+
.stops = stops,
37+
.f = wrapped_slide_comp,
38+
.group_key, ...
39+
)
40+
41+
# Validate returned values. This used to only happen when
42+
# .used_data_masking=FALSE, so if it seems too slow, consider bringing that
43+
# back.
44+
return_types <- purrr::map_chr(slide_values_list, function(x) {
45+
if (is.data.frame(x)) {
46+
"data.frame"
47+
} else if (vctrs::obj_is_vector(x) && is.null(vctrs::vec_names(x))) {
48+
"vector"
49+
} else {
50+
"other"
51+
}
52+
}) %>% unique()
53+
# Returned values must be data.frame or vector.
54+
if ("other" %in% return_types) {
55+
cli_abort(
56+
"epi_slide: slide computations must always return either data frames
57+
or unnamed vectors (as determined by the vctrs package).",
58+
class = "epiprocess__invalid_slide_comp_value"
59+
)
60+
}
61+
# Returned values must all be the same type.
62+
if (length(return_types) != 1L) {
63+
cli_abort(
64+
"epi_slide: slide computations must always return either a data.frame or a vector (as determined by the
65+
vctrs package), but not a mix of the two.",
66+
class = "epiprocess__invalid_slide_comp_value"
67+
)
68+
}
69+
# Returned values must always be a scalar vector or a data frame with one row.
70+
if (any(vctrs::list_sizes(slide_values_list) != 1L)) {
71+
cli_abort(
72+
"epi_slide: slide computations must return a single element (e.g. a scalar value, a single data.frame row,
73+
or a list).",
74+
class = "epiprocess__invalid_slide_comp_value"
75+
)
76+
}
77+
# Flatten the output list. This will also error if the user's slide function
78+
# returned inconsistent types.
79+
slide_values <- slide_values_list %>% vctrs::list_unchop()
80+
81+
slide_values
82+
}
83+
}
84+
85+
# TODO hopper -> skipper?
86+
87+
# TODO simplify to just trailing and put shift elsewhere?
88+
#'
89+
#' upstream_slide_to_simple_hopper(frollmean, .in_colnames = "value", .out_colnames = "slide_value", .before_n_steps = 1L, .after_n_steps = 0L)(
90+
#' tibble(time_value = 1:5, value = 1:5),
91+
#' tibble(geo_value = 1),
92+
#' 3:4
93+
#' )
94+
upstream_slide_to_simple_hopper <- function(.f, ..., .in_colnames, .out_colnames, .before_n_steps, .after_n_steps) {
95+
f_info <- upstream_slide_f_info(.f, ...)
96+
in_colnames <- .in_colnames
97+
out_colnames <- .out_colnames
98+
f_from_package <- f_info$from_package
99+
# TODO move .before_n_steps, .after_n_steps to args of this function?
100+
switch(
101+
f_from_package,
102+
data.table = if (.before_n_steps == Inf) {
103+
if (.after_n_steps != 0L) {
104+
stop(".before_n_steps only supported with .after_n_steps = 0")
105+
}
106+
function(grp_data, grp_key, ref_inds) {
107+
grp_data[, out_colnames] <-
108+
f_dots_baked(grp_data[, in_colnames], seq_len(nrow(grp_data)), adaptive = TRUE)
109+
grp_data[, out_colnames] <- out_cols
110+
grp_data
111+
}
112+
} else {
113+
function(grp_data, grp_key, ref_inds) {
114+
out_cols <- .f(grp_data[, in_colnames], .before_n_steps + .after_n_steps + 1L, ...)
115+
if (.after_n_steps != 0L) {
116+
# Shift an appropriate amount of NA padding from the start to the end.
117+
# (This padding will later be cut off when we filter down to the
118+
# original time_values.)
119+
out_cols <- lapply(out_cols, function(out_col) {
120+
c(out_col[(.after_n_steps + 1L):length(out_col)], rep(NA, .after_n_steps))
121+
})
122+
}
123+
grp_data[, out_colnames] <- out_cols
124+
grp_data
125+
}
126+
},
127+
slider = function(grp_data, grp_key, ref_inds) {
128+
for (col_i in seq_along(in_colnames)) {
129+
grp_data[[out_colnames[[col_i]]]] <- f_dots_baked(grp_data[[in_colnames[[col_i]]]], before = .before_n_steps, after = .after_n_steps)
130+
}
131+
grp_data
132+
},
133+
# TODO Inf checks?
134+
stop("unsupported package")
135+
)
136+
}
137+
138+
# TODO maybe make ref_inds optional or have special handling if it's the whole sequence?
139+
#
140+
# TODO decide whether/where to put time range stuff
141+
142+
# TODO grp_ -> ek_ ?

0 commit comments

Comments
 (0)