@@ -996,6 +996,7 @@ pred_arx_geo_pool <- usa_archive_dv |> epix_slide(
996
996
.versions = fc_time_values
997
997
)
998
998
```
999
+ [ Note] {.primary}: geo-pooling is the default in ` epipredict `
999
1000
1000
1001
``` {r arx-no-geo-pooling}
1001
1002
ma_archive_dv <- usa_archive_dv$DT |> filter(geo_value == "ma") |> as_epi_archive()
@@ -1055,9 +1056,9 @@ getAccuracy(ca, pred_ca_geo_pool, "CA")
1055
1056
## Predictions (geo-pooling, $h=28$)
1056
1057
1057
1058
``` {r finalized-ma-ny-tx}
1058
- ma <- df |> filter(geo_value == "ma")
1059
- ny <- df |> filter(geo_value == "ny")
1060
- tx <- df |> filter(geo_value == "tx")
1059
+ ma <- cases_deaths |> filter(geo_value == "ma")
1060
+ ny <- cases_deaths |> filter(geo_value == "ny")
1061
+ tx <- cases_deaths |> filter(geo_value == "tx")
1061
1062
```
1062
1063
1063
1064
``` {r arx-geo-pooling-plot}
@@ -1078,6 +1079,8 @@ pred_arx_geo_pool |>
1078
1079
1079
1080
```
1080
1081
1082
+
1083
+
1081
1084
``` {r error-geo-pooling-all-states}
1082
1085
rbind(getAccuracy(ca,
1083
1086
pred_arx_geo_pool |>
@@ -1157,7 +1160,7 @@ pred_arx_geo_pool_7 <- usa_archive_dv |>
1157
1160
predictors = c("deaths", "doctor_visits"),
1158
1161
trainer = linear_reg() |> set_engine("lm"),
1159
1162
args_list = arx_args_list(
1160
- lags = 0, #c(0, 7, 14),
1163
+ lags = 0,
1161
1164
ahead = 7,
1162
1165
quantile_levels = c(0.1, 0.9))
1163
1166
)$predictions |>
@@ -1290,7 +1293,7 @@ pred_qr_geo_pool <- usa_archive_dv |>
1290
1293
predictors = c("deaths", "doctor_visits"),
1291
1294
trainer = quantile_reg(),
1292
1295
args_list = arx_args_list(
1293
- lags = 0, #c(0, 7, 14),
1296
+ lags = 0,
1294
1297
ahead = 28,
1295
1298
quantile_levels = c(0.1, 0.9))
1296
1299
)$predictions |>
@@ -1339,6 +1342,45 @@ rbind(getAccuracy(ca,
1339
1342
"TX"))
1340
1343
```
1341
1344
1345
+ ## Predictions (geo-pooling + linear regression, $h=28$)
1346
+
1347
+ ``` {r arx-geo-pooling-plot-lm}
1348
+ #| fig-width: 7
1349
+ pred_arx_geo_pool |>
1350
+ filter(geo_value %in% c("ca", "ma", "ny", "tx")) |>
1351
+ ggplot(aes(target_date, .pred)) +
1352
+ geom_line(data = rbind(ca, ma, ny, tx), aes(x = time_value, y = deaths),
1353
+ inherit.aes = FALSE, na.rm = TRUE, alpha = .5) +
1354
+ geom_line(col = tertiary) +
1355
+ geom_ribbon(aes(ymin = `0.1`, ymax = `0.9`), alpha = .3, fill = tertiary) +
1356
+ geom_vline(xintercept = t0_date) +
1357
+ geom_vline(xintercept = t0_date + 28, lty = 2) +
1358
+ facet_wrap(vars(geo_value), scales = 'free_y') +
1359
+ labs(x = "", y = "Deaths per 100k people") +
1360
+ scale_y_continuous(expand = expansion(c(0, .05))) +
1361
+ theme(legend.position = "none")
1362
+
1363
+ ```
1364
+
1365
+ ``` {r error-geo-pooling-all-states-lm}
1366
+ rbind(getAccuracy(ca,
1367
+ pred_arx_geo_pool |>
1368
+ filter(geo_value == "ca" & target_date %in% ca$time_value),
1369
+ "CA"),
1370
+ getAccuracy(ma,
1371
+ pred_arx_geo_pool |>
1372
+ filter(geo_value == "ma" & target_date %in% ma$time_value),
1373
+ "MA"),
1374
+ getAccuracy(ny,
1375
+ pred_arx_geo_pool |>
1376
+ filter(geo_value == "ny" & target_date %in% ny$time_value),
1377
+ "NY"),
1378
+ getAccuracy(tx,
1379
+ pred_arx_geo_pool |>
1380
+ filter(geo_value == "tx" & target_date %in% tx$time_value),
1381
+ "TX"))
1382
+ ```
1383
+
1342
1384
1343
1385
# Build a forecaster from scratch
1344
1386
0 commit comments