Skip to content

Commit 199bbe8

Browse files
author
Joshua Mayanja
committed
New Updates
1 parent 6b2dfc0 commit 199bbe8

19 files changed

+253
-32
lines changed

.env

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
ASTRA_DB_CLIENT_ID=oorvSWmoNPLetRKkQytSrMzm
2+
ASTRA_DB_CLIENT_SECRET=SR9g6xHuSmI5S+qPsNWiODwO1JiWZaaCCFU2pkP+MGzjt4HZ-0ujoo2to5dAB01sMfM_PCTg9OMqz8-fsEHgLYa.-XfUW-mLZiHlxkPuPF+-PIw+HpWyqIgt3YU2p-gq
3+
ASTRA_DB_CLIENT_TOKEN=AstraCS:oorvSWmoNPLetRKkQytSrMzm:3a16792c6faa9df4b2e38a2733ecc98cdef97bbab3892c9fccc7dae591332701

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
env/
22
spam-classifier/
33
zips/
4+
ignored/
45
*.pkl
56

app/__pycache__/config.cpython-37.pyc

882 Bytes
Binary file not shown.

app/__pycache__/db.cpython-37.pyc

1.01 KB
Binary file not shown.
838 Bytes
Binary file not shown.

app/__pycache__/main.cpython-37.pyc

1.12 KB
Binary file not shown.

app/__pycache__/ml.cpython-37.pyc

3.78 KB
Binary file not shown.

app/__pycache__/models.cpython-37.pyc

593 Bytes
Binary file not shown.

app/__pycache__/schema.cpython-37.pyc

333 Bytes
Binary file not shown.

app/config.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from functools import lru_cache
2+
import os
3+
4+
from pydantic import BaseSettings, Field
5+
6+
os.environ['CQLENG_ALLOW_SCHEMA_MANAGEMENT'] = '1'
7+
class Settings(BaseSettings):
8+
db_client_id: str = Field(..., env="ASTRA_DB_CLIENT_ID")
9+
db_client_secret: str = Field(..., env="ASTRA_DB_CLIENT_SECRET")
10+
11+
class Config:
12+
env_file = '.env'
13+
14+
@lru_cache(maxsize=None)
15+
def get_settings():
16+
return Settings()

app/db.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pathlib
2+
from cassandra.cluster import Cluster
3+
from cassandra.auth import PlainTextAuthProvider
4+
from cassandra.cqlengine import connection
5+
6+
from app.config import get_settings
7+
8+
9+
BASE_DIR = pathlib.Path(__file__).resolve().parent
10+
CLUSTER_BUNDLE = str(BASE_DIR / 'ignored' / 'astradb_connect.zip')
11+
12+
settings = get_settings()
13+
14+
ASTRA_DB_CLIENT_ID = settings.db_client_id
15+
ASTRA_DB_CLIENT_SECRET = settings.db_client_secret
16+
17+
def get_cluster():
18+
cloud_config= {
19+
'secure_connect_bundle': CLUSTER_BUNDLE
20+
}
21+
auth_provider = PlainTextAuthProvider(ASTRA_DB_CLIENT_ID, ASTRA_DB_CLIENT_SECRET)
22+
cluster = Cluster(cloud=cloud_config, auth_provider=auth_provider)
23+
return cluster
24+
25+
def get_session():
26+
cluster = get_cluster()
27+
session = cluster.connect()
28+
connection.register_connection(str(session), session=session)
29+
connection.set_default_connection(str(session))
30+
return session

app/encoders.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import json
2+
import numpy as np
3+
4+
class NumpyEncoder(json.JSONEncoder):
5+
"""Special JSONEncoder for Numpy types."""
6+
def default(self, obj):
7+
if isinstance(obj, np.integer):
8+
return int(obj)
9+
elif isinstance(obj, np.floating):
10+
return float(obj)
11+
elif isinstance(obj, np.ndarray):
12+
return obj.tolist()
13+
return json.JSONEncoder.default(self, obj)
14+
15+
16+
def encode_to_json(data, as_py=True):
17+
encoded = json.dumps(data, cls=NumpyEncoder)
18+
if as_py:
19+
encoded = json.loads(encoded)
20+
return encoded

app/main.py

+64-30
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
1-
import json
2-
from multiprocessing.spawn import spawn_main
31
import pathlib
42
from typing import Optional
53
from fastapi import FastAPI
6-
from keras.models import load_model
7-
from keras_preprocessing.text import tokenizer_from_json
8-
from keras_preprocessing.sequence import pad_sequences
4+
from fastapi.responses import StreamingResponse
5+
from app.config import get_settings
6+
from app.db import get_session
7+
8+
from app.ml import SpamModel
9+
from app.models import SpamInference
10+
11+
from cassandra.cqlengine.management import sync_table
12+
from cassandra.query import SimpleStatement
13+
14+
from app.schema import Query
915

