Skip to content

Commit 0821ef3

Browse files
committed
WIP: devnet notebook
1 parent 8a880ec commit 0821ef3

File tree

2 files changed

+428
-0
lines changed

2 files changed

+428
-0
lines changed

docs/notebooks/devnet.ipynb

Lines changed: 228 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"id": "ea4ae65a-d555-4b54-96f9-11eed006adc2",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"# %pip uninstall -y coniferest\n",
11+
"# %pip install 'git+https://github.com/snad-space/coniferest@fix-devent-celeba'"
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": 2,
17+
"id": "3d9577061e9494ed",
18+
"metadata": {
19+
"ExecuteTime": {
20+
"end_time": "2024-03-13T15:41:49.204695Z",
21+
"start_time": "2024-03-13T15:41:49.201344Z"
22+
},
23+
"collapsed": false,
24+
"jupyter": {
25+
"outputs_hidden": false
26+
}
27+
},
28+
"outputs": [],
29+
"source": [
30+
"from collections import defaultdict\n",
31+
"\n",
32+
"import matplotlib.pyplot as plt\n",
33+
"import numpy as np\n",
34+
"from tqdm import tqdm\n",
35+
"\n",
36+
"from coniferest.aadforest import AADForest\n",
37+
"from coniferest.datasets import Dataset, DevNetDataset\n",
38+
"from coniferest.isoforest import IsolationForest\n",
39+
"from coniferest.label import Label\n",
40+
"from coniferest.pineforest import PineForest\n",
41+
"from coniferest.session.oracle import OracleSession, create_oracle_session"
42+
]
43+
},
44+
{
45+
"cell_type": "code",
46+
"execution_count": 3,
47+
"id": "initial_id",
48+
"metadata": {
49+
"ExecuteTime": {
50+
"end_time": "2024-03-13T15:41:49.210919Z",
51+
"start_time": "2024-03-13T15:41:49.206277Z"
52+
}
53+
},
54+
"outputs": [],
55+
"source": [
56+
"class Compare:\n",
57+
" models = {\n",
58+
" 'Isolation Forest': IsolationForest,\n",
59+
" 'AAD': AADForest,\n",
60+
" 'Pine Forest': PineForest,\n",
61+
" }\n",
62+
" \n",
63+
" def __init__(self, dataset: Dataset, *, iterations=100, n_jobs=-1):\n",
64+
" self.model_kwargs = {\n",
65+
" 'n_trees': 128,\n",
66+
" 'n_jobs': n_jobs,\n",
67+
" }\n",
68+
" self.session_kwargs = {\n",
69+
" 'data': dataset.data,\n",
70+
" 'labels': dataset.labels,\n",
71+
" 'max_iterations': iterations,\n",
72+
" }\n",
73+
" self.results = {}\n",
74+
" self.steps = np.arange(1, iterations + 1)\n",
75+
" self.total_anomaly_fraction = np.mean(dataset.labels == Label.A)\n",
76+
"\n",
77+
" def get_sessions(self, random_seed):\n",
78+
" model_kwargs = self.model_kwargs | {'random_seed': random_seed}\n",
79+
"\n",
80+
" return {\n",
81+
" name: create_oracle_session(model=model(**model_kwargs), **self.session_kwargs)\n",
82+
" for name, model in self.models.items()\n",
83+
" }\n",
84+
"\n",
85+
" def run(self, random_seeds):\n",
86+
" results = defaultdict(dict)\n",
87+
" \n",
88+
" for random_seed in tqdm(random_seeds):\n",
89+
" sessions = self.get_sessions(random_seed)\n",
90+
" for name, session in sessions.items():\n",
91+
" session.run()\n",
92+
" anomalies = np.cumsum(np.array(list(session.known_labels.values())) == Label.A)\n",
93+
" results[name][random_seed] = anomalies\n",
94+
"\n",
95+
" self.results |= results\n",
96+
" return self\n",
97+
" \n",
98+
" def plot(self, dataset_name: str, savefig=False):\n",
99+
" plt.figure(figsize=(8, 6))\n",
100+
" plt.title(f'Dataset: {dataset_name}')\n",
101+
"\n",
102+
" for name, anomalies_dict in self.results.items():\n",
103+
" anomalies = np.stack(list(anomalies_dict.values()))\n",
104+
" q10, median, q90 = np.quantile(anomalies, [0.1, 0.5, 0.9], axis = 0)\n",
105+
"\n",
106+
" plt.plot(self.steps, median, alpha=0.75, label=name)\n",
107+
" plt.fill_between(self.steps, q10, q90, alpha=0.5)\n",
108+
"\n",
109+
" plt.plot(self.steps, self.steps * self.total_anomaly_fraction, ls='--', color='grey', label='Theoretical radnom')\n",
110+
"\n",
111+
" plt.xlabel('Iteration')\n",
112+
" plt.ylabel('Number of anomalies')\n",
113+
" plt.grid()\n",
114+
" plt.legend()\n",
115+
" if savefig:\n",
116+
" plt.savefig(f'{dataset}.pdf')\n",
117+
" \n",
118+
" return self"
119+
]
120+
},
121+
{
122+
"cell_type": "code",
123+
"execution_count": null,
124+
"id": "71c337b3577915d5",
125+
"metadata": {
126+
"collapsed": false,
127+
"jupyter": {
128+
"outputs_hidden": false
129+
}
130+
},
131+
"outputs": [
132+
{
133+
"name": "stdout",
134+
"output_type": "stream",
135+
"text": [
136+
"['donors', 'census', 'fraud', 'celeba', 'backdoor', 'campaign', 'thyroid']\n",
137+
"donors\n"
138+
]
139+
},
140+
{
141+
"name": "stderr",
142+
"output_type": "stream",
143+
"text": [
144+
" 60%|██████████████████████████████████▏ | 12/20 [1:56:30<1:25:02, 637.84s/it]"
145+
]
146+
}
147+
],
148+
"source": [
149+
"print(DevNetDataset.avialble_datasets)\n",
150+
"\n",
151+
"seeds = range(20)\n",
152+
"\n",
153+
"for dataset in DevNetDataset.avialble_datasets:\n",
154+
" print(dataset)\n",
155+
" %time compare = Compare(DevNetDataset(dataset), iterations=100, n_jobs=10).run(seeds).plot(dataset, savefig=True)\n",
156+
" plt.show()"
157+
]
158+
},
159+
{
160+
"cell_type": "code",
161+
"execution_count": null,
162+
"id": "603f9b12-b5ca-470e-95ba-34e4c6571687",
163+
"metadata": {},
164+
"outputs": [],
165+
"source": [
166+
"%time compare = Compare(DevNetDataset(\"thyroid\"), iterations=7200, n_jobs=15).run([0]).plot(f'{dataset}_full', savefig=True)\n",
167+
"plt.show()"
168+
]
169+
},
170+
{
171+
"cell_type": "code",
172+
"execution_count": null,
173+
"id": "7e7fb96f-b3a4-4f33-8389-466ad23b9da6",
174+
"metadata": {},
175+
"outputs": [],
176+
"source": []
177+
}
178+
],
179+
"metadata": {
180+
"kernelspec": {
181+
"display_name": "Python 3 (ipykernel)",
182+
"language": "python",
183+
"name": "python3"
184+
},
185+
"language_info": {
186+
"codemirror_mode": {
187+
"name": "ipython",
188+
"version": 3
189+
},
190+
"file_extension": ".py",
191+
"mimetype": "text/x-python",
192+
"name": "python",
193+
"nbconvert_exporter": "python",
194+
"pygments_lexer": "ipython3",
195+
"version": "3.12.3"
196+
}
197+
},
198+
"nbformat": 4,
199+
"nbformat_minor": 5
200+
}

0 commit comments

Comments
 (0)