@@ -167,6 +167,67 @@ across_ish_names_info <- function(.x, time_type, col_names_quo, .f_namer,
167167 )
168168}
169169
170+ epi_slide_opt_one_epikey <- function (inp_snapshot , f_dots_baked , f_from_package , before , after , unit_step , time_type , ref_time_values , in_colnames , out_colnames ) {
171+ # TODO try converting time values to reals, do work on reals, convert back at very end?
172+ if (before == Inf ) {
173+ if (after != 0L ) {
174+ cli_abort(' .window_size = Inf is only supported with .align = "right"' ,
175+ class = " epiprocess__epi_slide_opt_archive__inf_window_invalid_align"
176+ )
177+ }
178+ # We need to use the entire input snapshot range, filling in time gaps. We
179+ # shouldn't pad the ends.
180+ slide_t_min <- min(inp_snapshot $ time_value )
181+ slide_t_max <- max(inp_snapshot $ time_value )
182+ } else {
183+ # If the input had updates in the range t1..t2, this could produce changes
184+ # in slide outputs in the range t1-after..t2+before, and to compute those
185+ # slide values, we need to look at the input snapshot from
186+ # t1-after-before..t2+before+after. nolint: commented_code_linter
187+ inp_update_t_min <- min(ref_time_values )
188+ inp_update_t_max <- max(ref_time_values )
189+ slide_t_min <- inp_update_t_min - (before + after ) * unit_step
190+ slide_t_max <- inp_update_t_max + (before + after ) * unit_step
191+ }
192+ slide_nrow <- time_delta_to_n_steps(slide_t_max - slide_t_min , time_type ) + 1L
193+ slide_time_values <- slide_t_min + 0L : (slide_nrow - 1L ) * unit_step
194+ slide_inp_backrefs <- vec_match(slide_time_values , inp_snapshot $ time_value )
195+ # Get additional values needed from inp_snapshot + perform any NA
196+ # tail-padding needed to make slider results a fixed window size rather than
197+ # adaptive at tails + perform any NA gap-filling needed:
198+ slide <- vec_slice(inp_snapshot , slide_inp_backrefs )
199+ slide $ time_value <- slide_time_values
200+ if (f_from_package == " data.table" ) {
201+ if (before == Inf ) {
202+ slide [, out_colnames ] <-
203+ f_dots_baked(slide [, in_colnames ], seq_len(slide_nrow ), adaptive = TRUE )
204+ } else {
205+ out_cols <- f_dots_baked(slide [, in_colnames ], before + after + 1L )
206+ if (after != 0L ) {
207+ # Shift an appropriate amount of NA padding from the start to the end.
208+ # (This padding will later be cut off when we filter down to the
209+ # original time_values.)
210+ out_cols <- lapply(out_cols , function (out_col ) {
211+ c(out_col [(after + 1L ): length(out_col )], rep(NA , after ))
212+ })
213+ }
214+ slide [, out_colnames ] <- out_cols
215+ }
216+ } else if (f_from_package == " slider" ) {
217+ for (col_i in seq_along(in_colnames )) {
218+ slide [[out_colnames [[col_i ]]]] <- f_dots_baked(slide [[in_colnames [[col_i ]]]], before = before , after = after )
219+ }
220+ } else {
221+ cli_abort(
222+ " epiprocess internal error: `f_from_package` was {format_chr_deparse(f_from_package)}, which is unsupported" ,
223+ class = " epiprocess__epi_slide_opt_archive__f_from_package_invalid"
224+ )
225+ }
226+ rows_should_keep <- vec_match(ref_time_values , slide_time_values )
227+ out_update <- vec_slice(slide , rows_should_keep )
228+ out_update
229+ }
230+
170231# ' Optimized slide functions for common cases
171232# '
172233# ' @description
@@ -425,6 +486,9 @@ epi_slide_opt.epi_df <- function(.x, .col_names, .f, ...,
425486 }
426487 validate_slide_window_arg(.window_size , time_type )
427488 window_args <- get_before_after_from_window(.window_size , .align , time_type )
489+ before <- time_delta_to_n_steps(window_args $ before , time_type )
490+ after <- time_delta_to_n_steps(window_args $ after , time_type )
491+ unit_step <- unit_time_delta(time_type , format = " fast" )
428492
429493 # Handle output naming:
430494 names_info <- across_ish_names_info(
@@ -434,97 +498,27 @@ epi_slide_opt.epi_df <- function(.x, .col_names, .f, ...,
434498 input_col_names <- names_info $ input_col_names
435499 output_col_names <- names_info $ output_col_names
436500
437- # Make a complete date sequence between min(.x$time_value) and max(.x$time_value).
438- date_seq_list <- full_date_seq(.x , window_args $ before , window_args $ after , time_type )
439- all_dates <- date_seq_list $ all_dates
440- pad_early_dates <- date_seq_list $ pad_early_dates
441- pad_late_dates <- date_seq_list $ pad_late_dates
442-
443- slide_one_grp <- function (.data_group , .group_key , ... ) {
444- missing_times <- all_dates [! vec_in(all_dates , .data_group $ time_value )]
445- # `frollmean` requires a full window to compute a result. Add NA values
446- # to beginning and end of the group so that we get results for the
447- # first `before` and last `after` elements.
448- .data_group <- vec_rbind(
449- .data_group , # (tibble; epi_slide_opt uses .keep = FALSE)
450- new_tibble(vec_recycle_common(
451- time_value = c(missing_times , pad_early_dates , pad_late_dates ),
452- .real = FALSE
453- ))
454- ) %> %
455- `[`(vec_order(. $ time_value ), )
456-
457- if (f_from_package == " data.table" ) {
458- # Grouping should ensure that we don't have duplicate time values.
459- # Completion above should ensure we have at least .window_size rows. Check
460- # that we don't have more than .window_size rows (or fewer somehow):
461- if (nrow(.data_group ) != length(c(all_dates , pad_early_dates , pad_late_dates ))) {
462- cli_abort(
463- c(
464- " group contains an unexpected number of rows" ,
465- " i" = c(" Input data may contain `time_values` closer together than the
466- expected `time_step` size" )
467- ),
468- class = " epiprocess__epi_slide_opt__unexpected_row_number" ,
469- epiprocess__data_group = .data_group ,
470- epiprocess__group_key = .group_key
471- )
472- }
473-
474- # `frollmean` is 1-indexed, so create a new window width based on our
475- # `before` and `after` params. Right-aligned `frollmean` results'
476- # `ref_time_value`s will be `after` timesteps ahead of where they should
477- # be; shift results to the left by `after` timesteps.
478- if (window_args $ before != Inf ) {
479- window_size <- window_args $ before + window_args $ after + 1L
480- roll_output <- .f(x = .data_group [, input_col_names ], n = window_size , ... )
481- } else {
482- window_size <- list (seq_along(.data_group $ time_value ))
483- roll_output <- .f(x = .data_group [, input_col_names ], n = window_size , adaptive = TRUE , ... )
484- }
485- if (window_args $ after > = 1 ) {
486- .data_group [, output_col_names ] <- lapply(roll_output , function (out_col ) {
487- # Shift an appropriate amount of NA padding from the start to the end.
488- # (This padding will later be cut off when we filter down to the
489- # original time_values.)
490- c(out_col [(window_args $ after + 1L ): length(out_col )], rep(NA , window_args $ after ))
491- })
492- } else {
493- .data_group [, output_col_names ] <- roll_output
494- }
495- }
496- if (f_from_package == " slider" ) {
497- for (i in seq_along(input_col_names )) {
498- .data_group [, output_col_names [i ]] <- .f(
499- x = .data_group [[input_col_names [i ]]],
500- before = as.numeric(window_args $ before ),
501- after = as.numeric(window_args $ after ),
502- ...
503- )
504- }
501+ f_dots_baked <-
502+ if (rlang :: dots_n(... ) == 0L ) {
503+ # Leaving `.f` unchanged slightly improves computation speed and trims
504+ # debug stack traces:
505+ .f
506+ } else {
507+ purrr :: partial(.f , ... )
505508 }
506509
507- .data_group
508- }
509-
510510 result <- .x %> %
511- `[[<-`(" .real" , value = TRUE ) %> %
512- group_modify(slide_one_grp , ... , .keep = FALSE ) %> %
513- `[`(. $ .real , names(. ) != " .real" ) %> %
514- arrange_col_canonical() %> %
515- group_by(!!! .x_orig_groups )
511+ group_modify(function (grp_data , grp_key ) {
512+ epi_slide_opt_one_epikey(grp_data , f_dots_baked , f_from_package , before , after , unit_step , time_type , vctrs :: vec_set_intersect(ref_time_values , grp_data $ time_value ), names_info $ input_col_names , names_info $ output_col_names )
513+ }) %> %
514+ arrange_col_canonical()
516515
517516 if (.all_rows ) {
518- result [! vec_in(result $ time_value , ref_time_values ), output_col_names ] <- NA
519- } else if (user_provided_rtvs ) {
520- result <- result [vec_in(result $ time_value , ref_time_values ), ]
517+ ekt_names <- key_colnames(.x )
518+ result <- left_join(ungroup(.x ), result [c(ekt_names , output_col_names )], by = ekt_names )
521519 }
522520
523- if (! is_epi_df(result )) {
524- # `.all_rows` handling strips epi_df format and metadata.
525- # Restore them.
526- result <- reclass(result , attributes(.x )$ metadata )
527- }
521+ result <- group_by(result , !!! .x_orig_groups )
528522
529523 return (result )
530524}
0 commit comments