-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
executable file
·195 lines (168 loc) · 6.7 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# -*- coding: utf-8 -*-
# @Time : 2021/9/1 18:30
# @Author : xinyuan tu
# @File : utils.py
# @Software: PyCharm
import torch
import torch.nn as nn
import os
import time
import math
import numpy as np
import torch
import logging
#from lib.render import *
def load_state_dict(module,state_dict):
# model_dict = module.state_dict()
# # #pretrained_dict = {k.replace("opt_layer.",""): v for k, v in state_dict.items() if k.replace("opt_layer.","") in model_dict}
# pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
# print(len(pretrained_dict))
# model_dict.update(pretrained_dict)
# module.load_state_dict(model_dict)
# return module
def _load_state_dict_into_module(state_dict, module, strict=True):
own_state = module.state_dict()
count = 0
for name, param in state_dict.items():
#name = name.replace('.module','')
if name in own_state:
if isinstance(param, nn.Parameter):
# backwards compatibility for serialized parameters
param = param.data
try:
own_state[name].resize_as_(param)
own_state[name].copy_(param)
except Exception:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
_load_state_dict_into_module(state_dict,module)
def dict_to_cuda_old(input_info):
# 把整型转换为浮点型 用于图像处理
for key, value in input_info.items():
if key=='basename':
continue
# 其实没有类型和数据类型转换也ok
# 不清楚是在哪里完成的,可能是dataloader
if type(value) is np.ndarray:
value=torch.from_numpy(value);
torch_int_list=[torch.uint8,torch.int,torch.int8,torch.int16,torch.int32]
if value.dtype in torch_int_list:
value=value.float()
input_info[key] = value.cuda(non_blocking=False)
return input_info
def dict_to_cuda(input_info):
# 不把整型转换为浮点型 用于整数标签处理
for key, value in input_info.items():
if key=='basename':
continue
# 其实没有类型和数据类型转换也ok
# 不清楚是在哪里完成的,可能是dataloader
if type(value) is np.ndarray:
value=torch.from_numpy(value);
# 不将整型改为浮点,因为有身份、表情标签需要为整型
if type(value) is int or type(value) is float:
value=torch.tensor(value)
# 将double转为float,因为Linear只能接受float
if value.dtype==torch.double:
value=value.float()
if type(value) is torch.Tensor:
# torch_int_list=[torch.uint8,torch.int,torch.int8,torch.int16,torch.int32]
# if value.dtype in torch_int_list:
# value=value.float()
input_info[key] = value.cuda(non_blocking=False)
# input_info[key] = value.cuda()
return input_info
def del_cuda(input_info,loss,output):
for key, value in input_info.items():
if key=='basename':
continue
del value
for key, value in loss.items():
if key=='basename':
continue
del value
for key, value in output.items():
if key=='basename':
continue
del value
def make_dirs(arg):
t = time.gmtime()
event_path = os.path.join(arg.event_path,arg.state, str(t.tm_mon).rjust(2, '0') + str(t.tm_mday).rjust(2, '0') + str((t.tm_hour+8)%24).rjust(2, '0') + str((t.tm_min)%60).rjust(2, '0'))
while(os.path.isdir(event_path)):
event_path=event_path+"_2"
if not os.path.isdir(event_path):
os.makedirs(event_path)
checkpoint_path = os.path.join(event_path, 'checkpoint_path')
if not os.path.isdir(checkpoint_path):
os.makedirs(checkpoint_path)
summary_path = os.path.join(event_path,'summary_path')
if not os.path.isdir(summary_path):
os.makedirs(summary_path)
# for visualize
if arg.test_save:
output_path = os.path.join(event_path, 'output')
if not os.path.isdir(output_path):
os.makedirs(output_path)
return event_path
def save_obj_pts(filename, vertices, faces):
import os
with open(filename, 'w') as f:
f.write('# %s\n' % os.path.basename(filename))
f.write('#\n')
f.write('\n')
for vertex in vertices:
f.write('v %.8f %.8f %.8f\n' % (vertex[0], vertex[1], vertex[2]))
f.write('\n');
for face in faces:
f.write('f %d %d %d\n' % (face[0], face[1], face[2]));
f.write('\n');
def test_dict_to_cuda():
test_1=np.zeros([3,3],dtype=np.uint8);
test_2=np.zeros([3,3],dtype=np.float);
test_3=torch.zeros([3,3],dtype=torch.int8);
test_4=torch.zeros([3,3],dtype=torch.float32,requires_grad=False);
example_dict={
"test_1":test_1,
"test_2":test_2,
"test_3":test_3,
"test_4":test_4
}
# print(type(test_1));
# torch_int_list=[torch.int,torch.int8,torch.int16,torch.int32]
# print(test_3.dtype in torch_int_list)
out_dict=dict_to_cuda(example_dict);
print(out_dict)
return
# 关于logging的设置
# 这里也统一放在了utils.py中
import logging
import os
import time
from configuration import *
modelconfig = args
def configure_logging(event_path):
root_logger = logging.getLogger("")
root_logger.setLevel(logging.INFO)
# create console handler and set level to debug
console = logging.StreamHandler()
console.setLevel(logging.INFO)
root_logger.addHandler(console)
log_filename = os.path.join(os.path.join(event_path,'checkpoint_path'),'log.txt')
logbook = logging.FileHandler(filename=log_filename, mode="a", encoding="utf-8")
logbook.setLevel(logging.INFO)
root_logger.addHandler(logbook)
return root_logger
def setup_logging_and_parse_arguments(logger):
for argument, value in sorted(vars(modelconfig).items()):
logger.info('{}: {}'.format( argument, value))
def main():
test_dict_to_cuda();
return
if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES']="3"
main()