Skip to content

Commit

Permalink
Predict tip amounts, pulling from and pushing to Kafka
Browse files Browse the repository at this point in the history
  • Loading branch information
erikperkins committed May 18, 2023
1 parent 8e6d39d commit aba36c8
Show file tree
Hide file tree
Showing 8 changed files with 323 additions and 3 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
.idea/
*.iml
venv/
mlruns/
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
*.iml
venv/
__pycache__/
mlruns/
46 changes: 43 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,44 @@
if __name__ == "__main__":
while True:
sleep(1)
import sentry_sdk
import json
import logging
from kafka import KafkaConsumer
from kafka import KafkaProducer
from model.tip import Model
from model.tip import TripValidationError

BOOTSTRAP_SERVER = 'kafka-service.kafka.svc.cluster.local:9092'
TRIPS = 'trips'
TIPS = 'tips'

logger = logging.getLogger('mariotte')

sentry_sdk.init(
dsn = "https://[email protected]/5",
traces_sample_rate = 1.0
)

consumer = KafkaConsumer(
bootstrap_servers = [BOOTSTRAP_SERVER],
value_deserializer = lambda x: json.loads(x)
)

producer = KafkaProducer(
bootstrap_servers = [BOOTSTRAP_SERVER],
value_serializer = lambda x: json.dumps(x).encode('utf-8')
)

model = Model()

if __name__ == "__main__":
consumer.subscribe(TRIPS)
for message in consumer:
trip = message.value
try:
tip = model.predict(trip)
trip['predicted_tip'] = tip
producer.send(TIPS, trip)
except TripValidationError as e:
logger.warning("Skipping invalid record")
continue
except Exception as e:
raise e
Empty file added model/__init__.py
Empty file.
67 changes: 67 additions & 0 deletions model/tip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from pandas import DataFrame
from pandas import Timestamp
from math import floor
import mlflow
import os


MODEL_URI = "models:/GPUTipPipeline/Production"


class TripValidationError(Exception):
def __init__(self, message):
super().__init__(message)

class Model():
"""Predict tip from trip data."""
def __init__(self):
mlflow.set_tracking_uri("https://mlflow.cauchy.link")
self.regressor = mlflow.sklearn.load_model(MODEL_URI)

def predict(self, message):
"""Predict tip using loaded model."""
self.validate(message)

data = DataFrame([message]).astype({
'pickup_datetime': 'datetime64[ns]',
'dropoff_datetime': 'datetime64[ns]',
'pickup_location_id': 'category',
'dropoff_location_id': 'category',
'payment_type': 'category'
})

tip, = self.regressor.predict(data)
return floor(100 * tip) / 100.

def validate(self, message):
"""
Validate message structure. Ensure all expected features are present,
and all types are correct.
"""
try:
keys = set(message.keys())
features = set(self.regressor.feature_names_in_)
assert features.issubset(keys)

assert type(message['pickup_location_id']) in [int, type(None)]
assert type(message['dropoff_location_id']) in [int, type(None)]
assert type(message['payment_type']) in [int, type(None)]

assert type(message['passenger_count']) in [float, type(None)]
assert type(message['trip_distance']) in [float, type(None)]
assert type(message['fare_amount']) in [float, type(None)]
assert type(message['extra']) in [float, type(None)]
assert type(message['mta_tax']) in [float, type(None)]
assert type(message['tolls_amount']) in [float, type(None)]
assert type(message['improvement_surcharge']) in [float, type(None)]
assert type(message['congestion_surcharge']) in [float, type(None)]
assert type(message['airport_fee']) in [float, type(None)]
assert type(message['tip_amount']) in [float, type(None)]

assert type(message['pickup_datetime']) in [str, type(None)]
assert type(message['dropoff_datetime']) in [str, type(None)]

Timestamp(message['pickup_datetime'])
Timestamp(message['dropoff_datetime'])
except Exception as e:
raise TripValidationError(e)
59 changes: 59 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
alembic==1.11.0
blinker==1.6.2
certifi==2023.5.7
charset-normalizer==3.1.0
click==8.1.3
cloudpickle==2.2.1
contourpy==1.0.7
cycler==0.11.0
databricks-cli==0.17.7
docker==6.1.2
entrypoints==0.4
Flask==2.3.2
fonttools==4.39.4
gitdb==4.0.10
GitPython==3.1.31
greenlet==2.0.2
gunicorn==20.1.0
idna==3.4
importlib-metadata==6.6.0
itsdangerous==2.1.2
Jinja2==3.1.2
joblib==1.2.0
kafka-python==2.0.2
kiwisolver==1.4.4
Mako==1.2.4
Markdown==3.4.3
MarkupSafe==2.1.2
matplotlib==3.7.1
mlflow==2.3.2
numpy==1.24.3
oauthlib==3.2.2
packaging==23.1
pandas==2.0.1
Pillow==9.5.0
protobuf==4.23.0
pyarrow==11.0.0
PyJWT==2.7.0
pyparsing==3.0.9
python-dateutil==2.8.2
pytz==2023.3
PyYAML==6.0
querystring-parser==1.2.4
requests==2.30.0
scikit-learn==1.2.2
scipy==1.10.1
sentry-sdk==1.23.0
six==1.16.0
smmap==5.0.0
SQLAlchemy==2.0.13
sqlparse==0.4.4
tabulate==0.9.0
threadpoolctl==3.1.0
typing_extensions==4.5.0
tzdata==2023.3
urllib3==1.26.15
websocket-client==1.5.1
Werkzeug==2.3.4
xgboost==1.7.5
zipp==3.15.0
Empty file added test/__init__.py
Empty file.
152 changes: 152 additions & 0 deletions test/test_tip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import unittest
from unittest.mock import Mock
from unittest.mock import patch
from model.tip import Model
from model.tip import TripValidationError
from numpy import array

