10
10
11
11
"""
12
12
13
+ import time
14
+
13
15
import hydra
14
16
import numpy as np
15
17
import torch
18
20
from torchrl .record .loggers import generate_exp_name , get_logger
19
21
20
22
from utils import (
23
+ log_metrics ,
24
+ make_continuous_cql_optimizer ,
25
+ make_continuous_loss ,
21
26
make_cql_model ,
22
- make_cql_optimizer ,
23
27
make_environment ,
24
- make_loss ,
25
28
make_offline_replay_buffer ,
26
29
)
27
30
28
31
29
32
@hydra .main (config_path = "." , config_name = "offline_config" , version_base = "1.1" )
30
33
def main (cfg : "DictConfig" ): # noqa: F821
34
+ # Create logger
31
35
exp_name = generate_exp_name ("CQL-offline" , cfg .env .exp_name )
32
36
logger = None
33
37
if cfg .logger .backend :
@@ -37,49 +41,96 @@ def main(cfg: "DictConfig"): # noqa: F821
37
41
experiment_name = exp_name ,
38
42
wandb_kwargs = {"mode" : cfg .logger .mode , "config" : cfg },
39
43
)
40
-
44
+ # Set seeds
41
45
torch .manual_seed (cfg .env .seed )
42
46
np .random .seed (cfg .env .seed )
43
47
device = torch .device (cfg .optim .device )
44
48
45
- # Make Env
49
+ # Create env
46
50
train_env , eval_env = make_environment (cfg , cfg .logger .eval_envs )
47
51
48
- # Make Buffer
52
+ # Create replay buffer
49
53
replay_buffer = make_offline_replay_buffer (cfg .replay_buffer )
50
54
51
- # Make Model
55
+ # Create agent
52
56
model = make_cql_model (cfg , train_env , eval_env , device )
53
57
54
- # Make Loss
55
- loss_module , target_net_updater = make_loss (cfg .loss , model )
58
+ # Create loss
59
+ loss_module , target_net_updater = make_continuous_loss (cfg .loss , model )
56
60
57
- # Make Optimizer
58
- optimizer = make_cql_optimizer (cfg .optim , loss_module )
61
+ # Create Optimizer
62
+ (
63
+ policy_optim ,
64
+ critic_optim ,
65
+ alpha_optim ,
66
+ alpha_prime_optim ,
67
+ ) = make_continuous_cql_optimizer (cfg , loss_module )
59
68
60
69
pbar = tqdm .tqdm (total = cfg .optim .gradient_steps )
61
70
62
- r0 = None
63
- l0 = None
64
-
65
71
gradient_steps = cfg .optim .gradient_steps
72
+ policy_eval_start = cfg .optim .policy_eval_start
66
73
evaluation_interval = cfg .logger .eval_iter
67
74
eval_steps = cfg .logger .eval_steps
68
75
76
+ # Training loop
77
+ start_time = time .time ()
69
78
for i in range (gradient_steps ):
70
79
pbar .update (i )
80
+ # sample data
71
81
data = replay_buffer .sample ()
72
- # loss
73
- loss_vals = loss_module (data )
74
- # backprop
75
- actor_loss = loss_vals ["loss_actor" ]
82
+ # compute loss
83
+ loss_vals = loss_module (data .clone ().to (device ))
84
+
85
+ # official cql implementation uses behavior cloning loss for first few updating steps as it helps for some tasks
86
+ if i >= policy_eval_start :
87
+ actor_loss = loss_vals ["loss_actor" ]
88
+ else :
89
+ actor_loss = loss_vals ["loss_actor_bc" ]
76
90
q_loss = loss_vals ["loss_qvalue" ]
77
- value_loss = loss_vals ["loss_value" ]
78
- loss_val = actor_loss + q_loss + value_loss
91
+ cql_loss = loss_vals ["loss_cql" ]
92
+
93
+ q_loss = q_loss + cql_loss
94
+
95
+ alpha_loss = loss_vals ["loss_alpha" ]
96
+ alpha_prime_loss = loss_vals ["loss_alpha_prime" ]
97
+
98
+ # update model
99
+ alpha_loss = loss_vals ["loss_alpha" ]
100
+ alpha_prime_loss = loss_vals ["loss_alpha_prime" ]
101
+
102
+ alpha_optim .zero_grad ()
103
+ alpha_loss .backward ()
104
+ alpha_optim .step ()
79
105
80
- optimizer .zero_grad ()
81
- loss_val .backward ()
82
- optimizer .step ()
106
+ policy_optim .zero_grad ()
107
+ actor_loss .backward ()
108
+ policy_optim .step ()
109
+
110
+ if alpha_prime_optim is not None :
111
+ alpha_prime_optim .zero_grad ()
112
+ alpha_prime_loss .backward (retain_graph = True )
113
+ alpha_prime_optim .step ()
114
+
115
+ critic_optim .zero_grad ()
116
+ # TODO: we have the option to compute losses independently retain is not needed?
117
+ q_loss .backward (retain_graph = False )
118
+ critic_optim .step ()
119
+
120
+ loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss
121
+
122
+ # log metrics
123
+ to_log = {
124
+ "loss" : loss .item (),
125
+ "loss_actor_bc" : loss_vals ["loss_actor_bc" ].item (),
126
+ "loss_actor" : loss_vals ["loss_actor" ].item (),
127
+ "loss_qvalue" : q_loss .item (),
128
+ "loss_cql" : cql_loss .item (),
129
+ "loss_alpha" : alpha_loss .item (),
130
+ "loss_alpha_prime" : alpha_prime_loss .item (),
131
+ }
132
+
133
+ # update qnet_target params
83
134
target_net_updater .step ()
84
135
85
136
# evaluation
@@ -88,20 +139,13 @@ def main(cfg: "DictConfig"): # noqa: F821
88
139
eval_td = eval_env .rollout (
89
140
max_steps = eval_steps , policy = model [0 ], auto_cast_to_device = True
90
141
)
142
+ eval_reward = eval_td ["next" , "reward" ].sum (1 ).mean ().item ()
143
+ to_log ["evaluation_reward" ] = eval_reward
91
144
92
- if r0 is None :
93
- r0 = eval_td ["next" , "reward" ].sum (1 ).mean ().item ()
94
- if l0 is None :
95
- l0 = loss_val .item ()
96
-
97
- for key , value in loss_vals .items ():
98
- logger .log_scalar (key , value .item (), i )
99
- eval_reward = eval_td ["next" , "reward" ].sum (1 ).mean ().item ()
100
- logger .log_scalar ("evaluation_reward" , eval_reward , i )
145
+ log_metrics (logger , to_log , i )
101
146
102
- pbar .set_description (
103
- f"loss: { loss_val .item (): 4.4f} (init: { l0 : 4.4f} ), evaluation_reward: { eval_reward : 4.4f} (init={ r0 : 4.4f} )"
104
- )
147
+ pbar .close ()
148
+ print (f"Training time: { time .time () - start_time } " )
105
149
106
150
107
151
if __name__ == "__main__" :
0 commit comments