From aba36c88af3fa44c1287efb6f74bc76035144602 Mon Sep 17 00:00:00 2001 From: Erik Perkins Date: Wed, 17 May 2023 10:27:18 -0700 Subject: [PATCH] Predict tip amounts, pulling from and pushing to Kafka --- .dockerignore | 1 + .gitignore | 1 + main.py | 46 +++++++++++++- model/__init__.py | 0 model/tip.py | 67 ++++++++++++++++++++ requirements.txt | 59 ++++++++++++++++++ test/__init__.py | 0 test/test_tip.py | 152 ++++++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 323 insertions(+), 3 deletions(-) create mode 100644 model/__init__.py create mode 100644 model/tip.py create mode 100644 test/__init__.py create mode 100644 test/test_tip.py diff --git a/.dockerignore b/.dockerignore index 0cd01c3..c41bf07 100644 --- a/.dockerignore +++ b/.dockerignore @@ -2,3 +2,4 @@ .idea/ *.iml venv/ +mlruns/ \ No newline at end of file diff --git a/.gitignore b/.gitignore index 65c6f87..2b0773b 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ *.iml venv/ __pycache__/ +mlruns/ \ No newline at end of file diff --git a/main.py b/main.py index 131558e..a2317dd 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,44 @@ -if __name__ == "__main__": - while True: - sleep(1) +import sentry_sdk +import json +import logging +from kafka import KafkaConsumer +from kafka import KafkaProducer +from model.tip import Model +from model.tip import TripValidationError + +BOOTSTRAP_SERVER = 'kafka-service.kafka.svc.cluster.local:9092' +TRIPS = 'trips' +TIPS = 'tips' + +logger = logging.getLogger('mariotte') + +sentry_sdk.init( + dsn = "https://89376a19f9d244c3b3e64f0bd599821c@sentry.cauchy.link/5", + traces_sample_rate = 1.0 +) +consumer = KafkaConsumer( + bootstrap_servers = [BOOTSTRAP_SERVER], + value_deserializer = lambda x: json.loads(x) +) + +producer = KafkaProducer( + bootstrap_servers = [BOOTSTRAP_SERVER], + value_serializer = lambda x: json.dumps(x).encode('utf-8') +) + +model = Model() + +if __name__ == "__main__": + consumer.subscribe(TRIPS) + for message in consumer: + trip = message.value + try: + tip = model.predict(trip) + trip['predicted_tip'] = tip + producer.send(TIPS, trip) + except TripValidationError as e: + logger.warning("Skipping invalid record") + continue + except Exception as e: + raise e diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/tip.py b/model/tip.py new file mode 100644 index 0000000..1d96d2a --- /dev/null +++ b/model/tip.py @@ -0,0 +1,67 @@ +from pandas import DataFrame +from pandas import Timestamp +from math import floor +import mlflow +import os + + +MODEL_URI = "models:/GPUTipPipeline/Production" + + +class TripValidationError(Exception): + def __init__(self, message): + super().__init__(message) + +class Model(): + """Predict tip from trip data.""" + def __init__(self): + mlflow.set_tracking_uri("https://mlflow.cauchy.link") + self.regressor = mlflow.sklearn.load_model(MODEL_URI) + + def predict(self, message): + """Predict tip using loaded model.""" + self.validate(message) + + data = DataFrame([message]).astype({ + 'pickup_datetime': 'datetime64[ns]', + 'dropoff_datetime': 'datetime64[ns]', + 'pickup_location_id': 'category', + 'dropoff_location_id': 'category', + 'payment_type': 'category' + }) + + tip, = self.regressor.predict(data) + return floor(100 * tip) / 100. + + def validate(self, message): + """ + Validate message structure. Ensure all expected features are present, + and all types are correct. + """ + try: + keys = set(message.keys()) + features = set(self.regressor.feature_names_in_) + assert features.issubset(keys) + + assert type(message['pickup_location_id']) in [int, type(None)] + assert type(message['dropoff_location_id']) in [int, type(None)] + assert type(message['payment_type']) in [int, type(None)] + + assert type(message['passenger_count']) in [float, type(None)] + assert type(message['trip_distance']) in [float, type(None)] + assert type(message['fare_amount']) in [float, type(None)] + assert type(message['extra']) in [float, type(None)] + assert type(message['mta_tax']) in [float, type(None)] + assert type(message['tolls_amount']) in [float, type(None)] + assert type(message['improvement_surcharge']) in [float, type(None)] + assert type(message['congestion_surcharge']) in [float, type(None)] + assert type(message['airport_fee']) in [float, type(None)] + assert type(message['tip_amount']) in [float, type(None)] + + assert type(message['pickup_datetime']) in [str, type(None)] + assert type(message['dropoff_datetime']) in [str, type(None)] + + Timestamp(message['pickup_datetime']) + Timestamp(message['dropoff_datetime']) + except Exception as e: + raise TripValidationError(e) diff --git a/requirements.txt b/requirements.txt index e69de29..ab32434 100644 --- a/requirements.txt +++ b/requirements.txt @@ -0,0 +1,59 @@ +alembic==1.11.0 +blinker==1.6.2 +certifi==2023.5.7 +charset-normalizer==3.1.0 +click==8.1.3 +cloudpickle==2.2.1 +contourpy==1.0.7 +cycler==0.11.0 +databricks-cli==0.17.7 +docker==6.1.2 +entrypoints==0.4 +Flask==2.3.2 +fonttools==4.39.4 +gitdb==4.0.10 +GitPython==3.1.31 +greenlet==2.0.2 +gunicorn==20.1.0 +idna==3.4 +importlib-metadata==6.6.0 +itsdangerous==2.1.2 +Jinja2==3.1.2 +joblib==1.2.0 +kafka-python==2.0.2 +kiwisolver==1.4.4 +Mako==1.2.4 +Markdown==3.4.3 +MarkupSafe==2.1.2 +matplotlib==3.7.1 +mlflow==2.3.2 +numpy==1.24.3 +oauthlib==3.2.2 +packaging==23.1 +pandas==2.0.1 +Pillow==9.5.0 +protobuf==4.23.0 +pyarrow==11.0.0 +PyJWT==2.7.0 +pyparsing==3.0.9 +python-dateutil==2.8.2 +pytz==2023.3 +PyYAML==6.0 +querystring-parser==1.2.4 +requests==2.30.0 +scikit-learn==1.2.2 +scipy==1.10.1 +sentry-sdk==1.23.0 +six==1.16.0 +smmap==5.0.0 +SQLAlchemy==2.0.13 +sqlparse==0.4.4 +tabulate==0.9.0 +threadpoolctl==3.1.0 +typing_extensions==4.5.0 +tzdata==2023.3 +urllib3==1.26.15 +websocket-client==1.5.1 +Werkzeug==2.3.4 +xgboost==1.7.5 +zipp==3.15.0 diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_tip.py b/test/test_tip.py new file mode 100644 index 0000000..7508532 --- /dev/null +++ b/test/test_tip.py @@ -0,0 +1,152 @@ +import unittest +from unittest.mock import Mock +from unittest.mock import patch +from model.tip import Model +from model.tip import TripValidationError +from numpy import array + +class TestModel(unittest.TestCase): + @patch('model.tip.mlflow', Mock()) + def setUp(self): + self.model = Model() + self.model.regressor = Mock() + self.model.regressor.feature_names_in_ = {'pickup_datetime': '2022-05-17 12:11:03'} + self.model.regressor.predict.return_value = array([3.0700364]) + + self.message = { + 'pickup_datetime': '2022-05-17 12:11:03', + 'dropoff_datetime': '2022-05-17 12:30:10', + 'pickup_location_id': 237, + 'dropoff_location_id': 79, + 'passenger_count': 2.0, + 'trip_distance': 2.39, + 'payment_type': 1, + 'fare_amount': 13.5, + 'extra': 0.0, + 'mta_tax': 0.5, + 'tolls_amount': 0.0, + 'improvement_surcharge': 0.3, + 'congestion_surcharge': 2.5, + 'airport_fee': 0.0, + 'tip_amount': 3.36 + } + + def tearDown(self): + self.model.regressor.reset_mock() + + @patch('model.tip.mlflow.sklearn.load_model') + def test_init_loads_model(self, mock_load_model): + model = Model() + mock_load_model.assert_called() + + def test_predict(self): + tip = self.model.predict(self.message) + self.assertEqual(tip, 3.07) + + with self.assertRaises(TripValidationError): + bad_message = self.message.copy() + del(bad_message['pickup_datetime']) + self.model.predict(bad_message) + + def test_validate(self): + try: + self.model.validate(self.message) + except TripValidationError: + raise self.failureException('TripValidationError raised on valid input') + + try: + incomplete_message = self.message.copy() + incomplete_message['passenger_count'] = None + self.model.validate(incomplete_message) + except TripValidationError: + raise self.failureException('TripValidationError raised on valid input') + + with self.assertRaises(TripValidationError, msg = 'missing covariate'): + bad_message = self.message.copy() + del(bad_message['pickup_datetime']) + self.model.validate(bad_message) + + with self.assertRaises(TripValidationError, msg = 'invalid pickup_location_id'): + bad_message = self.message.copy() + bad_message['pickup_location_id'] = '1' + self.model.validate(bad_message) + + with self.assertRaises(TripValidationError, msg = 'invalid dropoff_location_id'): + bad_message = self.message.copy() + bad_message['dropoff_location_id'] = '1' + self.model.validate(bad_message) + + with self.assertRaises(TripValidationError, msg = 'invalid payment_type'): + bad_message = self.message.copy() + bad_message['payment_type'] = '1' + self.model.validate(bad_message) + + with self.assertRaises(TripValidationError, msg = 'invalid passenger_count'): + bad_message = self.message.copy() + bad_message['passenger_count'] = 0 + self.model.validate(bad_message) + + with self.assertRaises(TripValidationError, msg = 'invalid trip_distance'): + bad_message = self.message.copy() + bad_message['trip_distance'] = 0 + self.model.validate(bad_message) + + with self.assertRaises(TripValidationError, msg = 'invalid fare_amount'): + bad_message = self.message.copy() + bad_message['fare_amount'] = 0 + self.model.validate(bad_message) + + with self.assertRaises(TripValidationError, msg = 'invalid extra'): + bad_message = self.message.copy() + bad_message['extra'] = 0 + self.model.validate(bad_message) + + with self.assertRaises(TripValidationError, msg = 'invalid mta_tax'): + bad_message = self.message.copy() + bad_message['mta_tax'] = 0 + self.model.validate(bad_message) + + with self.assertRaises(TripValidationError, msg = 'invalid tolls_amount'): + bad_message = self.message.copy() + bad_message['tolls_amount'] = 0 + self.model.validate(bad_message) + + with self.assertRaises(TripValidationError, msg = 'invalid improvement_surcharge'): + bad_message = self.message.copy() + bad_message['improvement_surcharge'] = 0 + self.model.validate(bad_message) + + with self.assertRaises(TripValidationError, msg = 'invalid congestion_surcharge'): + bad_message = self.message.copy() + bad_message['congestion_surcharge'] = 0 + self.model.validate(bad_message) + + with self.assertRaises(TripValidationError, msg = 'invalid airport_fee'): + bad_message = self.message.copy() + bad_message['airport_fee'] = 0 + self.model.validate(bad_message) + + with self.assertRaises(TripValidationError, msg = 'invalid tip_amount'): + bad_message = self.message.copy() + bad_message['tip_amount'] = 0 + self.model.validate(bad_message) + + with self.assertRaises(TripValidationError, msg = 'invalid pickup_datetime'): + bad_message = self.message.copy() + bad_message['pickup_datetime'] = 0 + self.model.validate(bad_message) + + with self.assertRaises(TripValidationError, msg = 'invalid dropoff_datetime'): + bad_message = self.message.copy() + bad_message['dropoff_datetime'] = 0 + self.model.validate(bad_message) + + with self.assertRaises(TripValidationError, msg = 'invalid pickup_datetime'): + bad_message = self.message.copy() + bad_message['pickup_datetime'] = 'Monday at noon' + self.model.validate(bad_message) + + with self.assertRaises(TripValidationError, msg = 'invalid dropoff_datetime'): + bad_message = self.message.copy() + bad_message['dropoff_datetime'] = 'Monday at noon' + self.model.validate(bad_message)