Skip to content

Analog threshold improvements #141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions devices/rotary_encoder.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@ def __init__(
falling_event=None,
bytes_per_sample=2,
reverse=False,
triggers=None,
):
assert output in ("velocity", "position"), "ouput argument must be 'velocity' or 'position'."
assert bytes_per_sample in (2, 4), "bytes_per_sample must be 2 or 4"
@@ -28,6 +29,7 @@ def __init__(
self.position = 0
self.velocity = 0
self.sampling_rate = sampling_rate

Analog_input.__init__(
self,
None,
@@ -37,6 +39,7 @@ def __init__(
rising_event,
falling_event,
data_type={2: "h", 4: "i"}[bytes_per_sample],
triggers=triggers,
)

def read_sample(self):
108 changes: 108 additions & 0 deletions devices/schmitt_trigger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from pyControl.hardware import IO_object, assign_ID, interrupt_queue
import pyControl.framework as fw
import pyControl.state_machine as sm


class Crossing:
above = "above"
below = "below"
none = "none"


class SchmittTrigger(IO_object):
"""
Generates framework events when an analog signal goes above an upper threshold and/or below a lower threshold.
The rising event is triggered when signal > upper bound, falling event is triggered when signal < lower bound.

This trigger implements hysteresis, which is a technique to prevent rapid oscillations or "bouncing" of events:
- Hysteresis creates a "dead zone" between the upper and lower thresholds
- Once a rising event is triggered (when signal crosses above the upper bound),
it cannot be triggered again until the signal has fallen below the lower bound
- Similarly, once a falling event is triggered (when signal crosses below the lower bound),
it cannot be triggered again until the signal has risen above the upper bound

This behavior is particularly useful for noisy signals that might otherwise rapidly cross a single threshold
multiple times, generating unwanted repeated events.
"""

def __init__(self, bounds, rising_event=None, falling_event=None):
if rising_event is None and falling_event is None:
raise ValueError("Either rising_event or falling_event or both must be specified.")
self.rising_event = rising_event
self.falling_event = falling_event
self.bounds = bounds
self.timestamp = 0
assign_ID(self)

def run_start(self):
self.set_bounds(self.bounds)

def set_bounds(self, threshold):
if isinstance(threshold, tuple):
threshold_requirements_str = "The threshold must be a tuple of two integers (lower_bound, upper_bound) where lower_bound <= upper_bound."
if len(threshold) != 2:
raise ValueError("{} is not a valid threshold. {}".format(threshold, threshold_requirements_str))
lower, upper = threshold
if not upper >= lower:
raise ValueError(
"{} is not a valid threshold because the lower bound {} is greater than the upper bound {}. {}".format(
threshold, lower, upper, threshold_requirements_str
)
)
self.upper_threshold = upper
self.lower_threshold = lower
else:
raise ValueError("{} is not a valid threshold. {}".format(threshold, threshold_requirements_str))
self.reset_crossing = True

content = {"bounds": (self.lower_threshold, self.upper_threshold)}
if self.rising_event is not None:
content["rising_event"] = self.rising_event
if self.falling_event is not None:
content["falling_event"] = self.falling_event
fw.data_output_queue.put(
fw.Datatuple(
fw.current_time,
fw.THRSH_TYP,
"s",
str(content),
)
)

def _initialise(self):
# Set event codes for rising and falling events.
self.rising_event_ID = sm.events[self.rising_event] if self.rising_event in sm.events else False
self.falling_event_ID = sm.events[self.falling_event] if self.falling_event in sm.events else False
self.threshold_active = self.rising_event_ID or self.falling_event_ID

def _process_interrupt(self):
# Put event generated by threshold crossing in event queue.
if self.was_above:
fw.event_queue.put(fw.Datatuple(self.timestamp, fw.EVENT_TYP, "i", self.rising_event_ID))
else:
fw.event_queue.put(fw.Datatuple(self.timestamp, fw.EVENT_TYP, "i", self.falling_event_ID))

@micropython.native
def check(self, sample):
if self.reset_crossing:
# this gets run when the first sample is taken and whenever the threshold is changed
self.reset_crossing = False
self.was_above = sample > self.upper_threshold
self.was_below = sample < self.lower_threshold
self.last_crossing = Crossing.none
return
is_above_threshold = sample > self.upper_threshold
is_below_threshold = sample < self.lower_threshold

if is_above_threshold and not self.was_above and self.last_crossing != Crossing.above:
self.timestamp = fw.current_time
self.last_crossing = Crossing.above
if self.rising_event_ID:
interrupt_queue.put(self.ID)
elif is_below_threshold and not self.was_below and self.last_crossing != Crossing.below:
self.timestamp = fw.current_time
self.last_crossing = Crossing.below
if self.falling_event_ID:
interrupt_queue.put(self.ID)

self.was_above, self.was_below = is_above_threshold, is_below_threshold
4 changes: 3 additions & 1 deletion source/communication/data_logger.py
Original file line number Diff line number Diff line change
@@ -73,7 +73,7 @@ def write_info_line(self, subtype, content, time=0):
self.data_file.write(self.tsv_row_str("info", time, subtype, content))

def tsv_row_str(self, rtype, time, subtype="", content=""):
time_str = f"{time/1000:.3f}" if isinstance(time, int) else time
time_str = f"{time / 1000:.3f}" if isinstance(time, int) else time
return f"{time_str}\t{rtype}\t{subtype}\t{content}\n"

def copy_task_file(self, data_dir, tasks_dir, dir_name="task_files"):
@@ -140,6 +140,8 @@ def data_to_string(self, new_data, prettify=False, max_len=60):
var_str += f'\t\t\t"{var_name}": {var_value}\n'
var_str += "\t\t\t}"
data_string += self.tsv_row_str("variable", time, nd.subtype, content=var_str)
elif nd.type == MsgType.THRSH: # Threshold
data_string += self.tsv_row_str("threshold", time, nd.subtype, content=nd.content)
elif nd.type == MsgType.WARNG: # Warning
data_string += self.tsv_row_str("warning", time, content=nd.content)
elif nd.type in (MsgType.ERROR, MsgType.STOPF): # Error or stop framework.
5 changes: 5 additions & 0 deletions source/communication/message.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@ class MsgType(Enum):
ERROR = b"!!" # Error
STOPF = b"X" # Stop framework
ANLOG = b"A" # Analog
THRSH = b"T" # Threshold

@classmethod
def from_byte(cls, byte_value):
@@ -51,5 +52,9 @@ def get_subtype(self, subtype_char):
"t": "task",
"a": "api",
"u": "user",
"s": "trigger",
},
MsgType.THRSH: {
"s": "set",
},
}[self][subtype_char]
2 changes: 1 addition & 1 deletion source/communication/pycboard.py
Original file line number Diff line number Diff line change
@@ -492,7 +492,7 @@ def process_data(self):
self.timestamp = msg_timestamp
if msg_type in (MsgType.EVENT, MsgType.STATE):
content = int(content_bytes.decode()) # Event/state ID.
elif msg_type in (MsgType.PRINT, MsgType.WARNG):
elif msg_type in (MsgType.PRINT, MsgType.WARNG, MsgType.THRSH):
content = content_bytes.decode() # Print or error string.
elif msg_type == MsgType.VARBL:
content = content_bytes.decode() # JSON string
1 change: 1 addition & 0 deletions source/pyControl/framework.py
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@ class pyControlError(BaseException): # Exception for pyControl errors.
VARBL_TYP = b"V" # Variable change : (time, VARBL_TYP, [g]et/user_[s]et/[a]pi_set/[p]rint/s[t]art/[e]nd, json_str)
WARNG_TYP = b"!" # Warning : (time, WARNG_TYP, "", print_string)
STOPF_TYP = b"X" # Stop framework : (time, STOPF_TYP, "", "")
THRSH_TYP = b"T" # Threshold : (time, THRSH_TYP, [s]et)