1016
app = FastAPI(
1117
version="1.0.0",
1218
title="DrexSpam",
1319
description="An Artificial Intelligence based Spam detector API using machine learning",
1420
)
1521

22+
SETTINGS = get_settings()
23+
1624
BASE_DIR = pathlib.Path(__file__).resolve().parent
1725

1826
MODEL_DIR = BASE_DIR.parent / "models"
@@ -22,34 +30,60 @@
2230
SPAM_METADATA_PATH = SPAM_MODEL_DIR / "spam-classifer-metadata.json"
2331

2432
SPAM_MODEL = None
25-
SPAM_TOKENIZER = None
26-
SPAM_METADATA = {}
27-
LEGEND_INVERTED = {}
33+
DB_SESSION = None
34+
SPAM_INFERENCE = SpamInference
2835

2936
@app.on_event("startup")
3037
def on_startup():
31-
global SPAM_MODEL, SPAM_TOKENIZER, SPAM_METADATA, LEGEND_INVERTED
32-
# Load model
33-
if SPAM_MODEL_PATH.exists():
34-
SPAM_MODEL = load_model(SPAM_MODEL_PATH)
35-
if SPAM_TOKENIZER_PATH.exists():
36-
t_json = SPAM_TOKENIZER_PATH.read_text()
37-
SPAM_TOKENIZER = tokenizer_from_json(t_json)
38-
if SPAM_METADATA_PATH.exists():
39-
SPAM_METADATA = json.loads(SPAM_METADATA_PATH.read_text())
40-
LEGEND_INVERTED = SPAM_METADATA["labels_legend_inverted"]
41-
42-
def predict(query: str):
43-
sequences = SPAM_TOKENIZER.texts_to_sequences([query])
44-
maxlen = SPAM_METADATA.get("max_sequence") or 280
45-
x_input = pad_sequences(sequences, maxlen=280)
46-
preds_array = SPAM_MODEL.predict(x_input)
47-
return {}
48-
38+
global SPAM_MODEL, DB_SESSION
39+
SPAM_MODEL = SpamModel(
40+
model_path = SPAM_MODEL_PATH,
41+
tokenizer_path = SPAM_TOKENIZER_PATH,
42+
metadata_path= SPAM_METADATA_PATH,
43+
)
44+
DB_SESSION = get_session()
45+
sync_table(SPAM_INFERENCE)
4946

5047
@app.get("/")
5148
def read_index(q: Optional[str] = None):
52-
global SPAM_MODEL, SPAM_METADATA
53-
query = q or "Hello world"
54-
print(SPAM_MODEL)
55-
return {"query": query, **SPAM_METADATA}
49+
return {"hello": "world"}
50+
51+
@app.post("/")
52+
def create_infercence(q: Query):
53+
global SPAM_MODEL
54+
query = q.query or "Hello world"
55+
preds_dict = SPAM_MODEL.predict_text(query)
56+
top = preds_dict.get("top")
57+
data = {"query": query, **top}
58+
obj = SPAM_INFERENCE.objects.create(**data)
59+
return obj
60+
61+
@app.get("/inferences")
62+
def get_inferences():
63+
q = SPAM_INFERENCE.objects.all()
64+
return list(q)
65+
66+
@app.get("/inferences/{my_uuid}")
67+
def get_inference_detail(my_uuid):
68+
obj = SPAM_INFERENCE.objects.get(uuid=my_uuid)
69+
return obj
70+
71+
def fetch_row(statement: SimpleStatement, fetch_size: int, session=None):
72+
statement.fetch_size = fetch_size
73+
result_set = session.execute(statement)
74+
has_pages = result_set.has_more_pages
75+
yield "uuid,label,confidence,query,model_version\n"
76+
while has_pages:
77+
for row in result_set.current_rows:
78+
yield f"{row['uuid']},{row['label']},{row['confidence']},{row['query']},{row['model_version']}\n"
79+
has_pages = result_set.has_more_pages
80+
result_set = session.execute(statement, paging_state=result_set.paging_state)
81+
82+
@app.get("/dataset")
83+
def export_inferences():
84+
global DB_SESSION
85+
cql_query = "SELECT * FROM spam_inferences.spam_inference LIMIT 10000"
86+
# rows = DB_SESSION.execute(cql_query)
87+
statement = SimpleStatement(cql_query)
88+
return StreamingResponse(fetch_row(statement, 25, DB_SESSION))
89+

