-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
103 additions
and
210 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
{ | ||
"ui":{ | ||
"title":"Corgi butt or loaf of bread?", | ||
"icon":"assets/icon/corgi-icon.png", | ||
"markdown":{ | ||
"star":"[](https://github.com/Kawaeee/butt_or_bread)", | ||
"release":"[](https://github.com/Kawaeee/butt_or_bread/releases/tag/v1.1)", | ||
"visitor":"" | ||
}, | ||
"mode":{ | ||
"upload":{ | ||
"main_label":"Upload an image" | ||
}, | ||
"select":{ | ||
"main_label":"Select pre-configured image", | ||
"class_label":"Pick a labels:", | ||
"corgi_label":"Pick your favorite corgi butt image 🐕:", | ||
"bread_label":"Pick your favorite loaf of bread image 🍞:" | ||
} | ||
} | ||
}, | ||
"model":{ | ||
"url":"https://github.com/Kawaeee/butt_or_bread/releases/download/v1.1/buttbread_resnet152_3.h5", | ||
"label":{ | ||
"corgi":"Corgi butt 🐕", | ||
"bread":"Loaf of bread 🍞" | ||
} | ||
}, | ||
"image":{ | ||
"base_path":"assets/images/", | ||
"corgi":{ | ||
"A loaf of corgi":"corgi/corgi_1.jpg", | ||
"Corgi butt pressed against window":"corgi/corgi_2.jpg", | ||
"Corgi butt wearing a glasses":"corgi/corgi_3.jpg", | ||
"Thicc corgi butt post":"corgi/corgi_4.jpg", | ||
"Cute corgi butt walking outdoor":"corgi/corgi_5.jpg" | ||
}, | ||
"bread":{ | ||
"A close up of a corgi butt bread":"bread/bread_1.jpg", | ||
"A loaf of bread on the wooden table":"bread/bread_2.jpg", | ||
"Big loaf of bread":"bread/bread_3.jpg", | ||
"Burnt version of corgi butt bread":"bread/bread_4.jpg", | ||
"Corgi butt bun":"bread/bread_5.jpg" | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,248 +1,95 @@ | ||
import streamlit as st | ||
from streamlit.logger import get_logger | ||
|
||
import gc | ||
import time | ||
import json | ||
import os | ||
import requests | ||
import psutil | ||
|
||
import streamlit as st | ||
from streamlit.logger import get_logger | ||
from PIL import Image | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.nn import functional as F | ||
from torchvision import models, transforms | ||
|
||
os.environ["LRU_CACHE_CAPACITY"] = "1" | ||
from butt_or_bread.core import ButtBreadClassifier | ||
from butt_or_bread.utils import health_check | ||
|
||
# Create Streamlit logger | ||
st_logger = get_logger(__name__) | ||
|
||
st.set_option("deprecation.showfileUploaderEncoding", False) | ||
|
||
st.set_page_config( | ||
layout="centered", | ||
page_title="Corgi butt or loaf of bread?", | ||
page_icon="icon/corgi-icon.png", | ||
) | ||
# Load Streamlit configuration file | ||
with open("streamlit_app.json") as cfg_file: | ||
st_app_cfg = json.load(cfg_file) | ||
|
||
# Markdown | ||
repo = "[](https://github.com/Kawaeee/butt_or_bread)" | ||
version = "[](https://github.com/Kawaeee/butt_or_bread/releases/tag/v1.1)" | ||
follower = "[](https://github.com/Kawaeee)" | ||
visitor = "" | ||
|
||
model_url_path = "https://github.com/Kawaeee/butt_or_bread/releases/download/v1.1/buttbread_resnet152_3.h5" | ||
|
||
# Test images | ||
test_images_path = "test_images" | ||
labels = ["Corgi butt 🐕", "Loaf of bread 🍞"] | ||
|
||
corgi_images_file = [ | ||
"corgi_1.jpg", | ||
"corgi_2.jpg", | ||
"corgi_3.jpg", | ||
"corgi_4.jpg", | ||
"corgi_5.jpg", | ||
] | ||
|
||
corgi_images_name = [ | ||
"A loaf of corgi", | ||
"Corgi butt pressed against window", | ||
"Corgi butt wearing a glasses", | ||
"Thicc corgi butt post", | ||
"Cute corgi butt walking outdoor", | ||
] | ||
|
||
corgi_images_dict = { | ||
name: os.path.join(test_images_path, c_file) | ||
for name, c_file in zip(corgi_images_name, corgi_images_file) | ||
} | ||
|
||
bread_images_file = [ | ||
"bread_1.jpg", | ||
"bread_2.jpg", | ||
"bread_3.jpg", | ||
"bread_4.jpg", | ||
"bread_5.jpg", | ||
] | ||
|
||
bread_images_name = [ | ||
"A close up of a corgi butt bread", | ||
"A loaf of bread on the wooden table", | ||
"Big loaf of bread", | ||
"Burnt version of corgi butt bread", | ||
"Corgi butt bun", | ||
] | ||
|
||
bread_images_dict = { | ||
name: os.path.join(test_images_path, b_file) | ||
for name, b_file in zip(bread_images_name, bread_images_file) | ||
} | ||
|
||
# Model configuration | ||
# Streamlit server does not provide GPU, So we will go with CPU! | ||
processing_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
img_normalizer = transforms.Normalize( | ||
mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225], | ||
) | ||
|
||
img_transformer = transforms.Compose( | ||
[ | ||
transforms.Resize((224, 224)), | ||
transforms.ToTensor(), | ||
img_normalizer, | ||
] | ||
ui_cfg = st_app_cfg["ui"] | ||
model_cfg = st_app_cfg["model"] | ||
image_cfg = st_app_cfg["image"] | ||
|
||
st.set_page_config( | ||
layout="centered", | ||
page_title=ui_cfg["title"], | ||
page_icon=ui_cfg["icon"], | ||
) | ||
|
||
|
||
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=3, ttl=300) | ||
def initialize_model(device=processing_device): | ||
"""Retrieves the butt_bread trained model and maps it to the CPU by default, can also specify GPU here.""" | ||
|
||
model = models.resnet152(pretrained=False).to(device) | ||
model.fc = nn.Sequential( | ||
nn.Linear(2048, 128), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(128, 2), | ||
).to(device) | ||
|
||
model.load_state_dict(torch.load("buttbread_resnet152_3.h5", map_location=device)) | ||
model.eval() | ||
|
||
return model | ||
|
||
|
||
def predict(img, model): | ||
"""Make a prediction on a single image""" | ||
def get_classifier(): | ||
"""Allow butt_bread model caching""" | ||
classifier = ButtBreadClassifier(model_url=model_cfg["url"]) | ||
classifier.download() | ||
classifier.initialize() | ||
|
||
input_img = img_transformer(img).float() | ||
input_img = input_img.unsqueeze(0) | ||
|
||
pred_logits_tensor = model(input_img) | ||
pred_probs = F.softmax(pred_logits_tensor, dim=1).cpu().data.numpy() | ||
|
||
bread_prob = pred_probs[0][0] | ||
butt_prob = pred_probs[0][1] | ||
|
||
json_output = { | ||
"name": img.filename, | ||
"format": img.format, | ||
"mode": img.mode, | ||
"width": img.width, | ||
"height": img.height, | ||
"prediction": { | ||
"labels": { | ||
"Corgi butt 🐕": "{:.3%}".format(float(butt_prob)), | ||
"Loaf of bread 🍞": "{:.3%}".format(float(bread_prob)), | ||
} | ||
}, | ||
} | ||
|
||
input_img = None | ||
pred_logit_tensor = None | ||
pred_probs = None | ||
|
||
return json_output | ||
|
||
|
||
def download_model(): | ||
"""Download model weight, if model does not exist in Streamlit server.""" | ||
|
||
if os.path.isfile("buttbread_resnet152_3.h5") is False: | ||
print("Downloading butt_bread model !!") | ||
req = requests.get(model_url_path, allow_redirects=True) | ||
open("buttbread_resnet152_3.h5", "wb").write(req.content) | ||
req = None | ||
|
||
return True | ||
|
||
|
||
def health_check(): | ||
""" "Check CPU/Memory/Disk usage of deployed machine""" | ||
|
||
cpu_percent = psutil.cpu_percent(0.15) | ||
total_memory = psutil.virtual_memory().total / float(1 << 30) | ||
used_memory = psutil.virtual_memory().used / float(1 << 30) | ||
total_disk = psutil.disk_usage("/").total / float(1 << 30) | ||
used_disk = psutil.disk_usage("/").used / float(1 << 30) | ||
|
||
cpu_usage = "CPU Usage: {:.2f}%".format(cpu_percent) | ||
memory_usage = "Memory usage: {:,.2f}G/{:,.2f}G".format(used_memory, total_memory) | ||
disk_usage = "Disk usage: {:,.2f}G/{:,.2f}G".format(used_disk, total_disk) | ||
|
||
return " | ".join([cpu_usage, memory_usage, disk_usage]) | ||
return classifier | ||
|
||
|
||
if __name__ == "__main__": | ||
img_file = None | ||
img = None | ||
prediction = None | ||
|
||
download_model() | ||
model = initialize_model() | ||
image_file, image, prediction = None, None, None | ||
classifier = get_classifier() | ||
|
||
st_logger.info("[DEBUG] %s", health_check(), exc_info=0) | ||
st_logger.info("[INFO] Initialize %s model successfully", "buttbread_resnet152_3.h5", exc_info=0) | ||
|
||
st.title("Corgi butt or loaf of bread? 🐕🍞") | ||
st.markdown(version + " " + repo + " " + visitor + " " + follower, unsafe_allow_html=True) | ||
st.title(ui_cfg["title"]) | ||
st.markdown(f'{ui_cfg["markdown"]["release"]} {ui_cfg["markdown"]["star"]} {ui_cfg["markdown"]["visitor"]}', unsafe_allow_html=True) | ||
|
||
processing_mode = st.radio("", ("Upload an image", "Select pre-configured image")) | ||
mode = st.radio("", [ui_cfg["mode"]["upload"]["main_label"], ui_cfg["mode"]["select"]["main_label"]]) | ||
|
||
if processing_mode == "Upload an image": | ||
img_file = st.file_uploader("Upload an image", accept_multiple_files=False) | ||
elif processing_mode == "Select pre-configured image": | ||
img_labels = st.selectbox("Pick a labels:", labels) | ||
if mode == ui_cfg["mode"]["upload"]["main_label"]: | ||
image_file = st.file_uploader(mode, accept_multiple_files=False) | ||
elif mode == ui_cfg["mode"]["select"]["main_label"]: | ||
class_label = st.selectbox(ui_cfg["mode"]["select"]["class_label"], model_cfg["label"].values()) | ||
|
||
if img_labels == labels[0]: | ||
corgi_list = st.selectbox("Pick your favorite corgi butt image 🐕:", corgi_images_name) | ||
img_file = corgi_images_dict[corgi_list] | ||
elif img_labels == labels[1]: | ||
bread_list = st.selectbox("Pick your favorite loaf of bread image 🍞:", bread_images_name) | ||
img_file = bread_images_dict[bread_list] | ||
if class_label == model_cfg["label"]["corgi"]: | ||
image_label = st.selectbox(ui_cfg["mode"]["select"]["corgi_label"], [*image_cfg["corgi"]]) | ||
image_file = os.path.join(image_cfg["base_path"], image_cfg["corgi"][image_label]) | ||
elif class_label == model_cfg["label"]["bread"]: | ||
image_label = st.selectbox(ui_cfg["mode"]["select"]["bread_label"], [*image_cfg["bread"]]) | ||
image_file = os.path.join(image_cfg["base_path"], image_cfg["bread"][image_label]) | ||
|
||
if img_file: | ||
if image_file: | ||
try: | ||
img = Image.open(img_file) | ||
image = Image.open(image_file) | ||
|
||
if img.mode != "RGB": | ||
tmp_format = img.format | ||
img = img.convert("RGB") | ||
img.format = tmp_format | ||
if processing_mode == "Upload an image": | ||
img.filename = img_file.name | ||
elif processing_mode == "Select pre-configured image": | ||
img.filename = os.path.basename(img_file) | ||
if image.mode != "RGB": | ||
temporary_format = image.format | ||
image = image.convert("RGB") | ||
image.format = temporary_format | ||
|
||
prediction = predict(img, model) | ||
if mode == ui_cfg["mode"]["upload"]["main_label"]: | ||
image.filename = image_file.name | ||
elif mode == ui_cfg["mode"]["select"]["main_label"]: | ||
image.filename = os.path.basename(image_file) | ||
|
||
st_logger.info("[DEBUG] %s", health_check(), exc_info=0) | ||
st_logger.info("[INFO] Predict %s image successfully", img.filename, exc_info=0) | ||
prediction = classifier.predict(image) | ||
|
||
except Exception as e: | ||
st.error("ERROR: Unable to predict {} ({}) !!!".format(img_file.name, img_file.type)) | ||
st_logger.error("[ERROR] Unable to predict %s (%s) !!!", img_file.name, img_file.type, exc_info=0) | ||
st_logger.info("[DEBUG] %s", health_check(), exc_info=0) | ||
st_logger.info("[INFO] Predict %s image successfully", image.filename, exc_info=0) | ||
|
||
img_file = None | ||
img = None | ||
prediction = None | ||
except Exception as ex: | ||
st.error("ERROR: Unable to predict {} ({}) !!!".format(image_file.name, image_file.type)) | ||
st_logger.error("[ERROR] Unable to predict %s (%s) !!!", image_file.name, image_file.type, exc_info=0) | ||
image_file, image, prediction = None, None, None | ||
|
||
if img is not None or prediction is not None: | ||
if image is not None or prediction is not None: | ||
st.header("Here is the image you've chosen") | ||
resized_image = img.resize((400, 400)) | ||
resized_image = image.resize((400, 400)) | ||
st.image(resized_image) | ||
st.write("Prediction:") | ||
st.json(prediction) | ||
|
||
img = None | ||
resized_image = None | ||
prediction = None | ||
|
||
del model | ||
gc.collect() | ||
model = None |