# Event_queue -----------------------------------------------------------------

90 changes: 66 additions & 24 deletions source/pyControl/hardware.py
Original file line number Diff line number Diff line change
@@ -235,25 +235,35 @@ class Analog_input(IO_object):
# streams data to computer. Optionally can generate framework events when voltage
# goes above / below specified value theshold.

def __init__(self, pin, name, sampling_rate, threshold=None, rising_event=None, falling_event=None, data_type="H"):
if rising_event or falling_event:
self.threshold = Analog_threshold(threshold, rising_event, falling_event)
else:
self.threshold = False
def __init__(
self,
pin,
name,
sampling_rate,
threshold=None,
rising_event=None,
falling_event=None,
data_type="H",
triggers=None,
):
self.triggers = triggers if triggers is not None else []
if threshold is not None:
self.triggers.append(Analog_threshold(threshold, rising_event, falling_event))

self.timer = pyb.Timer(available_timers.pop())
if pin: # pin argument can be None when Analog_input subclassed.
self.ADC = pyb.ADC(pin)
self.read_sample = self.ADC.read
self.name = name
self.Analog_channel = Analog_channel(name, sampling_rate, data_type)
self.channel = Analog_channel(name, sampling_rate, data_type)
assign_ID(self)

def _run_start(self):
# Start sampling timer, initialise threshold, aquire first sample.
self.timer.init(freq=self.Analog_channel.sampling_rate)
self.timer.init(freq=self.channel.sampling_rate)
self.timer.callback(self._timer_ISR)
if self.threshold:
self.threshold.run_start(self.read_sample())
for trigger in self.triggers:
trigger.run_start()
self._timer_ISR(0)

