Skip to content

Commit d59eaf7

Browse files
committed
add apo to monte cover
1 parent 7a9e2c1 commit d59eaf7

File tree

2 files changed

+153
-0
lines changed

2 files changed

+153
-0
lines changed

monte-cover/src/montecover/irm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Monte Carlo coverage simulations for IRM."""
22

3+
from montecover.irm.apo import APOCoverageSimulation
34
from montecover.irm.irm_ate import IRMATECoverageSimulation
45
from montecover.irm.irm_ate_sensitivity import IRMATESensitivityCoverageSimulation
56
from montecover.irm.irm_atte import IRMATTECoverageSimulation
@@ -8,6 +9,7 @@
89
from montecover.irm.irm_gate import IRMGATECoverageSimulation
910

1011
__all__ = [
12+
"APOCoverageSimulation",
1113
"IRMATECoverageSimulation",
1214
"IRMATESensitivityCoverageSimulation",
1315
"IRMATTECoverageSimulation",

monte-cover/src/montecover/irm/apo.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
from typing import Any, Dict, Optional
2+
3+
import doubleml as dml
4+
import numpy as np
5+
import pandas as pd
6+
from doubleml.datasets import make_irm_data_discrete_treatments
7+
8+
from montecover.base import BaseSimulation
9+
from montecover.utils import create_learner_from_config
10+
11+
12+
class APOCoverageSimulation(BaseSimulation):
13+
"""Simulation class for coverage properties of DoubleMLAPOs for APO estimation."""
14+
15+
def __init__(
16+
self,
17+
config_file: str,
18+
suppress_warnings: bool = True,
19+
log_level: str = "INFO",
20+
log_file: Optional[str] = None,
21+
):
22+
super().__init__(
23+
config_file=config_file,
24+
suppress_warnings=suppress_warnings,
25+
log_level=log_level,
26+
log_file=log_file,
27+
)
28+
29+
# Calculate oracle values
30+
self._calculate_oracle_values()
31+
32+
def _process_config_parameters(self):
33+
"""Process simulation-specific parameters from config"""
34+
# Process ML models in parameter grid
35+
assert "learners" in self.dml_parameters, "No learners specified in the config file"
36+
37+
required_learners = ["ml_g", "ml_m"]
38+
for learner in self.dml_parameters["learners"]:
39+
for ml in required_learners:
40+
assert ml in learner, f"No {ml} specified in the config file"
41+
42+
def _calculate_oracle_values(self):
43+
"""Calculate oracle values for the simulation."""
44+
self.logger.info("Calculating oracle values")
45+
46+
n_levels = self.dgp_parameters["n_levels"][0]
47+
data_apo_oracle = make_irm_data_discrete_treatments(
48+
n_obs=int(1e6), n_levels=n_levels, linear=self.dgp_parameters["linear"][0]
49+
)
50+
51+
y0 = data_apo_oracle["oracle_values"]["y0"]
52+
ite = data_apo_oracle["oracle_values"]["ite"]
53+
d = data_apo_oracle["d"]
54+
55+
average_ites = np.full(n_levels + 1, np.nan)
56+
apos = np.full(n_levels + 1, np.nan)
57+
for i in range(n_levels + 1):
58+
average_ites[i] = np.mean(ite[d == i]) * (i > 0)
59+
apos[i] = np.mean(y0) + average_ites[i]
60+
61+
ates = np.full(n_levels, np.nan)
62+
for i in range(n_levels):
63+
ates[i] = apos[i + 1] - apos[0]
64+
65+
self.logger.info(f"Levels and their counts:\n{np.unique(d, return_counts=True)}")
66+
self.logger.info(f"True APOs: {apos}")
67+
self.logger.info(f"True ATEs: {ates}")
68+
69+
self.oracle_values = dict()
70+
self.oracle_values["apos"] = apos
71+
self.oracle_values["ates"] = ates
72+
73+
def run_single_rep(self, dml_data: dml.DoubleMLData, dml_params: Dict[str, Any]) -> Dict[str, Any]:
74+
"""Run a single repetition with the given parameters."""
75+
# Extract parameters
76+
learner_config = dml_params["learners"]
77+
learner_g_name, ml_g = create_learner_from_config(learner_config["ml_g"])
78+
learner_m_name, ml_m = create_learner_from_config(learner_config["ml_m"])
79+
treatment_level = dml_params["treatment_level"]
80+
trimming_threshold = dml_params["trimming_threshold"]
81+
82+
# Model
83+
dml_model = dml.DoubleMLAPO(
84+
obj_dml_data=dml_data,
85+
ml_g=ml_g,
86+
ml_m=ml_m,
87+
treatment_level=treatment_level,
88+
trimming_threshold=trimming_threshold,
89+
)
90+
dml_model.fit()
91+
92+
result = {
93+
"coverage": [],
94+
}
95+
for level in self.confidence_parameters["level"]:
96+
level_result = dict()
97+
level_result["coverage"] = self._compute_coverage(
98+
thetas=dml_model.coef,
99+
oracle_thetas=self.oracle_values["apos"][treatment_level],
100+
confint=dml_model.confint(level=level),
101+
joint_confint=None,
102+
)
103+
104+
# add parameters to the result
105+
for res_metric in level_result.values():
106+
res_metric.update(
107+
{
108+
"Learner g": learner_g_name,
109+
"Learner m": learner_m_name,
110+
"level": level,
111+
}
112+
)
113+
for key, res in level_result.items():
114+
result[key].append(res)
115+
116+
return result
117+
118+
def summarize_results(self):
119+
"""Summarize the simulation results."""
120+
self.logger.info("Summarizing simulation results")
121+
122+
# Group by parameter combinations
123+
groupby_cols = ["Learner g", "Learner m", "level"]
124+
aggregation_dict = {
125+
"Coverage": "mean",
126+
"CI Length": "mean",
127+
"Bias": "mean",
128+
"repetition": "count",
129+
}
130+
131+
# Aggregate results (possibly multiple result dfs)
132+
result_summary = dict()
133+
for result_name, result_df in self.results.items():
134+
result_summary[result_name] = result_df.groupby(groupby_cols).agg(aggregation_dict).reset_index()
135+
self.logger.debug(f"Summarized {result_name} results")
136+
137+
return result_summary
138+
139+
def _generate_dml_data(self, dgp_params: Dict[str, Any]) -> dml.DoubleMLData:
140+
"""Generate data for the simulation."""
141+
data = make_irm_data_discrete_treatments(
142+
n_obs=dgp_params["n_obs"],
143+
n_levels=dgp_params["n_levels"],
144+
linear=dgp_params["linear"],
145+
)
146+
df_apo = pd.DataFrame(
147+
np.column_stack((data["y"], data["d"], data["x"])),
148+
columns=["y", "d"] + ["x" + str(i) for i in range(data["x"].shape[1])],
149+
)
150+
dml_data = dml.DoubleMLData(df_apo, "y", "d")
151+
return dml_data

0 commit comments

Comments
 (0)