Skip to content

Commit b2a3b37

Browse files
authored
Merge pull request #53 from cmu-delphi/alicecima-geo-pool-and-qr
Improve geo-pooling and quantile regression slides
2 parents b76546b + c6d7d69 commit b2a3b37

File tree

1 file changed

+47
-5
lines changed

1 file changed

+47
-5
lines changed

slides/day2-afternoon.qmd

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,7 @@ pred_arx_geo_pool <- usa_archive_dv |> epix_slide(
996996
.versions = fc_time_values
997997
)
998998
```
999+
[Note]{.primary}: geo-pooling is the default in `epipredict`
9991000

10001001
```{r arx-no-geo-pooling}
10011002
ma_archive_dv <- usa_archive_dv$DT |> filter(geo_value == "ma") |> as_epi_archive()
@@ -1055,9 +1056,9 @@ getAccuracy(ca, pred_ca_geo_pool, "CA")
10551056
## Predictions (geo-pooling, $h=28$)
10561057

10571058
```{r finalized-ma-ny-tx}
1058-
ma <- df |> filter(geo_value == "ma")
1059-
ny <- df |> filter(geo_value == "ny")
1060-
tx <- df |> filter(geo_value == "tx")
1059+
ma <- cases_deaths |> filter(geo_value == "ma")
1060+
ny <- cases_deaths |> filter(geo_value == "ny")
1061+
tx <- cases_deaths |> filter(geo_value == "tx")
10611062
```
10621063

10631064
```{r arx-geo-pooling-plot}
@@ -1078,6 +1079,8 @@ pred_arx_geo_pool |>
10781079
10791080
```
10801081

1082+
1083+
10811084
```{r error-geo-pooling-all-states}
10821085
rbind(getAccuracy(ca,
10831086
pred_arx_geo_pool |>
@@ -1157,7 +1160,7 @@ pred_arx_geo_pool_7 <- usa_archive_dv |>
11571160
predictors = c("deaths", "doctor_visits"),
11581161
trainer = linear_reg() |> set_engine("lm"),
11591162
args_list = arx_args_list(
1160-
lags = 0, #c(0, 7, 14),
1163+
lags = 0,
11611164
ahead = 7,
11621165
quantile_levels = c(0.1, 0.9))
11631166
)$predictions |>
@@ -1290,7 +1293,7 @@ pred_qr_geo_pool <- usa_archive_dv |>
12901293
predictors = c("deaths", "doctor_visits"),
12911294
trainer = quantile_reg(),
12921295
args_list = arx_args_list(
1293-
lags = 0, #c(0, 7, 14),
1296+
lags = 0,
12941297
ahead = 28,
12951298
quantile_levels = c(0.1, 0.9))
12961299
)$predictions |>
@@ -1339,6 +1342,45 @@ rbind(getAccuracy(ca,
13391342
"TX"))
13401343
```
13411344

1345+
## Predictions (geo-pooling + linear regression, $h=28$)
1346+
1347+
```{r arx-geo-pooling-plot-lm}
1348+
#| fig-width: 7
1349+
pred_arx_geo_pool |>
1350+
filter(geo_value %in% c("ca", "ma", "ny", "tx")) |>
1351+
ggplot(aes(target_date, .pred)) +
1352+
geom_line(data = rbind(ca, ma, ny, tx), aes(x = time_value, y = deaths),
1353+
inherit.aes = FALSE, na.rm = TRUE, alpha = .5) +
1354+
geom_line(col = tertiary) +
1355+
geom_ribbon(aes(ymin = `0.1`, ymax = `0.9`), alpha = .3, fill = tertiary) +
1356+
geom_vline(xintercept = t0_date) +
1357+
geom_vline(xintercept = t0_date + 28, lty = 2) +
1358+
facet_wrap(vars(geo_value), scales = 'free_y') +
1359+
labs(x = "", y = "Deaths per 100k people") +
1360+
scale_y_continuous(expand = expansion(c(0, .05))) +
1361+
theme(legend.position = "none")
1362+
1363+
```
1364+
1365+
```{r error-geo-pooling-all-states-lm}
1366+
rbind(getAccuracy(ca,
1367+
pred_arx_geo_pool |>
1368+
filter(geo_value == "ca" & target_date %in% ca$time_value),
1369+
"CA"),
1370+
getAccuracy(ma,
1371+
pred_arx_geo_pool |>
1372+
filter(geo_value == "ma" & target_date %in% ma$time_value),
1373+
"MA"),
1374+
getAccuracy(ny,
1375+
pred_arx_geo_pool |>
1376+
filter(geo_value == "ny" & target_date %in% ny$time_value),
1377+
"NY"),
1378+
getAccuracy(tx,
1379+
pred_arx_geo_pool |>
1380+
filter(geo_value == "tx" & target_date %in% tx$time_value),
1381+
"TX"))
1382+
```
1383+
13421384

13431385
# Build a forecaster from scratch
13441386

0 commit comments

Comments
 (0)