def _run_stop(self):
@@ -263,9 +273,10 @@ def _run_stop(self):
def _timer_ISR(self, t):
# Read a sample to the buffer, update write index.
sample = self.read_sample()
self.Analog_channel.put(sample)
if self.threshold:
self.threshold.check(sample)
self.channel.put(sample)
if self.triggers:
for trigger in self.triggers:
trigger.check(sample)

def record(self): # For backward compatibility.
pass
@@ -286,15 +297,21 @@ class Analog_channel(IO_object):
# data array bytes (variable)

def __init__(self, name, sampling_rate, data_type, plot=True):
assert data_type in ("b", "B", "h", "H", "i", "I"), "Invalid data_type."
assert not any(
[name == io.name for io in IO_dict.values() if isinstance(io, Analog_channel)]
), "Analog signals must have unique names."
if data_type not in ("b", "B", "h", "H", "i", "I"):
raise ValueError("Invalid data_type.")
if any([name == io.name for io in IO_dict.values() if isinstance(io, Analog_channel)]):
raise ValueError(
"Analog signals must have unique names.{} {}".format(
name, [io.name for io in IO_dict.values() if isinstance(io, Analog_channel)]
)
)

self.name = name
assign_ID(self)
self.sampling_rate = sampling_rate
self.data_type = data_type
self.plot = plot

