File tree 3 files changed +22
-3
lines changed
3 files changed +22
-3
lines changed Original file line number Diff line number Diff line change @@ -43,6 +43,12 @@ def parse_args():
43
43
return parser .parse_args ()
44
44
45
45
46
+ def inplace_relu (m ):
47
+ classname = m .__class__ .__name__
48
+ if classname .find ('ReLU' ) != - 1 :
49
+ m .inplace = True
50
+
51
+
46
52
def test (model , loader , num_class = 40 ):
47
53
mean_correct = []
48
54
class_acc = np .zeros ((num_class , 3 ))
@@ -126,6 +132,7 @@ def log_string(str):
126
132
127
133
classifier = model .get_model (num_class , normal_channel = args .use_normals )
128
134
criterion = model .get_loss ()
135
+ classifier .apply (inplace_relu )
129
136
130
137
if not args .use_cpu :
131
138
classifier = classifier .cuda ()
Original file line number Diff line number Diff line change 4
4
"""
5
5
import argparse
6
6
import os
7
- from data_utils .ShapeNetDataLoader import PartNormalDataset
8
7
import torch
9
8
import datetime
10
9
import logging
11
- from pathlib import Path
12
10
import sys
13
11
import importlib
14
12
import shutil
15
- from tqdm import tqdm
16
13
import provider
17
14
import numpy as np
18
15
16
+ from pathlib import Path
17
+ from tqdm import tqdm
18
+ from data_utils .ShapeNetDataLoader import PartNormalDataset
19
+
19
20
BASE_DIR = os .path .dirname (os .path .abspath (__file__ ))
20
21
ROOT_DIR = BASE_DIR
21
22
sys .path .append (os .path .join (ROOT_DIR , 'models' ))
30
31
seg_label_to_cat [label ] = cat
31
32
32
33
34
+ def inplace_relu (m ):
35
+ classname = m .__class__ .__name__
36
+ if classname .find ('ReLU' ) != - 1 :
37
+ m .inplace = True
38
+
33
39
def to_categorical (y , num_classes ):
34
40
""" 1-hot encodes a tensor """
35
41
new_y = torch .eye (num_classes )[y .cpu ().data .numpy (),]
@@ -111,6 +117,7 @@ def log_string(str):
111
117
112
118
classifier = MODEL .get_model (num_part , normal_channel = args .normal ).cuda ()
113
119
criterion = MODEL .get_loss ().cuda ()
120
+ classifier .apply (inplace_relu )
114
121
115
122
def weights_init (m ):
116
123
classname = m .__class__ .__name__
Original file line number Diff line number Diff line change 29
29
for i , cat in enumerate (seg_classes .keys ()):
30
30
seg_label_to_cat [i ] = cat
31
31
32
+ def inplace_relu (m ):
33
+ classname = m .__class__ .__name__
34
+ if classname .find ('ReLU' ) != - 1 :
35
+ m .inplace = True
32
36
33
37
def parse_args ():
34
38
parser = argparse .ArgumentParser ('Model' )
@@ -111,6 +115,7 @@ def log_string(str):
111
115
112
116
classifier = MODEL .get_model (NUM_CLASSES ).cuda ()
113
117
criterion = MODEL .get_loss ().cuda ()
118
+ classifier .apply (inplace_relu )
114
119
115
120
def weights_init (m ):
116
121
classname = m .__class__ .__name__
You can’t perform that action at this time.
0 commit comments