Skip to content

Commit e932f4a

Browse files
committed
Plot of many learning curves colored by date
1 parent cf9818c commit e932f4a

File tree

2 files changed

+294
-47
lines changed

2 files changed

+294
-47
lines changed

plot_results.py

Lines changed: 48 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -148,51 +148,52 @@ def plot4(descriptors: List[str], metrics: Dict[str, str], name: str):
148148
plt.show()
149149

150150

151-
if not os.path.exists('plots'):
152-
os.makedirs('plots')
153-
if not os.path.exists('plotspng'):
154-
os.makedirs('plotspng')
155-
156-
plot2dt('i17gv7pw', 'sidk0gu4')
157-
158-
baseline = Experiment(descriptor="f2034f-hpsetstandard", label="baseline")
159-
adr_ablations = [
160-
Experiment("f2034f-hpsetstandard-mothership_damage_scale0.0-mothership_damage_scale_schedule", "module cost, map curriculum"),
161-
Experiment("f2034f-adr_variety0.0-adr_variety_schedule-hpsetstandard", "mothership damage, map curriculum"),
162-
Experiment("f2034f-adr_variety0.0-adr_variety_schedule-hpsetstandard-mothership_damage_scale0.0-mothership_damage_scale_schedule", "map curriculum"),
163-
164-
Experiment("f2034f-adr_hstepsize0.0-hpsetstandard-linear_hardnessFalse-task_hardness150", "mothership damage, module cost, map randomization"),
165-
Experiment("f2034f-adr_hstepsize0.0-hpsetstandard-linear_hardnessFalse-mothership_damage_scale0.0-mothership_damage_scale_schedule-task_hardness150", "module cost, map randomization"),
166-
Experiment("f2034f-adr_hstepsize0.0-adr_variety0.0-adr_variety_schedule-hpsetstandard-linear_hardnessFalse-task_hardness150", "mothership damage, map randomization"),
167-
Experiment("f2034f-adr_hstepsize0.0-adr_variety0.0-adr_variety_schedule-hpsetstandard-linear_hardnessFalse-mothership_damage_scale0.0-mothership_damage_scale_schedule-task_hardness150", "map randomization"),
168-
169-
Experiment("049430-batches_per_update64-bs256-hpsetstandard", "mothership damage, module cost, fixed map"),
170-
Experiment("049430-batches_per_update64-bs256-hpsetstandard-mothership_damage_scale0.0-mothership_damage_scale_schedule", "module cost, fixed map"),
171-
Experiment("049430-adr_variety0.0-adr_variety_schedule-batches_per_update64-bs256-hpsetstandard", "mothership damage, fixed map"),
172-
Experiment("049430-adr_variety0.0-adr_variety_schedule-batches_per_update64-bs256-hpsetstandard-mothership_damage_scale0.0-mothership_damage_scale_schedule", "fixed map"),
173-
]
174-
ablations = [
175-
Experiment("f2034f-hpsetstandard-partial_score0.0", "sparse reward"),
176-
Experiment("f2034f-hpsetstandard-use_privilegedFalse", "non-omniscient value function"),
177-
Experiment("f2034f-d_agent128-d_item64-hpsetstandard", "smaller network"),
178-
Experiment("f2034f-batches_per_update64-bs256-hpsetstandard-rotational_invarianceFalse", "no rotational invariance"),
179-
Experiment("7a9d92-hpsetstandard", "no shared spatial embeddings"),
180-
*adr_ablations,
181-
]
182-
183-
184-
for xp in [baseline] + adr_ablations:
185-
label = xp.label
186-
score_mean, score_sem = final_score(xp.descriptor)
187-
print(f"{label} {score_mean} {score_sem}")
188-
189-
plot([baseline], tuple(EVAL_METRICS.values()), "Mean score against all opponents", "baseline")
190-
plot4([baseline], EVAL_METRICS, "breakdown")
191-
plot4([baseline, ablations[3]], EVAL_METRICS, "breakdown cost adr")
192-
193-
194-
for xp in ablations:
195-
print(f"plotting {xp.label}")
196-
plot([baseline, xp], tuple(EVAL_METRICS.values()), "Mean score against all opponents", xp.label)
197-
plot4([baseline, xp], EVAL_METRICS, f"breakdown {xp.label}")
151+
if __name__ == '__main__':
152+
if not os.path.exists('plots'):
153+
os.makedirs('plots')
154+
if not os.path.exists('plotspng'):
155+
os.makedirs('plotspng')
156+
157+
plot2dt('i17gv7pw', 'sidk0gu4')
158+
159+
baseline = Experiment(descriptor="f2034f-hpsetstandard", label="baseline")
160+
adr_ablations = [
161+
Experiment("f2034f-hpsetstandard-mothership_damage_scale0.0-mothership_damage_scale_schedule", "module cost, map curriculum"),
162+
Experiment("f2034f-adr_variety0.0-adr_variety_schedule-hpsetstandard", "mothership damage, map curriculum"),
163+
Experiment("f2034f-adr_variety0.0-adr_variety_schedule-hpsetstandard-mothership_damage_scale0.0-mothership_damage_scale_schedule", "map curriculum"),
164+
165+
Experiment("f2034f-adr_hstepsize0.0-hpsetstandard-linear_hardnessFalse-task_hardness150", "mothership damage, module cost, map randomization"),
166+
Experiment("f2034f-adr_hstepsize0.0-hpsetstandard-linear_hardnessFalse-mothership_damage_scale0.0-mothership_damage_scale_schedule-task_hardness150", "module cost, map randomization"),
167+
Experiment("f2034f-adr_hstepsize0.0-adr_variety0.0-adr_variety_schedule-hpsetstandard-linear_hardnessFalse-task_hardness150", "mothership damage, map randomization"),
168+
Experiment("f2034f-adr_hstepsize0.0-adr_variety0.0-adr_variety_schedule-hpsetstandard-linear_hardnessFalse-mothership_damage_scale0.0-mothership_damage_scale_schedule-task_hardness150", "map randomization"),
169+
170+
Experiment("049430-batches_per_update64-bs256-hpsetstandard", "mothership damage, module cost, fixed map"),
171+
Experiment("049430-batches_per_update64-bs256-hpsetstandard-mothership_damage_scale0.0-mothership_damage_scale_schedule", "module cost, fixed map"),
172+
Experiment("049430-adr_variety0.0-adr_variety_schedule-batches_per_update64-bs256-hpsetstandard", "mothership damage, fixed map"),
173+
Experiment("049430-adr_variety0.0-adr_variety_schedule-batches_per_update64-bs256-hpsetstandard-mothership_damage_scale0.0-mothership_damage_scale_schedule", "fixed map"),
174+
]
175+
ablations = [
176+
Experiment("f2034f-hpsetstandard-partial_score0.0", "sparse reward"),
177+
Experiment("f2034f-hpsetstandard-use_privilegedFalse", "non-omniscient value function"),
178+
Experiment("f2034f-d_agent128-d_item64-hpsetstandard", "smaller network"),
179+
Experiment("f2034f-batches_per_update64-bs256-hpsetstandard-rotational_invarianceFalse", "no rotational invariance"),
180+
Experiment("7a9d92-hpsetstandard", "no shared spatial embeddings"),
181+
*adr_ablations,
182+
]
183+
184+
185+
for xp in [baseline] + adr_ablations:
186+
label = xp.label
187+
score_mean, score_sem = final_score(xp.descriptor)
188+
print(f"{label} {score_mean} {score_sem}")
189+
190+
plot([baseline], tuple(EVAL_METRICS.values()), "Mean score against all opponents", "baseline")
191+
plot4([baseline], EVAL_METRICS, "breakdown")
192+
plot4([baseline, ablations[3]], EVAL_METRICS, "breakdown cost adr")
193+
194+
195+
for xp in ablations:
196+
print(f"plotting {xp.label}")
197+
plot([baseline, xp], tuple(EVAL_METRICS.values()), "Mean score against all opponents", xp.label)
198+
plot4([baseline, xp], EVAL_METRICS, f"breakdown {xp.label}")
198199

