Skip to content

Commit 85b506c

Browse files
committed
move _table_as_type_by_col from enterprise to sdv public
1 parent f2afcdf commit 85b506c

File tree

1 file changed

+35
-1
lines changed

1 file changed

+35
-1
lines changed

sdv/cag/base.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
"""Base Multi-Table Pattern."""
22

3+
import logging
4+
5+
import numpy as np
36
import pandas as pd
47

58
from sdv.errors import NotFittedError
69

10+
LOGGER = logging.getLogger(__name__)
11+
712

813
class BasePattern:
914
"""Base CAG Pattern Class."""
@@ -121,6 +126,31 @@ def transform(self, data):
121126
def _reverse_transform(self, data):
122127
raise NotImplementedError
123128

129+
def _table_as_type_by_col(self, reverse_transformed, table, table_name):
130+
"""Cast table to given types on a column by column basis.
131+
132+
Args:
133+
reverse_transformed (dict[str, pd.DataFrame])
134+
The reverse transformed data dictionary
135+
table (pd.DataFrame)
136+
The reverse transformed table
137+
table_name (str)
138+
The name of the table
139+
"""
140+
for col in table:
141+
try:
142+
reverse_transformed[table_name][col] = table[col].astype(
143+
self._dtypes[table_name][col]
144+
)
145+
except pd.errors.IntCastingNaNError:
146+
LOGGER.info(
147+
"Column '%s' is being converted to float because it contains NaNs.", col
148+
)
149+
self._dtypes[table_name][col] = np.dtype('float64')
150+
reverse_transformed[table_name][col] = table[col].astype(
151+
self._dtypes[table_name][col]
152+
)
153+
124154
def reverse_transform(self, data):
125155
"""Reverse transform the data back into the original space.
126156
@@ -135,7 +165,11 @@ def reverse_transform(self, data):
135165
reverse_transformed = self._reverse_transform(data)
136166
for table_name, table in reverse_transformed.items():
137167
table = table[self._original_data_columns[table_name]]
138-
reverse_transformed[table_name] = table.astype(self._dtypes[table_name])
168+
try:
169+
reverse_transformed[table_name] = table.astype(self._dtypes[table_name])
170+
except pd.errors.IntCastingNaNError:
171+
# iterate over the columns and cast individually
172+
self._table_as_type_by_col(reverse_transformed, table, table_name)
139173

140174
if self._single_table:
141175
return reverse_transformed[self._table_name]

0 commit comments

Comments
 (0)