@@ -573,9 +573,16 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None,
573
573
return ax .get_figure ()
574
574
575
575
576
- def plot_spikes_raster (cell_response , trial_idx = None , ax = None , show = True ,
577
- cell_types = None , colors = None ,
578
- ):
576
+ def plot_spikes_raster (
577
+ cell_response ,
578
+ trial_idx = None ,
579
+ ax = None ,
580
+ show = True ,
581
+ cell_types = None ,
582
+ colors = None ,
583
+ show_legend = True ,
584
+ marker_size = 1.0 ,
585
+ ):
579
586
"""Plot the aggregate spiking activity according to cell type.
580
587
581
588
Parameters
@@ -588,10 +595,16 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True,
588
595
An axis object from matplotlib. If None, a new figure is created.
589
596
show : bool
590
597
If True, show the figure.
591
- cell_types: list of str
598
+ cell_types : list of str
592
599
List of cell types to plot
593
- colors: list of str | None
600
+ colors : list of str | None
594
601
Optional custom colors to plot. Default will use the color cycler.
602
+ show_legend : bool
603
+ If True, show the legend with colors for cell types
604
+ marker_size : float
605
+ Optional marker size to use when plotting spikes. Uses
606
+ "linelengths" argument of ax.eventplot, which accepts positive
607
+ numeric values only
595
608
596
609
Returns
597
610
-------
@@ -649,6 +662,26 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True,
649
662
f"Got { colors .keys ()} " )
650
663
cell_colors .update (colors )
651
664
665
+ # validate show_legend argument
666
+ _validate_type (
667
+ show_legend ,
668
+ bool ,
669
+ 'show_legend' ,
670
+ 'bool'
671
+ )
672
+
673
+ # validate marker_size
674
+ _validate_type (
675
+ marker_size ,
676
+ (float , int ),
677
+ 'marker_size' ,
678
+ 'float or int'
679
+ )
680
+
681
+ # if marker_size is <= 0, set it to the default value of 1.0
682
+ if marker_size <= 0 :
683
+ marker_size = 1.0
684
+
652
685
# Extract desired trials
653
686
spike_times = np .concatenate (
654
687
np .array (cell_response ._spike_times , dtype = object )[trial_idx ])
@@ -672,19 +705,33 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True,
672
705
673
706
if cell_type_times :
674
707
events .append (
675
- ax .eventplot (cell_type_times , lineoffsets = cell_type_ypos ,
676
- color = color ,
677
- label = cell_type , linelengths = 1 ))
708
+ ax .eventplot (
709
+ cell_type_times ,
710
+ lineoffsets = cell_type_ypos ,
711
+ color = color ,
712
+ label = cell_type ,
713
+ linelengths = marker_size
714
+ )
715
+ )
678
716
else :
679
717
# Blank plot for no spiking
680
718
events .append (
681
719
ax .eventplot ([- 1 ], lineoffsets = [- 1 ],
682
720
color = color ,
683
721
label = cell_type , linelengths = 1 ))
684
722
723
+ # set legend
685
724
ax .legend (handles = [e [0 ] for e in events ], loc = 1 )
725
+ if not show_legend :
726
+ ax .get_legend ().remove ()
727
+
728
+ # set axis labels
686
729
ax .set_xlabel ('Time (ms)' )
687
- ax .get_yaxis ().set_visible (False )
730
+ ax .set_ylabel ('Cell ID' )
731
+
732
+ # hide y-axis ticks and tick labels
733
+ ax .set_yticklabels ([])
734
+ ax .tick_params (axis = 'y' , length = 0 )
688
735
689
736
if len (cell_response .times ) > 0 :
690
737
ax .set_xlim (left = 0 , right = cell_response .times [- 1 ])
0 commit comments