progress.ipynb

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
{
2+
"metadata": {
3+
"language_info": {
4+
"codemirror_mode": {
5+
"name": "ipython",
6+
"version": 3
7+
},
8+
"file_extension": ".py",
9+
"mimetype": "text/x-python",
10+
"name": "python",
11+
"nbconvert_exporter": "python",
12+
"pygments_lexer": "ipython3",
13+
"version": "3.7.5-final"
14+
},
15+
"orig_nbformat": 2,
16+
"kernelspec": {
17+
"name": "python3",
18+
"display_name": "Python 3.7.5 64-bit ('dcc': conda)",
19+
"metadata": {
20+
"interpreter": {
21+
"hash": "d660374ac31277b9ea7ee26abd64205cafc644f5ebc9b3efcbdb7eb83107acd0"
22+
}
23+
}
24+
}
25+
},
26+
"nbformat": 4,
27+
"nbformat_minor": 2,
28+
"cells": [
29+
{
30+
"cell_type": "code",
31+
"execution_count": null,
32+
"metadata": {},
33+
"outputs": [],
34+
"source": [
35+
"from plot_results import fetch_run_data\n",
36+
"import matplotlib.pyplot as plt\n",
37+
"import numpy as np\n",
38+
"import wandb\n",
39+
"from functools import lru_cache\n",
40+
"import matplotlib.dates as mdates\n",
41+
"from datetime import datetime \n"
42+
]
43+
},
44+
{
45+
"cell_type": "code",
46+
"execution_count": null,
47+
"metadata": {},
48+
"outputs": [],
49+
"source": [
50+
"\n",
51+
"@lru_cache(maxsize=None)\n",
52+
"def fetch_run_data(descriptor: str, metrics):\n",
53+
" if isinstance(metrics, str):\n",
54+
" metrics = [metrics]\n",
55+
" else:\n",
56+
" metrics = list(metrics)\n",
57+
" api = wandb.Api()\n",
58+
" runs = api.runs(\"cswinter/deep-codecraft-vs\", {\"config.descriptor\": descriptor})\n",
59+
" \n",
60+
" curves = []\n",
61+
" for run in runs:\n",
62+
" step = []\n",
63+
" value = []\n",
64+
" vals = run.history(keys=metrics, samples=100, pandas=False)\n",
65+
" for entry in vals:\n",
66+
" if metrics[0] in entry:\n",
67+
" step.append(entry['_step'] * 1e-6)\n",
68+
" meanvalue = np.array([entry[metric] for metric in metrics]).mean()\n",
69+
" value.append(meanvalue)\n",
70+
" curves.append((np.array(step), np.array(value)))\n",
71+
" return curves, runs[0].summary[\"_timestamp\"]"
72+
]
73+
},
74+
{
75+
"cell_type": "code",
76+
"execution_count": null,
77+
"metadata": {},
78+
"outputs": [],
79+
"source": [
80+
"runs = [\n",
81+
" \"154506-agents15-hpsetstandard-steps150e6\",\n",
82+
" \"24e131-agents15-hpsetstandard-steps150e6\",\n",
83+
" \"613056-agents15-hpsetstandard-steps150e6\",\n",
84+
" \"87c1ab-hpsetstandard\",\n",
85+
" \"8af81d-hpsetstandard-num_self_play30-num_vs_aggro_replicator1-num_vs_destroyer2-num_vs_replicator1\",\n",
86+
" \"d33903-batches_per_update32-batches_per_update_schedule-hpsetstandard-lr0.001-lr_schedulecosine-steps150e6\",\n",
87+
" \"49b7fa-entropy_bonus0.02-entropy_bonus_schedulelin 20e6:0.005,60e6:0.0-hpsetstandard\",\n",
88+
" \"49b7fa-feat_dist_to_wallTrue-hpsetstandard\",\n",
89+
" \"b9bab7-hpsetstandard-max_hardness150\",\n",
90+
" \"46e0b2-hpsetstandard-spatial_attnFalse\",\n",
91+
" \"2d9e29-hpsetstandard\",\n",
92+
" \"30ed5b-hpsetstandard-max_hardness175\",\n",
93+
" \"fc244e-hpsetstandard-spatial_attnTrue-spatial_attn_lr_multiplier10.0\",\n",
94+
" \"0a5940-hpsetstandard-item_item_attn_layers1-item_item_spatial_attnTrue-item_item_spatial_attn_vfFalse-max_grad_norm200\",\n",
95+
" \"0a5940-hpsetstandard-mothership_damage_scale4.0-mothership_damage_scale_schedulelin 50e6:1.0,150:0.0\",\n",
96+
" \"83a3af-hpsetstandard-mothership_damage_scale4.0-mothership_damage_scale_schedulelin 50e6:0.0\",\n",
97+
" \"667ac7-hpsetstandard\",\n",
98+
" \"80a87d-entropy_bonus0.15-entropy_bonus_schedulelin 15e6:0.07,60e6:0.0-hpsetstandard\",\n",
99+
" \"80a87d-entropy_bonus0.2-entropy_bonus_schedulelin 15e6:0.1,60e6:0.0-final_lr5e-05-hpsetstandard-lr0.0005-vf_coef1.0\",\n",
100+
" \"c0b3b4-hpsetstandard-partial_score0\",\n",
101+
" \"9fc3de-hpsetstandard\",\n",
102+
" \"9fc3de-adr_hstepsize0.001-hpsetstandard-linear_hardnessFalse\",\n",
103+
" \"ac84c0-gamma0.9997-hpsetstandard\",\n",
104+
" \"a1210b-gamma_schedulecos 1.0-hpsetstandard\",\n",
105+
" \"b9f907-adr_average_cost_target1-hpsetstandard\",\n",
106+
" \"5fb181-hpsetstandard\",\n",
107+
" \"5fb181-hpsetstandard-steps150e6\",\n",
108+
" \"3c69a5-adr_average_cost_target0.5-adr_avg_cost_schedulelin 80e6:1.0-hpsetstandard\",\n",
109+
" \"35b3a7-hpsetstandard-nearby_mapFalse-steps150e6\",\n",
110+
" \"152ec3-hpsetstandard-nearby_mapFalse-steps125e6\",\n",
111+
"]"
112+
]
113+
},
114+
{
115+
"cell_type": "code",
116+
"execution_count": null,
117+
"metadata": {
118+
"tags": []
119+
},
120+
"outputs": [],
121+
"source": [
122+
"fig, ax = plt.subplots(figsize=(19, 10))\n",
123+
"cmap = plt.get_cmap('viridis')\n",
124+
"\n",
125+
"t0 = 1593959023.8568478\n",
126+
"tn = 1607756232\n",
127+
"ts = []\n",
128+
"for ri, run in enumerate(runs):\n",
129+
" #print(f\"Fetching {run}\")\n",
130+
" curves, date = fetch_run_data(run, \"eval_mean_score\")\n",
131+
" samples = []\n",
132+
" values = []\n",
133+
" for curve in curves:\n",
134+
" ax.plot(curve[0], curve[1], color=cmap((date-t0)/(tn-t0)), marker='o')\n",
135+
" for i, value in enumerate(curve[1]):\n",
136+
" if len(values) <= i:\n",
137+
" samples.append(curve[0][i])\n",
138+
" values.append([value])\n",
139+
" else:\n",
140+
" values[i].append(value)\n",
141+
" #values = np.array([np.array(vals).mean() for vals in values])\n",
142+
" #ax.plot(samples, values, color=cmap((date-t0)/(tn-t0)), marker='o')\n",
143+
" #ts.append(mdates.date2num(datetime.fromtimestamp(date)))\n",
144+
"\n",
145+
"from matplotlib.cm import ScalarMappable\n",
146+
"from matplotlib.colors import Normalize\n",
147+
"loc = mdates.AutoDateLocator()\n",
148+
"def dateformatter(x, pos=None):\n",
149+
" return datetime.fromtimestamp(x*(tn-t0)+t0).strftime('%Y-%m-%d')\n",
150+
"fig.colorbar(ScalarMappable(cmap=cmap), ticks=loc, format=dateformatter)\n",
151+
"\n",
152+
"ax.set_yticks([-1.0, -0.5, 0, 0.5, 1])\n",
153+
"ax.set_xlim(0, 200)\n",
154+
"ax.grid()\n",
155+
"fig.show()"
156+
]
157+
},
158+
{
159+
"cell_type": "code",
160+
"execution_count": null,
161+
"metadata": {},
162+
"outputs": [],
163+
"source": [
164+
"fig, ax = plt.subplots(figsize=(20, 15))\n",
165+
"cmap = plt.get_cmap('viridis')\n",
166+
"\n",
167+
"t0 = 1593959023.8568478\n",
168+
"tn = 1607756232\n",
169+
"ts = []\n",
170+
"for ri, run in enumerate(runs):\n",
171+
" #print(f\"Fetching {run}\")\n",
172+
" curves, date = fetch_run_data(run, \"eval_mean_score\")\n",
173+
" samples = []\n",
174+
" values = []\n",
175+
" for curve in curves:\n",
176+
" for i, value in enumerate(curve[1]):\n",
177+
" if len(values) <= i:\n",
178+
" samples.append(curve[0][i])\n",
179+
" values.append([value])\n",
180+
" else:\n",
181+
" values[i].append(value)\n",
182+
" values = np.array([np.array(vals).mean() for vals in values])\n",
183+
" ax.plot(samples, values)#, color=cmap((date-t0)/(tn-t0)))\n",
184+
" #ts.append(mdates.date2num(datetime.fromtimestamp(date)))\n",
185+
"\n",
186+
"#from matplotlib.cm import ScalarMappable\n",
187+
"#from matplotlib.colors import Normalize\n",
188+
"#loc = mdates.AutoDateLocator()\n",
189+
"#fig.colorbar(ScalarMappable(norm=Normalize(t0, tn), cmap=cmap))#, ticks=loc, format=mdates.AutoDateFormatter(loc))\n",
190+
"\n",
191+
"ax.set(xlabel='million samples', ylim=(-1, 1))\n",
192+
"ax.set_yticks([-1.0, -0.5, 0, 0.5, 1])\n",
193+
"ax.set_xlim(0, 200e6)\n",
194+
"#ax.set_xticks([0, 25, 50, 75, 100, 125])\n",
195+
"ax.legend(loc='upper left')\n",
196+
"ax.grid()\n",
197+
"fig.show()"
198+
]
199+
},
200+
{
201+
"cell_type": "code",
202+
"execution_count": null,
203+
"metadata": {},
204+
"outputs": [],
205+
"source": [
206+
"api = wandb.Api()\n",
207+
"runs = api.runs(\"cswinter/deep-codecraft-vs\", {\"config.descriptor\": runs[0]})"
208+
]
209+
},
210+
{
211+
"cell_type": "code",
212+
"execution_count": null,
213+
"metadata": {},
214+
"outputs": [],
215+
"source": [
216+
"runs"
217+
]
218+
},
219+
{
220+
"cell_type": "code",
221+
"execution_count": null,
222+
"metadata": {},
223+
"outputs": [],
224+
"source": [
225+
"fetch_run_data(runs[-1], 'eval_mean_score')[1]"
226+
]
227+
},
228+
{
229+
"cell_type": "code",
230+
"execution_count": null,
231+
"metadata": {},
232+
"outputs": [],
233+
"source": []
234+
},
235+
{
236+
"cell_type": "code",
237+
"execution_count": null,
238+
"metadata": {},
239+
"outputs": [],
240+
"source": [
241+
"#help(runs[0])\n",
242+
"{metric: values for metric, values in runs[0].summary.items() if metric.startswith('eval')}"
243+
]
244+
}
245+
]
246+
}

0 commit comments

Comments
 (0)