-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from erikperkins/model
Model
- Loading branch information
Showing
8 changed files
with
323 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ | |
.idea/ | ||
*.iml | ||
venv/ | ||
mlruns/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ | |
*.iml | ||
venv/ | ||
__pycache__/ | ||
mlruns/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,44 @@ | ||
if __name__ == "__main__": | ||
while True: | ||
sleep(1) | ||
import sentry_sdk | ||
import json | ||
import logging | ||
from kafka import KafkaConsumer | ||
from kafka import KafkaProducer | ||
from model.tip import Model | ||
from model.tip import TripValidationError | ||
|
||
BOOTSTRAP_SERVER = 'kafka-service.kafka.svc.cluster.local:9092' | ||
TRIPS = 'trips' | ||
TIPS = 'tips' | ||
|
||
logger = logging.getLogger('mariotte') | ||
|
||
sentry_sdk.init( | ||
dsn = "https://[email protected]/5", | ||
traces_sample_rate = 1.0 | ||
) | ||
|
||
consumer = KafkaConsumer( | ||
bootstrap_servers = [BOOTSTRAP_SERVER], | ||
value_deserializer = lambda x: json.loads(x) | ||
) | ||
|
||
producer = KafkaProducer( | ||
bootstrap_servers = [BOOTSTRAP_SERVER], | ||
value_serializer = lambda x: json.dumps(x).encode('utf-8') | ||
) | ||
|
||
model = Model() | ||
|
||
if __name__ == "__main__": | ||
consumer.subscribe(TRIPS) | ||
for message in consumer: | ||
trip = message.value | ||
try: | ||
tip = model.predict(trip) | ||
trip['predicted_tip'] = tip | ||
producer.send(TIPS, trip) | ||
except TripValidationError as e: | ||
logger.warning("Skipping invalid record") | ||
continue | ||
except Exception as e: | ||
raise e |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from pandas import DataFrame | ||
from pandas import Timestamp | ||
from math import floor | ||
import mlflow | ||
import os | ||
|
||
|
||
MODEL_URI = "models:/GPUTipPipeline/Production" | ||
|
||
|
||
class TripValidationError(Exception): | ||
def __init__(self, message): | ||
super().__init__(message) | ||
|
||
class Model(): | ||
"""Predict tip from trip data.""" | ||
def __init__(self): | ||
mlflow.set_tracking_uri("https://mlflow.cauchy.link") | ||
self.regressor = mlflow.sklearn.load_model(MODEL_URI) | ||
|
||
def predict(self, message): | ||
"""Predict tip using loaded model.""" | ||
self.validate(message) | ||
|
||
data = DataFrame([message]).astype({ | ||
'pickup_datetime': 'datetime64[ns]', | ||
'dropoff_datetime': 'datetime64[ns]', | ||
'pickup_location_id': 'category', | ||
'dropoff_location_id': 'category', | ||
'payment_type': 'category' | ||
}) | ||
|
||
tip, = self.regressor.predict(data) | ||
return floor(100 * tip) / 100. | ||
|
||
def validate(self, message): | ||
""" | ||
Validate message structure. Ensure all expected features are present, | ||
and all types are correct. | ||
""" | ||
try: | ||
keys = set(message.keys()) | ||
features = set(self.regressor.feature_names_in_) | ||
assert features.issubset(keys) | ||
|
||
assert type(message['pickup_location_id']) in [int, type(None)] | ||
assert type(message['dropoff_location_id']) in [int, type(None)] | ||
assert type(message['payment_type']) in [int, type(None)] | ||
|
||
assert type(message['passenger_count']) in [float, type(None)] | ||
assert type(message['trip_distance']) in [float, type(None)] | ||
assert type(message['fare_amount']) in [float, type(None)] | ||
assert type(message['extra']) in [float, type(None)] | ||
assert type(message['mta_tax']) in [float, type(None)] | ||
assert type(message['tolls_amount']) in [float, type(None)] | ||
assert type(message['improvement_surcharge']) in [float, type(None)] | ||
assert type(message['congestion_surcharge']) in [float, type(None)] | ||
assert type(message['airport_fee']) in [float, type(None)] | ||
assert type(message['tip_amount']) in [float, type(None)] | ||
|
||
assert type(message['pickup_datetime']) in [str, type(None)] | ||
assert type(message['dropoff_datetime']) in [str, type(None)] | ||
|
||
Timestamp(message['pickup_datetime']) | ||
Timestamp(message['dropoff_datetime']) | ||
except Exception as e: | ||
raise TripValidationError(e) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
alembic==1.11.0 | ||
blinker==1.6.2 | ||
certifi==2023.5.7 | ||
charset-normalizer==3.1.0 | ||
click==8.1.3 | ||
cloudpickle==2.2.1 | ||
contourpy==1.0.7 | ||
cycler==0.11.0 | ||
databricks-cli==0.17.7 | ||
docker==6.1.2 | ||
entrypoints==0.4 | ||
Flask==2.3.2 | ||
fonttools==4.39.4 | ||
gitdb==4.0.10 | ||
GitPython==3.1.31 | ||
greenlet==2.0.2 | ||
gunicorn==20.1.0 | ||
idna==3.4 | ||
importlib-metadata==6.6.0 | ||
itsdangerous==2.1.2 | ||
Jinja2==3.1.2 | ||
joblib==1.2.0 | ||
kafka-python==2.0.2 | ||
kiwisolver==1.4.4 | ||
Mako==1.2.4 | ||
Markdown==3.4.3 | ||
MarkupSafe==2.1.2 | ||
matplotlib==3.7.1 | ||
mlflow==2.3.2 | ||
numpy==1.24.3 | ||
oauthlib==3.2.2 | ||
packaging==23.1 | ||
pandas==2.0.1 | ||
Pillow==9.5.0 | ||
protobuf==4.23.0 | ||
pyarrow==11.0.0 | ||
PyJWT==2.7.0 | ||
pyparsing==3.0.9 | ||
python-dateutil==2.8.2 | ||
pytz==2023.3 | ||
PyYAML==6.0 | ||
querystring-parser==1.2.4 | ||
requests==2.30.0 | ||
scikit-learn==1.2.2 | ||
scipy==1.10.1 | ||
sentry-sdk==1.23.0 | ||
six==1.16.0 | ||
smmap==5.0.0 | ||
SQLAlchemy==2.0.13 | ||
sqlparse==0.4.4 | ||
tabulate==0.9.0 | ||
threadpoolctl==3.1.0 | ||
typing_extensions==4.5.0 | ||
tzdata==2023.3 | ||
urllib3==1.26.15 | ||
websocket-client==1.5.1 | ||
Werkzeug==2.3.4 | ||
xgboost==1.7.5 | ||
zipp==3.15.0 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import unittest | ||
from unittest.mock import Mock | ||
from unittest.mock import patch | ||
from model.tip import Model | ||
from model.tip import TripValidationError | ||
from numpy import array | ||
|
||
class TestModel(unittest.TestCase): | ||
@patch('model.tip.mlflow', Mock()) | ||
def setUp(self): | ||
self.model = Model() | ||
self.model.regressor = Mock() | ||
self.model.regressor.feature_names_in_ = {'pickup_datetime': '2022-05-17 12:11:03'} | ||
self.model.regressor.predict.return_value = array([3.0700364]) | ||
|
||
self.message = { | ||
'pickup_datetime': '2022-05-17 12:11:03', | ||
'dropoff_datetime': '2022-05-17 12:30:10', | ||
'pickup_location_id': 237, | ||
'dropoff_location_id': 79, | ||
'passenger_count': 2.0, | ||
'trip_distance': 2.39, | ||
'payment_type': 1, | ||
'fare_amount': 13.5, | ||
'extra': 0.0, | ||
'mta_tax': 0.5, | ||
'tolls_amount': 0.0, | ||
'improvement_surcharge': 0.3, | ||
'congestion_surcharge': 2.5, | ||
'airport_fee': 0.0, | ||
'tip_amount': 3.36 | ||
} | ||
|
||
def tearDown(self): | ||
self.model.regressor.reset_mock() | ||
|
||
@patch('model.tip.mlflow.sklearn.load_model') | ||
def test_init_loads_model(self, mock_load_model): | ||
model = Model() | ||
mock_load_model.assert_called() | ||
|
||
def test_predict(self): | ||
tip = self.model.predict(self.message) | ||
self.assertEqual(tip, 3.07) | ||
|
||
with self.assertRaises(TripValidationError): | ||
bad_message = self.message.copy() | ||
del(bad_message['pickup_datetime']) | ||
self.model.predict(bad_message) | ||
|
||
def test_validate(self): | ||
try: | ||
self.model.validate(self.message) | ||
except TripValidationError: | ||
raise self.failureException('TripValidationError raised on valid input') | ||
|
||
try: | ||
incomplete_message = self.message.copy() | ||
incomplete_message['passenger_count'] = None | ||
self.model.validate(incomplete_message) | ||
except TripValidationError: | ||
raise self.failureException('TripValidationError raised on valid input') | ||
|
||
with self.assertRaises(TripValidationError, msg = 'missing covariate'): | ||
bad_message = self.message.copy() | ||
del(bad_message['pickup_datetime']) | ||
self.model.validate(bad_message) | ||
|
||
with self.assertRaises(TripValidationError, msg = 'invalid pickup_location_id'): | ||
bad_message = self.message.copy() | ||
bad_message['pickup_location_id'] = '1' | ||
self.model.validate(bad_message) | ||
|
||
with self.assertRaises(TripValidationError, msg = 'invalid dropoff_location_id'): | ||
bad_message = self.message.copy() | ||
bad_message['dropoff_location_id'] = '1' | ||
self.model.validate(bad_message) | ||
|
||
with self.assertRaises(TripValidationError, msg = 'invalid payment_type'): | ||
bad_message = self.message.copy() | ||
bad_message['payment_type'] = '1' | ||
self.model.validate(bad_message) | ||
|
||
with self.assertRaises(TripValidationError, msg = 'invalid passenger_count'): | ||
bad_message = self.message.copy() | ||
bad_message['passenger_count'] = 0 | ||
self.model.validate(bad_message) | ||
|
||
with self.assertRaises(TripValidationError, msg = 'invalid trip_distance'): | ||
bad_message = self.message.copy() | ||
bad_message['trip_distance'] = 0 | ||
self.model.validate(bad_message) | ||
|
||
with self.assertRaises(TripValidationError, msg = 'invalid fare_amount'): | ||
bad_message = self.message.copy() | ||
bad_message['fare_amount'] = 0 | ||
self.model.validate(bad_message) | ||
|
||
with self.assertRaises(TripValidationError, msg = 'invalid extra'): | ||
bad_message = self.message.copy() | ||
bad_message['extra'] = 0 | ||
self.model.validate(bad_message) | ||
|
||
with self.assertRaises(TripValidationError, msg = 'invalid mta_tax'): | ||
bad_message = self.message.copy() | ||
bad_message['mta_tax'] = 0 | ||
self.model.validate(bad_message) | ||
|
||
with self.assertRaises(TripValidationError, msg = 'invalid tolls_amount'): | ||
bad_message = self.message.copy() | ||
bad_message['tolls_amount'] = 0 | ||
self.model.validate(bad_message) | ||
|
||
with self.assertRaises(TripValidationError, msg = 'invalid improvement_surcharge'): | ||
bad_message = self.message.copy() | ||
bad_message['improvement_surcharge'] = 0 | ||
self.model.validate(bad_message) | ||
|
||
with self.assertRaises(TripValidationError, msg = 'invalid congestion_surcharge'): | ||
bad_message = self.message.copy() | ||
bad_message['congestion_surcharge'] = 0 | ||
self.model.validate(bad_message) | ||
|
||
with self.assertRaises(TripValidationError, msg = 'invalid airport_fee'): | ||
bad_message = self.message.copy() | ||
bad_message['airport_fee'] = 0 | ||
self.model.validate(bad_message) | ||
|
||
with self.assertRaises(TripValidationError, msg = 'invalid tip_amount'): | ||
bad_message = self.message.copy() | ||
bad_message['tip_amount'] = 0 | ||
self.model.validate(bad_message) | ||
|
||
with self.assertRaises(TripValidationError, msg = 'invalid pickup_datetime'): | ||
bad_message = self.message.copy() | ||
bad_message['pickup_datetime'] = 0 | ||
self.model.validate(bad_message) | ||
|
||
with self.assertRaises(TripValidationError, msg = 'invalid dropoff_datetime'): | ||
bad_message = self.message.copy() | ||
bad_message['dropoff_datetime'] = 0 | ||
self.model.validate(bad_message) | ||
|
||
with self.assertRaises(TripValidationError, msg = 'invalid pickup_datetime'): | ||
bad_message = self.message.copy() | ||
bad_message['pickup_datetime'] = 'Monday at noon' | ||
self.model.validate(bad_message) | ||
|
||
with self.assertRaises(TripValidationError, msg = 'invalid dropoff_datetime'): | ||
bad_message = self.message.copy() | ||
bad_message['dropoff_datetime'] = 'Monday at noon' | ||
self.model.validate(bad_message) |