Skip to content

Commit fdee9de

Browse files
committed
Test unifying epi_slide_opt inner comps between edf and archive
1 parent 7a5708f commit fdee9de

File tree

2 files changed

+80
-153
lines changed

2 files changed

+80
-153
lines changed

R/epi_slide_opt_archive.R

Lines changed: 2 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -52,81 +52,14 @@ epi_slide_opt_archive_one_epikey <- function(
5252
grp_updates_by_version <- grp_updates %>%
5353
nest(.by = version, .key = "subtbl") %>%
5454
arrange(version)
55-
unit_step <- unit_time_delta(time_type)
55+
unit_step <- unit_time_delta(time_type, format = "fast")
5656
prev_inp_snapshot <- NULL
5757
prev_out_snapshot <- NULL
5858
result <- map(seq_len(nrow(grp_updates_by_version)), function(version_i) {
5959
version <- grp_updates_by_version$version[[version_i]]
6060
inp_update <- grp_updates_by_version$subtbl[[version_i]]
6161
inp_snapshot <- tbl_patch(prev_inp_snapshot, inp_update, "time_value")
62-
if (before == Inf) {
63-
if (after != 0) {
64-
cli_abort('.window_size = Inf is only supported with .align = "right"',
65-
class = "epiprocess__epi_slide_opt_archive__inf_window_invalid_align"
66-
)
67-
}
68-
# We need to use the entire input snapshot range, filling in time gaps. We
69-
# shouldn't pad the ends.
70-
slide_t_min <- min(inp_snapshot$time_value)
71-
slide_t_max <- max(inp_snapshot$time_value)
72-
} else {
73-
# If the input had updates in the range t1..t2, this could produce changes
74-
# in slide outputs in the range t1-after..t2+before, and to compute those
75-
# slide values, we need to look at the input snapshot from
76-
# t1-after-before..t2+before+after. nolint: commented_code_linter
77-
inp_update_t_min <- min(inp_update$time_value)
78-
inp_update_t_max <- max(inp_update$time_value)
79-
slide_t_min <- inp_update_t_min - (before + after) * unit_step
80-
slide_t_max <- inp_update_t_max + (before + after) * unit_step
81-
}
82-
slide_nrow <- time_delta_to_n_steps(slide_t_max - slide_t_min, time_type) + 1L
83-
slide_time_values <- slide_t_min + 0L:(slide_nrow - 1L) * unit_step
84-
slide_inp_backrefs <- vec_match(slide_time_values, inp_snapshot$time_value)
85-
# Get additional values needed from inp_snapshot + perform any NA
86-
# tail-padding needed to make slider results a fixed window size rather than
87-
# adaptive at tails + perform any NA gap-filling needed:
88-
slide <- vec_slice(inp_snapshot, slide_inp_backrefs)
89-
slide$time_value <- slide_time_values
90-
if (f_from_package == "data.table") {
91-
if (before == Inf) {
92-
slide[, out_colnames] <-
93-
f_dots_baked(slide[, in_colnames], seq_len(slide_nrow), adaptive = TRUE)
94-
} else {
95-
out_cols <- f_dots_baked(slide[, in_colnames], before + after + 1L)
96-
if (after != 0L) {
97-
# Shift an appropriate amount of NA padding from the start to the end.
98-
# (This padding will later be cut off when we filter down to the
99-
# original time_values.)
100-
out_cols <- lapply(out_cols, function(out_col) {
101-
c(out_col[(after + 1L):length(out_col)], rep(NA, after))
102-
})
103-
}
104-
slide[, out_colnames] <- out_cols
105-
}
106-
} else if (f_from_package == "slider") {
107-
for (col_i in seq_along(in_colnames)) {
108-
slide[[out_colnames[[col_i]]]] <- f_dots_baked(slide[[in_colnames[[col_i]]]], before = before, after = after)
109-
}
110-
} else {
111-
cli_abort(
112-
"epiprocess internal error: `f_from_package` was {format_chr_deparse(f_from_package)}, which is unsupported",
113-
class = "epiprocess__epi_slide_opt_archive__f_from_package_invalid"
114-
)
115-
}
116-
rows_should_keep <-
117-
if (before == Inf) {
118-
# Re-introduce time gaps:
119-
!is.na(slide_inp_backrefs)
120-
} else {
121-
# Get back to t1-after..t2+before; times outside this range were included
122-
# only so those inside would have enough context for their slide
123-
# computations, but these "context" rows may contain invalid slide
124-
# computation outputs:
125-
vec_rep_each(c(FALSE, TRUE, FALSE), c(before, slide_nrow - before - after, after)) &
126-
# Only include time_values that appeared in the input snapshot:
127-
!is.na(slide_inp_backrefs)
128-
}
129-
out_update <- vec_slice(slide, rows_should_keep)
62+
out_update <- epi_slide_opt_one_epikey(inp_snapshot, f_dots_baked, f_from_package, before, after, unit_step, time_type, inp_update$time_value, in_colnames, out_colnames)
13063
out_diff <- tbl_diff2(prev_out_snapshot, out_update, "time_value", "update")
13164
prev_inp_snapshot <<- inp_snapshot
13265
prev_out_snapshot <<- tbl_patch(prev_out_snapshot, out_diff, "time_value")

R/epi_slide_opt_edf.R

Lines changed: 78 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,67 @@ across_ish_names_info <- function(.x, time_type, col_names_quo, .f_namer,
167167
)
168168
}
169169

170+
epi_slide_opt_one_epikey <- function(inp_snapshot, f_dots_baked, f_from_package, before, after, unit_step, time_type, ref_time_values, in_colnames, out_colnames) {
171+
# TODO try converting time values to reals, do work on reals, convert back at very end?
172+
if (before == Inf) {
173+
if (after != 0L) {
174+
cli_abort('.window_size = Inf is only supported with .align = "right"',
175+
class = "epiprocess__epi_slide_opt_archive__inf_window_invalid_align"
176+
)
177+
}
178+
# We need to use the entire input snapshot range, filling in time gaps. We
179+
# shouldn't pad the ends.
180+
slide_t_min <- min(inp_snapshot$time_value)
181+
slide_t_max <- max(inp_snapshot$time_value)
182+
} else {
183+
# If the input had updates in the range t1..t2, this could produce changes
184+
# in slide outputs in the range t1-after..t2+before, and to compute those
185+
# slide values, we need to look at the input snapshot from
186+
# t1-after-before..t2+before+after. nolint: commented_code_linter
187+
inp_update_t_min <- min(ref_time_values)
188+
inp_update_t_max <- max(ref_time_values)
189+
slide_t_min <- inp_update_t_min - (before + after) * unit_step
190+
slide_t_max <- inp_update_t_max + (before + after) * unit_step
191+
}
192+
slide_nrow <- time_delta_to_n_steps(slide_t_max - slide_t_min, time_type) + 1L
193+
slide_time_values <- slide_t_min + 0L:(slide_nrow - 1L) * unit_step
194+
slide_inp_backrefs <- vec_match(slide_time_values, inp_snapshot$time_value)
195+
# Get additional values needed from inp_snapshot + perform any NA
196+
# tail-padding needed to make slider results a fixed window size rather than
197+
# adaptive at tails + perform any NA gap-filling needed:
198+
slide <- vec_slice(inp_snapshot, slide_inp_backrefs)
199+
slide$time_value <- slide_time_values
200+
if (f_from_package == "data.table") {
201+
if (before == Inf) {
202+
slide[, out_colnames] <-
203+
f_dots_baked(slide[, in_colnames], seq_len(slide_nrow), adaptive = TRUE)
204+
} else {
205+
out_cols <- f_dots_baked(slide[, in_colnames], before + after + 1L)
206+
if (after != 0L) {
207+
# Shift an appropriate amount of NA padding from the start to the end.
208+
# (This padding will later be cut off when we filter down to the
209+
# original time_values.)
210+
out_cols <- lapply(out_cols, function(out_col) {
211+
c(out_col[(after + 1L):length(out_col)], rep(NA, after))
212+
})
213+
}
214+
slide[, out_colnames] <- out_cols
215+
}
216+
} else if (f_from_package == "slider") {
217+
for (col_i in seq_along(in_colnames)) {
218+
slide[[out_colnames[[col_i]]]] <- f_dots_baked(slide[[in_colnames[[col_i]]]], before = before, after = after)
219+
}
220+
} else {
221+
cli_abort(
222+
"epiprocess internal error: `f_from_package` was {format_chr_deparse(f_from_package)}, which is unsupported",
223+
class = "epiprocess__epi_slide_opt_archive__f_from_package_invalid"
224+
)
225+
}
226+
rows_should_keep <- vec_match(ref_time_values, slide_time_values)
227+
out_update <- vec_slice(slide, rows_should_keep)
228+
out_update
229+
}
230+
170231
#' Optimized slide functions for common cases
171232
#'
172233
#' @description
@@ -425,6 +486,9 @@ epi_slide_opt.epi_df <- function(.x, .col_names, .f, ...,
425486
}
426487
validate_slide_window_arg(.window_size, time_type)
427488
window_args <- get_before_after_from_window(.window_size, .align, time_type)
489+
before <- time_delta_to_n_steps(window_args$before, time_type)
490+
after <- time_delta_to_n_steps(window_args$after, time_type)
491+
unit_step <- unit_time_delta(time_type, format = "fast")
428492

429493
# Handle output naming:
430494
names_info <- across_ish_names_info(
@@ -434,97 +498,27 @@ epi_slide_opt.epi_df <- function(.x, .col_names, .f, ...,
434498
input_col_names <- names_info$input_col_names
435499
output_col_names <- names_info$output_col_names
436500

437-
# Make a complete date sequence between min(.x$time_value) and max(.x$time_value).
438-
date_seq_list <- full_date_seq(.x, window_args$before, window_args$after, time_type)
439-
all_dates <- date_seq_list$all_dates
440-
pad_early_dates <- date_seq_list$pad_early_dates
441-
pad_late_dates <- date_seq_list$pad_late_dates
442-
443-
slide_one_grp <- function(.data_group, .group_key, ...) {
444-
missing_times <- all_dates[!vec_in(all_dates, .data_group$time_value)]
445-
# `frollmean` requires a full window to compute a result. Add NA values
446-
# to beginning and end of the group so that we get results for the
447-
# first `before` and last `after` elements.
448-
.data_group <- vec_rbind(
449-
.data_group, # (tibble; epi_slide_opt uses .keep = FALSE)
450-
new_tibble(vec_recycle_common(
451-
time_value = c(missing_times, pad_early_dates, pad_late_dates),
452-
.real = FALSE
453-
))
454-
) %>%
455-
`[`(vec_order(.$time_value), )
456-
457-
if (f_from_package == "data.table") {
458-
# Grouping should ensure that we don't have duplicate time values.
459-
# Completion above should ensure we have at least .window_size rows. Check
460-
# that we don't have more than .window_size rows (or fewer somehow):
461-
if (nrow(.data_group) != length(c(all_dates, pad_early_dates, pad_late_dates))) {
462-
cli_abort(
463-
c(
464-
"group contains an unexpected number of rows",
465-
"i" = c("Input data may contain `time_values` closer together than the
466-
expected `time_step` size")
467-
),
468-
class = "epiprocess__epi_slide_opt__unexpected_row_number",
469-
epiprocess__data_group = .data_group,
470-
epiprocess__group_key = .group_key
471-
)
472-
}
473-
474-
# `frollmean` is 1-indexed, so create a new window width based on our
475-
# `before` and `after` params. Right-aligned `frollmean` results'
476-
# `ref_time_value`s will be `after` timesteps ahead of where they should
477-
# be; shift results to the left by `after` timesteps.
478-
if (window_args$before != Inf) {
479-
window_size <- window_args$before + window_args$after + 1L
480-
roll_output <- .f(x = .data_group[, input_col_names], n = window_size, ...)
481-
} else {
482-
window_size <- list(seq_along(.data_group$time_value))
483-
roll_output <- .f(x = .data_group[, input_col_names], n = window_size, adaptive = TRUE, ...)
484-
}
485-
if (window_args$after >= 1) {
486-
.data_group[, output_col_names] <- lapply(roll_output, function(out_col) {
487-
# Shift an appropriate amount of NA padding from the start to the end.
488-
# (This padding will later be cut off when we filter down to the
489-
# original time_values.)
490-
c(out_col[(window_args$after + 1L):length(out_col)], rep(NA, window_args$after))
491-
})
492-
} else {
493-
.data_group[, output_col_names] <- roll_output
494-
}
495-
}
496-
if (f_from_package == "slider") {
497-
for (i in seq_along(input_col_names)) {
498-
.data_group[, output_col_names[i]] <- .f(
499-
x = .data_group[[input_col_names[i]]],
500-
before = as.numeric(window_args$before),
501-
after = as.numeric(window_args$after),
502-
...
503-
)
504-
}
501+
f_dots_baked <-
502+
if (rlang::dots_n(...) == 0L) {
503+
# Leaving `.f` unchanged slightly improves computation speed and trims
504+
# debug stack traces:
505+
.f
506+
} else {
507+
purrr::partial(.f, ...)
505508
}
506509

507-
.data_group
508-
}
509-
510510
result <- .x %>%
511-
`[[<-`(".real", value = TRUE) %>%
512-
group_modify(slide_one_grp, ..., .keep = FALSE) %>%
513-
`[`(.$.real, names(.) != ".real") %>%
514-
arrange_col_canonical() %>%
515-
group_by(!!!.x_orig_groups)
511+
group_modify(function(grp_data, grp_key) {
512+
epi_slide_opt_one_epikey(grp_data, f_dots_baked, f_from_package, before, after, unit_step, time_type, vctrs::vec_set_intersect(ref_time_values, grp_data$time_value), names_info$input_col_names, names_info$output_col_names)
513+
}) %>%
514+
arrange_col_canonical()
516515

517516
if (.all_rows) {
518-
result[!vec_in(result$time_value, ref_time_values), output_col_names] <- NA
519-
} else if (user_provided_rtvs) {
520-
result <- result[vec_in(result$time_value, ref_time_values), ]
517+
ekt_names <- key_colnames(.x)
518+
result <- left_join(ungroup(.x), result[c(ekt_names, output_col_names)], by = ekt_names)
521519
}
522520

523-
if (!is_epi_df(result)) {
524-
# `.all_rows` handling strips epi_df format and metadata.
525-
# Restore them.
526-
result <- reclass(result, attributes(.x)$metadata)
527-
}
521+
result <- group_by(result, !!!.x_orig_groups)
528522

529523
return(result)
530524
}

0 commit comments

Comments
 (0)