Skip to content

Commit a59ecf7

Browse files
committed
fix breaking dnn input formatting tests
1 parent aec881f commit a59ecf7

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

neuro_py/ensemble/decoding/pipeline.py

+25-10
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import random
44

5+
from re import X
56
from typing import List, Tuple, Dict, Optional, Any
67

78
import sklearn.preprocessing
@@ -202,8 +203,11 @@ def zscore_trial_segs(
202203
Normalized train data, normalized rest features, and normalization parameters.
203204
"""
204205
is_2D = train[0].ndim == 1
205-
concat_train = train if is_2D else np.concatenate(train)
206-
train_mean = normparams['X_train_mean'] if normparams is not None else bn.nanmean(concat_train, axis=0)
206+
concat_train = train if is_2D else np.concatenate(train).astype(float)
207+
train_mean = (
208+
normparams['X_train_mean'] if normparams is not None
209+
else bn.nanmean(concat_train, axis=0)
210+
)
207211
train_std = normparams['X_train_std'] if normparams is not None else bn.nanstd(concat_train, axis=0)
208212

209213
train_notnan_cols = train_std != 0
@@ -213,15 +217,17 @@ def zscore_trial_segs(
213217
# if train is not jagged, it gets converted completely to object
214218
# np.ndarray. Hence, cannot exclusively use normed_train.loc
215219
if isinstance(normed_train, pd.DataFrame):
216-
normed_train = normed_train.loc
217-
normed_train[:, train_nan_cols] = 0
220+
normed_train.loc[:, train_nan_cols] = 0
221+
else:
222+
normed_train[:, train_nan_cols] = 0
218223
else:
219224
normed_train = np.empty_like(train)
220225
for i, nsvstseg in enumerate(train):
221226
zscored = np.divide(nsvstseg-train_mean, train_std, where=train_notnan_cols)
222227
if isinstance(zscored, pd.DataFrame):
223-
zscored = zscored.loc
224-
zscored[:, train_nan_cols] = 0
228+
zscored.loc[:, train_nan_cols] = 0
229+
else:
230+
zscored[:, train_nan_cols] = 0
225231
normed_train[i] = zscored
226232

227233
normed_rest_feats = []
@@ -230,16 +236,18 @@ def zscore_trial_segs(
230236
if is_2D:
231237
normed_feats = np.divide(feats-train_mean, train_std, where=train_notnan_cols)
232238
if isinstance(normed_feats, pd.DataFrame):
233-
normed_feats = normed_feats.loc
234-
normed_feats[:, train_nan_cols] = 0
239+
normed_feats.loc[:, train_nan_cols] = 0
240+
else:
241+
normed_feats[:, train_nan_cols] = 0
235242
normed_rest_feats.append(normed_feats)
236243
else:
237244
normed_feats = np.empty_like(feats)
238245
for i, trialSegROI in enumerate(feats):
239246
zscored = np.divide(feats[i]-train_mean, train_std, where=train_notnan_cols)
240247
if isinstance(zscored, pd.DataFrame):
241-
zscored = zscored.loc
242-
zscored[:, train_nan_cols] = 0
248+
zscored.loc[:, train_nan_cols] = 0
249+
else:
250+
zscored[:, train_nan_cols] = 0
243251
normed_feats[i] = zscored
244252
normed_rest_feats.append(normed_feats)
245253

@@ -351,6 +359,13 @@ def minibatchify(
351359
"""
352360
g_seed = torch.Generator()
353361
g_seed.manual_seed(seed)
362+
if Xtrain.ndim == 2: # handle object arrays
363+
Xtrain = Xtrain.astype(np.float32)
364+
Xval = Xval.astype(np.float32)
365+
Xtest = Xtest.astype(np.float32)
366+
ytrain = ytrain.astype(np.float32)
367+
yval = yval.astype(np.float32)
368+
ytest = ytest.astype(np.float32)
354369
train = torch.utils.data.TensorDataset(
355370
torch.from_numpy(Xtrain).type(torch.float32),
356371
torch.from_numpy(ytrain).type(torch.float32))

0 commit comments

Comments
 (0)