Skip to content

Commit

Permalink
fixed tests. Need to add more
Browse files Browse the repository at this point in the history
  • Loading branch information
mikaelene committed Nov 5, 2019
1 parent 167f153 commit 4f27c52
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 57 deletions.
11 changes: 8 additions & 3 deletions eneel/adapters/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,12 @@ def query_columns(self, query):
character_maximum_length = None
numeric_precision = None
numeric_scale = None
if data_type in ("cx_Oracle.BLOB", "cx_Oracle.OBJECT", "cx_Oracle.BFILE", "cx_Oracle.NCLOB"):
if data_type in (
"cx_Oracle.BLOB",
"cx_Oracle.OBJECT",
"cx_Oracle.BFILE",
"cx_Oracle.NCLOB",
):
data_type = "bytes"
elif data_type in ("cx_Oracle.DATETIME", "cx_Oracle.TIMESTAMP"):
data_type = "datetime.datetime"
Expand Down Expand Up @@ -215,9 +220,9 @@ def remove_unsupported_columns(self, columns):
for column in columns:
data_type = column[2]
character_maximum_length = column[3]
if data_type == 'str' and character_maximum_length > 8000:
if data_type == "str" and character_maximum_length > 8000:
columns_to_keep.remove(column)
if data_type in ('bytearray'):
if data_type in ("bytearray"):
columns_to_keep.remove(column)
return columns_to_keep

Expand Down
4 changes: 2 additions & 2 deletions eneel/adapters/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,9 @@ def remove_unsupported_columns(self, columns):
for column in columns:
data_type = column[2]
character_maximum_length = column[3]
if data_type == 'str' and character_maximum_length > 8000:
if data_type == "str" and character_maximum_length > 8000:
columns_to_keep.remove(column)
if data_type == 'bytearray':
if data_type == "bytearray":
columns_to_keep.remove(column)
return columns_to_keep

Expand Down
4 changes: 2 additions & 2 deletions eneel/adapters/sqlserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,9 @@ def remove_unsupported_columns(self, columns):
for column in columns:
data_type = column[2]
character_maximum_length = column[3]
#if data_type == 'str' and character_maximum_length > 8000:
# if data_type == 'str' and character_maximum_length > 8000:
# columns_to_keep.remove(column)
if data_type == 'bytearray':
if data_type == "bytearray":
columns_to_keep.remove(column)
return columns_to_keep

Expand Down
14 changes: 7 additions & 7 deletions eneel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ def run_cmd(cmd, envs=None):
return -1, sys.exc_info()[0]


def export_csv(rows, filename, delimiter='|'):
def export_csv(rows, filename, delimiter="|"):
try:
csv_file = open(filename, 'a', encoding="utf-8")
csv_file = open(filename, "a", encoding="utf-8")
for row in rows:
csv_row = ''
csv_row = ""
for i in range(len(row)):
col = row[i]
if col is None:
col = ''
col = ""
if col is True:
col = 1
if col is False:
Expand All @@ -100,9 +100,9 @@ def export_csv(rows, filename, delimiter='|'):
col = str(col).strip()
csv_row += col
# Replace linebreaks if any
csv_row = csv_row.replace('\n', ' ')
csv_row = csv_row.replace('\r', ' ')
csv_file.write(csv_row + '\n')
csv_row = csv_row.replace("\n", " ")
csv_row = csv_row.replace("\r", " ")
csv_file.write(csv_row + "\n")
csv_file.close()
rowcount = len(rows)
# logger.info(str(rowcount) + " rows added to " + filename)
Expand Down
32 changes: 13 additions & 19 deletions tests/unit/test_adapters_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,25 +85,6 @@ def test_get_max_column_value(self, db):
== "Not implemented for this adapter"
)

def test_export_table(self, tmpdir, db):
# columns = os.environ['ORACLE_TEST_TABLE_COLUMN'].split()
columns = [(1, "ID_COL", "NUMBER", None, 22, 0)]
path = tmpdir
file_path, delimiter, row_count = db.export_table(
os.getenv("ORACLE_TEST_SCHEMA"),
os.getenv("ORACLE_TEST_TABLE"),
columns,
path,
delimiter=",",
replication_key=None,
max_replication_key=None,
parallelization_key=None,
)

# assert row_count > 0
assert os.path.exists(file_path) == 1
assert os.stat(file_path).st_size > 0

def test_import_table(self, tmpdir, db):
assert (
db.import_table("test_target", "test1_target", "path")
Expand All @@ -127,3 +108,16 @@ def test_log(self, db):
db.log("log_schema", "log_table", project="project")
== "Not implemented for this adapter"
)

def test_query_columns(self, db):
query_columns = db.query_columns(
"select id_col, name_col, datetime_col from test.test1"
)

assert type(query_columns) == list
assert len(query_columns) > 0
assert query_columns[0][1] == "ID_COL"
assert query_columns[0][2] == "int"
assert query_columns[1][2] == "str"
assert query_columns[1][3] == 64
assert query_columns[2][2] == "datetime.datetime"
15 changes: 3 additions & 12 deletions tests/unit/test_adapters_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,7 @@ def test_get_min_max_column_value(self, db):
assert min == 1
assert max == 3

def test_export_table(self, tmpdir, db):
columns = [(1, "id_col", "integer", None, 32, 0)]
path = tmpdir
file_path, delimiter, row_count = db.export_table(
"test", "test1", columns, path
)

assert row_count == 3
assert os.path.exists(file_path) == 1

@pytest.mark.skip(reason="must create fixture for file to import")
def test_import_table(self, tmpdir, db):
columns = [(1, "id_col", "integer", None, 32, 0)]
path = tmpdir
Expand Down Expand Up @@ -151,6 +142,6 @@ def test_query_columns(self, db):
assert len(query_columns) > 0
assert query_columns[0][1] == "id_col"
assert query_columns[0][2] == "int"
assert query_columns[1][2] == "varchar"
assert query_columns[1][2] == "str"
assert query_columns[1][3] == 64
assert query_columns[2][2] == "timestamp"
assert query_columns[2][2] == "datetime.datetime"
15 changes: 3 additions & 12 deletions tests/unit/test_adapters_sqlserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,7 @@ def test_get_max_column_value(self, db):

assert db.get_max_column_value("test.test1", "id_col") == "3"

def test_export_table(self, tmpdir, db):
columns = [(1, "id_col", "integer", None, 32, 0)]
path = tmpdir
file_path, delimiter, row_count = db.export_table(
"test", "test1", columns, path
)

assert row_count == 3
assert os.path.exists(file_path) == 1

@pytest.mark.skip(reason="must create fixture for file to import")
def test_import_table(self, tmpdir, db):
columns = [(1, "id_col", "integer", None, 32, 0)]
path = tmpdir
Expand Down Expand Up @@ -152,6 +143,6 @@ def test_query_columns(self, db):
assert len(query_columns) > 0
assert query_columns[0][1] == "id_col"
assert query_columns[0][2] == "int"
assert query_columns[1][2] == "varchar"
assert query_columns[1][2] == "str"
assert query_columns[1][3] == 64
assert query_columns[2][2] == "datetime2"
assert query_columns[2][2] == "datetime.datetime"

0 comments on commit 4f27c52

Please sign in to comment.