Skip to content

Commit

Permalink
Added the feature to include custom predictors
Browse files Browse the repository at this point in the history
  • Loading branch information
kkannan7 committed Feb 9, 2024
1 parent a9eec2f commit ed6698a
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 7 deletions.
10 changes: 8 additions & 2 deletions src/prog_server/models/prog_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,26 @@ class ProgServer():
def __init__(self):
self.process = None

def run(self, host=DEFAULT_HOST, port=DEFAULT_PORT, debug=False, models={}, **kwargs) -> None:
def run(self, host=DEFAULT_HOST, port=DEFAULT_PORT, debug=False, models={}, predictors={}, **kwargs) -> None:
"""Run the server (blocking)
Keyword Args:
host (str, optional): Server host. Defaults to '127.0.0.1'.
port (int, optional): Server port. Defaults to 8555.
debug (bool, optional): If the server is started in debug mode
models (dict[str, PrognosticsModel]): a dictionary of extra models to consider. The key is the name used to identify it.
models (dict[str, PrognosticsModel]): a dictionary of extra models to consider. The key is the name used to identify it.
predictors (dict[str, predictors.Predictor]): a dictionary of extra predictors to consider. The key is the name used to identify it.
"""
if not isinstance(models, dict):
raise TypeError("Extra models (`model` arg in prog_server.run() or start()) must be in a dictionary in the form `name: model_name`")

session.extra_models.update(models)

if not isinstance(predictors, dict):
raise TypeError("Custom Predictors (`predictors` arg in prog_server.run() or start()) must be in a dictionary in the form `name: pred_name`")

session.extra_predictors.update(predictors)

self.host = host
self.port = port
self.process = app.run(host=host, port=port, debug=debug)
Expand Down
25 changes: 20 additions & 5 deletions src/prog_server/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from threading import Lock

extra_models = {}
extra_predictors = {}


class Session():
Expand Down Expand Up @@ -82,14 +83,28 @@ def __init__(self, session_id,

# Predictor
try:
pred_class = getattr(predictors, pred_name)
if pred_name in extra_predictors:
pred_class = extra_predictors[pred_name]
else:
pred_class = getattr(predictors, pred_name)
except AttributeError:
abort(400, f"Invalid predictor name {pred_name}")
app.logger.debug(f"Creating Predictor of type {self.pred_name}")
try:
self.pred = pred_class(self.model, **pred_cfg)
except Exception as e:
abort(400, f"Could not instantiate predictor with input: {e}")
if isinstance(pred_class, type) and issubclass(pred_class, predictors.Predictor):
# pred_class is a class, either from progpy or custom classes
try:
self.pred = pred_class(self.model, **pred_cfg)
except Exception as e:
abort(400, f"Could not instantiate predictor with input: {e}")
elif isinstance(pred_class, predictors.Predictor):
# pred_class is an instance of predictors.Predictor - use the object instead
# This happens for user predictors that are added to the server at startup.
self.pred = deepcopy(pred_class)
# Apply any configuration changes, overriding model config.
self.pred.parameters.update(pred_cfg)
else:
abort(400, f"Invalid model type {type(pred_name)} for model {pred_name}. For custom classes, the model must be either an instantiated PrognosticsModel subclass or classmame")

self.pred_cfg = self.pred.parameters

# State Estimator
Expand Down
36 changes: 36 additions & 0 deletions tests/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import time
import unittest
import numpy as np
import prog_client, prog_server
from progpy import PrognosticsModel
from progpy.predictors import MonteCarlo
from progpy.uncertain_data import MultivariateNormalDist
from progpy.models import ThrownObject

Expand Down Expand Up @@ -243,6 +245,40 @@ def test_custom_models(self):
def tearDownClass(cls):
prog_server.stop()

def test_custom_predictors(self):
# Restart server with model
prog_server.stop()
#define the model
ball = ThrownObject(thrower_height=1.5, throwing_speed=20)
prog_server.start(models={'ball': ball}, port=9883)
ball_session = prog_client.Session('ball', port=9883)
initial_state = ball.initialize()
#call the external/extra predictor (here from progpy)
mc = MonteCarlo(ball)

x = MultivariateNormalDist(initial_state.keys(), initial_state.values(), np.diag([x_i*0.01 for x_i in initial_state.values()]))

PREDICTION_HORIZON = 7.7
STEP_SIZE = 0.01
NUM_SAMPLES = 500

# Make Prediction
mc_results = mc.predict(x, n_samples=NUM_SAMPLES, dt=STEP_SIZE, horizon = PREDICTION_HORIZON)

metrics = mc_results.time_of_event.metrics()
print("\nPredicted Time of Event:")
print(metrics) # Note this takes some time
fig = mc_results.time_of_event.plot_hist(keys = 'impact')
fig = mc_results.time_of_event.plot_hist(keys = 'falling')

print("\nSamples where falling occurs before horizon: {:.2f}%".format(metrics['falling']['number of samples']/NUM_SAMPLES * 100))
print("\nSamples where impact occurs before horizon: {:.2f}%".format(metrics['impact']['number of samples']/NUM_SAMPLES * 100))

# Restart (to reset port)
prog_server.stop()
prog_server.start()


# This allows the module to be executed directly
def run_tests():
l = unittest.TestLoader()
Expand Down

0 comments on commit ed6698a

Please sign in to comment.