-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDDQNexample.py
60 lines (44 loc) · 2 KB
/
DDQNexample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gym
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from CS698R_DRLagents import DDQN
from CS698R_DRLagents.exploration_strategies import decayWrapper, selectEpsilonGreedyAction, selectGreedyAction
# make a gym environment
env = gym.make('CartPole-v0')
# pick a suitable device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# create the deep network
class net(nn.Module):
def __init__(self, inDim, outDim, hDim, activation = F.relu):
super(net, self).__init__()
self.inputlayer = nn.Linear(inDim, hDim[0])
self.hiddenlayers = nn.ModuleList([nn.Linear(hDim[i], hDim[i+1]) for i in range(len(hDim)-1)])
self.outputlayer = nn.Linear(hDim[-1], outDim)
self.activation = activation
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.to(device)
t = self.activation(self.inputlayer(x))
for layer in self.hiddenlayers:
t = self.activation(layer(t))
t = self.outputlayer(t)
return t
Qnetwork = net(inDim=4, outDim=2, hDim=[8,8], activation=F.relu).to(device)
# create the exploration and exploitation strategies
explorationStrategyTrain = decayWrapper(selectEpsilonGreedyAction, 0.5, 0.05, 500, device=device)
DDQNagent = DDQN(Qnetwork, env, 0, 0.8, 10, 10000, 512, optim.Adam, 0.001, 800, 1, explorationStrategyTrain, selectGreedyAction, 5, device=device)
# train the agent and evaluate
train_stats = DDQNagent.trainAgent()
eval_rewards = DDQNagent.evaluateAgent()
# see the agent in action
for i_episode in range(5):
observation = env.reset()
for t in range(600):
env.render()
action = selectGreedyAction(DDQNagent.policy_network, torch.tensor([observation], dtype=torch.float32, device=device))
observation, reward, done, info = env.step(action.item())
if done:
print("Episode finished after {} timesteps".format(t+1))
break
env.close()