Skip to content

Commit 358b60c

Browse files
committed
Merge pull request BVLC#3471 from beijbom/clean-datalayer-tutorial
[example] tutorial on python data layers and multilabel classification
2 parents 2c5a24e + 9f8f777 commit 358b60c

File tree

3 files changed

+815
-0
lines changed

3 files changed

+815
-0
lines changed

examples/pascal-multilabel-with-datalayer.ipynb

+478
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# imports
2+
import json
3+
import time
4+
import pickle
5+
import scipy.misc
6+
import skimage.io
7+
import caffe
8+
9+
import numpy as np
10+
import os.path as osp
11+
12+
from xml.dom import minidom
13+
from random import shuffle
14+
from threading import Thread
15+
from PIL import Image
16+
17+
from tools import SimpleTransformer
18+
19+
20+
class PascalMultilabelDataLayerSync(caffe.Layer):
21+
22+
"""
23+
This is a simple syncronous datalayer for training a multilabel model on
24+
PASCAL.
25+
"""
26+
27+
def setup(self, bottom, top):
28+
29+
self.top_names = ['data', 'label']
30+
31+
# === Read input parameters ===
32+
33+
# params is a python dictionary with layer parameters.
34+
params = eval(self.param_str)
35+
36+
# Check the paramameters for validity.
37+
check_params(params)
38+
39+
# store input as class variables
40+
self.batch_size = params['batch_size']
41+
42+
# Create a batch loader to load the images.
43+
self.batch_loader = BatchLoader(params, None)
44+
45+
# === reshape tops ===
46+
# since we use a fixed input image size, we can shape the data layer
47+
# once. Else, we'd have to do it in the reshape call.
48+
top[0].reshape(
49+
self.batch_size, 3, params['im_shape'][0], params['im_shape'][1])
50+
# Note the 20 channels (because PASCAL has 20 classes.)
51+
top[1].reshape(self.batch_size, 20)
52+
53+
print_info("PascalMultilabelDataLayerSync", params)
54+
55+
def forward(self, bottom, top):
56+
"""
57+
Load data.
58+
"""
59+
for itt in range(self.batch_size):
60+
# Use the batch loader to load the next image.
61+
im, multilabel = self.batch_loader.load_next_image()
62+
63+
# Add directly to the caffe data layer
64+
top[0].data[itt, ...] = im
65+
top[1].data[itt, ...] = multilabel
66+
67+
def reshape(self, bottom, top):
68+
"""
69+
There is no need to reshape the data, since the input is of fixed size
70+
(rows and columns)
71+
"""
72+
pass
73+
74+
def backward(self, top, propagate_down, bottom):
75+
"""
76+
These layers does not back propagate
77+
"""
78+
pass
79+
80+
81+
class BatchLoader(object):
82+
83+
"""
84+
This class abstracts away the loading of images.
85+
Images can either be loaded singly, or in a batch. The latter is used for
86+
the asyncronous data layer to preload batches while other processing is
87+
performed.
88+
"""
89+
90+
def __init__(self, params, result):
91+
self.result = result
92+
self.batch_size = params['batch_size']
93+
self.pascal_root = params['pascal_root']
94+
self.im_shape = params['im_shape']
95+
# get list of image indexes.
96+
list_file = params['split'] + '.txt'
97+
self.indexlist = [line.rstrip('\n') for line in open(
98+
osp.join(self.pascal_root, 'ImageSets/Main', list_file))]
99+
self._cur = 0 # current image
100+
# this class does some simple data-manipulations
101+
self.transformer = SimpleTransformer()
102+
103+
print "BatchLoader initialized with {} images".format(
104+
len(self.indexlist))
105+
106+
def load_next_image(self):
107+
"""
108+
Load the next image in a batch.
109+
"""
110+
# Did we finish an epoch?
111+
if self._cur == len(self.indexlist):
112+
self._cur = 0
113+
shuffle(self.indexlist)
114+
115+
# Load an image
116+
index = self.indexlist[self._cur] # Get the image index
117+
image_file_name = index + '.jpg'
118+
im = np.asarray(Image.open(
119+
osp.join(self.pascal_root, 'JPEGImages', image_file_name)))
120+
im = scipy.misc.imresize(im, self.im_shape) # resize
121+
122+
# do a simple horizontal flip as data augmentation
123+
flip = np.random.choice(2)*2-1
124+
im = im[:, ::flip, :]
125+
126+
# Load and prepare ground truth
127+
multilabel = np.zeros(20).astype(np.float32)
128+
anns = load_pascal_annotation(index, self.pascal_root)
129+
for label in anns['gt_classes']:
130+
# in the multilabel problem we don't care how MANY instances
131+
# there are of each class. Only if they are present.
132+
# The "-1" is b/c we are not interested in the background
133+
# class.
134+
multilabel[label - 1] = 1
135+
136+
self._cur += 1
137+
return self.transformer.preprocess(im), multilabel
138+
139+
140+
def load_pascal_annotation(index, pascal_root):
141+
"""
142+
This code is borrowed from Ross Girshick's FAST-RCNN code
143+
(https://github.com/rbgirshick/fast-rcnn).
144+
It parses the PASCAL .xml metadata files.
145+
See publication for further details: (http://arxiv.org/abs/1504.08083).
146+
147+
Thanks Ross!
148+
149+
"""
150+
classes = ('__background__', # always index 0
151+
'aeroplane', 'bicycle', 'bird', 'boat',
152+
'bottle', 'bus', 'car', 'cat', 'chair',
153+
'cow', 'diningtable', 'dog', 'horse',
154+
'motorbike', 'person', 'pottedplant',
155+
'sheep', 'sofa', 'train', 'tvmonitor')
156+
class_to_ind = dict(zip(classes, xrange(21)))
157+
158+
filename = osp.join(pascal_root, 'Annotations', index + '.xml')
159+
# print 'Loading: {}'.format(filename)
160+
161+
def get_data_from_tag(node, tag):
162+
return node.getElementsByTagName(tag)[0].childNodes[0].data
163+
164+
with open(filename) as f:
165+
data = minidom.parseString(f.read())
166+
167+
objs = data.getElementsByTagName('object')
168+
num_objs = len(objs)
169+
170+
boxes = np.zeros((num_objs, 4), dtype=np.uint16)
171+
gt_classes = np.zeros((num_objs), dtype=np.int32)
172+
overlaps = np.zeros((num_objs, 21), dtype=np.float32)
173+
174+
# Load object bounding boxes into a data frame.
175+
for ix, obj in enumerate(objs):
176+
# Make pixel indexes 0-based
177+
x1 = float(get_data_from_tag(obj, 'xmin')) - 1
178+
y1 = float(get_data_from_tag(obj, 'ymin')) - 1
179+
x2 = float(get_data_from_tag(obj, 'xmax')) - 1
180+
y2 = float(get_data_from_tag(obj, 'ymax')) - 1
181+
cls = class_to_ind[
182+
str(get_data_from_tag(obj, "name")).lower().strip()]
183+
boxes[ix, :] = [x1, y1, x2, y2]
184+
gt_classes[ix] = cls
185+
overlaps[ix, cls] = 1.0
186+
187+
overlaps = scipy.sparse.csr_matrix(overlaps)
188+
189+
return {'boxes': boxes,
190+
'gt_classes': gt_classes,
191+
'gt_overlaps': overlaps,
192+
'flipped': False,
193+
'index': index}
194+
195+
196+
def check_params(params):
197+
"""
198+
A utility function to check the parameters for the data layers.
199+
"""
200+
assert 'split' in params.keys(
201+
), 'Params must include split (train, val, or test).'
202+
203+
required = ['batch_size', 'pascal_root', 'im_shape']
204+
for r in required:
205+
assert r in params.keys(), 'Params must include {}'.format(r)
206+
207+
208+
def print_info(name, params):
209+
"""
210+
Ouput some info regarding the class
211+
"""
212+
print "{} initialized for split: {}, with bs: {}, im_shape: {}.".format(
213+
name,
214+
params['split'],
215+
params['batch_size'],
216+
params['im_shape'])

