Skip to content

Commit b2c4085

Browse files
[WIP] Improve Raster Plot (#1017)
* add templates for layer-specific PSD and update the Drive-Dipole-Spectrogram (3x1) tempate to point to the layer-specific PSDs * add layer to legend * update plot_spikes_raster to 1) have Cell ID y-axis label, 2) accept show_legend argument that will optionally hide the legend, and 3) accept marker_size argument to change the spike marker sizes --------- Co-authored-by: Austin E. Soplata <[email protected]>
1 parent 44838f5 commit b2c4085

File tree

2 files changed

+72
-13
lines changed

2 files changed

+72
-13
lines changed

hnn_core/cell_response.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,15 @@ def mean_rates(self, tstart, tstop, gid_ranges, mean_type='all'):
293293

294294
return spike_rates
295295

296-
def plot_spikes_raster(self, trial_idx=None, ax=None, show=True,
297-
colors=None):
296+
def plot_spikes_raster(
297+
self,
298+
trial_idx=None,
299+
ax=None,
300+
show=True,
301+
colors=None,
302+
show_legend=True,
303+
marker_size=1.0,
304+
):
298305
"""Plot the aggregate spiking activity according to cell type.
299306
300307
Parameters
@@ -314,8 +321,13 @@ def plot_spikes_raster(self, trial_idx=None, ax=None, show=True,
314321
The matplotlib figure object.
315322
"""
316323
return plot_spikes_raster(
317-
cell_response=self, trial_idx=trial_idx, ax=ax, show=show,
318-
colors=colors)
324+
cell_response=self,
325+
trial_idx=trial_idx,
326+
ax=ax, show=show,
327+
colors=colors,
328+
show_legend=show_legend,
329+
marker_size=marker_size,
330+
)
319331

320332
def plot_spikes_hist(self, trial_idx=None, ax=None, spike_types=None,
321333
color=None, invert_spike_types=None, show=True,

hnn_core/viz.py

+56-9
Original file line numberDiff line numberDiff line change
@@ -573,9 +573,16 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None,
573573
return ax.get_figure()
574574

575575

576-
def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True,
577-
cell_types=None, colors=None,
578-
):
576+
def plot_spikes_raster(
577+
cell_response,
578+
trial_idx=None,
579+
ax=None,
580+
show=True,
581+
cell_types=None,
582+
colors=None,
583+
show_legend=True,
584+
marker_size=1.0,
585+
):
579586
"""Plot the aggregate spiking activity according to cell type.
580587
581588
Parameters
@@ -588,10 +595,16 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True,
588595
An axis object from matplotlib. If None, a new figure is created.
589596
show : bool
590597
If True, show the figure.
591-
cell_types: list of str
598+
cell_types : list of str
592599
List of cell types to plot
593-
colors: list of str | None
600+
colors : list of str | None
594601
Optional custom colors to plot. Default will use the color cycler.
602+
show_legend : bool
603+
If True, show the legend with colors for cell types
604+
marker_size : float
605+
Optional marker size to use when plotting spikes. Uses
606+
"linelengths" argument of ax.eventplot, which accepts positive
607+
numeric values only
595608
596609
Returns
597610
-------
@@ -649,6 +662,26 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True,
649662
f"Got {colors.keys()}")
650663
cell_colors.update(colors)
651664

665+
# validate show_legend argument
666+
_validate_type(
667+
show_legend,
668+
bool,
669+
'show_legend',
670+
'bool'
671+
)
672+
673+
# validate marker_size
674+
_validate_type(
675+
marker_size,
676+
(float, int),
677+
'marker_size',
678+
'float or int'
679+
)
680+
681+
# if marker_size is <= 0, set it to the default value of 1.0
682+
if marker_size <= 0:
683+
marker_size = 1.0
684+
652685
# Extract desired trials
653686
spike_times = np.concatenate(
654687
np.array(cell_response._spike_times, dtype=object)[trial_idx])
@@ -672,19 +705,33 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True,
672705

673706
if cell_type_times:
674707
events.append(
675-
ax.eventplot(cell_type_times, lineoffsets=cell_type_ypos,
676-
color=color,
677-
label=cell_type, linelengths=1))
708+
ax.eventplot(
709+
cell_type_times,
710+
lineoffsets=cell_type_ypos,
711+
color=color,
712+
label=cell_type,
713+
linelengths=marker_size
714+
)
715+
)
678716
else:
679717
# Blank plot for no spiking
680718
events.append(
681719
ax.eventplot([-1], lineoffsets=[-1],
682720
color=color,
683721
label=cell_type, linelengths=1))
684722

723+
# set legend
685724
ax.legend(handles=[e[0] for e in events], loc=1)
725+
if not show_legend:
726+
ax.get_legend().remove()
727+
728+
# set axis labels
686729
ax.set_xlabel('Time (ms)')
687-
ax.get_yaxis().set_visible(False)
730+
ax.set_ylabel('Cell ID')
731+
732+
# hide y-axis ticks and tick labels
733+
ax.set_yticklabels([])
734+
ax.tick_params(axis='y', length=0)
688735

689736
if len(cell_response.times) > 0:
690737
ax.set_xlim(left=0, right=cell_response.times[-1])

0 commit comments

Comments
 (0)