Skip to content

Commit 31f0830

Browse files
Pandorolucasb-eyer
authored andcommitted
Add embedding and evaluation code.
1 parent 6a16199 commit 31f0830

File tree

7 files changed

+558
-23
lines changed

7 files changed

+558
-23
lines changed

README.md

+71-3
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,80 @@ TODO: Documentation. It's also pretty straightforward.
169169

170170
## Computing embeddings
171171

172-
TODO: Will be added later.
172+
Given a trained net, one often wants to compute the embeddings of a set of pictures for further processing.
173+
This can be done with the `embed.py` script, which can also serve as inspiration for using a trained model in a larger program.
173174

174-
# Evaluating embeddings
175+
The following invocation computes the embeddings of the Market1501 query set using some network:
176+
177+
```
178+
python embed.py \
179+
--experiment_root ~/experiments/my_experiment \
180+
--dataset data/market1501_query.csv \
181+
--filename test_embeddings.h5
182+
```
183+
184+
The embeddings will be written into the HDF5 file at `~/experiments/my_experiment/test_embeddings.h5` as dataset `embs`.
185+
Most relevant settings are automatically loaded from the experiment's `args.json` file, but some can be overruled on the commandline.
186+
187+
If the training was performed using data augmentation (highly recommended),
188+
one can invest a some more time in the embedding step in order to compute augmented embeddings,
189+
which are usually more robust and perform better in downstream tasks.
175190

176-
TODO: Will be added later.
191+
The following is an example that computes extensively augmented embeddings:
192+
193+
```
194+
python embed.py \
195+
--experiment_root ~/experiments/my_experiment \
196+
--dataset data/market1501_query.csv \
197+
--filename test_embeddings_augmented.h5 \
198+
--flip_augment \
199+
--crop_augment five \
200+
--aggregator mean
201+
```
202+
203+
This will take 10 times longer, because we perform a total of 10 augmentations per image (2 flips times 5 crops).
204+
All individual embeddings will also be stored in the `.h5` file, thus the disk-space also increases.
205+
One question is how the embeddings of the various augmentations should be combined.
206+
When training using the euclidean metric in the loss, simply taking the mean is what makes most sense,
207+
and also what the above invocation does through `--aggregator mean`.
208+
But if one for example trains a normalized embedding (by using a `_normalize` head for instance),
209+
The embeddings *must* be re-normalized after averaging, and so one should use `--aggregator normalized_mean`.
210+
The final combined embedding is again stored as `embs` in the `.h5` file, as usual.
211+
212+
# Evaluating embeddings
177213

214+
Once the embeddings have been generated, it is a good idea to compute CMC curves and mAP for evaluation.
215+
With only minor modifications, the embedding `.h5` files can be used in
216+
[the official Market1501 MATLAB evaluation code](https://github.com/zhunzhong07/IDE-baseline-Market-1501),
217+
which is exactly what we did for the paper.
218+
219+
For convenience, and to spite MATLAB, we also implemented our own evaluation code in Python.
220+
This code additionally depends on [scikit-learn](http://scikit-learn.org/stable/),
221+
and still uses TensorFlow only for re-using the same metric implementation as the training code, for consistency.
222+
We verified that it produces the exact same results as the reference implementation.
223+
224+
The following is an example of evaluating a Market1501 model, notice it takes a lot of parameters :smile::
225+
226+
```
227+
./evaluate.py \
228+
--excluder market1501 \
229+
--query_dataset data/market1501_query.csv \
230+
--query_embeddings ~/experiments/my_experiment/market1501_query_embeddings.h5 \
231+
--gallery_dataset data/market1501_test.csv \
232+
--gallery_embeddings ~/experiments/my_experiment/market1501_test_embeddings.h5 \
233+
--metric euclidean \
234+
--filename ~/experiments/my_experiment/market1501_evaluation.json
235+
```
236+
237+
The only thing that really needs explaining here is the `excluder`.
238+
For some datasets, especially multi-camera ones,
239+
one often excludes pictures of the query person from the gallery (for that one person)
240+
if it is taken from the same camera.
241+
This way, one gets more of a feeling for across-camera performance.
242+
Additionally, the Market1501 dataset contains some "junk" images in the gallery which should be ignored too.
243+
All this is taken care of by `excluders`.
244+
We provide one for the Market1501 dataset, and a `diagonal` one, which should be used where there is no such restriction,
245+
for example the Stanford Online Products dataset.
178246

179247
# Independent re-implementations
180248

aggregators.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import numpy as np
2+
3+
4+
def mean(embs):
5+
return np.mean(embs, axis=0)
6+
7+
8+
def normalized_mean(embs):
9+
embs = mean(embs)
10+
return embs / np.linalg.norm(embs, axis=1, keepdims=True)
11+
12+
13+
AGGREGATORS = {
14+
'mean': mean,
15+
'normalized_mean': normalized_mean,
16+
}

common.py

+22-20
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,10 @@ def load_dataset(csv_file, image_root, fail_on_missing=True):
103103
csv_file (string, file-like object): The csv data file to load.
104104
image_root (string): The path to which the image files as stored in the
105105
csv file are relative to. Used for verification purposes.
106-
fail_on_missing (bool): If one or more files from the dataset are not
107-
present in the `image_root`, either raise an IOError (if True) or
108-
remove it from the returned dataset (if False).
106+
If this is `None`, no verification at all is made.
107+
fail_on_missing (bool or None): If one or more files from the dataset
108+
are not present in the `image_root`, either raise an IOError (if
109+
True) or remove it from the returned dataset (if False).
109110
110111
Returns:
111112
(pids, fids) a tuple of numpy string arrays corresponding to the PIDs,
@@ -117,23 +118,24 @@ def load_dataset(csv_file, image_root, fail_on_missing=True):
117118
dataset = np.genfromtxt(csv_file, delimiter=',', dtype='|U')
118119
pids, fids = dataset.T
119120

120-
# Check if all files exist
121-
missing = np.full(len(fids), False, dtype=bool)
122-
for i, fid in enumerate(fids):
123-
missing[i] = not os.path.isfile(os.path.join(image_root, fid))
124-
125-
missing_count = np.sum(missing)
126-
if missing_count > 0:
127-
if fail_on_missing:
128-
raise IOError('Using the `{}` file and `{}` as an image root {}/'
129-
'{} images are missing'.format(
130-
csv_file, image_root, missing_count, len(fids)))
131-
else:
132-
print('[Warning] removing {} missing file(s) from the'
133-
' dataset.'.format(missing_count))
134-
# We simply remove the missing files.
135-
fids = fids[np.logical_not(missing)]
136-
pids = pids[np.logical_not(missing)]
121+
# Possibly check if all files exist
122+
if image_root is not None:
123+
missing = np.full(len(fids), False, dtype=bool)
124+
for i, fid in enumerate(fids):
125+
missing[i] = not os.path.isfile(os.path.join(image_root, fid))
126+
127+
missing_count = np.sum(missing)
128+
if missing_count > 0:
129+
if fail_on_missing:
130+
raise IOError('Using the `{}` file and `{}` as an image root {}/'
131+
'{} images are missing'.format(
132+
csv_file, image_root, missing_count, len(fids)))
133+
else:
134+
print('[Warning] removing {} missing file(s) from the'
135+
' dataset.'.format(missing_count))
136+
# We simply remove the missing files.
137+
fids = fids[np.logical_not(missing)]
138+
pids = pids[np.logical_not(missing)]
137139

138140
return pids, fids
139141

0 commit comments

Comments
 (0)