Skip to content

Commit 6335c6e

Browse files
authored
Add files via upload
1 parent 3489a56 commit 6335c6e

17 files changed

+404
-0
lines changed

image_retrieval.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import os
2+
import cv2
3+
import time
4+
from datetime import timedelta
5+
from retrieval.create_thumb_images import create_thumb_images
6+
from flask import Flask, render_template, request, redirect, url_for, make_response,jsonify, flash
7+
from retrieval.retrieval import load_model, load_data, extract_feature, load_query_image, sort_img, extract_feature_query
8+
9+
# Create thumb images.
10+
create_thumb_images(full_folder='./static/image_database/',
11+
thumb_folder='./static/thumb_images/',
12+
suffix='',
13+
height=200,
14+
del_former_thumb=True,
15+
)
16+
17+
# Prepare data set.
18+
data_loader = load_data(data_path='./static/image_database/',
19+
batch_size=2,
20+
shuffle=False,
21+
transform='default',
22+
)
23+
24+
# Prepare model.
25+
model = load_model(pretrained_model='./retrieval/models/net_best.pth', use_gpu=True)
26+
27+
# Extract database features.
28+
gallery_feature, image_paths = extract_feature(model=model, dataloaders=data_loader)
29+
30+
# Picture extension supported.
31+
ALLOWED_EXTENSIONS = set(['png', 'jpg', 'JPG', 'PNG', 'bmp', 'jpeg', 'JPEG'])
32+
def allowed_file(filename):
33+
return '.' in filename and filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS
34+
35+
app = Flask(__name__)
36+
# Set static file cache expiration time
37+
# app.send_file_max_age_default = timedelta(seconds=1)
38+
39+
40+
@app.route('/', methods=['POST', 'GET']) # add route
41+
def image_retrieval():
42+
43+
basepath = os.path.dirname(__file__) # current path
44+
upload_path = os.path.join(basepath, 'static/upload_image','query.jpg')
45+
46+
if request.method == 'POST':
47+
if request.form['submit'] == 'upload':
48+
if len(request.files) == 0:
49+
return render_template('upload_finish.html', message='Please select a picture file!')
50+
else:
51+
f = request.files['picture']
52+
53+
if not (f and allowed_file(f.filename)):
54+
# return jsonify({"error": 1001, "msg": "Examine picture extension, only png, PNG, jpg, JPG, or bmp supported."})
55+
return render_template('upload_finish.html', message='Examine picture extension, png、PNG、jpg、JPG、bmp support.')
56+
else:
57+
58+
f.save(upload_path)
59+
60+
# transform image format and name with opencv.
61+
img = cv2.imread(upload_path)
62+
cv2.imwrite(os.path.join(basepath, 'static/upload_image', 'query.jpg'), img)
63+
64+
return render_template('upload_finish.html', message='Upload successfully!')
65+
66+
elif request.form['submit'] == 'retrieval':
67+
start_time = time.time()
68+
# Query.
69+
query_image = load_query_image('./static/upload_image/query.jpg')
70+
# Extract query features.
71+
query_feature = extract_feature_query(model=model, img=query_image)
72+
# Sort.
73+
similarity, index = sort_img(query_feature, gallery_feature)
74+
sorted_paths = [image_paths[i] for i in index]
75+
76+
print(sorted_paths)
77+
tmb_images = ['./static/thumb_images/' + os.path.split(sorted_path)[1] for sorted_path in sorted_paths]
78+
# sorted_files = [os.path.split(sorted_path)[1] for sorted_path in sorted_paths]
79+
80+
return render_template('retrieval.html', message="Retrieval finished, cost {:3f} seconds.".format(time.time() - start_time),
81+
sml1=similarity[0], sml2=similarity[1], sml3=similarity[2], sml4=similarity[3], sml5=similarity[4], sml6=similarity[5],
82+
img1_tmb=tmb_images[0], img2_tmb=tmb_images[1],img3_tmb=tmb_images[2],img4_tmb=tmb_images[3],img5_tmb=tmb_images[4],img6_tmb=tmb_images[5])
83+
84+
return render_template('upload.html')
85+
86+
87+
if __name__ == '__main__':
88+
# app.debug = True
89+
app.run(host='0.0.0.0', port=8080, debug=True)

retrieval/create_thumb_images.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import os
2+
import cv2
3+
4+
5+
def del_file(path):
6+
for i in os.listdir(path):
7+
path_file = os.path.join(path, i)
8+
if os.path.isfile(path_file):
9+
os.remove(path_file)
10+
else:
11+
del_file(path_file)
12+
13+
14+
def create_thumb_images(full_folder, thumb_folder, suffix='thumb', height=100, del_former_thumb=False):
15+
if del_former_thumb:
16+
del_file(thumb_folder)
17+
for image_file in os.listdir(full_folder):
18+
image = cv2.imread(full_folder + image_file)
19+
height_src, width_src, _ = image.shape
20+
#print('width: {}, height: {}'.format(width_src, height_src))
21+
22+
width = (height / height_src) * width_src
23+
# print(' Thumb width: {}, height: {}'.format(width, height))
24+
25+
resized_image = cv2.resize(image, (int(width), int(height)))
26+
27+
image_name, image_extension = os.path.splitext(image_file)
28+
cv2.imwrite(thumb_folder + image_name + suffix + image_extension, resized_image)
29+
print('Creating thumb images finished.')
30+
31+
32+
if __name__ == '__main__':
33+
create_thumb_images(full_folder='./save_pic/',
34+
thumb_folder='./thumb_images/',
35+
suffix='',
36+
height=200,
37+
del_former_thumb=True,
38+
)

