-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcore_cuda.py
85 lines (66 loc) · 2.75 KB
/
core_cuda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
import torch
import torch.nn as nn
def combined_shape(length, shape=None):
if shape is None:
return (length,)
return (length, shape) if np.isscalar(shape) else (length, *shape)
def mlp(sizes, activation, output_activation=nn.Identity):
layers = []
print(sizes)
for j in range(len(sizes)-1):
if j < len(sizes)-2:
act = activation
layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
else:
layers += [nn.Linear(sizes[j], sizes[j+1])]
return nn.Sequential(*layers)
def count_vars(module):
return sum([np.prod(p.shape) for p in module.parameters()])
class MLPQFunction(nn.Module):
def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
super().__init__()
self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)
def forward(self, obs, act):
q = self.q(torch.cat([obs, act], dim=-1))
return torch.squeeze(q, -1) # Critical to ensure q has right shape.
class MLPQFunction_quantile(nn.Module):
def __init__(self, obs_dim, act_dim, hidden_sizes, activation,quantiles):
super().__init__()
#print("create",[obs_dim + act_dim] + list(hidden_sizes) + [len(quantiles)])
self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [len(quantiles)], activation)
#self.out=mlp_quantile(quantiles)
print(self.q)
def forward(self, obs, act):
#print("pass Q_i/p",torch.cat([obs, act], dim=-1).shape)
q = self.q(torch.cat([obs, act], dim=-1))
#print("#",q.shape)
return torch.squeeze(q, -1) # Critical to ensure q has right shape.
###################below not working
##error:
##cat(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.
"""
def mlp_quantile(quantiles):
outputs = []
for i, quantile in enumerate(quantiles):
outputss = nn.Sequential(nn.Linear(1, 1))
outputs.append(outputss)
qf=outputs
return qf
class MLPQFunction_quantile(nn.Module):
def __init__(self, obs_dim, act_dim, hidden_sizes, activation,quantiles):
super().__init__()
print("create",[obs_dim + act_dim] + list(hidden_sizes) + [1])
self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)
self.out=mlp_quantile(quantiles)
print(self.q,self.out)
def forward(self, obs, act):
print("pass Q_i/p",torch.cat([obs, act], dim=-1).shape)
q = self.q(torch.cat([obs, act], dim=-1))
quin=[]
print("#",q.shape)
for i in range(len(self.out)):
torch.cat(quin,out=torch.squeeze( self.out[i]( q), -1 ) )
print("Quin__",quin)
return quin # Critical to ensure q has right shape.
"""