Skip to content

Commit 4147c4d

Browse files
author
wbw520
committed
update read me
1 parent b8c43dc commit 4147c4d

14 files changed

+310
-61
lines changed

README.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,19 @@ Using the following command for training
2020
```
2121
python main_recon.py --num_classes 10 --num_cpt 20 --lr 0.001 --epoch 50 --lr_drop 30
2222
```
23-
Use the following command for the inference of a sample. You can change the index to select different input samples. Change deactivate (deactivate one concept, 1 to num_class) and see the changes of reconstruction. Change top-sample (top-10 in the paper) to show more details for a concept. Visualization for the input sample and all concepts are shown at folder "vis" and "vis_pp", respectively.
23+
Use the following command for the inference of a sample. You can change the index to select different input samples. Change top-sample (top-10 in the paper) to show more details for a concept. Visualization for the input sample and all concepts are shown at folder "vis" and "vis_pp", respectively.
2424
```
25-
python vis_recon.py --num_classes 10 --num_cpt 20 --index 0 --top_sample 10 ---deactivate -1
25+
python vis_recon.py --num_classes 10 --num_cpt 20 --index 0 --top_sample 20 ---deactivate -1
2626
```
2727

2828
#### Usage for CUB200, ImageNet, Synthetic (matplot) and Custom
2929
We first pre-train the backbone and then train the whole model. For ImageNet, Synthetic (matplot) and Custom, just change the name for dataset.
3030
```
3131
Pre-training of backbone:
32-
python main_retri.py --num_classes 50 --num_cpt 20 --base_model resnet18 --lr 0.0005 --epoch 60 --lr_drop 40 --pre_train True --dataset CUB200 --dataset_dir "your dir"
32+
python main_contrast.py --num_classes 50 --num_cpt 20 --base_model resnet18 --lr 0.0005 --epoch 60 --lr_drop 40 --pre_train True --dataset CUB200 --dataset_dir "your dir"
3333
3434
Training for BotCL:
35-
python main_retri.py --num_classes 50 --num_cpt 20 --base_model resnet18 --lr 0.0005 --epoch 60 --lr_drop 40 --dataset CUB200 --dataset_dir "your dir" --weak_supervision_bias 0.1 --quantity_bias 0.1 --distinctiveness_bias 0.01 --consistence_bias 0.05
35+
python main_contrast.py --num_classes 50 --num_cpt 20 --base_model resnet18 --lr 0.0005 --epoch 60 --lr_drop 40 --dataset CUB200 --dataset_dir "your dir" --weak_supervision_bias 0.1 --quantity_bias 0.1 --distinctiveness_bias 0.05 --consistence_bias 0.01
3636
```
3737

