-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
125 lines (106 loc) · 4.03 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import tensorflow as tf
from classifier import Classifier
from utils import getGpus, lookDeeperIfNeeded
import zipfile
import io
from os import makedirs
import sys
from config import prevent_model_update, mlflow_tracking_uri
from pydantic import BaseModel
import shutil
classifier = Classifier()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class MlflowModel(BaseModel):
name: str
version: str
@app.get("/")
async def root():
response = {
"application_name": "image classifier server",
"author": "Maxime MOREILLON, Shion ITO",
"version": "0.7.0",
"model_loaded": classifier.model_loaded,
'model_info': {**classifier.model_info},
"mlflow_tracking_uri": mlflow_tracking_uri,
'gpu': len(getGpus()),
'update_allowed': not prevent_model_update,
'versions': {
'python': sys.version,
'tensorflow': tf.__version__
}
}
if classifier.mlflow_model:
response["mlflow_model"] = {**classifier.mlflow_model}
return response
@app.post("/predict")
async def predict(image: bytes = File(), heatmap: bool = False):
if classifier.model_info['type'] != 'keras' and heatmap:
raise HTTPException(status_code=400, detail="Heatmap is NOT available. Please upload KERAS model, if you want to use Heatmap.")
result = await classifier.predict(image, heatmap)
return result
@app.post("/model")
async def upload_model(model: UploadFile = File(...)):
if prevent_model_update:
raise HTTPException(status_code=403, detail="Model update is forbidden")
# save model file according to file extension
if model.filename.endswith('.zip'):
# reset model folder
shutil.rmtree("./model", ignore_errors=True)
makedirs("./model", exist_ok=True)
with io.BytesIO(await model.read()) as tmp_stream, zipfile.ZipFile(tmp_stream, 'r') as zip_ref:
zip_ref.extractall("./model")
# unify folder structure when unzipping
lookDeeperIfNeeded('./model')
elif model.filename.endswith('.onnx'):
# reset model folder
shutil.rmtree("./model", ignore_errors=True)
makedirs("./model", exist_ok=True)
file_path = f'./model/{model.filename}'
with open(file_path, "wb") as buffer:
shutil.copyfileobj(model.file, buffer)
else:
raise HTTPException(status_code=400, detail="Invalid file type. Only .zip and .onnx files are accepted.")
# load model
classifier.load_model_from_local()
return classifier.model_info["type"]
# Proxying the MLflow REST API for the classifier server GUI
# TODO: Put those in a dedicated route
if mlflow_tracking_uri:
import mlflow
from mlflow import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
@app.get("/mlflow/models")
async def getMlflowModels(search: str='', page_token: str=""):
models = []
filter_string = f"name ILIKE '%{search}%'"
res = client.search_registered_models(filter_string=filter_string, page_token=page_token)
for model in res:
models.append(model)
return {"models": models, "page_token": res.token}
@app.get("/mlflow/models/{model}/versions")
async def getMlflowModelVersions(model):
versions = []
for version in client.search_model_versions(f"name='{model}'"):
versions.append(version)
return versions
@app.put("/model")
async def updateMlflowModel(mlflowModel: MlflowModel):
if prevent_model_update:
raise HTTPException(status_code=403, detail="Model update is forbidden")
model = mlflowModel.dict()["name"]
version = mlflowModel.dict()["version"]
classifier.load_model_from_mlflow(model, version)
return {
"model": model,
"version": version
}