Skip to content

Commit 9bdd113

Browse files
authored
Merge pull request #232 from basf/master
Hotfix: data dimension for predict data
2 parents 79e81ad + 4dc64e3 commit 9bdd113

File tree

5 files changed

+314
-350
lines changed

5 files changed

+314
-350
lines changed

mambular/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
#
1717

1818
# The following line *must* be the last in the module, exactly as formatted:
19-
__version__ = "1.2.0"
19+
__version__ = "1.2.1"

mambular/data_utils/datamodule.py

+33-95
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,7 @@ def preprocess_data(
130130
embeddings_val = [embeddings_val]
131131

132132
split_data += embeddings_train
133-
split_result = train_test_split(
134-
*split_data, test_size=val_size, random_state=random_state
135-
)
133+
split_result = train_test_split(*split_data, test_size=val_size, random_state=random_state)
136134

137135
self.X_train, self.X_val, self.y_train, self.y_val = split_result[:4]
138136
self.embeddings_train = split_result[4::2]
@@ -161,37 +159,31 @@ def preprocess_data(
161159
self.embeddings_val = None
162160

163161
# Fit the preprocessor on the combined training and validation data
164-
combined_X = pd.concat([self.X_train, self.X_val], axis=0).reset_index(
165-
drop=True
166-
)
162+
combined_X = pd.concat([self.X_train, self.X_val], axis=0).reset_index(drop=True)
167163
combined_y = np.concatenate((self.y_train, self.y_val), axis=0)
168164

169165
if self.embeddings_train is not None and self.embeddings_val is not None:
170166
combined_embeddings = [
171167
np.concatenate((emb_train, emb_val), axis=0)
172-
for emb_train, emb_val in zip(
173-
self.embeddings_train, self.embeddings_val
174-
)
168+
for emb_train, emb_val in zip(self.embeddings_train, self.embeddings_val, strict=False)
175169
]
176170
else:
177171
combined_embeddings = None
178172

179173
self.preprocessor.fit(combined_X, combined_y, combined_embeddings)
180174

181175
# Update feature info based on the actual processed data
182-
(self.num_feature_info, self.cat_feature_info, self.embedding_feature_info) = (
183-
self.preprocessor.get_feature_info()
184-
)
176+
(
177+
self.num_feature_info,
178+
self.cat_feature_info,
179+
self.embedding_feature_info,
180+
) = self.preprocessor.get_feature_info()
185181

186182
def setup(self, stage: str):
187183
"""Transform the data and create DataLoaders."""
188184
if stage == "fit":
189-
train_preprocessed_data = self.preprocessor.transform(
190-
self.X_train, self.embeddings_train
191-
)
192-
val_preprocessed_data = self.preprocessor.transform(
193-
self.X_val, self.embeddings_val
194-
)
185+
train_preprocessed_data = self.preprocessor.transform(self.X_train, self.embeddings_train)
186+
val_preprocessed_data = self.preprocessor.transform(self.X_val, self.embeddings_val)
195187

196188
# Initialize lists for tensors
197189
train_cat_tensors = []
@@ -205,75 +197,40 @@ def setup(self, stage: str):
205197
for key in self.cat_feature_info: # type: ignore
206198
dtype = (
207199
torch.float32
208-
if any(
209-
x in self.cat_feature_info[key]["preprocessing"]
210-
for x in ["onehot", "pretrained"]
211-
)
200+
if any(x in self.cat_feature_info[key]["preprocessing"] for x in ["onehot", "pretrained"]) # type: ignore
212201
else torch.long
213202
)
214203

215-
cat_key = "cat_" + str(
216-
key
217-
) # Assuming categorical keys are prefixed with 'cat_'
204+
cat_key = "cat_" + str(key) # Assuming categorical keys are prefixed with 'cat_'
218205
if cat_key in train_preprocessed_data:
219-
train_cat_tensors.append(
220-
torch.tensor(train_preprocessed_data[cat_key], dtype=dtype)
221-
)
206+
train_cat_tensors.append(torch.tensor(train_preprocessed_data[cat_key], dtype=dtype))
222207
if cat_key in val_preprocessed_data:
223-
val_cat_tensors.append(
224-
torch.tensor(val_preprocessed_data[cat_key], dtype=dtype)
225-
)
208+
val_cat_tensors.append(torch.tensor(val_preprocessed_data[cat_key], dtype=dtype))
226209

227210
binned_key = "num_" + str(key) # for binned features
228211
if binned_key in train_preprocessed_data:
229-
train_cat_tensors.append(
230-
torch.tensor(train_preprocessed_data[binned_key], dtype=dtype)
231-
)
212+
train_cat_tensors.append(torch.tensor(train_preprocessed_data[binned_key], dtype=dtype))
232213

233214
if binned_key in val_preprocessed_data:
234-
val_cat_tensors.append(
235-
torch.tensor(val_preprocessed_data[binned_key], dtype=dtype)
236-
)
215+
val_cat_tensors.append(torch.tensor(val_preprocessed_data[binned_key], dtype=dtype))
237216

238217
# Populate tensors for numerical features, if present in processed data
239218
for key in self.num_feature_info: # type: ignore
240-
num_key = "num_" + str(
241-
key
242-
) # Assuming numerical keys are prefixed with 'num_'
219+
num_key = "num_" + str(key) # Assuming numerical keys are prefixed with 'num_'
243220
if num_key in train_preprocessed_data:
244-
train_num_tensors.append(
245-
torch.tensor(
246-
train_preprocessed_data[num_key], dtype=torch.float32
247-
)
248-
)
221+
train_num_tensors.append(torch.tensor(train_preprocessed_data[num_key], dtype=torch.float32))
249222
if num_key in val_preprocessed_data:
250-
val_num_tensors.append(
251-
torch.tensor(
252-
val_preprocessed_data[num_key], dtype=torch.float32
253-
)
254-
)
223+
val_num_tensors.append(torch.tensor(val_preprocessed_data[num_key], dtype=torch.float32))
255224

256225
if self.embedding_feature_info is not None:
257226
for key in self.embedding_feature_info:
258227
if key in train_preprocessed_data:
259-
train_emb_tensors.append(
260-
torch.tensor(
261-
train_preprocessed_data[key], dtype=torch.float32
262-
)
263-
)
228+
train_emb_tensors.append(torch.tensor(train_preprocessed_data[key], dtype=torch.float32))
264229
if key in val_preprocessed_data:
265-
val_emb_tensors.append(
266-
torch.tensor(
267-
val_preprocessed_data[key], dtype=torch.float32
268-
)
269-
)
270-
271-
train_labels = torch.tensor(
272-
self.y_train, dtype=self.labels_dtype
273-
).unsqueeze(dim=1)
274-
val_labels = torch.tensor(self.y_val, dtype=self.labels_dtype).unsqueeze(
275-
dim=1
276-
)
230+
val_emb_tensors.append(torch.tensor(val_preprocessed_data[key], dtype=torch.float32))
231+
232+
train_labels = torch.tensor(self.y_train, dtype=self.labels_dtype).unsqueeze(dim=1)
233+
val_labels = torch.tensor(self.y_val, dtype=self.labels_dtype).unsqueeze(dim=1)
277234

278235
self.train_dataset = MambularDataset(
279236
train_cat_tensors,
@@ -300,42 +257,27 @@ def preprocess_new_data(self, X, embeddings):
300257
for key in self.cat_feature_info: # type: ignore
301258
dtype = (
302259
torch.float32
303-
if any(
304-
x in self.cat_feature_info[key]["preprocessing"]
305-
for x in ["onehot", "pretrained"]
306-
)
260+
if any(x in self.cat_feature_info[key]["preprocessing"] for x in ["onehot", "pretrained"]) # type: ignore
307261
else torch.long
308262
)
309-
cat_key = "cat_" + str(
310-
key
311-
) # Assuming categorical keys are prefixed with 'cat_'
263+
cat_key = "cat_" + str(key) # Assuming categorical keys are prefixed with 'cat_'
312264
if cat_key in preprocessed_data:
313-
cat_tensors.append(
314-
torch.tensor(preprocessed_data[cat_key], dtype=dtype)
315-
)
265+
cat_tensors.append(torch.tensor(preprocessed_data[cat_key], dtype=dtype))
316266

317267
binned_key = "num_" + str(key) # for binned features
318268
if binned_key in preprocessed_data:
319-
cat_tensors.append(
320-
torch.tensor(preprocessed_data[binned_key], dtype=dtype)
321-
)
269+
cat_tensors.append(torch.tensor(preprocessed_data[binned_key], dtype=dtype))
322270

323271
# Populate tensors for numerical features, if present in processed data
324272
for key in self.num_feature_info: # type: ignore
325-
num_key = "num_" + str(
326-
key
327-
) # Assuming numerical keys are prefixed with 'num_'
273+
num_key = "num_" + str(key) # Assuming numerical keys are prefixed with 'num_'
328274
if num_key in preprocessed_data:
329-
num_tensors.append(
330-
torch.tensor(preprocessed_data[num_key], dtype=torch.float32)
331-
)
275+
num_tensors.append(torch.tensor(preprocessed_data[num_key], dtype=torch.float32))
332276

333277
if self.embedding_feature_info is not None:
334278
for key in self.embedding_feature_info:
335279
if key in preprocessed_data:
336-
emb_tensors.append(
337-
torch.tensor(preprocessed_data[key], dtype=torch.float32)
338-
)
280+
emb_tensors.append(torch.tensor(preprocessed_data[key], dtype=torch.float32))
339281

340282
return MambularDataset(
341283
cat_tensors,
@@ -374,9 +316,7 @@ def val_dataloader(self):
374316
DataLoader: DataLoader instance for the validation dataset.
375317
"""
376318
if hasattr(self, "val_dataset"):
377-
return DataLoader(
378-
self.val_dataset, batch_size=self.batch_size, **self.dataloader_kwargs
379-
)
319+
return DataLoader(self.val_dataset, batch_size=self.batch_size, **self.dataloader_kwargs)
380320
else:
381321
raise ValueError("No validation dataset provided!")
382322

@@ -387,9 +327,7 @@ def test_dataloader(self):
387327
DataLoader: DataLoader instance for the test dataset.
388328
"""
389329
if hasattr(self, "test_dataset"):
390-
return DataLoader(
391-
self.test_dataset, batch_size=self.batch_size, **self.dataloader_kwargs
392-
)
330+
return DataLoader(self.test_dataset, batch_size=self.batch_size, **self.dataloader_kwargs)
393331
else:
394332
raise ValueError("No test dataset provided!")
395333

mambular/preprocessing/preprocessor.py

+4
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,10 @@ def transform(self, X, embeddings=None):
582582
raise NotFittedError(
583583
"The preprocessor must be fitted before transforming new data. Use .fit or .fit_transform"
584584
)
585+
if isinstance(X, np.ndarray):
586+
X = pd.DataFrame(X)
587+
else:
588+
X = X.copy()
585589
transformed_X = self.column_transformer.transform(X) # type: ignore
586590

587591
# Now let's convert this into a dictionary of arrays, one per column

0 commit comments

Comments
 (0)