retrieval/retrieval.py

+191
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from __future__ import print_function, division
4+
5+
import argparse
6+
import torch
7+
import torch.nn as nn
8+
from torch.autograd import Variable
9+
from torchvision import datasets, models, transforms
10+
import os
11+
12+
13+
from torch.utils.data import dataloader, Dataset
14+
from PIL import Image
15+
16+
17+
def get_file_list(file_path_list, sort=True):
18+
"""
19+
Get list of file paths in one folder.
20+
:param file_path: A folder path or path list.
21+
:return: file list: File path list of
22+
"""
23+
import random
24+
if isinstance(file_path_list, str):
25+
file_path_list = [file_path_list]
26+
file_lists = []
27+
for file_path in file_path_list:
28+
assert os.path.isdir(file_path)
29+
file_list = os.listdir(file_path)
30+
if sort:
31+
file_list.sort()
32+
else:
33+
random.shuffle(file_list)
34+
file_list = [file_path + file for file in file_list]
35+
file_lists.append(file_list)
36+
if len(file_lists) == 1:
37+
file_lists = file_lists[0]
38+
return file_lists
39+
40+
41+
class Gallery(Dataset):
42+
"""
43+
Images in database.
44+
"""
45+
46+
def __init__(self, image_paths, transform=None):
47+
super().__init__()
48+
49+
self.image_paths = image_paths
50+
self.transform = transform
51+
52+
def __getitem__(self, index):
53+
image_path = self.image_paths[index]
54+
image = Image.open(image_path).convert('RGB')
55+
56+
if self.transform is not None:
57+
image = self.transform(image)
58+
59+
return image, image_path
60+
61+
def __len__(self):
62+
return len(self.image_paths)
63+
64+
65+
def load_data(data_path, batch_size=1, shuffle=False, transform='default'):
66+
data_transform = transforms.Compose([
67+
transforms.Resize(256),
68+
transforms.CenterCrop(224),
69+
transforms.ToTensor(),
70+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
71+
]) if transform == 'default' else transform
72+
73+
image_path_list = get_file_list(data_path)
74+
75+
gallery_data = Gallery(image_paths=image_path_list,
76+
transform=data_transform,
77+
)
78+
79+
data_loader = dataloader.DataLoader(dataset=gallery_data,
80+
batch_size=batch_size,
81+
shuffle=shuffle,
82+
num_workers=0,
83+
)
84+
return data_loader
85+
86+
87+
def extract_feature(model, dataloaders, use_gpu=True):
88+
features = torch.FloatTensor()
89+
path_list = []
90+
91+
use_gpu = use_gpu and torch.cuda.is_available()
92+
for img, path in dataloaders:
93+
img = img.cuda() if use_gpu else img
94+
input_img = Variable(img.cuda())
95+
outputs = model(input_img)
96+
ff = outputs.data.cpu()
97+
# norm feature
98+
fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
99+
ff = ff.div(fnorm.expand_as(ff))
100+
features = torch.cat((features, ff), 0)
101+
path_list += list(path)
102+
return features, path_list
103+
104+
105+
def extract_feature_query(model, img, use_gpu=True):
106+
c, h, w = img.size()
107+
img = img.view(-1,c,h,w)
108+
use_gpu = use_gpu and torch.cuda.is_available()
109+
img = img.cuda() if use_gpu else img
110+
input_img = Variable(img)
111+
outputs = model(input_img)
112+
ff = outputs.data.cpu()
113+
fnorm = torch.norm(ff,p=2,dim=1, keepdim=True)
114+
ff = ff.div(fnorm.expand_as(ff))
115+
return ff
116+
117+
118+
def load_query_image(query_path):
119+
data_transforms = transforms.Compose([
120+
transforms.Resize(256),
121+
transforms.CenterCrop(224),
122+
transforms.ToTensor(),
123+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
124+
])
125+
query_image = datasets.folder.default_loader(query_path)
126+
query_image = data_transforms(query_image)
127+
return query_image
128+
129+
130+
def load_model(pretrained_model=None, use_gpu=True):
131+
"""
132+
133+
:param check_point: Pretrained model path.
134+
:return:
135+
"""
136+
model = models.resnet50(pretrained=False)
137+
num_ftrs = model.fc.in_features
138+
add_block = []
139+
add_block += [nn.Linear(num_ftrs, 30)] #number of training classes
140+
model.fc = nn.Sequential(*add_block)
141+
model.load_state_dict(torch.load(pretrained_model))
142+
143+
# remove the final fc layer
144+
model.fc = nn.Sequential()
145+
# change to test modal
146+
model = model.eval()
147+
use_gpu = use_gpu and torch.cuda.is_available()
148+
if use_gpu:
149+
model = model.cuda()
150+
return model
151+
152+
153+
# sort the images
154+
def sort_img(qf, gf):
155+
score = gf*qf
156+
score = score.sum(1)
157+
# predict index
158+
s, index = score.sort(dim=0, descending=True)
159+
s = s.cpu().data.numpy()
160+
import numpy as np
161+
s = np.around(s, 3)
162+
return s, index
163+
164+
165+
if __name__ == '__main__':
166+
167+
# Prepare data.
168+
data_loader = load_data(data_path='./test_pytorch/gallery/images/',
169+
batch_size=2,
170+
shuffle=False,
171+
transform='default',
172+
)
173+
174+
# Prepare model.
175+
model = load_model(pretrained_model='./model/ft_ResNet50/net_best.pth', use_gpu=True)
176+
177+
# Extract database features.
178+
gallery_feature, image_paths = extract_feature(model=model, dataloaders=data_loader)
179+
180+
# Query.
181+
query_image = load_query_image('./test_pytorch/query/query.jpg')
182+
183+
# Extract query features.
184+
query_feature = extract_feature_query(model=model, img=query_image)
185+
186+
# Sort.
187+
similarity, index = sort_img(query_feature, gallery_feature)
188+
189+
sorted_paths = [image_paths[i] for i in index]
190+
print(sorted_paths)
191+