self.bytes_per_sample = {"b": 1, "B": 1, "h": 2, "H": 2, "i": 4, "I": 4}[data_type]
self.buffer_size = max(4, min(256 // self.bytes_per_sample, sampling_rate // 10))
self.buffers = (array(data_type, [0] * self.buffer_size), array(data_type, [0] * self.buffer_size))
@@ -345,15 +362,14 @@ def send_buffer(self, run_stop=False):


class Analog_threshold(IO_object):
# Generates framework events when an analog signal goes above or below specified threshold.
# Generates framework events when an analog signal goes above or below specified threshold value.

def __init__(self, threshold=None, rising_event=None, falling_event=None):
assert isinstance(
threshold, int
), "Integer threshold must be specified if rising or falling events are defined."
self.threshold = threshold
def __init__(self, threshold, rising_event=None, falling_event=None):
if rising_event is None and falling_event is None:
raise ValueError("Either rising_event or falling_event or both must be specified.")
self.rising_event = rising_event
self.falling_event = falling_event
self.threshold = threshold
self.timestamp = 0
self.crossing_direction = False
assign_ID(self)
@@ -364,8 +380,8 @@ def _initialise(self):
self.falling_event_ID = sm.events[self.falling_event] if self.falling_event in sm.events else False
self.threshold_active = self.rising_event_ID or self.falling_event_ID

def run_start(self, sample):
self.above_threshold = sample > self.threshold
def run_start(self):
self.set_threshold(self.threshold)

def _process_interrupt(self):
# Put event generated by threshold crossing in event queue.
@@ -376,14 +392,40 @@ def _process_interrupt(self):

@micropython.native
def check(self, sample):
if self.reset_above_threshold:
# this gets run when the first sample is taken and whenever the threshold is changed
self.reset_above_threshold = False
self.above_threshold = sample > self.threshold
return
new_above_threshold = sample > self.threshold
if new_above_threshold != self.above_threshold: # Threshold crossing.
self.above_threshold = new_above_threshold
if (self.above_threshold and self.rising_event_ID) or (not self.above_threshold and self.falling_event_ID):
self.timestamp = fw.current_time
self.crossing_direction = self.above_threshold

interrupt_queue.put(self.ID)

def set_threshold(self, threshold):
if not isinstance(threshold, int):
raise ValueError(f"Threshold must be an integer, got {type(threshold).__name__}.")
self.threshold = threshold
self.reset_above_threshold = True

content = {"value": self.threshold}
if self.rising_event is not None:
content["rising_event"] = self.rising_event
if self.falling_event is not None:
content["falling_event"] = self.falling_event
fw.data_output_queue.put(
fw.Datatuple(
fw.current_time,
fw.THRSH_TYP,
"s",
str(content),
)
)


# Digital Output --------------------------------------------------------------

26 changes: 22 additions & 4 deletions tasks/example/running_wheel.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
# Example of using a rotary encoder to measure running speed and trigger events when
# running starts and stops. The subject must run for 10 seconds to trigger reward delivery,
# then stop running for 5 seconds to initiate the next trial.
# If while running the subject exceeds a bonus velocity threshold, they earn a bonus
# and the reward duration is extended by a bonus duration.

from pyControl.utility import *
from devices import *
from pyControl.hardware import Analog_threshold

# Variables.

v.run_time = 10 * second # Time subject must run to obtain reward.
v.stop_time = 5 * second # Time subject must stop running to intiate the next trial.
v.reward_duration = 100 * ms # Time reward solenoid is open for.
v.velocity_threshold = 100 # Minimum encoder velocity treated as running (encoder counts/second).
v.bonus_velocity_threshold = 5000 # Encoder velocity that triggers bonus reward (encoder counts/second).
v.give_bonus = False # Whether to give bonus reward.
v.bonus_reward_duration = 50 * ms # Time to add to reward duration if bonus is earned.

running_trigger = Analog_threshold(
threshold=v.velocity_threshold,
rising_event="started_running",
falling_event="stopped_running",
)

bonus_trigger = Analog_threshold(
threshold=v.bonus_velocity_threshold,
rising_event="bonus_earned",
)
# Instantiate hardware - would normally be in a seperate hardware definition file.

board = Breakout_1_2() # Breakout board.
@@ -21,9 +37,7 @@
name="running_wheel",
sampling_rate=100,
output="velocity",
threshold=v.velocity_threshold,
rising_event="started_running",
falling_event="stopped_running",
triggers=[running_trigger, bonus_trigger],
) # Running wheel must be plugged into port 1 of breakout board.

solenoid = Digital_output(board.port_2.POW_A) # Reward delivery solenoid.
@@ -40,6 +54,7 @@
events = [
"started_running",
"stopped_running",
"bonus_earned",
"run_timer",
"stopped_timer",
"reward_timer",
@@ -70,18 +85,21 @@ def running_for_reward(event):
# If subject runs for long enough go to reward state.
# If subject stops go back to trial start.
if event == "entry":
v.give_bonus = False
set_timer("run_timer", v.run_time)
elif event == "stopped_running":
disarm_timer("run_timer")
goto_state("trial_start")
elif event == "bonus_earned":
v.give_bonus = True
elif event == "run_timer":
goto_state("reward")


def reward(event):
# Deliver reward then go to inter trial interval.
if event == "entry":
timed_goto_state("inter_trial_interval", v.reward_duration)
timed_goto_state("inter_trial_interval", v.reward_duration + v.bonus_reward_duration * v.give_bonus)
solenoid.on()
elif event == "exit":
solenoid.off()