Skip to content

Commit aec881f

Browse files
committed
add support for non-jagged trials
1 parent 8c55cec commit aec881f

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

neuro_py/ensemble/decoding/pipeline.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -210,26 +210,36 @@ def zscore_trial_segs(
210210
train_nan_cols = ~train_notnan_cols
211211
if is_2D:
212212
normed_train = np.divide(train-train_mean, train_std, where=train_notnan_cols)
213-
normed_train.loc[:, train_nan_cols] = 0
213+
# if train is not jagged, it gets converted completely to object
214+
# np.ndarray. Hence, cannot exclusively use normed_train.loc
215+
if isinstance(normed_train, pd.DataFrame):
216+
normed_train = normed_train.loc
217+
normed_train[:, train_nan_cols] = 0
214218
else:
215219
normed_train = np.empty_like(train)
216220
for i, nsvstseg in enumerate(train):
217221
zscored = np.divide(nsvstseg-train_mean, train_std, where=train_notnan_cols)
218-
zscored.loc[:, train_nan_cols] = 0
222+
if isinstance(zscored, pd.DataFrame):
223+
zscored = zscored.loc
224+
zscored[:, train_nan_cols] = 0
219225
normed_train[i] = zscored
220226

221227
normed_rest_feats = []
222228
if rest_feats is not None:
223229
for feats in rest_feats:
224230
if is_2D:
225231
normed_feats = np.divide(feats-train_mean, train_std, where=train_notnan_cols)
226-
normed_feats.loc[:, train_nan_cols] = 0
232+
if isinstance(normed_feats, pd.DataFrame):
233+
normed_feats = normed_feats.loc
234+
normed_feats[:, train_nan_cols] = 0
227235
normed_rest_feats.append(normed_feats)
228236
else:
229237
normed_feats = np.empty_like(feats)
230238
for i, trialSegROI in enumerate(feats):
231239
zscored = np.divide(feats[i]-train_mean, train_std, where=train_notnan_cols)
232-
zscored.loc[:, train_nan_cols] = 0
240+
if isinstance(zscored, pd.DataFrame):
241+
zscored = zscored.loc
242+
zscored[:, train_nan_cols] = 0
233243
normed_feats[i] = zscored
234244
normed_rest_feats.append(normed_feats)
235245

neuro_py/ensemble/decoding/preprocess.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def partition_sets(
8787
state vectors and behavioral variables.
8888
"""
8989
partitions = []
90-
is_2D = nsv_trial_segs.ndim == 1
90+
is_2D = nsv_trial_segs[0].ndim == 1
9191
for (train_indices, val_indices, test_indices) in partitions_indices:
9292
if is_2D:
9393
if isinstance(nsv_trial_segs, pd.DataFrame):

0 commit comments

Comments
 (0)