-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtorch_utils.py
107 lines (82 loc) · 2.76 KB
/
torch_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
from torch.autograd import grad
def get_device():
'''
Return a torch.device object. Returns a CUDA device if it is available and
a CPU device otherwise.
'''
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
save_dir = 'saved-sessions'
def apply_update(parameterized_fun, update):
'''
Add update to the weights of parameterized_fun
Parameters
----------
parameterized_fun : torch.nn.Sequential
the function approximator to be updated
update : torch.FloatTensor
a flattened version of the update to be applied
'''
n = 0
for param in parameterized_fun.parameters():
numel = param.numel()
param_update = update[n:n + numel].view(param.size())
param.data += param_update
n += numel
def flatten(vecs):
'''
Return an unrolled, concatenated copy of vecs
Parameters
----------
vecs : list
a list of Pytorch Tensor objects
Returns
-------
flattened : torch.FloatTensor
the flattened version of vecs
'''
flattened = torch.cat([v.view(-1) for v in vecs])
return flattened
def flat_grad(functional_output, inputs, retain_graph=False, create_graph=False):
'''
Return a flattened view of the gradients of functional_output w.r.t. inputs
Parameters
----------
functional_output : torch.FloatTensor
The output of the function for which the gradient is to be calculated
inputs : torch.FloatTensor (with requires_grad=True)
the variables w.r.t. which the gradient will be computed
retain_graph : bool
whether to keep the computational graph in memory after computing the
gradient (not required if create_graph is True)
create_graph : bool
whether to create a computational graph of the gradient computation
itself
Return
------
flat_grads : torch.FloatTensor
a flattened view of the gradients of functional_output w.r.t. inputs
'''
if create_graph == True:
retain_graph = True
grads = grad(functional_output, inputs, retain_graph=retain_graph, create_graph=create_graph)
flat_grads = flatten(grads)
return flat_grads
def get_flat_params(parameterized_fun):
'''
Get a flattened view of the parameters of a function approximator
Parameters
----------
parameterized_fun : torch.nn.Sequential
the function approximator for which the parameters are to be returned
Returns
-------
flat_params : torch.FloatTensor
a flattened view of the parameters of parameterized_fun
'''
parameters = parameterized_fun.parameters()
flat_params = flatten([param.view(-1) for param in parameters])
return flat_params