Skip to content

Add support to print the predictions and to use mean_value when mean_file is not supplied #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: rc3-bvlc-ppc
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion python/caffe/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,14 @@ def set_mean(self, in_, mean):
if len(ms) != 3:
raise ValueError('Mean shape invalid')
if ms != self.inputs[in_][1:]:
raise ValueError('Mean shape incompatible with input shape.')
print(self.inputs[in_])
in_shape = self.inputs[in_][1:]
m_min, m_max = mean.min(), mean.max()
normal_mean = (mean - m_min) / (m_max - m_min)
mean = resize_image(normal_mean.transpose((1,2,0)),
in_shape[1:]).transpose((2,0,1)) * \
(m_max - m_min) + m_min
#raise ValueError('Mean shape incompatible with input shape.')
self.mean[in_] = mean

def set_input_scale(self, in_, scale):
Expand Down
56 changes: 54 additions & 2 deletions python/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
By default it configures and runs the Caffe reference ImageNet model.
"""
import numpy as np
import pandas as pd
import os
import sys
import argparse
Expand Down Expand Up @@ -80,6 +81,17 @@ def main(argv):
help="Order to permute input channels. The default converts " +
"RGB -> BGR since BGR is the Caffe default by way of OpenCV."
)
parser.add_argument(
"--labels_file",
default=os.path.join(pycaffe_dir,
"../data/ilsvrc12/synset_words.txt"),
help="Readable label definition file."
)
parser.add_argument(
"--print_results",
action='store_true',
help="Write output text to stdout rather than serializing to a file."
)
parser.add_argument(
"--ext",
default='jpg',
Expand All @@ -93,6 +105,10 @@ def main(argv):
mean, channel_swap = None, None
if args.mean_file:
mean = np.load(args.mean_file)
else:
# channel-wise mean
mean = np.array([104,117,123])

if args.channel_swap:
channel_swap = [int(s) for s in args.channel_swap.split(',')]

Expand Down Expand Up @@ -126,12 +142,48 @@ def main(argv):

# Classify.
start = time.time()
predictions = classifier.predict(inputs, not args.center_only)
scores = classifier.predict(inputs, not args.center_only).flatten()
print("Done in %.2f s." % (time.time() - start))

# The script has been updated to support --print_results option.
# Ref - http://stackoverflow.com/questions/37265197/classify-py-is-not-taking-argument-print-results
# However, the labels format supported here has been modified, such that the file can have shorttext
# corresponding to the category classes instead of the general format.
# The commented part correspond to the general format of labels file which has mapping between
# synset_id and the text.

if args.print_results:
# with open(args.labels_file) as f:
# labels_df = pd.DataFrame([
# {
# 'synset_id': l.strip().split(' ')[0],
# 'name': ' '.join(l.strip().split(' ')[1:]).split(',')[0]
# }
# for l in f.readlines()
# ])
# labels_df.synset_id = labels_df.synset_id.astype(np.int64)
# labels = labels_df.sort('synset_id')['name'].values

labels_file = open(args.labels_file, 'r')
labels = labels_file.readlines()

indices = (-scores).argsort()[:5]
# predictions = labels[indices]
# print(predictions)
# meta = [
# (p, '%.5f' % scores[i])
# for i, p in zip(indices, predictions)
# ]
# print meta
print("---------------------------------")
print("The top 5 predictions are")
for i in indices:
print('%.4f %s' % (scores[i] , labels[i].strip('\n')))
print("---------------------------------")

# Save
print("Saving results into %s" % args.output_file)
np.save(args.output_file, predictions)
np.save(args.output_file, scores)


if __name__ == '__main__':
Expand Down