1+ # import required libraries
2+ import tensorflow as tf
3+ import streamlit as st
4+ import cv2
5+ from PIL import Image , ImageOps
6+ import numpy as np
7+ # load the model
8+ model = tf .keras .models .load_model ("sortit_model(v4).hdf5" )
9+ # header
10+ st .write (
11+ """
12+ # SortIt Image Classifier
13+ """
14+
15+ )
16+ st .write (
17+ """
18+ This app classifies images of cats, cars, dogs, bicycles, and motorcycles
19+ """
20+ )
21+ # upload an image for prediction
22+ file = st .file_uploader ("Please upload an image file" , type = ["jpg" ,"png" ])
23+
24+ # function that imports image resizes it, and runs prediction
25+ def import_and_predict (img_data ,model ):
26+ # specify the image size
27+ size = (150 ,150 )
28+ # import the image and resize it
29+ image = ImageOps .fit (img_data ,size ,Image .ANTIALIAS )
30+ # convert image to numpy array
31+ image = np .asarray (image )
32+ # convert the image from bgr to rgb
33+ # img = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
34+ # resize the numpy array
35+ img_resize = (cv2 .resize (image ,dsize = (75 ,75 ),interpolation = cv2 .INTER_CUBIC ))/ 255
36+ # reshape the image
37+ img_reshape = img_resize [np .newaxis ,...]
38+ # predictions = np.argmax(model.predict(img_reshape), axis=-1)
39+ # run prediction on the image
40+ prediction = model .predict (img_reshape )
41+ return prediction
42+
43+ if file is None :
44+ st .text ("Please upload an image file" )
45+ else :
46+ #open the image and convert it to rgb
47+ image = Image .open (file ).convert ('RGB' )
48+ st .image (image ,use_column_width = True )
49+ prediction = import_and_predict (image ,model )
50+ # print the prediction
51+ st .write (prediction )
0 commit comments