3838
Use the following commend to visualize the learned concept.
@@ -42,5 +42,5 @@ First run process.py to extarct the activation for all dataset samples:
4242
python process.py
4343
4444
Then see the generated concepts by:
45-
python vis_retri.py --num_classes 50 --num_cpt 20 --base_model resnet18 --index 300 --top_sample 10 --dataset CUB200
45+
python vis_contrast.py --num_classes 50 --num_cpt 20 --base_model resnet18 --index 300 --top_sample 20 --dataset CUB200
4646
```

configs.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import argparse
44
parser = argparse.ArgumentParser(description="PyTorch implementation of cpt")
5-
parser.add_argument('--dataset', type=str, default="ImageNet")
6-
parser.add_argument('--dataset_dir', type=str, default="/home/wangbowen/DATA")
5+
parser.add_argument('--dataset', type=str, default="imagenet")
6+
parser.add_argument('--dataset_dir', type=str, default="/data/li")
77
parser.add_argument('--output_dir', type=str, default="saved_model")
88
# ========================= Model Configs ==========================
99
parser.add_argument('--num_classes', default=50, type=int, help='category for classification')
@@ -20,24 +20,24 @@
2020
parser.add_argument('--layer', default=1, help='layers for fc, default as one')
2121
# ========================= Training Configs ==========================
2222
parser.add_argument('--weak_supervision_bias', type=float, default=0.1, help='weight fot the weak supervision branch')
23-
parser.add_argument('--att_bias', type=float, default=0.5, help='used to prevent overflow, default as 0.1')
23+
parser.add_argument('--att_bias', type=float, default=0.1, help='used to prevent overflow, default as 0.1')
2424
parser.add_argument('--quantity_bias', type=float, default=0.1, help='force each concept to be binary')
25-
parser.add_argument('--distinctiveness_bias', type=float, default=0.1, help='refer to paper')
25+
parser.add_argument('--distinctiveness_bias', type=float, default=0.01, help='refer to paper')
2626
parser.add_argument('--consistence_bias', type=float, default=0.05, help='refer to paper')
2727
# ========================= Learning Configs ==========================
2828
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
29-
parser.add_argument('--lr', default=0.0001, type=float)
29+
parser.add_argument('--lr', default=0.0005, type=float)
3030
parser.add_argument('--batch_size', default=256, type=int)
31-
parser.add_argument('--epoch', default=200, type=int)
32-
parser.add_argument('--lr_drop', default=160, type=float, nargs="+",
31+
parser.add_argument('--epoch', default=40, type=int)
32+
parser.add_argument('--lr_drop', default=30, type=float, nargs="+",
3333
metavar='LRSteps', help='epochs to decay learning rate by 10')
3434
# ========================= Machine Configs ==========================
3535
parser.add_argument('--num_workers', default=4, type=int)
36-
parser.add_argument('--device', default='cuda:2', help='device to use for training / testing')
36+
parser.add_argument('--device', default='cuda:0', help='device to use for training / testing')
3737
# ========================= Demo Configs ==========================
3838
parser.add_argument('--index', default=0, type=int)
3939
parser.add_argument('--use_weight', default=False, help='whether use fc weight for the generation of attention mask')
40-
parser.add_argument('--top_samples', default=50, type=int, help='top n activated samples')
40+
parser.add_argument('--top_samples', default=20, type=int, help='top n activated samples')
4141
# parser.add_argument('--demo_cls', default="n01498041", type=str)
4242
parser.add_argument('--fre', default=1, type=int, help='frequent of show results during training')
4343
parser.add_argument('--deactivate', default=-1, type=int, help='the index of concept to be deativated')

loaders/get_loader.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def get_transform(args):
6262
transform = transforms.Compose([transforms.Resize([args.img_size, args.img_size]), transforms.ToTensor(),
6363
transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761])])
6464
return {"train": transform, "val": transform}
65-
elif args.dataset == "CUB200" or args.dataset == "ImageNet":
65+
elif args.dataset == "CUB200" or args.dataset == "ImageNet" or args.dataset == "imagenet":
6666
transform_train = get_train_transformations(args, [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]])
6767
transform_val = get_val_transformations(args, [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]])
6868
return {"train": transform_train, "val": transform_val}
@@ -91,7 +91,7 @@ def select_dataset(args, transform):
9191
dataset_train = CUB_200(args, train=True, transform=transform["train"])
9292
dataset_val = CUB_200(args, train=False, transform=transform["val"])
9393
return dataset_train, dataset_val
94-
elif args.dataset == "ImageNet":
94+
elif args.dataset == "ImageNet" or args.dataset == "imagenet":
9595
dataset_train = ImageNet(args, "train", transform=transform["train"])
9696
dataset_val = ImageNet(args, "val", transform=transform["val"])
9797
return dataset_train, dataset_val
@@ -154,7 +154,7 @@ def filter(data):
154154
val_imgs = datasets.CIFAR10('./data/cifar10', train=False, download=True, transform=None).data
155155
val_labels = datasets.CIFAR10('./data/cifar10', train=False, download=True, transform=None).targets
156156
return train_imgs, train_labels, val_imgs, val_labels, cat
157-
elif args.dataset == "ImageNet" or args.dataset == "Custom":
157+
elif args.dataset == "ImageNet" or args.dataset == "imagenet" or args.dataset == "Custom":
158158
train = ImageNet(args, "train", transform=None).train
159159
val = ImageNet(args, "train", transform=None).val
160160
cat = ImageNet(args, "train", transform=None).category

main_retri.py main_contrast.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def main():
6060
print("get better result, save current model.")
6161
torch.save(model.state_dict(), os.path.join(args.output_dir,
6262
f"{args.dataset}_{args.base_model}_cls{args.num_classes}_" + f"cpt{args.num_cpt if not args.pre_train else ''}_" +
63-
f"{'use_slot_' + args.cpt_activation if not args.pre_train else 'no_slot'}.pt"))
63+
f"{'use_slot_' + args.cpt_activation if not args.pre_train else 'no_slot'}3.pt"))
6464

6565

6666
if __name__ == '__main__':

ACE.py model/ACE.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import shutil
1313
# from draws.draw_synthetic import draw_syn
1414
import torch
15-
from quantitative_eval import make_statistic
15+
from utils.quantitative_eval import make_statistic
1616
import json
1717
from model.retrieval.model_main import MainModel
1818

@@ -268,7 +268,7 @@ def cal_ace(self):
268268
args.device = "cuda:1"
269269
device = torch.device(args.device)
270270
model_.to(device)
271-
args.output_dir = "saved_model"
271+
args.output_dir = "../saved_model"
272272
checkpoint = torch.load(os.path.join(args.output_dir,
273273
f"{args.dataset}_{args.base_model}_cls{args.num_classes}_cpt_no_slot.pt"), map_location=device)
274274
model_.load_state_dict(checkpoint, strict=True)

kmeans.py model/kmeans.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88
import torch
99
import json
1010
from sklearn.manifold import TSNE
11-
from quantitative_eval import make_statistic
11+
from utils.quantitative_eval import make_statistic
1212
from loaders.ImageNet import get_name
13-
import sklearn.metrics.pairwise as metrics
1413
from sklearn.decomposition import PCA
1514
from model.retrieval.model_main import MainModel
1615
import cv2
@@ -210,7 +209,7 @@ def draw(data, labels):
210209
args.device = "cuda:2"
211210
device = torch.device(args.device)
212211
model_.to(device)
213-
args.output_dir = "saved_model"
212+
args.output_dir = "../saved_model"
214213
checkpoint = torch.load(os.path.join(args.output_dir,
215214
f"{args.dataset}_{args.base_model}_cls{args.num_classes}_cpt_no_slot.pt"),
216215
map_location=device)

model/protopnet.py

+243
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.utils.model_zoo as model_zoo
4+
import torch.nn.functional as F
5+
6+
7+
class PPNet(nn.Module):
8+
9+
def __init__(self, features, img_size, prototype_shape,
10+
num_classes, init_weights=True,
11+
prototype_activation_function='log',
12+
add_on_layers_type='bottleneck'):
13+
14+
super(PPNet, self).__init__()
15+
self.img_size = img_size
16+
self.prototype_shape = prototype_shape
17+
self.num_prototypes = prototype_shape[0]
18+
self.num_classes = num_classes
19+
self.epsilon = 1e-4
20+
21+
# prototype_activation_function could be 'log', 'linear',
22+
# or a generic function that converts distance to similarity score
23+
self.prototype_activation_function = prototype_activation_function
24+
25+
'''
26+
Here we are initializing the class identities of the prototypes
27+
Without domain specific knowledge we allocate the same number of
28+
prototypes for each class
29+
'''
30+
assert (self.num_prototypes % self.num_classes == 0)
31+
# a onehot indication matrix for each prototype's class identity
32+
self.prototype_class_identity = torch.zeros(self.num_prototypes,
33+
self.num_classes)
34+
35+
num_prototypes_per_class = self.num_prototypes // self.num_classes
36+
for j in range(self.num_prototypes):
37+
self.prototype_class_identity[j, j // num_prototypes_per_class] = 1
38+
39+
# this has to be named features to allow the precise loading
40+
self.features = features
41+
42+
if add_on_layers_type == 'bottleneck':
43+
add_on_layers = []
44+
current_in_channels = 512
45+
while (current_in_channels > self.prototype_shape[1]) or (len(add_on_layers) == 0):
46+
current_out_channels = max(self.prototype_shape[1], (current_in_channels // 2))
47+
add_on_layers.append(nn.Conv2d(in_channels=current_in_channels,
48+
out_channels=current_out_channels,
49+
kernel_size=1))
50+
add_on_layers.append(nn.ReLU())
51+
add_on_layers.append(nn.Conv2d(in_channels=current_out_channels,
52+
out_channels=current_out_channels,
53+
kernel_size=1))
54+
if current_out_channels > self.prototype_shape[1]:
55+
add_on_layers.append(nn.ReLU())
56+
else:
57+
assert (current_out_channels == self.prototype_shape[1])
58+
add_on_layers.append(nn.Sigmoid())
59+
current_in_channels = current_in_channels // 2
60+
self.add_on_layers = nn.Sequential(*add_on_layers)
61+
else:
62+
self.add_on_layers = nn.Sequential(
63+
nn.Conv2d(in_channels=512, out_channels=self.prototype_shape[1],
64+
kernel_size=1),
65+
nn.ReLU(),
66+
nn.Conv2d(in_channels=self.prototype_shape[1], out_channels=self.prototype_shape[1], kernel_size=1),
67+
nn.Sigmoid()
68+
)
69+
70+
self.prototype_vectors = nn.Parameter(torch.rand(self.prototype_shape),
71+
requires_grad=True)
72+
73+
# do not make this just a tensor,
74+
# since it will not be moved automatically to gpu
75+
self.ones = nn.Parameter(torch.ones(self.prototype_shape),
76+
requires_grad=False)
77+
78+
self.last_layer = nn.Linear(self.num_prototypes, self.num_classes,
79+
bias=False) # do not use bias
80+
81+
if init_weights:
82+
self._initialize_weights()
83+
84+
def conv_features(self, x):
85+
'''
86+
the feature input to prototype layer
87+
'''
88+
x = self.features(x)
89+
x = self.add_on_layers(x)
90+
return x
91+
92+
@staticmethod
93+
def _weighted_l2_convolution(input, filter, weights):
94+
'''
95+
input of shape N * c * h * w
96+
filter of shape P * c * h1 * w1
97+
weight of shape P * c * h1 * w1
98+
'''
99+
input2 = input ** 2
100+
input_patch_weighted_norm2 = F.conv2d(input=input2, weight=weights)
101+
102+
filter2 = filter ** 2
103+
weighted_filter2 = filter2 * weights
104+
filter_weighted_norm2 = torch.sum(weighted_filter2, dim=(1, 2, 3))
105+
filter_weighted_norm2_reshape = filter_weighted_norm2.view(-1, 1, 1)
106+
107+
weighted_filter = filter * weights
108+
weighted_inner_product = F.conv2d(input=input, weight=weighted_filter)
109+
110+
# use broadcast
111+
intermediate_result = \
112+
- 2 * weighted_inner_product + filter_weighted_norm2_reshape
113+
# x2_patch_sum and intermediate_result are of the same shape
114+
distances = F.relu(input_patch_weighted_norm2 + intermediate_result)
115+
116+
return distances
117+
118+
def _l2_convolution(self, x):
119+
'''
120+
apply self.prototype_vectors as l2-convolution filters on input x
121+
'''
122+
x2 = x ** 2
123+
x2_patch_sum = F.conv2d(input=x2, weight=self.ones)
124+
125+
p2 = self.prototype_vectors ** 2
126+
p2 = torch.sum(p2, dim=(1, 2, 3))
127+
# p2 is a vector of shape (num_prototypes,)
128+
# then we reshape it to (num_prototypes, 1, 1)
129+
p2_reshape = p2.view(-1, 1, 1)
130+
131+
xp = F.conv2d(input=x, weight=self.prototype_vectors)
132+
intermediate_result = - 2 * xp + p2_reshape # use broadcast
133+
# x2_patch_sum and intermediate_result are of the same shape
134+
distances = F.relu(x2_patch_sum + intermediate_result)
135+
136+
return distances
137+
138+
def prototype_distances(self, x):
139+
'''
140+
x is the raw input
141+
'''
142+
conv_features = self.conv_features(x)
143+
distances = self._l2_convolution(conv_features)
144+
return distances
145+
146+
def distance_2_similarity(self, distances):
147+
if self.prototype_activation_function == 'log':
148+
return torch.log((distances + 1) / (distances + self.epsilon))
149+
elif self.prototype_activation_function == 'linear':
150+
return -distances
151+
else:
152+
return self.prototype_activation_function(distances)
153+
154+
def forward(self, x):
155+
distances = self.prototype_distances(x)
156+
'''
157+
we cannot refactor the lines below for similarity scores
158+
because we need to return min_distances
159+
'''
160+
# global min pooling
161+
min_distances = -F.max_pool2d(-distances,
162+
kernel_size=(distances.size()[2],
163+
distances.size()[3]))
164+
min_distances = min_distances.view(-1, self.num_prototypes)
165+
prototype_activations = self.distance_2_similarity(min_distances)
166+
logits = self.last_layer(prototype_activations)
167+
return logits, min_distances
168+
169+
def push_forward(self, x):
170+
'''this method is needed for the pushing operation'''
171+
conv_output = self.conv_features(x)
172+
distances = self._l2_convolution(conv_output)
173+
return conv_output, distances
174+
175+
def prune_prototypes(self, prototypes_to_prune):
176+
'''
177+
prototypes_to_prune: a list of indices each in
178+
[0, current number of prototypes - 1] that indicates the prototypes to
179+
be removed
180+
'''
181+
prototypes_to_keep = list(set(range(self.num_prototypes)) - set(prototypes_to_prune))
182+
183+
self.prototype_vectors = nn.Parameter(self.prototype_vectors.data[prototypes_to_keep, ...],
184+
requires_grad=True)
185+
186+
self.prototype_shape = list(self.prototype_vectors.size())
187+
self.num_prototypes = self.prototype_shape[0]
188+
189+
# changing self.last_layer in place
190+
# changing in_features and out_features make sure the numbers are consistent
191+
self.last_layer.in_features = self.num_prototypes
192+
self.last_layer.out_features = self.num_classes
193+
self.last_layer.weight.data = self.last_layer.weight.data[:, prototypes_to_keep]
194+
195+
# self.ones is nn.Parameter
196+
self.ones = nn.Parameter(self.ones.data[prototypes_to_keep, ...],
197+
requires_grad=False)
198+
# self.prototype_class_identity is torch tensor
199+
# so it does not need .data access for value update
200+
self.prototype_class_identity = self.prototype_class_identity[prototypes_to_keep, :]
201+
202+
def set_last_layer_incorrect_connection(self, incorrect_strength):
203+
'''
204+
the incorrect strength will be actual strength if -0.5 then input -0.5
205+
'''
206+
positive_one_weights_locations = torch.t(self.prototype_class_identity)
207+
negative_one_weights_locations = 1 - positive_one_weights_locations
208+
209+
correct_class_connection = 1
210+
incorrect_class_connection = incorrect_strength
211+
self.last_layer.weight.data.copy_(
212+
correct_class_connection * positive_one_weights_locations
213+
+ incorrect_class_connection * negative_one_weights_locations)
214+
215+
def _initialize_weights(self):
216+
for m in self.add_on_layers.modules():
217+
if isinstance(m, nn.Conv2d):
218+
# every init technique has an underscore _ in the name
219+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
220+
221+
if m.bias is not None:
222+
nn.init.constant_(m.bias, 0)
223+
224+
elif isinstance(m, nn.BatchNorm2d):
225+
nn.init.constant_(m.weight, 1)
226+
nn.init.constant_(m.bias, 0)
227+
228+
self.set_last_layer_incorrect_connection(incorrect_strength=-0.5)
229+
230+
231+
def construct_PPNet(bone, img_size=224,
232+
prototype_shape=(1500, 512, 1, 1), num_classes=15,
233+
prototype_activation_function='log',
234+
add_on_layers_type='bottleneck'):
235+
features = bone
236+
237+
return PPNet(features=features,
238+
img_size=img_size,
239+
prototype_shape=prototype_shape,
240+
num_classes=num_classes,
241+
init_weights=True,
242+
prototype_activation_function=prototype_activation_function,
243+
add_on_layers_type=add_on_layers_type)

0 commit comments

Comments
 (0)