-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain.py
111 lines (89 loc) · 4.45 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gym
import pybullet_envs
import numpy as np
from collections import deque
import torch
import wandb
import argparse
from buffer import ReplayBuffer
import glob
from utils import save, collect_random
import random
from agent import IQL
def get_config():
parser = argparse.ArgumentParser(description='RL')
parser.add_argument("--run_name", type=str, default="IQL", help="Run name, default: SAC")
parser.add_argument("--env", type=str, default="Pendulum-v1", help="Gym environment name, default: Pendulum-v0")
parser.add_argument("--episodes", type=int, default=200, help="Number of episodes, default: 100")
parser.add_argument("--buffer_size", type=int, default=200_000, help="Maximal training dataset size, default: 100_000")
parser.add_argument("--seed", type=int, default=1, help="Seed, default: 1")
parser.add_argument("--log_video", type=int, default=0, help="Log agent behaviour to wanbd when set to 1, default: 0")
parser.add_argument("--save_every", type=int, default=100, help="Saves the network every x epochs, default: 25")
parser.add_argument("--batch_size", type=int, default=256, help="Batch size, default: 256")
parser.add_argument("--hidden_size", type=int, default=256, help="")
parser.add_argument("--learning_rate", type=float, default=3e-4, help="")
parser.add_argument("--temperature", type=float, default=3, help="")
parser.add_argument("--expectile", type=float, default=0.7, help="")
parser.add_argument("--tau", type=float, default=5e-3, help="")
args = parser.parse_args()
return args
def train(config):
np.random.seed(config.seed)
random.seed(config.seed)
torch.manual_seed(config.seed)
env = gym.make(config.env)
env.seed(config.seed)
env.action_space.seed(config.seed)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
steps = 0
average10 = deque(maxlen=10)
with wandb.init(project="IQL", name=config.run_name, config=config):
agent = IQL(state_size=env.observation_space.shape[0],
action_size=env.action_space.shape[0],
learning_rate=config.learning_rate,
hidden_size=config.hidden_size,
tau=config.tau,
temperature=config.temperature,
expectile=config.expectile,
device=device)
wandb.watch(agent, log="gradients", log_freq=10)
buffer = ReplayBuffer(buffer_size=config.buffer_size, batch_size=config.batch_size, device=device)
collect_random(env=env, dataset=buffer, num_samples=5000)
if config.log_video:
env = gym.wrappers.Monitor(env, './video', video_callable=lambda x: x%10==0, force=True)
for i in range(1, config.episodes+1):
state = env.reset()
episode_steps = 0
rewards = 0
while True:
action = agent.get_action(state)
steps += 1
next_state, reward, done, _ = env.step(action)
buffer.add(state, action, reward, next_state, done)
policy_loss, critic1_loss, critic2_loss, value_loss = agent.learn(buffer.sample())
state = next_state
rewards += reward
episode_steps += 1
if done:
break
average10.append(rewards)
print("Episode: {} | Reward: {} | Polciy Loss: {} | Steps: {}".format(i, rewards, policy_loss, steps,))
wandb.log({"Reward": rewards,
"Average10": np.mean(average10),
"Policy Loss": policy_loss,
"Value Loss": value_loss,
"Critic 1 Loss": critic1_loss,
"Critic 2 Loss": critic2_loss,
"Steps": steps,
"Episode": i,
"Buffer size": buffer.__len__()})
if (i %10 == 0) and config.log_video:
mp4list = glob.glob('video/*.mp4')
if len(mp4list) > 1:
mp4 = mp4list[-2]
wandb.log({"gameplays": wandb.Video(mp4, caption='episode: '+str(i-10), fps=4, format="gif"), "Episode": i})
if i % config.save_every == 0:
save(config, save_name="IQL", model=agent.actor_local, wandb=wandb, ep=0)
if __name__ == "__main__":
config = get_config()
train(config)