Skip to content

Commit e7635a4

Browse files
committed
revert back to 6454fd4
1 parent c1d7a16 commit e7635a4

File tree

1 file changed

+103
-172
lines changed

1 file changed

+103
-172
lines changed

src/emgdecompy/viz.py

Lines changed: 103 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
# Copyright (C) 2022 Daniel King, Jasmine Ortega, Rada Rudyak, Rowan Sivanandam
2+
# This script contains functions used to visualize the results of the
3+
# blind source separation algorithm based off of Francesco Negro et al 2016 J. Neural Eng. 13 026027.
4+
15
from codecs import raw_unicode_escape_decode
6+
import ipywidgets as widgets
27
import numpy as np
38
import pandas as pd
49
import altair as alt
510
import panel as pn
6-
from panel.interact import interact, interactive, fixed, interact_manual
7-
from panel import widgets
811
import math
912
from sklearn.metrics import mean_squared_error
1013
from emgdecompy.preprocessing import (
@@ -46,6 +49,50 @@ def RMSE(arr1, arr2):
4649
return RMSE
4750

4851

52+
def mismatch_score(mu_data, peak_data, mu_index, method=RMSE, channel=-1):
53+
"""
54+
Evaluates how well a given peak contributes to a given MUAP.
55+
This is called by muap_plot() function and is used to include error in the title of the muap plot.
56+
57+
Parameters
58+
----------
59+
mu_data: dict
60+
Dictionary containing MUAP shapes for each motor unit.
61+
peak_data: dict
62+
Dictionary containing shapes for a given peak per channel.
63+
mu_index: int
64+
Index of motor unit to examine
65+
method: function name
66+
Function to use for evaluating discrepency between mu_data and peak_data.
67+
Default: RMSE.
68+
channel: int
69+
Channel to run evaluation on.
70+
Defaul = -1 and it means average of all channels.
71+
72+
Returns
73+
-------
74+
float
75+
Root Mean Square Error of MU data vs Peak data.
76+
"""
77+
if channel == -1: # For all channels, we can just
78+
# straight up compare RMSE across the board
79+
mu_sig = mu_data[f"mu_{mu_index}"]["signal"]
80+
peak_sig = peak_data[f"mu_{mu_index}"]["signal"]
81+
score = RMSE(mu_sig, peak_sig)
82+
83+
else: # Otherwise, filter for a given channel
84+
# filter mu_data for signal data that channel
85+
indexes = np.where(mu_data[f"mu_{mu_index}"]["channel"] == channel)
86+
mu_sig = mu_data[f"mu_{mu_index}"]["signal"][indexes]
87+
88+
indexes = np.where(peak_data[f"mu_{mu_index}"]["channel"] == channel)
89+
peak_sig = peak_data[f"mu_{mu_index}"]["signal"][indexes]
90+
91+
score = RMSE(mu_sig, peak_sig)
92+
93+
return score
94+
95+
4996
def muap_dict(raw, pt, l=31):
5097
"""
5198
Returns multi-level dictionary containing sample number, average signal, and channel
@@ -173,51 +220,8 @@ def muap_dict_by_peak(raw, peak, mu_index=0, l=31):
173220

174221
return shape_dict
175222

176-
def mismatch_score(muap_dict, peak_dict, mu_index, method=RMSE, channel=-1):
177-
"""
178-
Evaluates how well a given peak contributes to a given MUAP.
179-
This is called by muap_plot() function and is used to include error in the title of the muap plot.
180-
181-
Parameters
182-
----------
183-
muap_dict: dict
184-
Dictionary containing MUAP shapes for each motor unit.
185-
peak_dict: dict
186-
Dictionary containing shapes for a given peak per channel.
187-
mu_index: int
188-
Index of motor unit to examine
189-
method: function name
190-
Function to use for evaluating discrepency between mu_dict and peak_dict.
191-
Default: RMSE.
192-
channel: int
193-
Channel to run evaluation on.
194-
Default = -1 and it means average of all channels.
195-
196-
Returns
197-
-------
198-
float
199-
Root Mean Square Error of MU data vs Peak data.
200-
"""
201-
if channel == -1: # For all channels, we can just
202-
# straight up compare RMSE across the board
203-
mu_sig = muap_dict[f"mu_{mu_index}"]["signal"]
204-
peak_sig = peak_dict[f"mu_{mu_index}"]["signal"]
205-
score = RMSE(mu_sig, peak_sig)
206-
207-
else: # Otherwise, filter for a given channel
208-
# filter mu_dict for signal data that channel
209-
indexes = np.where(muap_dict[f"mu_{mu_index}"]["channel"] == channel)
210-
mu_sig = muap_dict[f"mu_{mu_index}"]["signal"][indexes]
211-
212-
indexes = np.where(peak_dict[f"mu_{mu_index}"]["channel"] == channel)
213-
peak_sig = peak_dict[f"mu_{mu_index}"]["signal"][indexes]
214-
215-
score = RMSE(mu_sig, peak_sig)
216223

217-
return score
218-
219-
220-
def channel_preset(preset="standard"):
224+
def channel_preset(name="standard"):
221225
"""
222226
Returns a dictionary with two keys:
223227
'sort_order' with the list to order channels,
@@ -226,7 +230,7 @@ def channel_preset(preset="standard"):
226230
227231
Parameters
228232
----------
229-
preset: str
233+
name: str
230234
Name of the preset to use
231235
232236
Returns
@@ -239,7 +243,7 @@ def channel_preset(preset="standard"):
239243
240244
Examples
241245
--------
242-
>>> channel_preset(preset='vert63')
246+
>>> channel_preset(name='vert63')
243247
{
244248
'cols': 5,
245249
'sort_order': [
@@ -248,11 +252,11 @@ def channel_preset(preset="standard"):
248252
}
249253
"""
250254

251-
if preset == "standard":
255+
if name == "standard":
252256
sort_order = list(range(0, 64, 1))
253257
cols = 8
254258

255-
elif preset == "vert63":
259+
elif name == "vert63":
256260
sort_order = [
257261
63,
258262
38,
@@ -425,19 +429,15 @@ def muap_plot(
425429
def pulse_plot(pt, c_sq_mean, mu_index, sel_type="single"):
426430
"""
427431
Plot firings for a given motor unit.
428-
429432
Parameters
430433
----------
431-
pt: np.array
434+
pulse_train: np.array
432435
Pulse train.
433436
c_sq_mean: np.array
434437
Centered, squared and averaged firings over the duration of the trial.
435438
mu_index: int
436439
Motor Unit of interest to plot firings for.
437440
Default is None and means return all pulses.
438-
sel_type: str
439-
Whether to select single points or intervals
440-
441441
Returns
442442
-------
443443
altair plot object
@@ -581,10 +581,39 @@ def pulse_plot(pt, c_sq_mean, mu_index, sel_type="single"):
581581
return chart_top & chart_rate & chart_pulse
582582

583583

584-
def select_peak(
585-
selection, mu_index, raw, shape_dict, pt, preset="standard", method=RMSE
586-
):
584+
def create_widget_dd(options, value=0, desc="Motor Unit:", disabled=False):
587585
"""
586+
Create a dropdown widget.
587+
588+
Parameters
589+
----------
590+
options: list
591+
Options for the dropdown.
592+
value: int or str
593+
Original value to be selected.
594+
desc: str
595+
Description to be displayed above the widget.
596+
disabled: bool
597+
Whether the widget is disabled by default
598+
599+
Returns
600+
-------
601+
widget object: dropdown widget to be used in altair interactions.
602+
"""
603+
604+
widget = widgets.Dropdown(
605+
options=options,
606+
value=value,
607+
description=desc,
608+
disabled=disabled,
609+
)
610+
611+
return widget
612+
613+
614+
def select_peak(selection, mu_index, raw, shape_dict, pt):
615+
"""
616+
Interactivity function for the Firing plot.
588617
Retrieves a given peak (if any) and re-graphs MUAP plot via muap_plot() function.
589618
Called within dashboard() function, binded to the peak selection on pulse graphs.
590619
@@ -611,28 +640,17 @@ def select_peak(
611640
altair plot object
612641
613642
"""
614-
global selected_peak
615-
616643
if not selection:
617-
plot = muap_plot(shape_dict, mu_index, l=31, preset=preset, method=RMSE)
618-
selected_peak = -1
644+
plot = muap_plot(shape_dict, mu_index, l=31)
619645

620646
else:
621647
print(selection)
622648
sel = selection[0] - 1
623-
# for some reason beyond my grasp these are 1-indexed
649+
# for some reason beyond my grast these are 1-indexed
624650
peak = pt[mu_index][sel]
625651

626652
peak_data = muap_dict_by_peak(raw, peak, mu_index=mu_index, l=31)
627-
plot = muap_plot(
628-
shape_dict,
629-
mu_index,
630-
peak_data,
631-
l=31,
632-
peak=str(peak),
633-
preset=preset,
634-
method=RMSE,
635-
)
653+
plot = muap_plot(shape_dict, mu_index, peak_data, l=31, peak=str(peak))
636654

637655
return pn.Column(
638656
pn.Row(
@@ -662,16 +680,16 @@ def remove_false_peak(decomp_results, mu_index, peak):
662680
"""
663681

664682
decomp_results["MUPulses"] = list(decomp_results["MUPulses"])
665-
decomp_results["MUPulses"][mu_index] = np.delete(
666-
decomp_results["MUPulses"][mu_index],
667-
np.argwhere(decomp_results["MUPulses"][mu_index] == peak),
683+
decomp_results["MUPulses"][0][mu_index] = np.delete(
684+
decomp_results["MUPulses"][0][mu_index],
685+
np.argwhere(decomp_results["MUPulses"][0][mu_index][0] == peak),
668686
)
669687
decomp_results["MUPulses"] = np.array(decomp_results["MUPulses"], dtype=object)
670688

671689
return decomp_results
672690

673691

674-
def dashboard(decomp_results, raw, mu_index=0, preset="standard", method=RMSE):
692+
def dashboard(decomp_results, raw, mu_index=0):
675693
"""
676694
Parent function for creating interactive visual component of decomposition.
677695
Dashboard consists of four plots:
@@ -692,23 +710,10 @@ def dashboard(decomp_results, raw, mu_index=0, preset="standard", method=RMSE):
692710
mu_index: int
693711
Currently plotted Motor Unit.
694712
695-
method: function name
696-
Function to use for evaluating discrepency between mu_data and peak_data.
697-
Default: RMSE.
698-
699-
preset: str
700-
Name of the preset to use
701-
702713
Returns
703714
-------
704715
panel object containing interactive altair plots
705716
"""
706-
# A little hacky, because I don't know how to pass params to the button
707-
# Delete button uses these to pass preset and method to muap_plot
708-
global gl_preset
709-
global gl_method
710-
gl_preset = preset
711-
gl_method = method
712717

713718
signal = flatten_signal(raw)
714719
signal = np.apply_along_axis(
@@ -725,94 +730,20 @@ def dashboard(decomp_results, raw, mu_index=0, preset="standard", method=RMSE):
725730
c_sq_mean = c_sq.mean(axis=0)
726731

727732
pt = decomp_results["MUPulses"]
733+
# # from raw data
734+
# pt = raw_data["MUPulses"].squeeze()
728735

729736
shape_dict = muap_dict(raw, pt, l=31)
730737
pulse = pulse_plot(pt, c_sq_mean, mu_index, sel_type="interval")
731-
pulse_pn = pn.pane.Vega(pulse, debounce=10)
732-
mu_charts_pn = pn.bind(
733-
select_peak,
734-
pulse_pn.selection.param.sel_peak,
735-
mu_index,
736-
raw,
737-
shape_dict,
738-
pt,
739-
preset,
740-
method,
741-
)
742-
743-
button_del = pn.widgets.Button(
744-
name="Delete Selected Peak", button_type="primary", width=50
745-
)
746-
button_del.on_click(b_click)
747-
748-
res = pn.Column(
749-
button_del,
750-
pulse_pn,
751-
mu_charts_pn,
752-
)
753-
754-
return res
755-
756-
757-
def b_click(event):
758-
"""
759-
Function triggered by clicking "Delete Selected Peak" button on the dashboard
760-
Bound to the button widget inside dashboard() function
761-
Deletes selected peak from the output variable and reruns the dashboard
762-
763-
Parameters
764-
----------
765-
event: event
766-
event that triggered the funciton
767-
768-
Returns
769-
-------
770-
Null
771-
"""
772-
if selected_peak > -1:
773-
774-
# Get the peak and the selected MU index
775-
###############################
776-
peak = dash_p[1][0][1].object.data.iloc[selected_peak]["Pulse"]
777-
mu_index = dash_p[0][0].value
778-
779-
# Change decomp_results:
780-
###############################
781-
global output
782-
output = remove_false_peak(output, mu_index, peak)
783-
784-
# Reconstruct the plot:
785-
###############################
786-
raw = raw_data_dict["SIG"]
787-
decomp_results = output
788-
signal = flatten_signal(raw)
789-
signal = np.apply_along_axis(
790-
butter_bandpass_filter,
791-
axis=1,
792-
arr=signal,
793-
lowcut=10,
794-
highcut=900,
795-
fs=2048,
796-
order=6,
797-
)
798-
centered = center_matrix(signal)
799-
c_sq = centered ** 2
800-
c_sq_mean = c_sq.mean(axis=0)
801-
pt = decomp_results["MUPulses"]
802-
shape_dict = muap_dict(raw, pt, l=31)
803-
pulse = pulse_plot(pt, c_sq_mean, mu_index, sel_type="interval")
804-
pulse_pn = pn.pane.Vega(pulse, debounce=10)
805-
dash_p[1][0][1] = pulse_pn
806-
807-
# Also redo mu_charts graph so that it no longer selects the deleted peak:
808-
mu_charts_pn = pn.bind(
738+
vega_pane = pn.pane.Vega(pulse, debounce=10)
739+
return pn.Column(
740+
vega_pane,
741+
pn.bind(
809742
select_peak,
810-
pulse_pn.selection.param.sel_peak,
743+
vega_pane.selection.param.sel_peak,
811744
mu_index,
812745
raw,
813746
shape_dict,
814747
pt,
815-
preset=gl_preset,
816-
method=gl_method,
817-
)
818-
dash_p[1][0][2] = mu_charts_pn
748+
),
749+
)

0 commit comments

Comments
 (0)