From 4f87072ad4995e680acab407cf0960b3f038380b Mon Sep 17 00:00:00 2001 From: CamDavidsonPilon Date: Mon, 20 Jan 2025 09:32:29 -0500 Subject: [PATCH] protocols can target multiple devices, but it now passed in as an arg to run --- pioreactor/calibrations/__init__.py | 51 +++++++++-------------- pioreactor/calibrations/od_calibration.py | 48 ++++----------------- pioreactor/calibrations/utils.py | 19 +++++---- pioreactor/cli/calibrations.py | 4 +- 4 files changed, 43 insertions(+), 79 deletions(-) diff --git a/pioreactor/calibrations/__init__.py b/pioreactor/calibrations/__init__.py index 9b767d81..520dbc7f 100644 --- a/pioreactor/calibrations/__init__.py +++ b/pioreactor/calibrations/__init__.py @@ -26,9 +26,18 @@ class CalibrationProtocol: + protocol_name: str + target_device: str | list[str] + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - calibration_protocols[(cls.target_device, cls.protocol_name)] = cls + if isinstance(cls.target_device, str): + calibration_protocols[(cls.target_device, cls.protocol_name)] = cls + elif isinstance(cls.target_device, list): + for device in cls.target_device: + calibration_protocols[(device, cls.protocol_name)] = cls + else: + raise ValueError("target_device must be a string or a list of strings") def run(self, *args, **kwargs): raise NotImplementedError("Subclasses must implement this method.") @@ -38,7 +47,7 @@ class SingleVialODProtocol(CalibrationProtocol): target_device = "od" protocol_name = "single_vial" - def run(self) -> structs.ODCalibration: + def run(self, *args) -> structs.ODCalibration: from pioreactor.calibrations.od_calibration import run_od_calibration return run_od_calibration() @@ -48,47 +57,27 @@ class BatchVialODProtocol(CalibrationProtocol): target_device = "od" protocol_name = "batch_vial" - def run(self) -> structs.ODCalibration: - from pioreactor.calibrations.od_calibration import run_od_calibration - - return run_od_calibration() - - -class DurationBasedMediaPumpProtocol(CalibrationProtocol): - target_device = "media_pump" - protocol_name = "duration_based" - - def run(self) -> structs.SimplePeristalticPumpCalibration: - from pioreactor.calibrations.pump_calibration import run_pump_calibration - - return run_pump_calibration(self.target_device) - - -class DurationBasedAltMediaPumpProtocol(CalibrationProtocol): - target_device = "alt_media_pump" - protocol_name = "duration_based" - - def run(self) -> structs.SimplePeristalticPumpCalibration: - from pioreactor.calibrations.pump_calibration import run_pump_calibration - - return run_pump_calibration(self.target_device) + def run(self, *args) -> structs.ODCalibration: + raise NotImplementedError("Not implemented yet") -class DurationBasedWasteMediaPumpProtocol(CalibrationProtocol): - target_device = "waste_pump" +class DurationBasedPumpProtocol(CalibrationProtocol): + target_device = ["media_pump", "alt_media_pump", "waste_pump"] protocol_name = "duration_based" - def run(self) -> structs.SimplePeristalticPumpCalibration: + def run(self, target_device: str) -> structs.SimplePeristalticPumpCalibration: from pioreactor.calibrations.pump_calibration import run_pump_calibration - return run_pump_calibration(self.target_device) + return run_pump_calibration(target_device) class DCBasedStirringProtocol(CalibrationProtocol): target_device = "stirring" protocol_name = "dc_based" - def run(self, min_dc: str | None = None, max_dc: str | None = None) -> structs.SimpleStirringCalibration: + def run( + self, target_device: str, min_dc: str | None = None, max_dc: str | None = None + ) -> structs.SimpleStirringCalibration: from pioreactor.calibrations.stirring_calibration import run_stirring_calibration return run_stirring_calibration( diff --git a/pioreactor/calibrations/od_calibration.py b/pioreactor/calibrations/od_calibration.py index ed8ddc64..f5bb8e35 100644 --- a/pioreactor/calibrations/od_calibration.py +++ b/pioreactor/calibrations/od_calibration.py @@ -179,42 +179,6 @@ def start_stirring(): return st -def plot_data( - x, - y, - title, - x_min=None, - x_max=None, - interpolation_curve=None, - highlight_recent_point=True, -): - import plotext as plt # type: ignore - - plt.clf() - - plt.scatter(x, y, marker="hd") - - if highlight_recent_point: - plt.scatter([x[-1]], [y[-1]], color=204, marker="hd") - - plt.theme("pro") - plt.title(title) - plt.xlabel("OD600") - plt.ylabel("OD Reading (Raw)") - - plt.plot_size(105, 22) - - if interpolation_curve: - plt.plot(sorted(x), [interpolation_curve(x_) for x_ in sorted(x)], color=204) - plt.plot_size(145, 26) - - plt.xlim(x_min, x_max) - plt.yfrequency(6) - plt.xfrequency(6) - - plt.show() - - def start_recording_and_diluting( st: Stirrer, initial_od600: pt.OD, @@ -283,10 +247,12 @@ def get_voltage_from_adc() -> pt.Voltage: for i in range(n_samples): clear() - plot_data( + utils.plot_data( inferred_od600s, voltages, title="OD Calibration (ongoing)", + x_label="OD600", + y_label="Voltage", x_min=minimum_od600, x_max=initial_od600, ) @@ -326,10 +292,12 @@ def get_voltage_from_adc() -> pt.Voltage: else: # executed if the loop did not break clear() - plot_data( + utils.plot_data( inferred_od600s, voltages, title="OD Calibration (ongoing)", + x_label="OD600", + y_label="Voltage", x_min=minimum_od600, x_max=initial_od600, ) @@ -350,10 +318,12 @@ def get_voltage_from_adc() -> pt.Voltage: sleep(1.0) clear() - plot_data( + utils.plot_data( inferred_od600s, voltages, title="OD Calibration (ongoing)", + x_label="OD600", + y_label="Voltage", x_min=minimum_od600, x_max=initial_od600, ) diff --git a/pioreactor/calibrations/utils.py b/pioreactor/calibrations/utils.py index f6d22c80..e623ac44 100644 --- a/pioreactor/calibrations/utils.py +++ b/pioreactor/calibrations/utils.py @@ -65,15 +65,18 @@ def curve_callable(x): raise NotImplementedError() -def linspace(start: float, stop: float, num: int = 50): - num = int(num) - start = start * 1.0 - stop = stop * 1.0 +def linspace(start: float, stop: float, num: int = 50) -> list[float]: + def linspace_(start: float, stop: float, num: int = 50): + num = int(num) + start = start * 1.0 + stop = stop * 1.0 - step = (stop - start) / (num - 1) + step = (stop - start) / (num - 1) - for i in range(num): - yield start + step * i + for i in range(num): + yield start + step * i + + return list(linspace_(start, stop, num)) def plot_data( @@ -93,7 +96,7 @@ def plot_data( if interpolation_curve: x_min, x_max = min(x) - 0.1, max(x) + 0.1 - xs = list(linspace(x_min, x_max, num=100)) + xs = linspace(x_min, x_max, num=100) ys = [interpolation_curve(x_) for x_ in xs] plt.plot(xs, ys, color=204) plt.plot_size(145, 26) diff --git a/pioreactor/cli/calibrations.py b/pioreactor/cli/calibrations.py index 15fc52c7..c5ba8007 100644 --- a/pioreactor/cli/calibrations.py +++ b/pioreactor/cli/calibrations.py @@ -92,8 +92,9 @@ def run_calibration(ctx, device: str, protocol_name: str | None, y: bool) -> Non # Dispatch to the assistant function for that device if protocol_name is None and device in DEFAULT_PROTOCOLS: protocol_name = DEFAULT_PROTOCOLS[device] + elif protocol_name is None: + raise ValueError("Must provide protocol name: --protocol-name ") - assert protocol_name is not None assistant = calibration_protocols.get((device, protocol_name)) if assistant is None: click.echo( @@ -104,6 +105,7 @@ def run_calibration(ctx, device: str, protocol_name: str | None, y: bool) -> Non # Run the assistant function to get the final calibration data calibration_struct = assistant().run( + target_device=device, **{ctx.args[i][2:].replace("-", "_"): ctx.args[i + 1] for i in range(0, len(ctx.args), 2)}, )