Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions libemg/emg_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import re
from matplotlib.animation import FuncAnimation
from functools import partial
from typing import Callable

from libemg.utils import get_windows
from libemg.environments.controllers import RegressorController, ClassifierController
Expand Down Expand Up @@ -307,21 +308,34 @@ def add_majority_vote(self, num_samples=5):

def add_velocity(self, train_windows, train_labels,
velocity_metric_handle = None,
velocity_mapping_handle = None):
velocity_mapping_handle: str | None | Callable[[int], int] = None):
"""Adds velocity (i.e., proportional) control where a multiplier is generated for the level of contraction intensity.

Note, that when using this optional, ramp contractions should be captured for training.

Parameters
-----------
train_windows: np.ndarray
The training windows extracted from the offline data handler.
train_labels: np.ndarray
The labels associated with the train windows. Allows for per class proportional control mapping.
velocity_mapping_handle: function or a string (valid options: "SIGMOID", "SQUARED", "LOG", "RELU")
A function that maps the proportionality (bounded between 0-1) to some other value.
"""
self.velocity_metric_handle = velocity_metric_handle
self.velocity_mapping_handle = velocity_mapping_handle
if isinstance(velocity_mapping_handle, str):
if not velocity_mapping_handle in ['SIGMOID', 'SQUARED', 'LOG', 'RELU']:
print("Invalid velocity mapping... Defaulting to linear.")
else:
if velocity_mapping_handle == 'SQUARED':
self.velocity_mapping_handle = lambda x: x**2
# TODO: Fill out rest
else:
self.velocity_metric_handle = velocity_metric_handle
self.velocity_mapping_handle = velocity_mapping_handle

self.velocity = True

self.th_min_dic, self.th_max_dic = self._set_up_velocity_control(train_windows, train_labels)



'''
---------------------- Private Helper Functions ----------------------
Expand Down
Loading