-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathHDQN_agent.py
More file actions
342 lines (296 loc) · 14.1 KB
/
HDQN_agent.py
File metadata and controls
342 lines (296 loc) · 14.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
from collections import deque
from DQN_utils import preprocess
from visualization import plot_rewards, plot_option_ratio
import torch
import torch.nn as nn
import torch.optim as optim
import pickle
class HDQN_agent:
def __init__(self, env, network, buffer, epsilon=0.25, batch_size=32, option_len=10):
# config
self.env = env
self.epsilon = epsilon
self.batch_size = batch_size
self.window = 100
self.skip_frames = 4
# network
self.network = network
self.target_network = deepcopy(network)
self.tau = network.tau
# init
self.training_rewards = []
self.losses = {"meta":[], 0: [], 1:[], 2:[], 3:[]}
self.mean_training_rewards = []
self.mean_validation_rewards = []
self.sync_eps = []
self.rewards = 0
self.step_count = 0
self.s_0 = preprocess(self.env.reset())
self.state_buffer = deque(maxlen=self.tau)
[self.state_buffer.append(np.zeros(self.s_0.shape)) for i in range(self.tau)]
self.next_state_buffer = deepcopy(self.state_buffer)
# HDQN
self.option_len = option_len
self.meta_step_count = 0
self.meta_state = np.stack([deepcopy(self.state_buffer)])
self.meta_rewards = 0
self.meta_buffer = buffer
self.option_buffer = deepcopy(buffer)
self.meta_buffer.burn_in = batch_size
# analysis
self.train_ep_option_dist = np.zeros(self.network.n_options)
self.train_ep_option_ratio = {}
for o in range(self.network.n_options):
self.train_ep_option_ratio[o] = {}
self.eval_option_dist = np.zeros(self.network.n_options)
def take_step(self, mode='train'):
r_reg = 0
state_buffer = deepcopy(self.state_buffer)
if mode == 'explore':
action = self.env.action_space.sample()
else:
action = self.network.get_action(np.stack([self.state_buffer]), epsilon=self.epsilon)
self.step_count += 1
for i in range(self.skip_frames):
self.state_buffer.append(self.s_0)
s_1_raw_i, r_raw_i, done, _ = self.env.step(action)
self.rewards += r_raw_i
s_1_i = preprocess(s_1_raw_i)
self.next_state_buffer.append(s_1_i.copy())
self.s_0 = s_1_i.copy()
r_reg = max(r_reg, r_raw_i) # register max reward of the 4 skipped frames
if done:
break
self.meta_rewards += self.filter_reward(r_reg, done, network="meta")
self.option_buffer.append(state_buffer, action, r_reg, done, deepcopy(self.next_state_buffer))
return done
def take_option(self, mode='train'):
# Observe outcome
done = False
meta_next_state = np.stack([deepcopy(self.state_buffer)])
self.meta_rewards /= self.option_len # scale down to ~ [-1,1]
self.meta_rewards = max(min(self.meta_rewards,1),-1) # truncate to [-1,1]
self.meta_buffer.append(deepcopy(self.meta_state), self.network.current_option, self.meta_rewards, done, meta_next_state)
# Reset meta rewards and set next state
self.meta_rewards = 0
self.meta_state = meta_next_state
# Take next option
if mode == 'explore':
next_option = np.random.choice(self.network.n_options)
elif mode == 'flight':
next_option = 0
else:
next_option = self.network.get_option(self.meta_state, epsilon=self.epsilon)
self.meta_step_count += 1
self.network.current_option = next_option
# log it
if mode != "eval":
self.train_ep_option_dist[next_option] += 1
else:
self.eval_option_dist[next_option] += 1
return next_option
# HDQN training
def train(self, gamma=0.99, max_episodes=10000, batch_size=32,
network_update_frequency=4, network_sync_frequency=2000,
network_save_frequency=100, network_evaluate_frequency=100,
n_val_episodes=10, start_from_eps=0, checkpoint_path=None,
checkpoint_prefix="hdqn", plot_result=False,
epsilon_start=None, epsilon_end=None, epsilon_final_episode=None,
op_ratio_n_bins=100, meta_burn_in_ep=0, average_option_losses=False):
self.checkpoint_path = checkpoint_path
self.checkpoint_prefix = checkpoint_prefix
self.gamma = gamma
self.network_evaluate_frequency = network_evaluate_frequency
self.mean_validation_rewards = {} # possible to not reset
self.options_network_update_frequency = network_update_frequency*self.network.n_options
self.meta_network_update_frequency = network_update_frequency*self.option_len
self.op_ratio_window = max(round((max_episodes-start_from_eps)/op_ratio_n_bins),1)
self.meta_burn_in_ep = meta_burn_in_ep
self.average_option_losses = average_option_losses
# Annealing
if not (epsilon_start is None or epsilon_end is None or epsilon_final_episode is None):
self.epsilon = epsilon_start
eps_incr = (epsilon_end-epsilon_start)/epsilon_final_episode
# Here it is possible to start from nonzero episode (see DQN_agent.py)
# Populate the replay buffer
print("Populating replay buffer...")
p_step = 0
while self.option_buffer.burn_in_capacity() < 1 or self.meta_buffer.burn_in_capacity() < 1:
if p_step % self.option_len == 0:
self.take_option(mode="explore")
done = self.take_step(mode='explore')
if done:
self.s_0 = preprocess(self.env.reset())
p_step += 1
# Start learning
print("Beginning training...")
ep = start_from_eps
training = True
while training:
# reset state and reward
self.s_0 = preprocess(self.env.reset())
[self.state_buffer.append(np.zeros(self.s_0.shape)) for i in range(self.tau)]
self.next_state_buffer = deepcopy(self.state_buffer)
self.rewards = 0
done = False
# let option-heads specialize before learning to choose between them
meta_mode = "explore" if ep < self.meta_burn_in_ep else "train"
# Play a game
while not done:
# Take next option
if self.step_count % self.option_len == 0:
self.take_option(mode=meta_mode)
# Take next action
done = self.take_step(mode='train')
# Update options networks
if self.step_count % (self.options_network_update_frequency+1) == 0:
# update all networks at the same rate
self.update(network='options')
self.update(network='meta')
# Sync networks
if self.step_count % network_sync_frequency == 0:
self.target_network.load_state_dict(self.network.state_dict())
self.sync_eps.append(ep)
# Evaluate networks
if ep % network_evaluate_frequency == 0:
self.eval_performance(n_val_episodes, ep)
# Save networks and learning curves
if (ep+1) % network_save_frequency == 0:
# network
filename = 'checkpoint_'+str(ep)+'_eps.pth'
if self.checkpoint_prefix != "":
filename = self.checkpoint_prefix + "_" + filename
self.network.save_weights(filename)
# Learning curves
self.save_learning_curves(prefix=self.checkpoint_prefix)
# log stuff
self.training_rewards.append(self.rewards)
mean_rewards = np.mean(self.training_rewards[-self.window:])
self.mean_training_rewards.append(mean_rewards)
if meta_mode != "eval" and ep != 0 and ep % self.op_ratio_window == 0:
op_ratios = self.train_ep_option_dist/max(np.sum(self.train_ep_option_dist, keepdims=True),1)
for o in range(self.network.n_options):
self.train_ep_option_ratio[o][ep] = op_ratios[o]
self.train_ep_option_dist = np.zeros(self.network.n_options)
ep += 1
txt = "\rEpisode {:d} Mean Rewards {:.2f}\t\t"
print(txt.format(ep, mean_rewards), end="")
# Anneal epsilon
if ep <= epsilon_final_episode:
self.epsilon += eps_incr
# Wrap it up
if ep >= max_episodes:
print('\nEpisode limit reached.')
if plot_result:
self.plot_results()
return
def calculate_options_loss(self, batch, option):
dev = self.network.device
states, actions, rewards, dones, next_states = [i for i in batch]
# filter rewards based on option head
rewards = [self.filter_reward(rewards[i],dones[i],"options",option) for i in range(self.batch_size)]
rewards_t = torch.FloatTensor(rewards).to(device=dev)
actions_t = torch.LongTensor(np.array(actions)).reshape(-1,1).to(device=dev)
dones_t = torch.ByteTensor(dones).to(dtype=torch.bool).to(device=dev)
q_vals_raw = self.network.get_qvals(states)
qvals = torch.gather(q_vals_raw, 1, actions_t)
q_vals_next_raw = self.target_network.get_qvals(next_states)
qvals_next = torch.max(q_vals_next_raw, dim=-1)[0].detach()
qvals_next[dones_t] = 0 # Zero-out terminal states
expected_qvals = (self.gamma * qvals_next + rewards_t).reshape(-1,1)
loss = nn.MSELoss()(qvals, expected_qvals)
return loss
def calculate_meta_loss(self, batch):
dev = self.network.device
states, options, rewards, dones, next_states = [i for i in batch]
states = np.stack(states).squeeze()
next_states = np.stack(next_states).squeeze()
rewards_t = torch.FloatTensor(rewards).to(device=dev)
options_t = torch.LongTensor(np.array(options)).reshape(-1,1).to(device=dev)
dones_t = torch.ByteTensor(dones).to(dtype=torch.bool).to(device=dev)
o_vals_raw = self.network.get_ovals(states)
ovals = torch.gather(o_vals_raw, 1, options_t)
o_vals_next_raw = self.target_network.get_ovals(next_states)
ovals_next = torch.max(o_vals_next_raw, dim=-1)[0].detach()
ovals_next[dones_t] = 0 # Zero-out terminal states
expected_ovals = (self.gamma * ovals_next + rewards_t).reshape(-1,1)
loss = nn.MSELoss()(ovals, expected_ovals)
return loss
def update(self, network="options"):
# losses
self.network.optimizer.zero_grad()
if network == "meta":
batch = self.meta_buffer.sample_batch(batch_size=self.batch_size)
loss = self.calculate_meta_loss(batch)
if self.network.device == 'cuda':
self.losses['meta'].append(loss.detach().cpu().numpy())
else:
self.losses['meta'].append(loss.detach().numpy())
elif network == "options":
losses = []
for i in range(self.network.n_options):
batch = self.option_buffer.sample_batch(batch_size=self.batch_size)
losses.append(self.calculate_options_loss(batch, option=i))
if self.network.device == 'cuda':
self.losses[i].append(losses[i].detach().cpu().numpy())
else:
self.losses[i].append(losses[i].detach().numpy())
loss = sum(losses)
if self.average_option_losses:
loss /= self.network.n_options
# Backprop
loss.backward()
nn.utils.clip_grad_norm_(self.network.parameters(), self.network.clip_val)
self.network.optimizer.step()
def eval_performance(self, n_val_episodes, eps):
rewards = []
for _ in range(n_val_episodes):
r = float(self.play_a_game())
rewards.append(r)
self.mean_validation_rewards[int(eps)] = np.mean(np.array(rewards))
def filter_reward(self, r, done, network, opt=None):
if network == "options":
if opt is None:
opt = self.network.current_option
if (opt == 0) or (opt == 1 and r != 10) or (opt == 2 and r != 50) or (opt==3 and r <= 50):
r = 0
if opt == 0 and done:
r = -1
return max(-1, min(1, r))
def play_a_game(self):
self.state_buffer = deque(maxlen=self.tau) # init state buffer
init_state = preprocess(self.env.reset())
[self.state_buffer.append(np.zeros(init_state.shape)) for i in range(4)]
r = 0
done = False
step_count = 0
while not done:
if step_count % self.option_len == 0:
self.take_option(mode='eval')
action = self.network.get_action([self.state_buffer])
state, reward, done, _ = self.env.step(action)
step_count += 1
self.state_buffer.append(preprocess(state))
r += reward
return r
def save_learning_curves(self, name="learning_curves", prefix=""):
if self.checkpoint_path is None:
return
learning_curves = {
"train": self.training_rewards,
"mean_train": self.mean_training_rewards,
"val": self.mean_validation_rewards,
"option_ratio": self.train_ep_option_ratio,
"losses": self.losses
}
path = self.checkpoint_path + name + ".pkl"
if prefix != "":
path = self.checkpoint_path + prefix + "_" + name + ".pkl"
with open(path, 'wb+') as f:
pickle.dump(learning_curves, f, pickle.HIGHEST_PROTOCOL)
def plot_results(self, start_from_eps=0):
plot_rewards(self.training_rewards, self.mean_training_rewards, self.mean_validation_rewards, start_from_eps)
plot_option_ratio(self.train_ep_option_ratio, self.op_ratio_window)