Skip to content

Commit 00f3675

Browse files
committed
change inplace=True in ReLU
1 parent 48f6801 commit 00f3675

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

train_classification.py

+7
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ def parse_args():
4343
return parser.parse_args()
4444

4545

46+
def inplace_relu(m):
47+
classname = m.__class__.__name__
48+
if classname.find('ReLU') != -1:
49+
m.inplace=True
50+
51+
4652
def test(model, loader, num_class=40):
4753
mean_correct = []
4854
class_acc = np.zeros((num_class, 3))
@@ -126,6 +132,7 @@ def log_string(str):
126132

127133
classifier = model.get_model(num_class, normal_channel=args.use_normals)
128134
criterion = model.get_loss()
135+
classifier.apply(inplace_relu)
129136

130137
if not args.use_cpu:
131138
classifier = classifier.cuda()

train_partseg.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44
"""
55
import argparse
66
import os
7-
from data_utils.ShapeNetDataLoader import PartNormalDataset
87
import torch
98
import datetime
109
import logging
11-
from pathlib import Path
1210
import sys
1311
import importlib
1412
import shutil
15-
from tqdm import tqdm
1613
import provider
1714
import numpy as np
1815

16+
from pathlib import Path
17+
from tqdm import tqdm
18+
from data_utils.ShapeNetDataLoader import PartNormalDataset
19+
1920
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
2021
ROOT_DIR = BASE_DIR
2122
sys.path.append(os.path.join(ROOT_DIR, 'models'))
@@ -30,6 +31,11 @@
3031
seg_label_to_cat[label] = cat
3132

3233

34+
def inplace_relu(m):
35+
classname = m.__class__.__name__
36+
if classname.find('ReLU') != -1:
37+
m.inplace=True
38+
3339
def to_categorical(y, num_classes):
3440
""" 1-hot encodes a tensor """
3541
new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
@@ -111,6 +117,7 @@ def log_string(str):
111117

112118
classifier = MODEL.get_model(num_part, normal_channel=args.normal).cuda()
113119
criterion = MODEL.get_loss().cuda()
120+
classifier.apply(inplace_relu)
114121

115122
def weights_init(m):
116123
classname = m.__class__.__name__

train_semseg.py

+5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
for i, cat in enumerate(seg_classes.keys()):
3030
seg_label_to_cat[i] = cat
3131

32+
def inplace_relu(m):
33+
classname = m.__class__.__name__
34+
if classname.find('ReLU') != -1:
35+
m.inplace=True
3236

3337
def parse_args():
3438
parser = argparse.ArgumentParser('Model')
@@ -111,6 +115,7 @@ def log_string(str):
111115

112116
classifier = MODEL.get_model(NUM_CLASSES).cuda()
113117
criterion = MODEL.get_loss().cuda()
118+
classifier.apply(inplace_relu)
114119

115120
def weights_init(m):
116121
classname = m.__class__.__name__

0 commit comments

Comments
 (0)