Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MNT] Solving moabb and braindecode compatibility #669

Merged
4 changes: 3 additions & 1 deletion docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ Bugs
- Fix Stieger2021 dataset bugs (:gh:`651` by `Martin Wimpff`_)
- Unpinning major version Scikit-learn and numpy (:gh:`652` by `Bruno Aristimunha`_)
- Replacing the func:`numpy.string_` to func:`numpy.bytes_` (:gh:`665` by `Bruno Aristimunha`_)
- Fixing the set_download_dir that was not working when we tried to set the dir more than 10 times at the same time (:gh:`668` by `Bruno Aristimunha`_)
- Fixing the set_download_dir that was not working when we tried to set the dir more than 10 times at the same time (:gh:`668` by `Bruno Aristimunha`_)
- Creating stimulus channels in :class:`moabb.datasets.Zhou2016` and :class:`moabb.datasets.PhysionetMI` to allow braindecode compatibility (:gh:`669` by `Bruno Aristimunha`_)


API changes
~~~~~~~~~~~
Expand Down
6 changes: 5 additions & 1 deletion moabb/datasets/Zhou2016.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .base import BaseDataset
from .download import get_dataset_path
from .utils import stim_channels_with_selected_ids


DATA_PATH = "https://ndownloader.figshare.com/files/3662952"
Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__(self):
paradigm="imagery",
doi="10.1371/journal.pone.0162657",
)
self.events = dict(left_hand=1, right_hand=2, feet=3)

def _get_single_subject_data(self, subject):
"""Return data for a single subject."""
Expand All @@ -105,7 +107,9 @@ def _get_single_subject_data(self, subject):
stim[stim == "2"] = "right_hand"
stim[stim == "3"] = "feet"
raw.annotations.description = stim
out[sess_key][run_key] = raw
out[sess_key][run_key] = stim_channels_with_selected_ids(
raw, desired_event_id=self.events
)
out[sess_key][run_key].set_montage(make_standard_montage("standard_1005"))
return out

Expand Down
49 changes: 46 additions & 3 deletions moabb/datasets/physionet_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from moabb.datasets.base import BaseDataset
from moabb.datasets.download import data_dl, get_dataset_path
from moabb.datasets.utils import stim_channels_with_selected_ids


BASE_URL = "https://physionet.org/files/eegmmidb/1.0.0/"
Expand Down Expand Up @@ -79,7 +80,7 @@ def __init__(self, imagined=True, executed=False):
paradigm="imagery",
doi="10.1109/TBME.2004.827072",
)

self.events = dict(left_hand=2, right_hand=3, feet=5, hands=4, rest=1)
self.imagined = imagined
self.executed = executed
self.feet_runs = []
Expand Down Expand Up @@ -123,7 +124,9 @@ def _get_single_subject_data(self, subject):
stim[stim == "T1"] = "left_hand"
stim[stim == "T2"] = "right_hand"
raw.annotations.description = stim
data[str(idx)] = raw
data[str(idx)] = stim_channels_with_selected_ids(
raw, desired_event_id=self.events
)
idx += 1

# feet runs
Expand All @@ -136,7 +139,9 @@ def _get_single_subject_data(self, subject):
stim[stim == "T1"] = "hands"
stim[stim == "T2"] = "feet"
raw.annotations.description = stim
data[str(idx)] = raw
data[str(idx)] = stim_channels_with_selected_ids(
raw, desired_event_id=self.events
)
idx += 1

return {"0": data}
Expand Down Expand Up @@ -171,3 +176,41 @@ def _load_data(self, subject, runs, path=None, force_update=False, verbose=None)
p = data_dl(url, sign, path, force_update, verbose)
data_paths.append(p)
return data_paths

def _create_stim_channels(self, raw):
# Define a consistent mapping from event descriptions to integer IDs
desired_event_id = self.events

# Get events using the consistent event_id mapping
events, _ = mne.events_from_annotations(raw, event_id=desired_event_id)

# Filter the events array to include only desired events
desired_event_ids = list(desired_event_id.values())
filtered_events = events[np.isin(events[:, 2], desired_event_ids)]

# Create annotations from filtered events using the inverted mapping
event_desc = {v: k for k, v in desired_event_id.items()}
annot_from_events = mne.annotations_from_events(
events=filtered_events,
event_desc=event_desc,
sfreq=raw.info["sfreq"],
orig_time=raw.info["meas_date"],
)
raw.set_annotations(annot_from_events)

# Create the stim channel data array
stim_channs = np.zeros((1, raw.n_times))
for event in filtered_events:
sample_index = event[0]
event_code = event[2] # Consistent event IDs
stim_channs[0, sample_index] = event_code

# Create the stim channel and add it to raw
stim_channel_name = "STIM"
stim_info = mne.create_info(
[stim_channel_name], sfreq=raw.info["sfreq"], ch_types=["stim"]
)
stim_raw = mne.io.RawArray(stim_channs, stim_info, verbose=False)
raw_with_stim = raw.copy().add_channels([stim_raw], force_update_info=True)

return raw_with_stim
50 changes: 50 additions & 0 deletions moabb/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import inspect

import mne
import numpy as np
from mne import create_info
from mne.io import RawArray
Expand Down Expand Up @@ -273,3 +274,52 @@ def add_stim_channel_epoch(
)
raw = raw.add_channels([RawArray(data=stim_chan, info=info, verbose=False)])
return raw


def stim_channels_with_selected_ids(
raw: mne.io.BaseRaw, desired_event_id: dict, stim_channel_name="STIM"
):
"""
Add a stimulus channel with filtering and renaming based on events_ids.

Parameters
----------
raw: mne.Raw
The raw object to add the stimulus channel to.
desired_event_id: dict
Dictionary with events
"""

# Get events using the consistent event_id mapping
events, _ = mne.events_from_annotations(raw, event_id=desired_event_id)

# Filter the events array to include only desired events
desired_event_ids = list(desired_event_id.values())
filtered_events = events[np.isin(events[:, 2], desired_event_ids)]

# Create annotations from filtered events using the inverted mapping
event_desc = {v: k for k, v in desired_event_id.items()}
annot_from_events = mne.annotations_from_events(
events=filtered_events,
event_desc=event_desc,
sfreq=raw.info["sfreq"],
orig_time=raw.info["meas_date"],
)
raw.set_annotations(annot_from_events)

# Create the stim channel data array
stim_channs = np.zeros((1, raw.n_times))
for event in filtered_events:
sample_index = event[0]
event_code = event[2] # Consistent event IDs
stim_channs[0, sample_index] = event_code

# Create the stim channel and add it to raw

stim_info = mne.create_info(
[stim_channel_name], sfreq=raw.info["sfreq"], ch_types=["stim"]
)
stim_raw = mne.io.RawArray(stim_channs, stim_info, verbose=False)
raw_with_stim = raw.copy().add_channels([stim_raw], force_update_info=True)

return raw_with_stim
Loading