-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathmodels.py
59 lines (47 loc) · 2.16 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from torch import nn
import torch
from torch.nn import functional as F
import numpy as np
def init_weights_biases(size):
v = 1.0 / np.sqrt(size[0])
return torch.FloatTensor(size).uniform_(-v, v)
class Actor(nn.Module):
def __init__(self, n_states, n_actions, n_goals, n_hidden1=256, n_hidden2=256, n_hidden3=256, initial_w=3e-3):
self.n_states = n_states[0]
self.n_actions = n_actions
self.n_goals = n_goals
self.n_hidden1 = n_hidden1
self.n_hidden2 = n_hidden2
self.n_hidden3 = n_hidden3
self.initial_w = initial_w
super(Actor, self).__init__()
self.fc1 = nn.Linear(in_features=self.n_states + self.n_goals, out_features=self.n_hidden1)
self.fc2 = nn.Linear(in_features=self.n_hidden1, out_features=self.n_hidden2)
self.fc3 = nn.Linear(in_features=self.n_hidden2, out_features=self.n_hidden3)
self.output = nn.Linear(in_features=self.n_hidden3, out_features=self.n_actions)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
output = torch.tanh(self.output(x)) # TODO add scale of the action
return output
class Critic(nn.Module):
def __init__(self, n_states, n_goals, n_hidden1=256, n_hidden2=256, n_hidden3=256, initial_w=3e-3, action_size=1):
self.n_states = n_states[0]
self.n_goals = n_goals
self.n_hidden1 = n_hidden1
self.n_hidden2 = n_hidden2
self.n_hidden3 = n_hidden3
self.initial_w = initial_w
self.action_size = action_size
super(Critic, self).__init__()
self.fc1 = nn.Linear(in_features=self.n_states + self.n_goals + self.action_size, out_features=self.n_hidden1)
self.fc2 = nn.Linear(in_features=self.n_hidden1, out_features=self.n_hidden2)
self.fc3 = nn.Linear(in_features=self.n_hidden2, out_features=self.n_hidden3)
self.output = nn.Linear(in_features=self.n_hidden3, out_features=1)
def forward(self, x, a):
x = F.relu(self.fc1(torch.cat([x, a], dim=-1)))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
output = self.output(x)
return output