class TestModel(unittest.TestCase):
@patch('model.tip.mlflow', Mock())
def setUp(self):
self.model = Model()
self.model.regressor = Mock()
self.model.regressor.feature_names_in_ = {'pickup_datetime': '2022-05-17 12:11:03'}
self.model.regressor.predict.return_value = array([3.0700364])

self.message = {
'pickup_datetime': '2022-05-17 12:11:03',
'dropoff_datetime': '2022-05-17 12:30:10',
'pickup_location_id': 237,
'dropoff_location_id': 79,
'passenger_count': 2.0,
'trip_distance': 2.39,
'payment_type': 1,
'fare_amount': 13.5,
'extra': 0.0,
'mta_tax': 0.5,
'tolls_amount': 0.0,
'improvement_surcharge': 0.3,
'congestion_surcharge': 2.5,
'airport_fee': 0.0,
'tip_amount': 3.36
}

def tearDown(self):
self.model.regressor.reset_mock()

@patch('model.tip.mlflow.sklearn.load_model')
def test_init_loads_model(self, mock_load_model):
model = Model()
mock_load_model.assert_called()

def test_predict(self):
tip = self.model.predict(self.message)
self.assertEqual(tip, 3.07)

with self.assertRaises(TripValidationError):
bad_message = self.message.copy()
del(bad_message['pickup_datetime'])
self.model.predict(bad_message)

def test_validate(self):
try:
self.model.validate(self.message)
except TripValidationError:
raise self.failureException('TripValidationError raised on valid input')

try:
incomplete_message = self.message.copy()
incomplete_message['passenger_count'] = None
self.model.validate(incomplete_message)
except TripValidationError:
raise self.failureException('TripValidationError raised on valid input')

with self.assertRaises(TripValidationError, msg = 'missing covariate'):
bad_message = self.message.copy()
del(bad_message['pickup_datetime'])
self.model.validate(bad_message)

with self.assertRaises(TripValidationError, msg = 'invalid pickup_location_id'):
bad_message = self.message.copy()
bad_message['pickup_location_id'] = '1'
self.model.validate(bad_message)

with self.assertRaises(TripValidationError, msg = 'invalid dropoff_location_id'):
bad_message = self.message.copy()
bad_message['dropoff_location_id'] = '1'
self.model.validate(bad_message)

with self.assertRaises(TripValidationError, msg = 'invalid payment_type'):
bad_message = self.message.copy()
bad_message['payment_type'] = '1'
self.model.validate(bad_message)

with self.assertRaises(TripValidationError, msg = 'invalid passenger_count'):
bad_message = self.message.copy()
bad_message['passenger_count'] = 0
self.model.validate(bad_message)

with self.assertRaises(TripValidationError, msg = 'invalid trip_distance'):
bad_message = self.message.copy()
bad_message['trip_distance'] = 0
self.model.validate(bad_message)

with self.assertRaises(TripValidationError, msg = 'invalid fare_amount'):
bad_message = self.message.copy()
bad_message['fare_amount'] = 0
self.model.validate(bad_message)

with self.assertRaises(TripValidationError, msg = 'invalid extra'):
bad_message = self.message.copy()
bad_message['extra'] = 0
self.model.validate(bad_message)

with self.assertRaises(TripValidationError, msg = 'invalid mta_tax'):
bad_message = self.message.copy()
bad_message['mta_tax'] = 0
self.model.validate(bad_message)

with self.assertRaises(TripValidationError, msg = 'invalid tolls_amount'):
bad_message = self.message.copy()
bad_message['tolls_amount'] = 0
self.model.validate(bad_message)

with self.assertRaises(TripValidationError, msg = 'invalid improvement_surcharge'):
bad_message = self.message.copy()
bad_message['improvement_surcharge'] = 0
self.model.validate(bad_message)

with self.assertRaises(TripValidationError, msg = 'invalid congestion_surcharge'):
bad_message = self.message.copy()
bad_message['congestion_surcharge'] = 0
self.model.validate(bad_message)

with self.assertRaises(TripValidationError, msg = 'invalid airport_fee'):
bad_message = self.message.copy()
bad_message['airport_fee'] = 0
self.model.validate(bad_message)

with self.assertRaises(TripValidationError, msg = 'invalid tip_amount'):
bad_message = self.message.copy()
bad_message['tip_amount'] = 0
self.model.validate(bad_message)

with self.assertRaises(TripValidationError, msg = 'invalid pickup_datetime'):
bad_message = self.message.copy()
bad_message['pickup_datetime'] = 0
self.model.validate(bad_message)

with self.assertRaises(TripValidationError, msg = 'invalid dropoff_datetime'):
bad_message = self.message.copy()
bad_message['dropoff_datetime'] = 0
self.model.validate(bad_message)

with self.assertRaises(TripValidationError, msg = 'invalid pickup_datetime'):
bad_message = self.message.copy()
bad_message['pickup_datetime'] = 'Monday at noon'
self.model.validate(bad_message)

with self.assertRaises(TripValidationError, msg = 'invalid dropoff_datetime'):
bad_message = self.message.copy()
bad_message['dropoff_datetime'] = 'Monday at noon'
self.model.validate(bad_message)

0 comments on commit aba36c8

Please sign in to comment.