-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathcustom_networks.py
147 lines (130 loc) · 5.62 KB
/
custom_networks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# Copyright 2022 Twitter, Inc.
# SPDX-License-Identifier: Apache-2.0
from torch import nn
import torch
import torch as th
from utils import weight_init, final_weight_init
from torch.nn.utils.parametrizations import spectral_norm
from utils import to_numpy
def apply_sn(m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
return spectral_norm(m)
else:
return m
def apply_sn_all(modules):
for module in modules:
module.apply(apply_sn)
def get_cont_mean_norm(input):
input_shape = input.size()
rs_input = input.view(input_shape[0], -1)
return th.norm(rs_input, p=2, dim=-1, keepdim=True).mean()
class ImpalaResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.residual_block = nn.Sequential(nn.ReLU(),
nn.Conv2d(in_channels=channels,
out_channels=channels,
kernel_size=3,
padding=1, ),
nn.ReLU(),
nn.Conv2d(in_channels=channels,
out_channels=channels,
kernel_size=3,
padding=1, ))
def forward(self, inputs):
return self.residual_block(inputs) + inputs
class ImpalaResidualStack(nn.Module):
def __init__(self, in_channels, stack_channels):
super().__init__()
self.stack = nn.Sequential(nn.Conv2d(in_channels=in_channels,
out_channels=stack_channels,
kernel_size=3,
padding=1, ),
nn.MaxPool2d(kernel_size=3, stride=2,
padding=1),
ImpalaResidualBlock(channels=stack_channels),
ImpalaResidualBlock(channels=stack_channels))
def forward(self, inputs):
return self.stack(inputs)
def make_impala_modules(obs_dims,
n_actions,
channels=[16, 32, 32],
hidden_units=256, shared_conv_trunk=True,
shared_fc_head=True,
init=weight_init,
final_init=final_weight_init,
pre_final_sn=False,
auxiliary_dims=None,
):
in_channels, w, h = obs_dims[-3:]
number_of_stacks = len(channels)
flattened_dims = (w * h) // (4 ** number_of_stacks) * channels[-1]
if auxiliary_dims:
assert isinstance(auxiliary_dims, int)
assert auxiliary_dims > 0
actor_output_dims = n_actions + auxiliary_dims
else:
actor_output_dims = n_actions
shared_modules = []
actor_modules = []
critic_modules = []
if shared_conv_trunk:
for stack_channels in channels:
shared_modules.append(ImpalaResidualStack(in_channels,
stack_channels))
in_channels = stack_channels
shared_modules += [nn.Flatten(),
nn.ReLU(),]
else:
assert not shared_fc_head
for stack_channels in channels:
actor_modules.append(ImpalaResidualStack(in_channels,
stack_channels))
critic_modules.append(ImpalaResidualStack(in_channels,
stack_channels))
in_channels = stack_channels
actor_modules += [nn.Flatten(),
nn.ReLU(),]
critic_modules += [nn.Flatten(),
nn.ReLU(), ]
if shared_fc_head:
assert shared_conv_trunk
shared_modules += [nn.Linear(in_features=flattened_dims,
out_features=hidden_units),
nn.ReLU(),
nn.Linear(in_features=hidden_units,
out_features=actor_output_dims + 1)]
if pre_final_sn:
apply_sn_all(shared_modules[:-1])
for m in shared_modules[:-1]:
init(m)
final_init(shared_modules[-1])
else:
actor_modules += [nn.Linear(in_features=flattened_dims,
out_features=hidden_units),
nn.ReLU(),
nn.Linear(in_features=hidden_units,
out_features=actor_output_dims)]
critic_modules += [nn.Linear(in_features=flattened_dims,
out_features=hidden_units),
nn.ReLU(),
nn.Linear(in_features=hidden_units,
out_features=1)]
if pre_final_sn:
apply_sn_all(shared_modules)
apply_sn_all(critic_modules[:-1])
apply_sn_all(actor_modules[:-1])
for m in shared_modules:
init(m)
for m in actor_modules[:-1]:
init(m)
for m in critic_modules[:-1]:
init(m)
final_init(actor_modules[-1])
final_init(critic_modules[-1])
modules = {}
modules['shared_modules'] = shared_modules
modules['actor_modules'] = actor_modules
modules['critic_modules'] = critic_modules
modules['auxiliary_dims'] = auxiliary_dims
return modules