Skip to content

Commit

Permalink
Merge pull request #27 from nasa/custom_estimators
Browse files Browse the repository at this point in the history
Custom estimators
  • Loading branch information
teubert authored Aug 8, 2024
2 parents 025c4ca + c83bd1c commit 3bc6579
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 11 deletions.
10 changes: 8 additions & 2 deletions src/prog_server/models/prog_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@ class ProgServer():
def __init__(self):
self.process = None

def run(self, host=DEFAULT_HOST, port=DEFAULT_PORT, debug=False, models={}, predictors={}, **kwargs) -> None:
def run(self, host=DEFAULT_HOST, port=DEFAULT_PORT, debug=False, models={}, predictors={}, state_estimators={}, **kwargs) -> None:
"""Run the server (blocking)
Keyword Args:
host (str, optional): Server host. Defaults to '127.0.0.1'.
port (int, optional): Server port. Defaults to 8555.
debug (bool, optional): If the server is started in debug mode
models (dict[str, PrognosticsModel]): a dictionary of extra models to consider. The key is the name used to identify it.
predictors (dict[str, predictors.Predictor]): a dictionary of extra predictors to consider. The key is the name used to identify it.
predictors (dict[str, predictors.Predictor]): a dictionary of extra predictors to consider. The key is the name used to identify it.
state_estimators (dict[str, state_estimators.StateEstimator]): a dictionary of extra estimators to consider. The key is the name used to identify it.
"""
if not isinstance(models, dict):
raise TypeError("Extra models (`model` arg in prog_server.run() or start()) must be in a dictionary in the form `name: model_name`")
Expand All @@ -39,6 +40,11 @@ def run(self, host=DEFAULT_HOST, port=DEFAULT_PORT, debug=False, models={}, pred

session.extra_predictors.update(predictors)

if not isinstance(state_estimators, dict):
raise TypeError("Custom Estimator (`state_estimators` arg in prog_server.run() or start()) must be in a dictionary in the form `name: state_est_name`")

session.extra_estimators.update(state_estimators)

self.host = host
self.port = port
self.process = app.run(host=host, port=port, debug=debug)
Expand Down
36 changes: 27 additions & 9 deletions src/prog_server/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

extra_models = {}
extra_predictors = {}

extra_estimators = {}

class Session():
def __init__(self, session_id,
Expand Down Expand Up @@ -116,23 +116,41 @@ def __init__(self, session_id,
# Otherwise, will have to be initialized later
# Check state estimator and predictor data
try:
getattr(state_estimators, state_est_name)
if self.state_est_name not in extra_estimators:
getattr(state_estimators, state_est_name)
except AttributeError:
abort(400, f"Invalid state estimator name {state_est_name}")

def __initialize(self, x0, predict_queue=True):
app.logger.debug("Initializing...")
state_est_class = getattr(state_estimators, self.state_est_name)
#Estimator
try:
if self.state_est_name in extra_estimators:
state_est_class = extra_estimators[self.state_est_name]
else:
state_est_class = getattr(state_estimators, self.state_est_name)
except AttributeError:
abort(400, f"Invalid state estimator name {self.state_est_name}")
app.logger.debug(f"Creating State Estimator of type {self.state_est_name}")

if isinstance(x0, str):
x0 = json.loads(x0)
x0 = json.loads(x0) #loads the initial state
if set(self.model.states) != set(list(x0.keys())):
abort(400, f"Initial state must have every state in model. states. Had {list(x0.keys())}, needed {self.model.states}")

try:
self.state_est = state_est_class(self.model, x0, **self.state_est_cfg)
except Exception as e:
abort(400, f"Could not instantiate state estimator with input: {e}")

if isinstance(state_est_class, type) and issubclass(state_est_class, state_estimators.StateEstimator):
try:
self.state_est = state_est_class(self.model, x0, **self.state_est_cfg)
except Exception as e:
abort(400, f"Could not instantiate state estimator with input: {e}")
elif isinstance(state_est_class, state_estimators.StateEstimator):
# state_est_class is an instance of state_estimators.StateEstimator - use the object instead
# This happens for user state estimators that are added to the server at startup.
self.state_est = deepcopy(state_est_class)
# Apply any configuration changes, overriding estimator config
self.state_est.parameters.update(self.state_est_cfg)
else:
abort(400, f"Invalid state estimator type {type(self.state_est_name)} for estimator {self.state_est_name}. For custom classes, the state estimator must be mentioned with quotes in the est argument")

self.initialized = True
if predict_queue:
Expand Down
52 changes: 52 additions & 0 deletions tests/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from progpy.predictors import MonteCarlo
from progpy.uncertain_data import MultivariateNormalDist
from progpy.models import ThrownObject
from progpy.models.thrown_object import LinearThrownObject
from progpy.state_estimators import KalmanFilter
from progpy.state_estimators import ParticleFilter
from progpy.uncertain_data import MultivariateNormalDist
from progpy.uncertain_data import UnweightedSamples


class IntegrationTest(unittest.TestCase):
Expand Down Expand Up @@ -307,6 +312,53 @@ def test_custom_predictors(self):
prog_server.stop()
prog_server.start()

def test_custom_estimators(self):
# Restart server with model
prog_server.stop()

#define the custom model
ball = LinearThrownObject(thrower_height=1.5, throwing_speed=20)
x_guess = ball.StateContainer({'x': 1.75, 'v': 35})
kf = KalmanFilter(ball, x_guess)
with self.assertRaises(Exception):
# state_estimators not a dictionary
prog_server.start(models ={'ball':ball}, port=9883, state_estimators=[20])
prog_server.start(models ={'ball':ball}, port=9883, state_estimators={'kf':kf})
ball_session = prog_client.Session('ball', port=9883, state_est='kf')

# time step (s)
dt = 0.01
x = ball.initialize()
# Initial input
u = ball.InputContainer({})

# Iterate forward 1 second and compare
x = ball.next_state(x, u, 1)
ball_session.send_data(time=1, x=x['x'])

t, x_s = ball_session.get_state()

# To check if the output state is multivariate normal distribution
self.assertIsInstance(x_s, MultivariateNormalDist)

# Setup Particle Filter
pf = ParticleFilter(ball, x_guess)
prog_server.stop()
prog_server.start(models ={'ball':ball}, port=9883, state_estimators={'pf': pf, 'kf':kf})
ball_session = prog_client.Session('ball', port=9883, state_est='pf')

# Iterate forward 1 second and compare
x = ball.next_state(x, u, 1)
ball_session.send_data(time=1, x=x['x'])

t, x_s = ball_session.get_state()

# Ensure that PF output is unweighted samples
self.assertIsInstance(x_s, UnweightedSamples)

prog_server.stop()
prog_server.start()


# This allows the module to be executed directly
def run_tests():
Expand Down

0 comments on commit 3bc6579

Please sign in to comment.