-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathalgae.py
571 lines (430 loc) · 20.8 KB
/
algae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
"""Implementation of AlgaeDICE.
Based on the publication "AlgaeDICE: Policy Gradient from Arbitrary Experience"
by Ofir Nachum, Bo Dai, Ilya Kostrikov, Yinlam Chow, Lihong Li, Dale Schuurmans.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
import torch
import keras_utils as keras_utils
import torch.nn as nn
import torch.nn.functional as F
import math
from torch import distributions as pyd
ds = tfp.distributions
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def soft_update_params(net, target_net, tau):
for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
def torch_to_tf_tensor(x):
return tf.convert_to_tensor(x.detach().cpu().numpy())
def to_np(t):
if t is None:
return None
elif t.nelement() == 0:
return np.array([])
else:
return t.cpu().detach().numpy()
def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None):
if hidden_depth == 0:
mods = [nn.Linear(input_dim, output_dim)]
else:
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=False)]
for i in range(hidden_depth - 1):
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=False)]
mods.append(nn.Linear(hidden_dim, output_dim))
if output_mod is not None:
mods.append(output_mod)
trunk = nn.Sequential(*mods)
return trunk
class TanhTransform(pyd.transforms.Transform):
domain = pyd.constraints.real
codomain = pyd.constraints.interval(-1.0, 1.0)
bijective = True
sign = +1
def __init__(self, cache_size=1):
super(TanhTransform,self).__init__(cache_size=cache_size)
@staticmethod
def atanh(x):
return 0.5 * (x.log1p() - (-x).log1p())
def __eq__(self, other):
return isinstance(other, TanhTransform)
def _call(self, x):
return x.tanh()
def _inverse(self, y):
# We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
# one should use `cache_size=1` instead
return self.atanh(y)
def log_abs_det_jacobian(self, x, y):
# We use a formula that is more numerically stable, see details in the following link
# https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
return 2. * (math.log(2.) - x - F.softplus(-2. * x))
class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
def __init__(self, loc, scale):
self.loc = loc
self.scale = scale
self.base_dist = pyd.Normal(loc, scale)
transforms = [TanhTransform()]
super(SquashedNormal, self).__init__(self.base_dist, transforms)
@property
def mean(self):
mu = self.loc
for tr in self.transforms:
mu = tr(mu)
return mu
# source https://github.com/kevinzakka/pytorch-goodies
def orthogonal_regularization(model, device):
with torch.enable_grad():
reg = 1e-6
orth_loss = torch.zeros(1).to(device)
for name, param in model.named_parameters():
if 'bias' not in name:
param_flat = param.view(param.shape[0], -1)
sym = torch.mm(param_flat, torch.t(param_flat))
sym -= torch.eye(param_flat.shape[0]).to(device)
orth_loss = orth_loss + (reg * sym.abs().sum())
return orth_loss
def weight_init(m):
"""Custom weight init for Conv2D and Linear layers."""
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
gain = nn.init.calculate_gain('relu')
nn.init.orthogonal_(m.weight.data, gain)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, action_range, log_std_bounds=[-5, 2]):
super(Actor, self).__init__()
self.log_std_bounds = log_std_bounds
self.trunk = mlp(state_dim, 256, 2 * action_dim, hidden_depth=2)
self.action_range = action_range
self.outputs = dict()
self.apply(weight_init)
def forward(self, obs):
mu, log_std = self.trunk(obs).chunk(2, dim=-1)
# constrain log_std inside [log_std_min, log_std_max]
log_std = torch.tanh(log_std)
log_std_min, log_std_max = self.log_std_bounds
log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std +
1)
std = log_std.exp()
self.outputs['mu'] = mu
self.outputs['std'] = std
dist = SquashedNormal(mu, std)
return dist
class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super(Critic, self).__init__()
self.Q1 = mlp(state_dim + action_dim,
256, 1, 2)
self.outputs = dict()
self.apply(weight_init)
def forward(self, obs, action):
assert obs.size(0) == action.size(0)
obs_action = torch.cat([obs, action], dim=-1)
q1 = self.Q1(obs_action)
self.outputs['q1'] = q1
return q1
class DoubleCritic(nn.Module):
def __init__(self, state_dim, action_dim):
super(DoubleCritic, self).__init__()
self.Q1 = mlp(state_dim + action_dim,
256, 1, 2)
self.Q2 = mlp(state_dim + action_dim,
256, 1, 2)
self.outputs = dict()
self.apply(weight_init)
def forward(self, obs, action):
assert obs.size(0) == action.size(0)
obs_action = torch.cat([obs, action], dim=-1)
q1 = self.Q1(obs_action)
q2 = self.Q2(obs_action)
self.outputs['q1'] = q1
self.outputs['q2'] = q2
return q1, q2
class ALGAE(object):
"""Class performing algae training."""
def __init__(self,
state_dim,
action_dim,
action_range,
log_interval,
actor_lr=1e-3,
critic_lr=1e-3,
alpha_init=1.0,
learn_alpha=True,
algae_alpha=1.0,
use_dqn=True,
use_init_states=True,
exponent=2.0):
"""Creates networks.
Args:
state_dim: State size.
action_dim: Action size.
log_interval: Log losses every N steps.
actor_lr: Actor learning rate.
critic_lr: Critic learning rate.
alpha_init: Initial temperature value for causal entropy regularization.
learn_alpha: Whether to learn alpha or not.
algae_alpha: Algae regularization weight.
use_dqn: Whether to use double networks for target value.
use_init_states: Whether to use initial states in objective.
exponent: Exponent p of function f(x) = |x|^p / p.
"""
self.action_range = action_range
self.actor = Actor(state_dim, action_dim, action_range).to(device)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
self.use_init_states = use_init_states
if use_dqn:
self.critic = DoubleCritic(state_dim, action_dim).to(device)
self.critic_target = DoubleCritic(state_dim, action_dim).to(device)
else:
self.critic = Critic(state_dim, action_dim).to(device)
self.critic_target = Critic(state_dim, action_dim).to(device)
soft_update_params(self.critic, self.critic_target, tau=1.0)
self._lambda = torch.tensor(0.0).to(device)
self._lambda.requires_grad = True
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
initial_temperature = alpha_init
self.log_alpha = torch.tensor(np.log(initial_temperature)).to(device)
self.log_alpha.requires_grad = True
self.learn_alpha = learn_alpha
self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha])
self.log_interval = log_interval
self.algae_alpha = algae_alpha
self.use_dqn = use_dqn
self.exponent = exponent
self.device=device
if self.exponent <= 1:
raise ValueError('Exponent must be greather than 1, but received %f.' %
self.exponent)
self.f = lambda resid: torch.pow(torch.abs(resid), self.exponent) / self.exponent
clip_resid = lambda resid: torch.clamp(resid, 0.0, 1e6)
self.fgrad = lambda resid: torch.pow(clip_resid(resid), self.exponent - 1)
# save
self.avg_actor_loss = tf.keras.metrics.Mean('actor_loss', dtype=tf.float32)
self.avg_alpha_loss = tf.keras.metrics.Mean('alpha_loss', dtype=tf.float32)
self.avg_actor_entropy = tf.keras.metrics.Mean('actor_entropy', dtype=tf.float32)
self.avg_alpha = tf.keras.metrics.Mean('alpha', dtype=tf.float32)
self.avg_lambda = tf.keras.metrics.Mean('lambda', dtype=tf.float32)
self.avg_critic_loss = tf.keras.metrics.Mean('critic_loss', dtype=tf.float32)
self.training = True
self.train()
self.critic_target.train()
def train(self, training=True):
self.training = training
self.actor.train(training)
self.critic.train(training)
@property
def alpha(self):
return self.log_alpha.exp()
def sample(self, obs):
dist = self.actor(obs)
action = dist.rsample()
log_prob = dist.log_prob(action).sum(-1, keepdim=True)
return action, log_prob
def act(self, obs, sample=False):
obs = torch.FloatTensor(obs).to(device)
obs = obs.unsqueeze(0)
with torch.no_grad():
dist = self.actor(obs)
action = dist.sample() if sample else dist.mean
action = action.clamp(*self.action_range)
assert action.ndim == 2 and action.shape[0] == 1
return to_np(action[0])
def critic_mix(self, s, a):
if self.use_dqn:
target_q1, target_q2 = self.critic_target(s, a)
target_q = torch.min(target_q1, target_q2)
q1, q2 = self.critic(s, a)
return (q1 * 0.05 + target_q * 0.95), (q2 * 0.05 + target_q * 0.95)
else:
return (self.critic(s, a) * 0.05 + self.critic_target(s, a) * 0.95)
def fit_critic(self, states, actions, next_states, rewards, masks, discount,
init_states):
"""Updates critic parameters.
Args:
states: A batch of states.
actions: A batch of actions.
next_states: A batch of next states.
rewards: A batch of rewards.
masks: A batch of masks indicating the end of the episodes.
discount: An MDP discount factor.
init_states: A batch of init states from the MDP.
Returns:
Critic loss.
"""
with torch.no_grad():
init_actions, _ = self.sample(init_states)
next_actions, next_log_probs = self.sample(next_states)
# ========== for double Q case ========== #
if self.use_dqn:
with torch.no_grad():
target_q1, target_q2 = self.critic_mix(next_states, next_actions)
#target_q1, target_q2 = self.critic_target(next_states, next_actions)
target_q1 = target_q1 - self.alpha * next_log_probs
target_q2 = target_q2 - self.alpha * next_log_probs
target_q1 = (rewards + discount * masks * target_q1)
target_q2 = (rewards + discount * masks * target_q2)
q1, q2 = self.critic(states, actions)
init_q1, init_q2 = self.critic(init_states, init_actions)
if discount == 1:
critic_loss1 = torch.mean(self.f(self._lambda + self.algae_alpha + target_q1 - q1) - self.algae_alpha * self._lambda)
critic_loss2 = torch.mean(self.f(self._lambda + self.algae_alpha + target_q2 - q2) - self.algae_alpha * self._lambda)
else:
critic_loss1 = torch.mean(self.f(target_q1 - q1) + (1 - discount) * init_q1 * self.algae_alpha)
critic_loss2 = torch.mean(self.f(target_q2 - q2) + (1 - discount) * init_q2 * self.algae_alpha)
critic_loss = (critic_loss1 + critic_loss2)
# ============ for single Q ============ #
else:
with torch.no_grad():
target_q = self.critic_mix(next_states, next_actions)
target_q = target_q - self.alpha * next_log_probs
target_q = (rewards + discount * masks * target_q) # r(s,a) + \gamma ( nu(s',a') - temp * log(\pi(a'|s')) )
q = self.critic(states, actions) # nu(s,a)
init_q = self.critic(init_states, init_actions) # nu(s_0,a_0)
if discount == 1:
critic_loss = torch.mean(self.f(self._lambda + self.algae_alpha + target_q - q) - self.algae_alpha * self._lambda)
else:
# Equation 25: 2*alpha * (1-gamma) * E [nu(s_0,a_0)] + E [ clip(bellman residuals) ]
# self.f clip the bellman residuals \delta (explained in "continous control" section in AlgaeDICE)
critic_loss = torch.mean(self.f(target_q - q) + (1 - discount) * init_q * self.algae_alpha)
# TODO : add self._lambda
# TODO tf uses following : self.critic_optimizer.apply_gradients(zip(critic_grads, self.critic.variables + [self._lambda]))
# Optimize the critic
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
return critic_loss
def fit_actor(self, states, actions, next_states, rewards, masks, discount,
target_entropy, init_states):
"""Updates critic parameters.
Args:
states: A batch of states.
actions: A batch of actions.
next_states: A batch of next states.
rewards: A batch of rewards.
masks: A batch of masks indicating the end of the episodes.
discount: An MDP discount factor.
target_entropy: Target entropy value for alpha.
init_states: A batch of init states from the MDP.
Returns:
Actor and alpha losses.
"""
init_actions, _ = self.sample(init_states)
next_actions, next_log_probs = self.sample(next_states)
# ========== for double Q case ========== #
if self.use_dqn:
target_q1, target_q2 = self.critic_mix(next_states, next_actions)
#with torch.no_grad():
#target_q1, target_q2 = self.critic_target(next_states, next_actions)
target_q1 = target_q1 - self.alpha.detach() * next_log_probs
target_q2 = target_q2 - self.alpha.detach() * next_log_probs
target_q1 = rewards + discount * masks * target_q1
target_q2 = rewards + discount * masks * target_q2
q1, q2 = self.critic(states, actions)
init_q1, init_q2 = self.critic(init_states, init_actions)
if discount == 1:
actor_loss1 = -torch.mean(self.fgrad(self._lambda + self.algae_alpha + target_q1 - q1).detach() * (target_q1 - q1))
actor_loss2 = -torch.mean(self.fgrad(self._lambda + self.algae_alpha + target_q2 - q2).detach() * (target_q2 - q2))
else:
actor_loss1 = -torch.mean(self.fgrad(target_q1 - q1).detach() * (target_q1 - q1) + (1 - discount) * init_q1 * self.algae_alpha)
actor_loss2 = -torch.mean(self.fgrad(target_q2 - q2).detach() * (target_q2 - q2) + (1 - discount) * init_q2 * self.algae_alpha)
loss = (actor_loss1 + actor_loss2) / 2.0
# ============ for single Q ============ #
else:
target_q = self.critic_mix(next_states, next_actions)
target_q = target_q - self.alpha.detach() * next_log_probs
target_q = rewards + discount * masks * target_q
q = self.critic(states, actions)
init_q = self.critic(init_states, init_actions)
if discount == 1:
loss = -torch.mean(self.fgrad(self._lambda + self.algae_alpha + target_q - q).detach() * (target_q - q))
else:
# for policy training cliped (delta or residuals) must be >0 (explained in "continous control" section in AlgaeDICE)
# This line is an application of the change rule. Specifically:
# Derivative of 1/k * (a - b)^k = (a - b)^{k-1} * [Derivative of a - b].
loss = -torch.mean(self.fgrad(target_q - q).detach() * (target_q - q) + (1 - discount) * init_q * self.algae_alpha)
actor_loss = loss + orthogonal_regularization(self.actor.trunk, self.device)
# optimize the actor
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
alpha_loss = torch.mean(self.alpha * (-next_log_probs.data - target_entropy))
if self.learn_alpha:
self.log_alpha_optimizer.zero_grad()
alpha_loss.backward()
self.log_alpha_optimizer.step()
return actor_loss, alpha_loss, -next_log_probs
def update(self,
replay_buffer,
total_timesteps,
discount=0.99,
tau=0.005,
target_entropy=0,
actor_update_freq=2):
"""Performs a single training step for critic and actor.
Args:
replay_buffer_iter: An tensorflow graph iteratable object for sampling
transitions.
init_replay_buffer: An tensorflow graph iteratable object for sampling
init states.
discount: A discount used to compute returns.
tau: A soft updates discount.
target_entropy: A target entropy for alpha.
actor_update_freq: A frequency of the actor network updates.
Returns:
Actor and alpha losses.
"""
states, actions, rewards, next_states, masks = replay_buffer.sample()
init_states = states
# TODO: IMPORTANT TO CHECK
# if self.use_init_states:
# init_states = next(init_replay_buffer)[0]
# else:
# init_states = states
# ========== critic update ========== #
critic_loss = self.fit_critic(states, actions, next_states, rewards, masks,
discount, init_states)
step = 0
self.avg_critic_loss(torch_to_tf_tensor(critic_loss))
if tf.equal(total_timesteps % self.log_interval, 0):
train_measurements = [('train/critic_loss', self.avg_critic_loss.result()),]
for (label, value) in train_measurements:
tf.summary.scalar(label, value, step=step)
keras_utils.my_reset_states(self.avg_critic_loss)
# ========= actor update & critic target update ========= #
if tf.equal(total_timesteps % actor_update_freq, 0):
actor_loss, alpha_loss, entropy = self.fit_actor(states, actions,
next_states, rewards,
masks, discount,
target_entropy,
init_states)
soft_update_params(self.critic, self.critic_target, tau=tau)
self.avg_actor_loss(torch_to_tf_tensor(actor_loss))
self.avg_alpha_loss(torch_to_tf_tensor(alpha_loss))
self.avg_actor_entropy(torch_to_tf_tensor(entropy))
self.avg_alpha(torch_to_tf_tensor(self.alpha))
self.avg_lambda(torch_to_tf_tensor(self._lambda))
if tf.equal(total_timesteps % self.log_interval, 0):
print('critic loss : {} | actor loss : {}'.format(critic_loss.data.cpu().numpy(),actor_loss.data.cpu().numpy()))
train_measurements = [
('train/actor_loss', self.avg_actor_loss.result()),
('train/alpha_loss', self.avg_alpha_loss.result()),
('train/actor entropy', self.avg_actor_entropy.result()),
('train/alpha', self.avg_alpha.result()),
('train/lambda', self.avg_lambda.result()),
]
for (label, value) in train_measurements:
tf.summary.scalar(label, value, step=total_timesteps)
keras_utils.my_reset_states(self.avg_actor_loss)
keras_utils.my_reset_states(self.avg_alpha_loss)
keras_utils.my_reset_states(self.avg_actor_entropy)
keras_utils.my_reset_states(self.avg_alpha)
keras_utils.my_reset_states(self.avg_lambda)