1
+ # Copyright (C) 2022 Daniel King, Jasmine Ortega, Rada Rudyak, Rowan Sivanandam
2
+ # This script contains functions used to visualize the results of the
3
+ # blind source separation algorithm based off of Francesco Negro et al 2016 J. Neural Eng. 13 026027.
4
+
1
5
from codecs import raw_unicode_escape_decode
6
+ import ipywidgets as widgets
2
7
import numpy as np
3
8
import pandas as pd
4
9
import altair as alt
5
10
import panel as pn
6
- from panel .interact import interact , interactive , fixed , interact_manual
7
- from panel import widgets
8
11
import math
9
12
from sklearn .metrics import mean_squared_error
10
13
from emgdecompy .preprocessing import (
@@ -46,6 +49,50 @@ def RMSE(arr1, arr2):
46
49
return RMSE
47
50
48
51
52
+ def mismatch_score (mu_data , peak_data , mu_index , method = RMSE , channel = - 1 ):
53
+ """
54
+ Evaluates how well a given peak contributes to a given MUAP.
55
+ This is called by muap_plot() function and is used to include error in the title of the muap plot.
56
+
57
+ Parameters
58
+ ----------
59
+ mu_data: dict
60
+ Dictionary containing MUAP shapes for each motor unit.
61
+ peak_data: dict
62
+ Dictionary containing shapes for a given peak per channel.
63
+ mu_index: int
64
+ Index of motor unit to examine
65
+ method: function name
66
+ Function to use for evaluating discrepency between mu_data and peak_data.
67
+ Default: RMSE.
68
+ channel: int
69
+ Channel to run evaluation on.
70
+ Defaul = -1 and it means average of all channels.
71
+
72
+ Returns
73
+ -------
74
+ float
75
+ Root Mean Square Error of MU data vs Peak data.
76
+ """
77
+ if channel == - 1 : # For all channels, we can just
78
+ # straight up compare RMSE across the board
79
+ mu_sig = mu_data [f"mu_{ mu_index } " ]["signal" ]
80
+ peak_sig = peak_data [f"mu_{ mu_index } " ]["signal" ]
81
+ score = RMSE (mu_sig , peak_sig )
82
+
83
+ else : # Otherwise, filter for a given channel
84
+ # filter mu_data for signal data that channel
85
+ indexes = np .where (mu_data [f"mu_{ mu_index } " ]["channel" ] == channel )
86
+ mu_sig = mu_data [f"mu_{ mu_index } " ]["signal" ][indexes ]
87
+
88
+ indexes = np .where (peak_data [f"mu_{ mu_index } " ]["channel" ] == channel )
89
+ peak_sig = peak_data [f"mu_{ mu_index } " ]["signal" ][indexes ]
90
+
91
+ score = RMSE (mu_sig , peak_sig )
92
+
93
+ return score
94
+
95
+
49
96
def muap_dict (raw , pt , l = 31 ):
50
97
"""
51
98
Returns multi-level dictionary containing sample number, average signal, and channel
@@ -173,51 +220,8 @@ def muap_dict_by_peak(raw, peak, mu_index=0, l=31):
173
220
174
221
return shape_dict
175
222
176
- def mismatch_score (muap_dict , peak_dict , mu_index , method = RMSE , channel = - 1 ):
177
- """
178
- Evaluates how well a given peak contributes to a given MUAP.
179
- This is called by muap_plot() function and is used to include error in the title of the muap plot.
180
-
181
- Parameters
182
- ----------
183
- muap_dict: dict
184
- Dictionary containing MUAP shapes for each motor unit.
185
- peak_dict: dict
186
- Dictionary containing shapes for a given peak per channel.
187
- mu_index: int
188
- Index of motor unit to examine
189
- method: function name
190
- Function to use for evaluating discrepency between mu_dict and peak_dict.
191
- Default: RMSE.
192
- channel: int
193
- Channel to run evaluation on.
194
- Default = -1 and it means average of all channels.
195
-
196
- Returns
197
- -------
198
- float
199
- Root Mean Square Error of MU data vs Peak data.
200
- """
201
- if channel == - 1 : # For all channels, we can just
202
- # straight up compare RMSE across the board
203
- mu_sig = muap_dict [f"mu_{ mu_index } " ]["signal" ]
204
- peak_sig = peak_dict [f"mu_{ mu_index } " ]["signal" ]
205
- score = RMSE (mu_sig , peak_sig )
206
-
207
- else : # Otherwise, filter for a given channel
208
- # filter mu_dict for signal data that channel
209
- indexes = np .where (muap_dict [f"mu_{ mu_index } " ]["channel" ] == channel )
210
- mu_sig = muap_dict [f"mu_{ mu_index } " ]["signal" ][indexes ]
211
-
212
- indexes = np .where (peak_dict [f"mu_{ mu_index } " ]["channel" ] == channel )
213
- peak_sig = peak_dict [f"mu_{ mu_index } " ]["signal" ][indexes ]
214
-
215
- score = RMSE (mu_sig , peak_sig )
216
223
217
- return score
218
-
219
-
220
- def channel_preset (preset = "standard" ):
224
+ def channel_preset (name = "standard" ):
221
225
"""
222
226
Returns a dictionary with two keys:
223
227
'sort_order' with the list to order channels,
@@ -226,7 +230,7 @@ def channel_preset(preset="standard"):
226
230
227
231
Parameters
228
232
----------
229
- preset : str
233
+ name : str
230
234
Name of the preset to use
231
235
232
236
Returns
@@ -239,7 +243,7 @@ def channel_preset(preset="standard"):
239
243
240
244
Examples
241
245
--------
242
- >>> channel_preset(preset ='vert63')
246
+ >>> channel_preset(name ='vert63')
243
247
{
244
248
'cols': 5,
245
249
'sort_order': [
@@ -248,11 +252,11 @@ def channel_preset(preset="standard"):
248
252
}
249
253
"""
250
254
251
- if preset == "standard" :
255
+ if name == "standard" :
252
256
sort_order = list (range (0 , 64 , 1 ))
253
257
cols = 8
254
258
255
- elif preset == "vert63" :
259
+ elif name == "vert63" :
256
260
sort_order = [
257
261
63 ,
258
262
38 ,
@@ -425,19 +429,15 @@ def muap_plot(
425
429
def pulse_plot (pt , c_sq_mean , mu_index , sel_type = "single" ):
426
430
"""
427
431
Plot firings for a given motor unit.
428
-
429
432
Parameters
430
433
----------
431
- pt : np.array
434
+ pulse_train : np.array
432
435
Pulse train.
433
436
c_sq_mean: np.array
434
437
Centered, squared and averaged firings over the duration of the trial.
435
438
mu_index: int
436
439
Motor Unit of interest to plot firings for.
437
440
Default is None and means return all pulses.
438
- sel_type: str
439
- Whether to select single points or intervals
440
-
441
441
Returns
442
442
-------
443
443
altair plot object
@@ -581,10 +581,39 @@ def pulse_plot(pt, c_sq_mean, mu_index, sel_type="single"):
581
581
return chart_top & chart_rate & chart_pulse
582
582
583
583
584
- def select_peak (
585
- selection , mu_index , raw , shape_dict , pt , preset = "standard" , method = RMSE
586
- ):
584
+ def create_widget_dd (options , value = 0 , desc = "Motor Unit:" , disabled = False ):
587
585
"""
586
+ Create a dropdown widget.
587
+
588
+ Parameters
589
+ ----------
590
+ options: list
591
+ Options for the dropdown.
592
+ value: int or str
593
+ Original value to be selected.
594
+ desc: str
595
+ Description to be displayed above the widget.
596
+ disabled: bool
597
+ Whether the widget is disabled by default
598
+
599
+ Returns
600
+ -------
601
+ widget object: dropdown widget to be used in altair interactions.
602
+ """
603
+
604
+ widget = widgets .Dropdown (
605
+ options = options ,
606
+ value = value ,
607
+ description = desc ,
608
+ disabled = disabled ,
609
+ )
610
+
611
+ return widget
612
+
613
+
614
+ def select_peak (selection , mu_index , raw , shape_dict , pt ):
615
+ """
616
+ Interactivity function for the Firing plot.
588
617
Retrieves a given peak (if any) and re-graphs MUAP plot via muap_plot() function.
589
618
Called within dashboard() function, binded to the peak selection on pulse graphs.
590
619
@@ -611,28 +640,17 @@ def select_peak(
611
640
altair plot object
612
641
613
642
"""
614
- global selected_peak
615
-
616
643
if not selection :
617
- plot = muap_plot (shape_dict , mu_index , l = 31 , preset = preset , method = RMSE )
618
- selected_peak = - 1
644
+ plot = muap_plot (shape_dict , mu_index , l = 31 )
619
645
620
646
else :
621
647
print (selection )
622
648
sel = selection [0 ] - 1
623
- # for some reason beyond my grasp these are 1-indexed
649
+ # for some reason beyond my grast these are 1-indexed
624
650
peak = pt [mu_index ][sel ]
625
651
626
652
peak_data = muap_dict_by_peak (raw , peak , mu_index = mu_index , l = 31 )
627
- plot = muap_plot (
628
- shape_dict ,
629
- mu_index ,
630
- peak_data ,
631
- l = 31 ,
632
- peak = str (peak ),
633
- preset = preset ,
634
- method = RMSE ,
635
- )
653
+ plot = muap_plot (shape_dict , mu_index , peak_data , l = 31 , peak = str (peak ))
636
654
637
655
return pn .Column (
638
656
pn .Row (
@@ -662,16 +680,16 @@ def remove_false_peak(decomp_results, mu_index, peak):
662
680
"""
663
681
664
682
decomp_results ["MUPulses" ] = list (decomp_results ["MUPulses" ])
665
- decomp_results ["MUPulses" ][mu_index ] = np .delete (
666
- decomp_results ["MUPulses" ][mu_index ],
667
- np .argwhere (decomp_results ["MUPulses" ][mu_index ] == peak ),
683
+ decomp_results ["MUPulses" ][0 ][ mu_index ] = np .delete (
684
+ decomp_results ["MUPulses" ][0 ][ mu_index ],
685
+ np .argwhere (decomp_results ["MUPulses" ][0 ][ mu_index ][ 0 ] == peak ),
668
686
)
669
687
decomp_results ["MUPulses" ] = np .array (decomp_results ["MUPulses" ], dtype = object )
670
688
671
689
return decomp_results
672
690
673
691
674
- def dashboard (decomp_results , raw , mu_index = 0 , preset = "standard" , method = RMSE ):
692
+ def dashboard (decomp_results , raw , mu_index = 0 ):
675
693
"""
676
694
Parent function for creating interactive visual component of decomposition.
677
695
Dashboard consists of four plots:
@@ -692,23 +710,10 @@ def dashboard(decomp_results, raw, mu_index=0, preset="standard", method=RMSE):
692
710
mu_index: int
693
711
Currently plotted Motor Unit.
694
712
695
- method: function name
696
- Function to use for evaluating discrepency between mu_data and peak_data.
697
- Default: RMSE.
698
-
699
- preset: str
700
- Name of the preset to use
701
-
702
713
Returns
703
714
-------
704
715
panel object containing interactive altair plots
705
716
"""
706
- # A little hacky, because I don't know how to pass params to the button
707
- # Delete button uses these to pass preset and method to muap_plot
708
- global gl_preset
709
- global gl_method
710
- gl_preset = preset
711
- gl_method = method
712
717
713
718
signal = flatten_signal (raw )
714
719
signal = np .apply_along_axis (
@@ -725,94 +730,20 @@ def dashboard(decomp_results, raw, mu_index=0, preset="standard", method=RMSE):
725
730
c_sq_mean = c_sq .mean (axis = 0 )
726
731
727
732
pt = decomp_results ["MUPulses" ]
733
+ # # from raw data
734
+ # pt = raw_data["MUPulses"].squeeze()
728
735
729
736
shape_dict = muap_dict (raw , pt , l = 31 )
730
737
pulse = pulse_plot (pt , c_sq_mean , mu_index , sel_type = "interval" )
731
- pulse_pn = pn .pane .Vega (pulse , debounce = 10 )
732
- mu_charts_pn = pn .bind (
733
- select_peak ,
734
- pulse_pn .selection .param .sel_peak ,
735
- mu_index ,
736
- raw ,
737
- shape_dict ,
738
- pt ,
739
- preset ,
740
- method ,
741
- )
742
-
743
- button_del = pn .widgets .Button (
744
- name = "Delete Selected Peak" , button_type = "primary" , width = 50
745
- )
746
- button_del .on_click (b_click )
747
-
748
- res = pn .Column (
749
- button_del ,
750
- pulse_pn ,
751
- mu_charts_pn ,
752
- )
753
-
754
- return res
755
-
756
-
757
- def b_click (event ):
758
- """
759
- Function triggered by clicking "Delete Selected Peak" button on the dashboard
760
- Bound to the button widget inside dashboard() function
761
- Deletes selected peak from the output variable and reruns the dashboard
762
-
763
- Parameters
764
- ----------
765
- event: event
766
- event that triggered the funciton
767
-
768
- Returns
769
- -------
770
- Null
771
- """
772
- if selected_peak > - 1 :
773
-
774
- # Get the peak and the selected MU index
775
- ###############################
776
- peak = dash_p [1 ][0 ][1 ].object .data .iloc [selected_peak ]["Pulse" ]
777
- mu_index = dash_p [0 ][0 ].value
778
-
779
- # Change decomp_results:
780
- ###############################
781
- global output
782
- output = remove_false_peak (output , mu_index , peak )
783
-
784
- # Reconstruct the plot:
785
- ###############################
786
- raw = raw_data_dict ["SIG" ]
787
- decomp_results = output
788
- signal = flatten_signal (raw )
789
- signal = np .apply_along_axis (
790
- butter_bandpass_filter ,
791
- axis = 1 ,
792
- arr = signal ,
793
- lowcut = 10 ,
794
- highcut = 900 ,
795
- fs = 2048 ,
796
- order = 6 ,
797
- )
798
- centered = center_matrix (signal )
799
- c_sq = centered ** 2
800
- c_sq_mean = c_sq .mean (axis = 0 )
801
- pt = decomp_results ["MUPulses" ]
802
- shape_dict = muap_dict (raw , pt , l = 31 )
803
- pulse = pulse_plot (pt , c_sq_mean , mu_index , sel_type = "interval" )
804
- pulse_pn = pn .pane .Vega (pulse , debounce = 10 )
805
- dash_p [1 ][0 ][1 ] = pulse_pn
806
-
807
- # Also redo mu_charts graph so that it no longer selects the deleted peak:
808
- mu_charts_pn = pn .bind (
738
+ vega_pane = pn .pane .Vega (pulse , debounce = 10 )
739
+ return pn .Column (
740
+ vega_pane ,
741
+ pn .bind (
809
742
select_peak ,
810
- pulse_pn .selection .param .sel_peak ,
743
+ vega_pane .selection .param .sel_peak ,
811
744
mu_index ,
812
745
raw ,
813
746
shape_dict ,
814
747
pt ,
815
- preset = gl_preset ,
816
- method = gl_method ,
817
- )
818
- dash_p [1 ][0 ][2 ] = mu_charts_pn
748
+ ),
749
+ )
0 commit comments