1
+ {
2
+ "metadata" : {
3
+ "language_info" : {
4
+ "codemirror_mode" : {
5
+ "name" : " ipython" ,
6
+ "version" : 3
7
+ },
8
+ "file_extension" : " .py" ,
9
+ "mimetype" : " text/x-python" ,
10
+ "name" : " python" ,
11
+ "nbconvert_exporter" : " python" ,
12
+ "pygments_lexer" : " ipython3" ,
13
+ "version" : " 3.7.5-final"
14
+ },
15
+ "orig_nbformat" : 2 ,
16
+ "kernelspec" : {
17
+ "name" : " python3" ,
18
+ "display_name" : " Python 3.7.5 64-bit ('dcc': conda)" ,
19
+ "metadata" : {
20
+ "interpreter" : {
21
+ "hash" : " d660374ac31277b9ea7ee26abd64205cafc644f5ebc9b3efcbdb7eb83107acd0"
22
+ }
23
+ }
24
+ }
25
+ },
26
+ "nbformat" : 4 ,
27
+ "nbformat_minor" : 2 ,
28
+ "cells" : [
29
+ {
30
+ "cell_type" : " code" ,
31
+ "execution_count" : null ,
32
+ "metadata" : {},
33
+ "outputs" : [],
34
+ "source" : [
35
+ " from plot_results import fetch_run_data\n " ,
36
+ " import matplotlib.pyplot as plt\n " ,
37
+ " import numpy as np\n " ,
38
+ " import wandb\n " ,
39
+ " from functools import lru_cache\n " ,
40
+ " import matplotlib.dates as mdates\n " ,
41
+ " from datetime import datetime \n "
42
+ ]
43
+ },
44
+ {
45
+ "cell_type" : " code" ,
46
+ "execution_count" : null ,
47
+ "metadata" : {},
48
+ "outputs" : [],
49
+ "source" : [
50
+ " \n " ,
51
+ " @lru_cache(maxsize=None)\n " ,
52
+ " def fetch_run_data(descriptor: str, metrics):\n " ,
53
+ " if isinstance(metrics, str):\n " ,
54
+ " metrics = [metrics]\n " ,
55
+ " else:\n " ,
56
+ " metrics = list(metrics)\n " ,
57
+ " api = wandb.Api()\n " ,
58
+ " runs = api.runs(\" cswinter/deep-codecraft-vs\" , {\" config.descriptor\" : descriptor})\n " ,
59
+ " \n " ,
60
+ " curves = []\n " ,
61
+ " for run in runs:\n " ,
62
+ " step = []\n " ,
63
+ " value = []\n " ,
64
+ " vals = run.history(keys=metrics, samples=100, pandas=False)\n " ,
65
+ " for entry in vals:\n " ,
66
+ " if metrics[0] in entry:\n " ,
67
+ " step.append(entry['_step'] * 1e-6)\n " ,
68
+ " meanvalue = np.array([entry[metric] for metric in metrics]).mean()\n " ,
69
+ " value.append(meanvalue)\n " ,
70
+ " curves.append((np.array(step), np.array(value)))\n " ,
71
+ " return curves, runs[0].summary[\" _timestamp\" ]"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type" : " code" ,
76
+ "execution_count" : null ,
77
+ "metadata" : {},
78
+ "outputs" : [],
79
+ "source" : [
80
+ " runs = [\n " ,
81
+ " \" 154506-agents15-hpsetstandard-steps150e6\" ,\n " ,
82
+ " \" 24e131-agents15-hpsetstandard-steps150e6\" ,\n " ,
83
+ " \" 613056-agents15-hpsetstandard-steps150e6\" ,\n " ,
84
+ " \" 87c1ab-hpsetstandard\" ,\n " ,
85
+ " \" 8af81d-hpsetstandard-num_self_play30-num_vs_aggro_replicator1-num_vs_destroyer2-num_vs_replicator1\" ,\n " ,
86
+ " \" d33903-batches_per_update32-batches_per_update_schedule-hpsetstandard-lr0.001-lr_schedulecosine-steps150e6\" ,\n " ,
87
+ " \" 49b7fa-entropy_bonus0.02-entropy_bonus_schedulelin 20e6:0.005,60e6:0.0-hpsetstandard\" ,\n " ,
88
+ " \" 49b7fa-feat_dist_to_wallTrue-hpsetstandard\" ,\n " ,
89
+ " \" b9bab7-hpsetstandard-max_hardness150\" ,\n " ,
90
+ " \" 46e0b2-hpsetstandard-spatial_attnFalse\" ,\n " ,
91
+ " \" 2d9e29-hpsetstandard\" ,\n " ,
92
+ " \" 30ed5b-hpsetstandard-max_hardness175\" ,\n " ,
93
+ " \" fc244e-hpsetstandard-spatial_attnTrue-spatial_attn_lr_multiplier10.0\" ,\n " ,
94
+ " \" 0a5940-hpsetstandard-item_item_attn_layers1-item_item_spatial_attnTrue-item_item_spatial_attn_vfFalse-max_grad_norm200\" ,\n " ,
95
+ " \" 0a5940-hpsetstandard-mothership_damage_scale4.0-mothership_damage_scale_schedulelin 50e6:1.0,150:0.0\" ,\n " ,
96
+ " \" 83a3af-hpsetstandard-mothership_damage_scale4.0-mothership_damage_scale_schedulelin 50e6:0.0\" ,\n " ,
97
+ " \" 667ac7-hpsetstandard\" ,\n " ,
98
+ " \" 80a87d-entropy_bonus0.15-entropy_bonus_schedulelin 15e6:0.07,60e6:0.0-hpsetstandard\" ,\n " ,
99
+ " \" 80a87d-entropy_bonus0.2-entropy_bonus_schedulelin 15e6:0.1,60e6:0.0-final_lr5e-05-hpsetstandard-lr0.0005-vf_coef1.0\" ,\n " ,
100
+ " \" c0b3b4-hpsetstandard-partial_score0\" ,\n " ,
101
+ " \" 9fc3de-hpsetstandard\" ,\n " ,
102
+ " \" 9fc3de-adr_hstepsize0.001-hpsetstandard-linear_hardnessFalse\" ,\n " ,
103
+ " \" ac84c0-gamma0.9997-hpsetstandard\" ,\n " ,
104
+ " \" a1210b-gamma_schedulecos 1.0-hpsetstandard\" ,\n " ,
105
+ " \" b9f907-adr_average_cost_target1-hpsetstandard\" ,\n " ,
106
+ " \" 5fb181-hpsetstandard\" ,\n " ,
107
+ " \" 5fb181-hpsetstandard-steps150e6\" ,\n " ,
108
+ " \" 3c69a5-adr_average_cost_target0.5-adr_avg_cost_schedulelin 80e6:1.0-hpsetstandard\" ,\n " ,
109
+ " \" 35b3a7-hpsetstandard-nearby_mapFalse-steps150e6\" ,\n " ,
110
+ " \" 152ec3-hpsetstandard-nearby_mapFalse-steps125e6\" ,\n " ,
111
+ " ]"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type" : " code" ,
116
+ "execution_count" : null ,
117
+ "metadata" : {
118
+ "tags" : []
119
+ },
120
+ "outputs" : [],
121
+ "source" : [
122
+ " fig, ax = plt.subplots(figsize=(19, 10))\n " ,
123
+ " cmap = plt.get_cmap('viridis')\n " ,
124
+ " \n " ,
125
+ " t0 = 1593959023.8568478\n " ,
126
+ " tn = 1607756232\n " ,
127
+ " ts = []\n " ,
128
+ " for ri, run in enumerate(runs):\n " ,
129
+ " #print(f\" Fetching {run}\" )\n " ,
130
+ " curves, date = fetch_run_data(run, \" eval_mean_score\" )\n " ,
131
+ " samples = []\n " ,
132
+ " values = []\n " ,
133
+ " for curve in curves:\n " ,
134
+ " ax.plot(curve[0], curve[1], color=cmap((date-t0)/(tn-t0)), marker='o')\n " ,
135
+ " for i, value in enumerate(curve[1]):\n " ,
136
+ " if len(values) <= i:\n " ,
137
+ " samples.append(curve[0][i])\n " ,
138
+ " values.append([value])\n " ,
139
+ " else:\n " ,
140
+ " values[i].append(value)\n " ,
141
+ " #values = np.array([np.array(vals).mean() for vals in values])\n " ,
142
+ " #ax.plot(samples, values, color=cmap((date-t0)/(tn-t0)), marker='o')\n " ,
143
+ " #ts.append(mdates.date2num(datetime.fromtimestamp(date)))\n " ,
144
+ " \n " ,
145
+ " from matplotlib.cm import ScalarMappable\n " ,
146
+ " from matplotlib.colors import Normalize\n " ,
147
+ " loc = mdates.AutoDateLocator()\n " ,
148
+ " def dateformatter(x, pos=None):\n " ,
149
+ " return datetime.fromtimestamp(x*(tn-t0)+t0).strftime('%Y-%m-%d')\n " ,
150
+ " fig.colorbar(ScalarMappable(cmap=cmap), ticks=loc, format=dateformatter)\n " ,
151
+ " \n " ,
152
+ " ax.set_yticks([-1.0, -0.5, 0, 0.5, 1])\n " ,
153
+ " ax.set_xlim(0, 200)\n " ,
154
+ " ax.grid()\n " ,
155
+ " fig.show()"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type" : " code" ,
160
+ "execution_count" : null ,
161
+ "metadata" : {},
162
+ "outputs" : [],
163
+ "source" : [
164
+ " fig, ax = plt.subplots(figsize=(20, 15))\n " ,
165
+ " cmap = plt.get_cmap('viridis')\n " ,
166
+ " \n " ,
167
+ " t0 = 1593959023.8568478\n " ,
168
+ " tn = 1607756232\n " ,
169
+ " ts = []\n " ,
170
+ " for ri, run in enumerate(runs):\n " ,
171
+ " #print(f\" Fetching {run}\" )\n " ,
172
+ " curves, date = fetch_run_data(run, \" eval_mean_score\" )\n " ,
173
+ " samples = []\n " ,
174
+ " values = []\n " ,
175
+ " for curve in curves:\n " ,
176
+ " for i, value in enumerate(curve[1]):\n " ,
177
+ " if len(values) <= i:\n " ,
178
+ " samples.append(curve[0][i])\n " ,
179
+ " values.append([value])\n " ,
180
+ " else:\n " ,
181
+ " values[i].append(value)\n " ,
182
+ " values = np.array([np.array(vals).mean() for vals in values])\n " ,
183
+ " ax.plot(samples, values)#, color=cmap((date-t0)/(tn-t0)))\n " ,
184
+ " #ts.append(mdates.date2num(datetime.fromtimestamp(date)))\n " ,
185
+ " \n " ,
186
+ " #from matplotlib.cm import ScalarMappable\n " ,
187
+ " #from matplotlib.colors import Normalize\n " ,
188
+ " #loc = mdates.AutoDateLocator()\n " ,
189
+ " #fig.colorbar(ScalarMappable(norm=Normalize(t0, tn), cmap=cmap))#, ticks=loc, format=mdates.AutoDateFormatter(loc))\n " ,
190
+ " \n " ,
191
+ " ax.set(xlabel='million samples', ylim=(-1, 1))\n " ,
192
+ " ax.set_yticks([-1.0, -0.5, 0, 0.5, 1])\n " ,
193
+ " ax.set_xlim(0, 200e6)\n " ,
194
+ " #ax.set_xticks([0, 25, 50, 75, 100, 125])\n " ,
195
+ " ax.legend(loc='upper left')\n " ,
196
+ " ax.grid()\n " ,
197
+ " fig.show()"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type" : " code" ,
202
+ "execution_count" : null ,
203
+ "metadata" : {},
204
+ "outputs" : [],
205
+ "source" : [
206
+ " api = wandb.Api()\n " ,
207
+ " runs = api.runs(\" cswinter/deep-codecraft-vs\" , {\" config.descriptor\" : runs[0]})"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type" : " code" ,
212
+ "execution_count" : null ,
213
+ "metadata" : {},
214
+ "outputs" : [],
215
+ "source" : [
216
+ " runs"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type" : " code" ,
221
+ "execution_count" : null ,
222
+ "metadata" : {},
223
+ "outputs" : [],
224
+ "source" : [
225
+ " fetch_run_data(runs[-1], 'eval_mean_score')[1]"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type" : " code" ,
230
+ "execution_count" : null ,
231
+ "metadata" : {},
232
+ "outputs" : [],
233
+ "source" : []
234
+ },
235
+ {
236
+ "cell_type" : " code" ,
237
+ "execution_count" : null ,
238
+ "metadata" : {},
239
+ "outputs" : [],
240
+ "source" : [
241
+ " #help(runs[0])\n " ,
242
+ " {metric: values for metric, values in runs[0].summary.items() if metric.startswith('eval')}"
243
+ ]
244
+ }
245
+ ]
246
+ }
0 commit comments