static/image_database/10.jpg

55 KB
Loading

static/image_database/11.jpg

71.2 KB
Loading

static/image_database/2.jpg

68.9 KB
Loading

static/image_database/3.jpg

66 KB
Loading

static/image_database/4.jpg

82.6 KB
Loading

static/image_database/5.jpg

105 KB
Loading

static/image_database/6.jpg

73.5 KB
Loading

static/image_database/7.jpg

70.4 KB
Loading

static/image_database/8.jpg

93.5 KB
Loading

static/image_database/9.jpg

50.7 KB
Loading

static/upload_image/query.jpg

52.8 KB
Loading

templates/retrieval.html

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
<!DOCTYPE html>
2+
<html lang="en">
3+
<head>
4+
<meta charset="UTF-8">
5+
<title>Image Retrieval</title>
6+
</head>
7+
<body>
8+
<div style="padding-left: 60px;">
9+
<h1>Image Retrieval</h1>
10+
11+
<div style="width:450px; float:left; background:#FFFFFF">
12+
<form action="" enctype='multipart/form-data' method='POST'>
13+
<input type="file" name="picture" style="margin-top:20px;"/>
14+
<br>
15+
<i>Upload your picture, please click here: </i>
16+
<input type="submit" value="upload" name="submit" class="button-new" style="margin-top:15px;"/>
17+
<br>
18+
<i>Start retrieval, please click here: </i>
19+
<input type="submit" value="retrieval" name="submit" class="button-new" style="margin-top:15px;"/>
20+
</form>
21+
<h3>{{message}}</h3>
22+
<img src="{{ url_for('static', filename= './upload_image/query.jpg') }}" width="400" />
23+
</div>
24+
25+
<div style=" margin-left:450px; background:#FFFFFF">
26+
<h3>Retrieval result:</h3>
27+
<h4>Similarity of top 3: {{sml1}}, {{sml2}}, {{sml3}}</h4>
28+
<div>
29+
<img src="{{img1_tmb}}" height="200px" overflow=hidden/>
30+
<img src="{{img2_tmb}}" height="200px" overflow=hidden/>
31+
<img src="{{img3_tmb}}" height="200px" overflow=hidden/>
32+
</div>
33+
<h4>Similarity of 4 ~ 6: {{sml4}}, {{sml5}}, {{sml6}}</h4>
34+
<div>
35+
<img src="{{img4_tmb}}" height="200px" overflow=hidden/>
36+
<img src="{{img5_tmb}}" height="200px" overflow=hidden/>
37+
<img src="{{img6_tmb}}" height="200px" overflow=hidden/>
38+
</div>
39+
</div>
40+
</div>
41+
</body>
42+
</html>

templates/upload.html

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
<!DOCTYPE html>
2+
<html lang="en">
3+
<head>
4+
<meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
5+
<title>Image Retrieval</title>
6+
</head>
7+
<body>
8+
<div style="padding-left: 60px;">
9+
<h1>Image Retrieval</h1>
10+
<form action="" enctype='multipart/form-data' method='POST'>
11+
<input type="file" name="picture" style="margin-top:20px;"/>
12+
<br>
13+
<i>Upload your picture, please click here: </i>
14+
<input type="submit" value="upload" name="submit" class="button-new" style="margin-top:15px;"/>
15+
<br>
16+
<i>Start retrieval, please click here: </i>
17+
<input type="submit" value="retrieval" name="submit" class="button-new" style="margin-top:15px;"/>
18+
</form>
19+
</div>
20+
</body>
21+
</html>

0 commit comments

Comments
 (0)