app/ml.py

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from dataclasses import dataclass
2+
import json
3+
import numpy as np
4+
from pathlib import Path
5+
from typing import Any, List, Optional
6+
from importlib_metadata import metadata
7+
from keras.models import load_model
8+
from keras_preprocessing.sequence import pad_sequences
9+
from keras_preprocessing.text import tokenizer_from_json
10+
11+
from app.encoders import NumpyEncoder, encode_to_json
12+
13+
14+
@dataclass
15+
class SpamModel:
16+
"""Drex Machine Learning Spam Classifier Model"""
17+
18+
model_path: Path
19+
metadata_path: Optional[Path] = None
20+
tokenizer_path: Optional[Path] = None
21+
22+
model = None
23+
tokenizer = None
24+
metadata = None
25+
26+
def __post_init__(self):
27+
if self.model_path.exists():
28+
self.model = load_model(self.model_path)
29+
if self.tokenizer_path:
30+
if self.tokenizer_path.exists():
31+
if self.is_json(self.tokenizer_path):
32+
tokenizer_text = self.tokenizer_path.read_text()
33+
self.tokenizer = tokenizer_from_json(tokenizer_text)
34+
if self.metadata_path:
35+
if self.metadata_path.exists():
36+
if self.is_json(self.metadata_path):
37+
metadata_text = self.metadata_path.read_text()
38+
self.metadata = json.loads(metadata_text)
39+
40+
def get_model(self):
41+
if not self.model:
42+
raise Exception("Model not loaded")
43+
return self.model
44+
45+
def get_tokenizer(self):
46+
if not self.tokenizer:
47+
raise Exception("Tokenizer not loaded")
48+
return self.tokenizer
49+
50+
def get_metadata(self):
51+
if not self.metadata:
52+
raise Exception("Metadata not loaded")
53+
return self.metadata
54+
55+
def get_sequences_from_text(self, texts: List[str]):
56+
tokenizer = self.get_tokenizer()
57+
sequences = tokenizer.texts_to_sequences(texts)
58+
return sequences
59+
60+
def get_input_from_sequences(self, sequences: List[Any]):
61+
metadata = self.get_metadata()
62+
maxlen = metadata.get("max_sequence") or 280
63+
x_input = pad_sequences(sequences, maxlen)
64+
return x_input
65+
66+
def get_label_legend_inverted(self):
67+
metadata = self.get_metadata()
68+
legend = metadata.get("labels_legend_inverted") or {}
69+
if len(legend.keys()) != 2:
70+
raise Exception("Legend invalid")
71+
return legend
72+
73+
def get_label_pred(self, index: int, val):
74+
label_legend_inverted = self.get_label_legend_inverted()
75+
labeled_pred = {
76+
"label": label_legend_inverted[str(index)],
77+
"confidence": val,
78+
}
79+
return labeled_pred
80+
81+
def get_top_label_pred(self, preds):
82+
top_index = np.argmax(preds)
83+
top_pred = self.get_label_pred(top_index, preds[top_index])
84+
return top_pred
85+
86+
def is_json(self, path: Path):
87+
if path.name.endswith(".json"):
88+
return True
89+
return False
90+
91+
def predict_text(self, query: str, include_top=True, encode_json=True):
92+
model = self.get_model()
93+
sequences = self.get_sequences_from_text([query])
94+
x_input = self.get_input_from_sequences(sequences)
95+
preds_array = model.predict(x_input)[0]
96+
preds = [self.get_label_pred(i, x) for i, x in enumerate(list(preds_array))]
97+
results = {"predictions": preds}
98+
if include_top:
99+
results["top"] = self.get_top_label_pred(preds_array)
100+
if encode_json:
101+
results = encode_to_json(results)
102+
return results

app/models.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import uuid
2+
from cassandra.cqlengine import columns
3+
from cassandra.cqlengine.models import Model
4+
5+
class SpamInference(Model):
6+
__keyspace__ = "spam_inferences"
7+
uuid = columns.UUID(primary_key=True, default=uuid.uuid1)
8+
query = columns.Text()
9+
label = columns.Text()
10+
confidence = columns.Float()
11+
model_version = columns.Text(default="v1")

app/schema.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from pydantic import BaseModel
2+
3+
class Query(BaseModel):
4+
query: str

models/spam/spam-classifer-metadata.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
"ham": 0,
88
"spam": 1
99
},
10-
"max_sequence": 300,
10+
"max_sequence": 280,
1111
"max_words": 280
1212
}

models/spam/spam-classifer-tokenizer.json

+1-1
Large diffs are not rendered by default.

models/spam/spam-model.h5

0 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)