Skip to content

Commit

Permalink
Merge pull request #23 from nasa/custom_predictors
Browse files Browse the repository at this point in the history
Custom predictors
  • Loading branch information
kjjarvis authored Mar 27, 2024
2 parents a9eec2f + 03ec4db commit eee004c
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 7 deletions.
10 changes: 8 additions & 2 deletions src/prog_server/models/prog_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,26 @@ class ProgServer():
def __init__(self):
self.process = None

def run(self, host=DEFAULT_HOST, port=DEFAULT_PORT, debug=False, models={}, **kwargs) -> None:
def run(self, host=DEFAULT_HOST, port=DEFAULT_PORT, debug=False, models={}, predictors={}, **kwargs) -> None:
"""Run the server (blocking)
Keyword Args:
host (str, optional): Server host. Defaults to '127.0.0.1'.
port (int, optional): Server port. Defaults to 8555.
debug (bool, optional): If the server is started in debug mode
models (dict[str, PrognosticsModel]): a dictionary of extra models to consider. The key is the name used to identify it.
models (dict[str, PrognosticsModel]): a dictionary of extra models to consider. The key is the name used to identify it.
predictors (dict[str, predictors.Predictor]): a dictionary of extra predictors to consider. The key is the name used to identify it.
"""
if not isinstance(models, dict):
raise TypeError("Extra models (`model` arg in prog_server.run() or start()) must be in a dictionary in the form `name: model_name`")

session.extra_models.update(models)

if not isinstance(predictors, dict):
raise TypeError("Custom Predictors (`predictors` arg in prog_server.run() or start()) must be in a dictionary in the form `name: pred_name`")

session.extra_predictors.update(predictors)

self.host = host
self.port = port
self.process = app.run(host=host, port=port, debug=debug)
Expand Down
25 changes: 20 additions & 5 deletions src/prog_server/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from threading import Lock

extra_models = {}
extra_predictors = {}


class Session():
Expand Down Expand Up @@ -82,14 +83,28 @@ def __init__(self, session_id,

# Predictor
try:
pred_class = getattr(predictors, pred_name)
if pred_name in extra_predictors:
pred_class = extra_predictors[pred_name]
else:
pred_class = getattr(predictors, pred_name)
except AttributeError:
abort(400, f"Invalid predictor name {pred_name}")
app.logger.debug(f"Creating Predictor of type {self.pred_name}")
try:
self.pred = pred_class(self.model, **pred_cfg)
except Exception as e:
abort(400, f"Could not instantiate predictor with input: {e}")
if isinstance(pred_class, type) and issubclass(pred_class, predictors.Predictor):
# pred_class is a class, either from progpy or custom classes
try:
self.pred = pred_class(self.model, **pred_cfg)
except Exception as e:
abort(400, f"Could not instantiate predictor with input: {e}")
elif isinstance(pred_class, predictors.Predictor):
# pred_class is an instance of predictors.Predictor - use the object instead
# This happens for user predictors that are added to the server at startup.
self.pred = deepcopy(pred_class)
# Apply any configuration changes, overriding predictor config.
self.pred.parameters.update(pred_cfg)
else:
abort(400, f"Invalid predictor type {type(pred_name)} for predictor {pred_name}. For custom classes, the predictor must be mentioned with quotes in the pred argument")

self.pred_cfg = self.pred.parameters

# State Estimator
Expand Down
33 changes: 33 additions & 0 deletions tests/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import time
import unittest
import numpy as np
import prog_client, prog_server
from progpy import PrognosticsModel
from progpy.predictors import MonteCarlo
from progpy.uncertain_data import MultivariateNormalDist
from progpy.models import ThrownObject

Expand Down Expand Up @@ -243,6 +245,37 @@ def test_custom_models(self):
def tearDownClass(cls):
prog_server.stop()

def test_custom_predictors(self):
# Restart server with model
prog_server.stop()
#define the model
ball = ThrownObject(thrower_height=1.5, throwing_speed=20)
#call the external/extra predictor (here from progpy)
mc = MonteCarlo(ball)
with self.assertRaises(Exception):
prog_server.start(port=9883, predictors=[1, 2])

prog_server.start(port=9883, predictors={'mc':mc})
ball_session = prog_client.Session('ThrownObject', port=9883, pred='mc')

#check that the prediction completes successfully without error
#get_prediction_status from the session - for errors
sessions_in_progress = True
STEP = 15 # Time to wait between pinging server (s)
while sessions_in_progress:
sessions_in_progress = False
status = ball_session.get_prediction_status()
if status['in progress'] != 0:
print(f'\tSession {ball_session.session_id} is still in progress')
sessions_in_progress = True
time.sleep(STEP)
print(f'\tSession {ball_session.session_id} complete')
print(status)
# Restart (to reset port)
prog_server.stop()
prog_server.start()


# This allows the module to be executed directly
def run_tests():
l = unittest.TestLoader()
Expand Down

0 comments on commit eee004c

Please sign in to comment.