Skip to content

Commit aa4b8e5

Browse files
committed
Add instructions and sample code for using TriNet.
1 parent 4fa93e7 commit aa4b8e5

File tree

2 files changed

+344
-2
lines changed

2 files changed

+344
-2
lines changed

README.md

+44-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,47 @@
11
# Triplet-based Person Re-Identification
2+
23
Code for reproducing the results of our "In Defense of the Triplet Loss for Person Re-Identification" paper.
34

4-
# Publication Pending!
5-
The publication of all code is pending acceptance of the paper. Since we've been asked several times, this work was a submission to ICCV'17 and you can have look at the [timeline](http://iccv2017.thecvf.com/submission/timeline), so if all goes well, code will come online sometime this summer.
5+
Both main authors are currently in an internship.
6+
We will publish the full training code after our internships, which is end of September 2017.
7+
(By "Watching" this project on github, you will receive e-mails about updates to this repo.)
8+
Meanwhile, we provide the pre-trained weights for the TriNet model, as well as some rudimentary example code for using it to compute embeddings, see below.
9+
10+
# Pretrained Models
11+
12+
This is a first, simple release. A better more generic script will follow in a few months, but this should be enough to get started trying out our models!
13+
14+
As a first step, download the weights for the TriNet model [trained on MARS](https://omnomnom.vision.rwth-aachen.de/data/trinet-mars.npz) or trained on [Market1501](https://omnomnom.vision.rwth-aachen.de/data/trinet-market1501.npz).
15+
(Pre-trained LuNet models will follow.)
16+
17+
Next, create a file (`files.txt`) which contains the full path to the image files you want to embed, one filename per line, like so:
18+
19+
```
20+
/path/to/file1.png
21+
/path/to/file2.jpg
22+
```
23+
24+
Finally, run the `trinet_embed.py` script, passing both the above file and the weights file you want to use, like so:
25+
26+
```
27+
python trinet_embed.py files.txt /path/to/trinet-mars.npz
28+
```
29+
30+
And it will output one comma-separated line for each file, containing the filename followed by the embedding, like so:
31+
32+
```
33+
/path/to/file1.png,-1.234,5.678,...
34+
/path/to/file2.jpg,9.876,-1.234,...
35+
```
36+
37+
You could for example redirect it to a file for further processing:
38+
39+
```
40+
python trinet_embed.py files.txt /path/to/trinet-market1501.npz >embeddings.csv
41+
```
42+
43+
You can now do meaningful work by comparing these embeddings using the Euclidean distance, for example, try some K-means clustering!
44+
45+
A couple notes:
46+
- The script depends on both [Theano](http://deeplearning.net/software/theano/install.html) and [Lasagne](http://lasagne.readthedocs.io/en/latest/user/installation.html) being correctly installed.
47+
- The input files should be crops of a full person standing upright, and they will be resized to `288x144` before being passed to the network.

trinet_embed.py

+300
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
#!/usr/bin/env python
2+
from __future__ import print_function
3+
import numpy as np
4+
import cv2
5+
import pickle
6+
import sys
7+
8+
9+
if len(sys.argv) != 3:
10+
print("Usage: {} IMAGE_LIST_FILE MODEL_WEIGHT_FILE".format(sys.argv[0]))
11+
sys.exit(1)
12+
13+
# Specify the path to a Market-1501 image that should be embedded and the location of the weights we provided.
14+
image_list = list(map(str.strip, open(sys.argv[1]).readlines()))
15+
weight_fname = sys.argv[2]
16+
17+
18+
19+
# Setup the pretrained ResNet
20+
21+
#This is based on the Lasagne ResNet-50 example with slight modifications to allow for different input sizes.
22+
#The original can be found at: https://github.com/Lasagne/Recipes/blob/master/examples/resnet50/ImageNet%20Pretrained%20Network%20(ResNet-50).ipynb
23+
import theano
24+
import lasagne
25+
from lasagne.layers import InputLayer
26+
from lasagne.layers import Conv2DLayer as ConvLayer
27+
from lasagne.layers import BatchNormLayer
28+
from lasagne.layers import Pool2DLayer as PoolLayer
29+
from lasagne.layers import NonlinearityLayer
30+
from lasagne.layers import ElemwiseSumLayer
31+
from lasagne.layers import DenseLayer
32+
from lasagne.nonlinearities import rectify, softmax
33+
34+
35+
def build_simple_block(incoming_layer, names,
36+
num_filters, filter_size, stride, pad,
37+
use_bias=False, nonlin=rectify):
38+
"""Creates stacked Lasagne layers ConvLayer -> BN -> (ReLu)
39+
40+
Parameters:
41+
----------
42+
incoming_layer : instance of Lasagne layer
43+
Parent layer
44+
45+
names : list of string
46+
Names of the layers in block
47+
48+
num_filters : int
49+
Number of filters in convolution layer
50+
51+
filter_size : int
52+
Size of filters in convolution layer
53+
54+
stride : int
55+
Stride of convolution layer
56+
57+
pad : int
58+
Padding of convolution layer
59+
60+
use_bias : bool
61+
Whether to use bias in conlovution layer
62+
63+
nonlin : function
64+
Nonlinearity type of Nonlinearity layer
65+
66+
Returns
67+
-------
68+
tuple: (net, last_layer_name)
69+
net : dict
70+
Dictionary with stacked layers
71+
last_layer_name : string
72+
Last layer name
73+
"""
74+
net = []
75+
net.append((
76+
names[0],
77+
ConvLayer(incoming_layer, num_filters, filter_size, stride, pad,
78+
flip_filters=False, nonlinearity=None) if use_bias
79+
else ConvLayer(incoming_layer, num_filters, filter_size, stride, pad, b=None,
80+
flip_filters=False, nonlinearity=None)
81+
))
82+
83+
net.append((
84+
names[1],
85+
BatchNormLayer(net[-1][1])
86+
))
87+
if nonlin is not None:
88+
net.append((
89+
names[2],
90+
NonlinearityLayer(net[-1][1], nonlinearity=nonlin)
91+
))
92+
93+
return dict(net), net[-1][0]
94+
95+
96+
def build_residual_block(incoming_layer, ratio_n_filter=1.0, ratio_size=1.0, has_left_branch=False,
97+
upscale_factor=4, ix=''):
98+
"""Creates two-branch residual block
99+
100+
Parameters:
101+
----------
102+
incoming_layer : instance of Lasagne layer
103+
Parent layer
104+
105+
ratio_n_filter : float
106+
Scale factor of filter bank at the input of residual block
107+
108+
ratio_size : float
109+
Scale factor of filter size
110+
111+
has_left_branch : bool
112+
if True, then left branch contains simple block
113+
114+
upscale_factor : float
115+
Scale factor of filter bank at the output of residual block
116+
117+
ix : int
118+
Id of residual block
119+
120+
Returns
121+
-------
122+
tuple: (net, last_layer_name)
123+
net : dict
124+
Dictionary with stacked layers
125+
last_layer_name : string
126+
Last layer name
127+
"""
128+
simple_block_name_pattern = ['res%s_branch%i%s', 'bn%s_branch%i%s', 'res%s_branch%i%s_relu']
129+
130+
net = {}
131+
132+
# right branch
133+
net_tmp, last_layer_name = build_simple_block(
134+
incoming_layer, list(map(lambda s: s % (ix, 2, 'a'), simple_block_name_pattern)),
135+
int(lasagne.layers.get_output_shape(incoming_layer)[1]*ratio_n_filter), 1, int(1.0/ratio_size), 0)
136+
net.update(net_tmp)
137+
138+
net_tmp, last_layer_name = build_simple_block(
139+
net[last_layer_name], list(map(lambda s: s % (ix, 2, 'b'), simple_block_name_pattern)),
140+
lasagne.layers.get_output_shape(net[last_layer_name])[1], 3, 1, 1)
141+
net.update(net_tmp)
142+
143+
net_tmp, last_layer_name = build_simple_block(
144+
net[last_layer_name], list(map(lambda s: s % (ix, 2, 'c'), simple_block_name_pattern)),
145+
lasagne.layers.get_output_shape(net[last_layer_name])[1]*upscale_factor, 1, 1, 0,
146+
nonlin=None)
147+
net.update(net_tmp)
148+
149+
right_tail = net[last_layer_name]
150+
left_tail = incoming_layer
151+
152+
# left branch
153+
if has_left_branch:
154+
net_tmp, last_layer_name = build_simple_block(
155+
incoming_layer, list(map(lambda s: s % (ix, 1, ''), simple_block_name_pattern)),
156+
int(lasagne.layers.get_output_shape(incoming_layer)[1]*4*ratio_n_filter), 1, int(1.0/ratio_size), 0,
157+
nonlin=None)
158+
net.update(net_tmp)
159+
left_tail = net[last_layer_name]
160+
161+
net['res%s' % ix] = ElemwiseSumLayer([left_tail, right_tail], coeffs=1)
162+
net['res%s_relu' % ix] = NonlinearityLayer(net['res%s' % ix], nonlinearity=rectify, name = 'res%s_relu' % ix)
163+
164+
return net, 'res%s_relu' % ix
165+
166+
167+
def build_model(input_size):
168+
net = {}
169+
net['input'] = InputLayer(input_size)
170+
sub_net, parent_layer_name = build_simple_block(
171+
net['input'], ['conv1', 'bn_conv1', 'conv1_relu'],
172+
64, 7, 2, 3, use_bias=True)
173+
net.update(sub_net)
174+
net['pool1'] = PoolLayer(net[parent_layer_name], pool_size=3, stride=2, pad=0, mode='max', ignore_border=False)
175+
block_size = list('abc')
176+
parent_layer_name = 'pool1'
177+
for c in block_size:
178+
if c == 'a':
179+
sub_net, parent_layer_name = build_residual_block(net[parent_layer_name], 1, 1, True, 4, ix='2%s' % c)
180+
else:
181+
sub_net, parent_layer_name = build_residual_block(net[parent_layer_name], 1.0/4, 1, False, 4, ix='2%s' % c)
182+
net.update(sub_net)
183+
184+
block_size = list('abcd')
185+
for c in block_size:
186+
if c == 'a':
187+
sub_net, parent_layer_name = build_residual_block(
188+
net[parent_layer_name], 1.0/2, 1.0/2, True, 4, ix='3%s' % c)
189+
else:
190+
sub_net, parent_layer_name = build_residual_block(net[parent_layer_name], 1.0/4, 1, False, 4, ix='3%s' % c)
191+
net.update(sub_net)
192+
193+
block_size = list('abcdef')
194+
for c in block_size:
195+
if c == 'a':
196+
sub_net, parent_layer_name = build_residual_block(
197+
net[parent_layer_name], 1.0/2, 1.0/2, True, 4, ix='4%s' % c)
198+
else:
199+
sub_net, parent_layer_name = build_residual_block(net[parent_layer_name], 1.0/4, 1, False, 4, ix='4%s' % c)
200+
net.update(sub_net)
201+
202+
block_size = list('abc')
203+
for c in block_size:
204+
if c == 'a':
205+
sub_net, parent_layer_name = build_residual_block(
206+
net[parent_layer_name], 1.0/2, 1.0/2, True, 4, ix='5%s' % c)
207+
else:
208+
sub_net, parent_layer_name = build_residual_block(net[parent_layer_name], 1.0/4, 1, False, 4, ix='5%s' % c)
209+
net.update(sub_net)
210+
net['pool5'] = PoolLayer(net[parent_layer_name], pool_size=7, stride=1, pad=0,
211+
mode='average_exc_pad', ignore_border=False)
212+
213+
return net
214+
215+
216+
#Setup the original network
217+
resnet = build_model(input_size=(None, 3, 256,128))
218+
219+
#Now we modify the network's final pooling layer and add 2 new layers at the end to predict the 128-dimensional embedding.
220+
#Different input size.
221+
inp = resnet['input']
222+
223+
network_features = resnet['pool5']
224+
network_features.pool_size=(8,4)
225+
226+
#New additional final layer
227+
network = lasagne.layers.batch_norm(lasagne.layers.DenseLayer(
228+
network_features,
229+
num_units=1024,
230+
nonlinearity=lasagne.nonlinearities.rectify,
231+
W=lasagne.init.GlorotUniform('relu'),
232+
b=None))
233+
234+
network_out = lasagne.layers.DenseLayer(
235+
network,
236+
num_units=128,
237+
nonlinearity=None,
238+
W=lasagne.init.Orthogonal())
239+
240+
241+
242+
#Setup the function to predict the embeddings.
243+
predict_features = theano.function(
244+
inputs=[inp.input_var],
245+
outputs=lasagne.layers.get_output(network_out, deterministic=True))
246+
247+
248+
#Set the parameters
249+
with np.load(weight_fname) as f:
250+
param_values = [f['arr_%d' % i] for i in range(len(f.files))]
251+
lasagne.layers.set_all_param_values(network_out, param_values)
252+
253+
254+
255+
#We subtract the per-channel mean of the "mean image" as loaded from the original ResNet-50 weight dump.
256+
#For simplcity, we just hardcode it here.
257+
im_mean = np.asarray([103.0626238, 115.90288257, 123.15163084], dtype=np.float32)
258+
259+
260+
261+
# a little helper function to create a test-time augmentation batch.
262+
def get_augmentation_batch(image, im_mean):
263+
#Resize it correctly, as needed by the test time augmentation.
264+
image = cv2.resize(image, (128+16, 256+32))
265+
266+
#Change into CHW format
267+
image = np.rollaxis(image,2)
268+
269+
#Setup storage for the batch
270+
batch = np.zeros((10,3,256,128), dtype=np.float32)
271+
272+
#Four corner crops and the center crop
273+
batch[0] = image[:,16:-16, 8:-8] #Center crop
274+
batch[1] = image[:, :-32, :-16] #Top left
275+
batch[2] = image[:, :-32, 16:] #Top right
276+
batch[3] = image[:, 32:, :-16] #Bottom left
277+
batch[4] = image[:, 32:, 16:] #Bottom right
278+
279+
#Flipping
280+
batch[5:] = batch[:5,:,:,::-1]
281+
282+
#Subtract the mean
283+
batch = batch-im_mean[None,:,None,None]
284+
285+
return batch
286+
287+
288+
289+
for image_filename in image_list:
290+
print(image_filename, end=",")
291+
sys.stdout.flush()
292+
293+
image = cv2.imread(image_filename)
294+
if image is None:
295+
raise ValueError("Couldn't load image {}".format(image_filename))
296+
297+
#Setup a batch of images and use the function to predict the embedding.
298+
batch = get_augmentation_batch(image, im_mean)
299+
embedding = np.mean(predict_features(batch), axis=0)
300+
print(','.join(map(str, embedding)))

0 commit comments

Comments
 (0)