Skip to content

Commit d3158b1

Browse files
committed
add apos to monte cover
1 parent a684923 commit d3158b1

File tree

2 files changed

+166
-0
lines changed

2 files changed

+166
-0
lines changed

monte-cover/src/montecover/irm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Monte Carlo coverage simulations for IRM."""
22

33
from montecover.irm.apo import APOCoverageSimulation
4+
from montecover.irm.apos import APOSCoverageSimulation
45
from montecover.irm.irm_ate import IRMATECoverageSimulation
56
from montecover.irm.irm_ate_sensitivity import IRMATESensitivityCoverageSimulation
67
from montecover.irm.irm_atte import IRMATTECoverageSimulation
@@ -10,6 +11,7 @@
1011

1112
__all__ = [
1213
"APOCoverageSimulation",
14+
"APOSCoverageSimulation",
1315
"IRMATECoverageSimulation",
1416
"IRMATESensitivityCoverageSimulation",
1517
"IRMATTECoverageSimulation",
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
from typing import Any, Dict, Optional
2+
3+
import doubleml as dml
4+
import numpy as np
5+
import pandas as pd
6+
from doubleml.datasets import make_irm_data_discrete_treatments
7+
8+
from montecover.base import BaseSimulation
9+
from montecover.utils import create_learner_from_config
10+
11+
12+
class APOSCoverageSimulation(BaseSimulation):
13+
"""Simulation class for coverage properties of DoubleMLAPOs for APO estimation."""
14+
15+
def __init__(
16+
self,
17+
config_file: str,
18+
suppress_warnings: bool = True,
19+
log_level: str = "INFO",
20+
log_file: Optional[str] = None,
21+
):
22+
super().__init__(
23+
config_file=config_file,
24+
suppress_warnings=suppress_warnings,
25+
log_level=log_level,
26+
log_file=log_file,
27+
)
28+
29+
# Calculate oracle values
30+
self._calculate_oracle_values()
31+
32+
def _process_config_parameters(self):
33+
"""Process simulation-specific parameters from config"""
34+
# Process ML models in parameter grid
35+
assert "learners" in self.dml_parameters, "No learners specified in the config file"
36+
37+
required_learners = ["ml_g", "ml_m"]
38+
for learner in self.dml_parameters["learners"]:
39+
for ml in required_learners:
40+
assert ml in learner, f"No {ml} specified in the config file"
41+
42+
def _calculate_oracle_values(self):
43+
"""Calculate oracle values for the simulation."""
44+
self.logger.info("Calculating oracle values")
45+
46+
n_levels = self.dgp_parameters["n_levels"][0]
47+
data_apo_oracle = make_irm_data_discrete_treatments(
48+
n_obs=int(1e6), n_levels=n_levels, linear=self.dgp_parameters["linear"][0]
49+
)
50+
51+
y0 = data_apo_oracle["oracle_values"]["y0"]
52+
ite = data_apo_oracle["oracle_values"]["ite"]
53+
d = data_apo_oracle["d"]
54+
55+
average_ites = np.full(n_levels + 1, np.nan)
56+
apos = np.full(n_levels + 1, np.nan)
57+
for i in range(n_levels + 1):
58+
average_ites[i] = np.mean(ite[d == i]) * (i > 0)
59+
apos[i] = np.mean(y0) + average_ites[i]
60+
61+
ates = np.full(n_levels, np.nan)
62+
for i in range(n_levels):
63+
ates[i] = apos[i + 1] - apos[0]
64+
65+
self.logger.info(f"Levels and their counts:\n{np.unique(d, return_counts=True)}")
66+
self.logger.info(f"True APOs: {apos}")
67+
self.logger.info(f"True ATEs: {ates}")
68+
69+
self.oracle_values = dict()
70+
self.oracle_values["apos"] = apos
71+
self.oracle_values["ates"] = ates
72+
73+
def run_single_rep(self, dml_data: dml.DoubleMLData, dml_params: Dict[str, Any]) -> Dict[str, Any]:
74+
"""Run a single repetition with the given parameters."""
75+
# Extract parameters
76+
learner_config = dml_params["learners"]
77+
learner_g_name, ml_g = create_learner_from_config(learner_config["ml_g"])
78+
learner_m_name, ml_m = create_learner_from_config(learner_config["ml_m"])
79+
treatment_levels = dml_params["treatment_levels"]
80+
trimming_threshold = dml_params["trimming_threshold"]
81+
82+
# Model
83+
dml_model = dml.DoubleMLAPOS(
84+
obj_dml_data=dml_data,
85+
ml_g=ml_g,
86+
ml_m=ml_m,
87+
treatment_levels=treatment_levels,
88+
trimming_threshold=trimming_threshold,
89+
)
90+
dml_model.fit()
91+
dml_model.bootstrap(n_rep_boot=2000)
92+
93+
causal_contrast_model = dml_model.causal_contrast(reference_levels=0)
94+
causal_contrast_model.bootstrap(n_rep_boot=2000)
95+
96+
result = {
97+
"coverage": [],
98+
"causal_contrast": [],
99+
}
100+
for level in self.confidence_parameters["level"]:
101+
level_result = dict()
102+
level_result["coverage"] = self._compute_coverage(
103+
thetas=dml_model.coef,
104+
oracle_thetas=self.oracle_values["apos"],
105+
confint=dml_model.confint(level=level),
106+
joint_confint=dml_model.confint(level=level, joint=True),
107+
)
108+
level_result["causal_contrast"] = self._compute_coverage(
109+
thetas=causal_contrast_model.thetas,
110+
oracle_thetas=self.oracle_values["ates"],
111+
confint=causal_contrast_model.confint(level=level),
112+
joint_confint=causal_contrast_model.confint(level=level, joint=True),
113+
)
114+
115+
# add parameters to the result
116+
for res_metric in level_result.values():
117+
res_metric.update(
118+
{
119+
"Learner g": learner_g_name,
120+
"Learner m": learner_m_name,
121+
"level": level,
122+
}
123+
)
124+
for key, res in level_result.items():
125+
result[key].append(res)
126+
127+
return result
128+
129+
def summarize_results(self):
130+
"""Summarize the simulation results."""
131+
self.logger.info("Summarizing simulation results")
132+
133+
# Group by parameter combinations
134+
groupby_cols = ["Learner g", "Learner m", "level"]
135+
aggregation_dict = {
136+
"Coverage": "mean",
137+
"CI Length": "mean",
138+
"Bias": "mean",
139+
"Uniform Coverage": "mean",
140+
"Uniform CI Length": "mean",
141+
"repetition": "count",
142+
}
143+
144+
# Aggregate results (possibly multiple result dfs)
145+
result_summary = dict()
146+
for result_name, result_df in self.results.items():
147+
result_summary[result_name] = result_df.groupby(groupby_cols).agg(aggregation_dict).reset_index()
148+
self.logger.debug(f"Summarized {result_name} results")
149+
150+
return result_summary
151+
152+
def _generate_dml_data(self, dgp_params: Dict[str, Any]) -> dml.DoubleMLData:
153+
"""Generate data for the simulation."""
154+
data = make_irm_data_discrete_treatments(
155+
n_obs=dgp_params["n_obs"],
156+
n_levels=dgp_params["n_levels"],
157+
linear=dgp_params["linear"],
158+
)
159+
df_apo = pd.DataFrame(
160+
np.column_stack((data["y"], data["d"], data["x"])),
161+
columns=["y", "d"] + ["x" + str(i) for i in range(data["x"].shape[1])],
162+
)
163+
dml_data = dml.DoubleMLData(df_apo, "y", "d")
164+
return dml_data

0 commit comments

Comments
 (0)