examples/pycaffe/tools.py

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import numpy as np
2+
3+
4+
class SimpleTransformer:
5+
6+
"""
7+
SimpleTransformer is a simple class for preprocessing and deprocessing
8+
images for caffe.
9+
"""
10+
11+
def __init__(self, mean=[128, 128, 128]):
12+
self.mean = np.array(mean, dtype=np.float32)
13+
self.scale = 1.0
14+
15+
def set_mean(self, mean):
16+
"""
17+
Set the mean to subtract for centering the data.
18+
"""
19+
self.mean = mean
20+
21+
def set_scale(self, scale):
22+
"""
23+
Set the data scaling.
24+
"""
25+
self.scale = scale
26+
27+
def preprocess(self, im):
28+
"""
29+
preprocess() emulate the pre-processing occuring in the vgg16 caffe
30+
prototxt.
31+
"""
32+
33+
im = np.float32(im)
34+
im = im[:, :, ::-1] # change to BGR
35+
im -= self.mean
36+
im *= self.scale
37+
im = im.transpose((2, 0, 1))
38+
39+
return im
40+
41+
def deprocess(self, im):
42+
"""
43+
inverse of preprocess()
44+
"""
45+
im = im.transpose(1, 2, 0)
46+
im /= self.scale
47+
im += self.mean
48+
im = im[:, :, ::-1] # change to RGB
49+
50+
return np.uint8(im)
51+
52+
53+
class CaffeSolver:
54+
55+
"""
56+
Caffesolver is a class for creating a solver.prototxt file. It sets default
57+
values and can export a solver parameter file.
58+
Note that all parameters are stored as strings. Strings variables are
59+
stored as strings in strings.
60+
"""
61+
62+
def __init__(self, testnet_prototxt_path="testnet.prototxt",
63+
trainnet_prototxt_path="trainnet.prototxt", debug=False):
64+
65+
self.sp = {}
66+
67+
# critical:
68+
self.sp['base_lr'] = '0.001'
69+
self.sp['momentum'] = '0.9'
70+
71+
# speed:
72+
self.sp['test_iter'] = '100'
73+
self.sp['test_interval'] = '250'
74+
75+
# looks:
76+
self.sp['display'] = '25'
77+
self.sp['snapshot'] = '2500'
78+
self.sp['snapshot_prefix'] = '"snapshot"' # string withing a string!
79+
80+
# learning rate policy
81+
self.sp['lr_policy'] = '"fixed"'
82+
83+
# important, but rare:
84+
self.sp['gamma'] = '0.1'
85+
self.sp['weight_decay'] = '0.0005'
86+
self.sp['train_net'] = '"' + trainnet_prototxt_path + '"'
87+
self.sp['test_net'] = '"' + testnet_prototxt_path + '"'
88+
89+
# pretty much never change these.
90+
self.sp['max_iter'] = '100000'
91+
self.sp['test_initialization'] = 'false'
92+
self.sp['average_loss'] = '25' # this has to do with the display.
93+
self.sp['iter_size'] = '1' # this is for accumulating gradients
94+
95+
if (debug):
96+
self.sp['max_iter'] = '12'
97+
self.sp['test_iter'] = '1'
98+
self.sp['test_interval'] = '4'
99+
self.sp['display'] = '1'
100+
101+
def add_from_file(self, filepath):
102+
"""
103+
Reads a caffe solver prototxt file and updates the Caffesolver
104+
instance parameters.
105+
"""
106+
with open(filepath, 'r') as f:
107+
for line in f:
108+
if line[0] == '#':
109+
continue
110+
splitLine = line.split(':')
111+
self.sp[splitLine[0].strip()] = splitLine[1].strip()
112+
113+
def write(self, filepath):
114+
"""
115+
Export solver parameters to INPUT "filepath". Sorted alphabetically.
116+
"""
117+
f = open(filepath, 'w')
118+
for key, value in sorted(self.sp.items()):
119+
if not(type(value) is str):
120+
raise TypeError('All solver parameters must be strings')
121+
f.write('%s: %s\n' % (key, value))

0 commit comments

Comments
 (0)