Skip to content

Separate sync as its own stream in SpikeGLXRawIO #1683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 46 additions & 7 deletions neo/rawio/spikeglxrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pathlib import Path
import os
import re
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -109,6 +110,12 @@ def __init__(self, dirname="", load_sync_channel=False, load_channel_location=Fa
BaseRawWithBufferApiIO.__init__(self)
self.dirname = dirname
self.load_sync_channel = load_sync_channel
if load_sync_channel:
warn(
"The load_sync_channel=True option is deprecated and will be removed in version 0.15. "
"Use load_sync_channel=False instead, which will add sync channels as separate streams.",
DeprecationWarning, stacklevel=2
)
self.load_channel_location = load_channel_location

def _source_name(self):
Expand Down Expand Up @@ -152,6 +159,8 @@ def _parse_header(self):
signal_buffers = []
signal_streams = []
signal_channels = []
sync_stream_id_to_buffer_id = {}

for stream_name in stream_names:
# take first segment
info = self.signals_info_dict[0, stream_name]
Expand All @@ -168,6 +177,16 @@ def _parse_header(self):
for local_chan in range(info["num_chan"]):
chan_name = info["channel_names"][local_chan]
chan_id = f"{stream_name}#{chan_name}"

# Sync channel
if "nidq" not in stream_name and "SY0" in chan_name and not self.load_sync_channel and local_chan == info["num_chan"] - 1:
# This is a sync channel and should be added as its own stream
sync_stream_id = f"{stream_name}-SYNC"
sync_stream_id_to_buffer_id[sync_stream_id] = buffer_id
stream_id_for_chan = sync_stream_id
else:
stream_id_for_chan = stream_id

signal_channels.append(
(
chan_name,
Expand All @@ -177,25 +196,33 @@ def _parse_header(self):
info["units"],
info["channel_gains"][local_chan],
info["channel_offsets"][local_chan],
stream_id,
stream_id_for_chan,
buffer_id,
)
)

# all channel by dafult unless load_sync_channel=False
# all channel by default unless load_sync_channel=False
self._stream_buffer_slice[stream_id] = None

# check sync channel validity
if "nidq" not in stream_name:
if not self.load_sync_channel and info["has_sync_trace"]:
# the last channel is remove from the stream but not from the buffer
last_chan = signal_channels[-1]
last_chan = last_chan[:-2] + ("", buffer_id)
signal_channels = signal_channels[:-1] + [last_chan]
# the last channel is removed from the stream but not from the buffer
self._stream_buffer_slice[stream_id] = slice(0, -1)

# Add a buffer slice for the sync channel
sync_stream_id = f"{stream_name}-SYNC"
self._stream_buffer_slice[sync_stream_id] = slice(-1, None)

if self.load_sync_channel and not info["has_sync_trace"]:
raise ValueError("SYNC channel is not present in the recording. " "Set load_sync_channel to False")

signal_buffers = np.array(signal_buffers, dtype=_signal_buffer_dtype)

# Add sync channels as their own streams
for sync_stream_id, buffer_id in sync_stream_id_to_buffer_id.items():
signal_streams.append((sync_stream_id, sync_stream_id, buffer_id))

signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype)
signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)

Expand Down Expand Up @@ -237,6 +264,14 @@ def _parse_header(self):
t_start = frame_start / sampling_frequency

self._t_starts[stream_name][seg_index] = t_start

# This need special logic because sync not present in stream_names
if f"{stream_name}-SYNC" in signal_streams["name"]:
sync_stream_name = f"{stream_name}-SYNC"
if sync_stream_name not in self._t_starts:
self._t_starts[sync_stream_name] = {}
self._t_starts[sync_stream_name][seg_index] = t_start

t_stop = info["sample_length"] / info["sampling_rate"]
self._t_stops[seg_index] = max(self._t_stops[seg_index], t_stop)

Expand Down Expand Up @@ -265,7 +300,11 @@ def _parse_header(self):
if self.load_channel_location:
# need probeinterface to be installed
import probeinterface


# Skip for sync streams
if "SYNC" in stream_name:
continue

info = self.signals_info_dict[seg_index, stream_name]
if "imroTbl" in info["meta"] and info["stream_kind"] == "ap":
# only for ap channel
Expand Down
26 changes: 25 additions & 1 deletion neo/test/rawiotest/test_spikeglxrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_loading_only_one_probe_in_multi_probe_scenario(self):
rawio = SpikeGLXRawIO(probe_folder_path)
rawio.parse_header()

expected_stream_names = ["imec1.ap", "imec1.lf"]
expected_stream_names = ["imec1.ap", "imec1.lf", "imec1.ap-SYNC", "imec1.lf-SYNC"]
actual_stream_names = rawio.header["signal_streams"]["name"].tolist()
assert (
actual_stream_names == expected_stream_names
Expand Down Expand Up @@ -130,6 +130,30 @@ def test_nidq_digital_channel(self):
atol = 0.001
assert np.allclose(on_diff, 1, atol=atol)

def test_sync_channel_as_separate_stream(self):
"""Test that sync channel is added as its own stream when load_sync_channel=False."""
import warnings

# Test with load_sync_channel=False (default)
rawio_no_sync = SpikeGLXRawIO(self.get_local_path("spikeglx/NP2_with_sync"), load_sync_channel=False)
rawio_no_sync.parse_header()

# Get stream names
stream_names = rawio_no_sync.header["signal_streams"]["name"].tolist()

# Check if there's a sync channel stream (should contain "SY0" or "SYNC" in the name)
sync_streams = [name for name in stream_names if "SY0" in name or "SYNC" in name]
assert len(sync_streams) > 0, "No sync channel stream found when load_sync_channel=False"

# Test deprecation warning when load_sync_channel=True
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
rawio_with_sync = SpikeGLXRawIO(self.get_local_path("spikeglx/NP2_with_sync"), load_sync_channel=True)

# Check if deprecation warning was raised
assert any(issubclass(warning.category, DeprecationWarning) for warning in w), "No deprecation warning raised"
assert any("will be removed in version 0.15" in str(warning.message) for warning in w), "Deprecation warning message is incorrect"

def test_t_start_reading(self):
"""Test that t_start values are correctly read for all streams and segments."""

Expand Down
Loading