Skip to content

Commit 867874d

Browse files
committed
edge concat interop
1 parent f0eb1bf commit 867874d

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

graphistry/feature_utils.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,13 +1482,16 @@ def process_edge_dataframes(
14821482
" and is empty"
14831483
)
14841484

1485-
if feature_engine in ["none", "pandas"]:
1485+
if feature_engine in ["none", "pandas", "cudf"]:
14861486

14871487
X_enc, y_enc, data_encoder, label_encoder = get_numeric_transformers(
14881488
other_df, y
14891489
)
14901490
# add the two datasets together
1491-
X_enc = pd.concat([T, X_enc], axis=1)
1491+
if feature_engine == 'pandas':
1492+
X_enc = pd.concat([T, X_enc], axis=1)
1493+
elif feature_engine == 'cudf':
1494+
X_enc = cudf.concat([T, X_enc], axis=1)
14921495
# then scale them
14931496
X_encs, y_encs, scaling_pipeline, scaling_pipeline_target = smart_scaler( # noqa
14941497
X_enc,
@@ -1556,10 +1559,20 @@ def process_edge_dataframes(
15561559
logger.debug("-" * 60)
15571560
logger.debug("<= Found Edges and Dirty_cat encoding =>")
15581561
T_type = str(getmodule(T))
1559-
if 'cudf' in T_type:
1562+
X_type = str(getmodule(X_enc))
1563+
if 'cudf' in T_type and 'cudf' in X_type:
15601564
X_enc = cudf.concat([T, X_enc], axis=1)
1561-
else:
1565+
elif 'pd' in T_type and 'pd' in X_type:
15621566
X_enc = pd.concat([T, X_enc], axis=1)
1567+
else:
1568+
try:
1569+
X_enc = cudf.concat([cudf.from_pandas(T), X_enc], axis=1)
1570+
except:
1571+
pass
1572+
try:
1573+
X_enc = cudf.concat([T, cudf.from_pandas(X_enc)], axis=1)
1574+
except:
1575+
pass
15631576
elif not T.empty and X_enc.empty:
15641577
logger.debug("-" * 60)
15651578
logger.debug("<= Found only Edges =>")

0 commit comments

Comments
 (0)