-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_model.py
62 lines (40 loc) · 1.65 KB
/
load_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# import tensorflow as tf
import tflite_runtime.interpreter as tflite
import numpy as np
from pathlib import Path
from PIL import Image
from urllib import request
MODEL_URL = 'https://github.com/yeguacelestial/face-mask-detection-api/raw/main/mask_classifier.tflite'
mask_classifier_file = Path('mask_classifier.tflite')
if mask_classifier_file.exists():
pass
else:
print("[*] Downloading model...")
request.urlretrieve(MODEL_URL, 'mask_classifier.tflite')
print("[+] Done.")
# Load TFLite model and allocate tensors
interpreter = tflite.Interpreter(model_path='mask_classifier.tflite')
interpreter.allocate_tensors()
# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
floating_model = input_details[0]['dtype'] == np.float32
height = input_details[0]['shape'][1]
width = input_details[0]['shape'][2]
def predict_mask_on_img(img_path):
img = Image.open(img_path).resize((width, height))
input_data = np.expand_dims(img, axis=0)
if floating_model:
input_data = (np.float32(input_data) - 127.5) / 127.5
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
results = np.squeeze(output_data)
scalar_result = int(round(np.asscalar(results)))
print(f"\n[***] OUTPUT: {output_data}")
if scalar_result:
print(f"[-] No se detectó una mascara en {img_path}: {scalar_result}")
return False, scalar_result
else:
print(f"[+] Se detecto una mascara en {img_path}: {scalar_result}")
return True, scalar_result