1
1
"""Base Multi-Table Pattern."""
2
2
3
+ import logging
4
+
5
+ import numpy as np
3
6
import pandas as pd
4
7
5
8
from sdv .errors import NotFittedError
6
9
10
+ LOGGER = logging .getLogger (__name__ )
11
+
7
12
8
13
class BasePattern :
9
14
"""Base CAG Pattern Class."""
@@ -121,6 +126,31 @@ def transform(self, data):
121
126
def _reverse_transform (self , data ):
122
127
raise NotImplementedError
123
128
129
+ def _table_as_type_by_col (self , reverse_transformed , table , table_name ):
130
+ """Cast table to given types on a column by column basis.
131
+
132
+ Args:
133
+ reverse_transformed (dict[str, pd.DataFrame])
134
+ The reverse transformed data dictionary
135
+ table (pd.DataFrame)
136
+ The reverse transformed table
137
+ table_name (str)
138
+ The name of the table
139
+ """
140
+ for col in table :
141
+ try :
142
+ reverse_transformed [table_name ][col ] = table [col ].astype (
143
+ self ._dtypes [table_name ][col ]
144
+ )
145
+ except pd .errors .IntCastingNaNError :
146
+ LOGGER .info (
147
+ "Column '%s' is being converted to float because it contains NaNs." , col
148
+ )
149
+ self ._dtypes [table_name ][col ] = np .dtype ('float64' )
150
+ reverse_transformed [table_name ][col ] = table [col ].astype (
151
+ self ._dtypes [table_name ][col ]
152
+ )
153
+
124
154
def reverse_transform (self , data ):
125
155
"""Reverse transform the data back into the original space.
126
156
@@ -135,7 +165,11 @@ def reverse_transform(self, data):
135
165
reverse_transformed = self ._reverse_transform (data )
136
166
for table_name , table in reverse_transformed .items ():
137
167
table = table [self ._original_data_columns [table_name ]]
138
- reverse_transformed [table_name ] = table .astype (self ._dtypes [table_name ])
168
+ try :
169
+ reverse_transformed [table_name ] = table .astype (self ._dtypes [table_name ])
170
+ except pd .errors .IntCastingNaNError :
171
+ # iterate over the columns and cast individually
172
+ self ._table_as_type_by_col (reverse_transformed , table , table_name )
139
173
140
174
if self ._single_table :
141
175
return reverse_transformed [self ._table_name ]
0 commit comments