-
Notifications
You must be signed in to change notification settings - Fork 29.2k
Expand file tree
/
Copy pathutils.py
More file actions
1222 lines (1029 loc) · 46.1 KB
/
utils.py
File metadata and controls
1222 lines (1029 loc) · 46.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import struct
import sys
import unittest
import difflib
import faulthandler
import functools
from decimal import Decimal
from time import time, sleep
import signal
from typing import (
Any,
Optional,
Union,
Dict,
List,
Callable,
)
from itertools import zip_longest
from pyspark import SparkConf
from pyspark.errors import PySparkAssertionError, PySparkException, PySparkTypeError
from pyspark.errors.exceptions.base import QueryContextType
from pyspark.sql.dataframe import DataFrame
from pyspark.sql import Row
from pyspark.sql.types import StructType, StructField, VariantVal
from pyspark.sql.functions import col, when
__all__ = ["assertDataFrameEqual", "assertSchemaEqual"]
def have_package(name: str) -> bool:
import importlib
return importlib.util.find_spec(name) is not None
have_numpy = have_package("numpy")
numpy_requirement_message = "" if have_numpy else "No module named 'numpy'"
have_scipy = have_package("scipy")
scipy_requirement_message = "" if have_scipy else "No module named 'scipy'"
have_sklearn = have_package("sklearn")
sklearn_requirement_message = "" if have_sklearn else "No module named 'sklearn'"
have_torch = have_package("torch")
torch_requirement_message = "" if have_torch else "No module named 'torch'"
have_torcheval = have_package("torcheval")
torcheval_requirement_message = "" if have_torcheval else "No module named 'torcheval'"
have_deepspeed = have_package("deepspeed")
deepspeed_requirement_message = "" if have_deepspeed else "No module named 'deepspeed'"
have_plotly = have_package("plotly")
plotly_requirement_message = "" if have_plotly else "No module named 'plotly'"
have_matplotlib = have_package("matplotlib")
matplotlib_requirement_message = "" if have_matplotlib else "No module named 'matplotlib'"
have_tabulate = have_package("tabulate")
tabulate_requirement_message = "" if have_tabulate else "No module named 'tabulate'"
have_graphviz = have_package("graphviz")
graphviz_requirement_message = "" if have_graphviz else "No module named 'graphviz'"
have_flameprof = have_package("flameprof")
flameprof_requirement_message = "" if have_flameprof else "No module named 'flameprof'"
have_jinja2 = have_package("jinja2")
jinja2_requirement_message = "" if have_jinja2 else "No module named 'jinja2'"
have_openpyxl = have_package("openpyxl")
openpyxl_requirement_message = "" if have_openpyxl else "No module named 'openpyxl'"
have_yaml = have_package("yaml")
yaml_requirement_message = "" if have_yaml else "No module named 'yaml'"
have_grpc = have_package("grpc")
grpc_requirement_message = "" if have_grpc else "No module named 'grpc'"
have_grpc_status = have_package("grpc_status")
grpc_status_requirement_message = "" if have_grpc_status else "No module named 'grpc_status'"
have_zstandard = have_package("zstandard")
zstandard_requirement_message = "" if have_zstandard else "No module named 'zstandard'"
googleapis_common_protos_requirement_message = ""
try:
from google.rpc import error_details_pb2 # noqa: F401
except ImportError as e:
googleapis_common_protos_requirement_message = str(e)
have_googleapis_common_protos = not googleapis_common_protos_requirement_message
pandas_requirement_message = ""
try:
from pyspark.sql.pandas.utils import require_minimum_pandas_version
require_minimum_pandas_version()
except Exception as e:
# If Pandas version requirement is not satisfied, skip related tests.
pandas_requirement_message = str(e)
have_pandas = not pandas_requirement_message
pyarrow_requirement_message = ""
try:
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
require_minimum_pyarrow_version()
except Exception as e:
# If Arrow version requirement is not satisfied, skip related tests.
pyarrow_requirement_message = str(e)
have_pyarrow = not pyarrow_requirement_message
connect_requirement_message = (
pandas_requirement_message
or pyarrow_requirement_message
or grpc_requirement_message
or googleapis_common_protos_requirement_message
or grpc_status_requirement_message
or zstandard_requirement_message
)
should_test_connect = not connect_requirement_message
is_ansi_mode_test = True
if os.environ.get("SPARK_ANSI_SQL_MODE") == "false":
is_ansi_mode_test = False
ansi_mode_not_supported_message = "ANSI mode is not supported" if is_ansi_mode_test else ""
def read_int(b):
return struct.unpack("!i", b)[0]
def write_int(i):
return struct.pack("!i", i)
def timeout(timeout):
def decorator(func):
def handler(signum, frame):
raise TimeoutError(f"Function {func.__name__} timed out after {timeout} seconds")
def wrapper(*args, **kwargs):
signal.alarm(0)
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
try:
result = func(*args, **kwargs)
finally:
signal.alarm(0)
return result
return wrapper
return decorator
def eventually(
timeout=30.0,
catch_assertions=False,
catch_timeout=False,
quiet=True,
interval=0.1,
expected_exceptions=tuple(),
):
"""
Wait a given amount of time for a condition to pass, else fail with an error.
This is a helper utility for PySpark tests.
Parameters
----------
condition : function
Function that checks for termination conditions. condition() can return:
- True or None: Conditions met. Return without error.
- other value: Conditions not met yet. Continue. Upon timeout,
include last such value in error message.
Note that this method may be called at any time during
streaming execution (e.g., even before any results
have been created).
timeout : int
Number of seconds to wait. Default 30 seconds.
catch_assertions : bool
If False (default), do not catch AssertionErrors.
If True, catch AssertionErrors; continue, but save
error to throw upon timeout.
catch_timeout : bool
If False (default), do not catch TimeoutError.
If True, catch TimeoutError; continue, but save
error to throw upon timeout.
quiet : bool
If True (default), do not print any output.
If False, print output.
interval : float
Number of seconds to wait between attempts. Default 0.1 seconds.
"""
assert timeout > 0
assert isinstance(catch_assertions, bool)
assert isinstance(catch_timeout, bool)
assert isinstance(quiet, bool)
assert isinstance(interval, float)
assert isinstance(expected_exceptions, (tuple, list))
expected_exceptions = list(expected_exceptions)
if catch_assertions:
expected_exceptions.append(AssertionError)
if catch_timeout:
expected_exceptions.append(TimeoutError)
expected_exceptions = tuple(expected_exceptions)
def decorator(condition: Callable) -> Callable:
assert isinstance(condition, Callable)
@functools.wraps(condition)
def wrapper(*args: Any, **kwargs: Any) -> Any:
start_time = time()
lastValue = None
numTries = 0
while time() - start_time < timeout:
numTries += 1
try:
lastValue = condition(*args, **kwargs)
except expected_exceptions as e:
lastValue = e
if lastValue is True or lastValue is None:
return
if not quiet:
print(f"\nAttempt #{numTries} failed!\n{lastValue}")
sleep(interval)
if isinstance(lastValue, expected_exceptions):
raise lastValue
else:
raise AssertionError(
"Test failed due to timeout after %g sec, with last condition returning: %s"
% (timeout, lastValue)
)
return wrapper
return decorator
class QuietTest:
def __init__(self, sc):
self.log4j = sc._jvm.org.apache.log4j
def __enter__(self):
self.old_level = self.log4j.LogManager.getRootLogger().getLevel()
self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL)
def __exit__(self, exc_type, exc_val, exc_tb):
self.log4j.LogManager.getRootLogger().setLevel(self.old_level)
class PySparkBaseTestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
if os.environ.get("PYSPARK_TEST_TIMEOUT"):
faulthandler.register(signal.SIGTERM, file=sys.__stderr__, all_threads=True)
@classmethod
def tearDownClass(cls):
if os.environ.get("PYSPARK_TEST_TIMEOUT"):
faulthandler.unregister(signal.SIGTERM)
class PySparkTestCase(PySparkBaseTestCase):
def setUp(self):
from pyspark import SparkContext
self._old_sys_path = list(sys.path)
class_name = self.__class__.__name__
self.sc = SparkContext("local[4]", class_name)
def tearDown(self):
self.sc.stop()
sys.path = self._old_sys_path
class ReusedPySparkTestCase(PySparkBaseTestCase):
@classmethod
def conf(cls):
"""
Override this in subclasses to supply a more specific conf
"""
return SparkConf()
@classmethod
def setUpClass(cls):
super().setUpClass()
from pyspark import SparkContext
cls.sc = SparkContext(cls.master(), cls.__name__, conf=cls.conf())
@classmethod
def master(cls):
return "local[4]"
@classmethod
def tearDownClass(cls):
try:
cls.sc.stop()
finally:
super().tearDownClass()
def test_assert_classic_mode(self):
from pyspark.sql import is_remote
self.assertFalse(is_remote())
def quiet(self):
from pyspark.testing.utils import QuietTest
return QuietTest(self.sc)
class ByteArrayOutput:
def __init__(self):
self.buffer = bytearray()
def write(self, b):
self.buffer += b
def close(self):
pass
def _terminal_color_support():
try:
# determine if environment supports color
script = "$(test $(tput colors)) && $(test $(tput colors) -ge 8) && echo true || echo false"
return os.popen(script).read()
except Exception:
return False
def _context_diff(actual: List[str], expected: List[str], n: int = 3):
"""
Modified from difflib context_diff API,
see original code here: https://github.com/python/cpython/blob/main/Lib/difflib.py#L1180
"""
def red(s: str) -> str:
red_color = "\033[31m"
no_color = "\033[0m"
return red_color + str(s) + no_color
prefix = dict(insert="+ ", delete="- ", replace="! ", equal=" ")
for group in difflib.SequenceMatcher(None, actual, expected).get_grouped_opcodes(n):
yield "*** actual ***"
if any(tag in {"replace", "delete"} for tag, _, _, _, _ in group):
for tag, i1, i2, _, _ in group:
for line in actual[i1:i2]:
if tag != "equal" and _terminal_color_support():
yield red(prefix[tag] + str(line))
else:
yield prefix[tag] + str(line)
yield "\n"
yield "*** expected ***"
if any(tag in {"replace", "insert"} for tag, _, _, _, _ in group):
for tag, _, _, j1, j2 in group:
for line in expected[j1:j2]:
if tag != "equal" and _terminal_color_support():
yield red(prefix[tag] + str(line))
else:
yield prefix[tag] + str(line)
class PySparkErrorTestUtils:
"""
This util provide functions to accurate and consistent error testing
based on PySpark error classes.
"""
def check_error(
self,
exception: PySparkException,
errorClass: str,
messageParameters: Optional[Dict[str, str]] = None,
query_context_type: Optional[QueryContextType] = None,
fragment: Optional[str] = None,
matchPVals: bool = False,
match_exact_condition_and_parameters: bool = False,
):
"""
Check that the exception has the expected error condition and (optionally) parameters.
By default, condition matches if the exception's condition equals the expected condition
or is a subcondition (e.g. CONDITION.SUBCONDITION when expecting CONDITION). When
passing only the main condition, message parameters are not required to be full (subset
check). Use match_exact_condition_and_parameters=True when a test must require the
exact condition with no subcondition and exact parameters.
"""
query_context = exception.getQueryContext()
assert bool(query_context) == (query_context_type is not None), (
"`query_context_type` is required when QueryContext exists. "
f"QueryContext: {query_context}."
)
# Test if given error is an instance of PySparkException.
self.assertIsInstance(
exception,
PySparkException,
f"checkError requires 'PySparkException', got '{exception.__class__.__name__}'.",
)
# Test error class (exact or prefix match by default)
expected = errorClass
actual_condition = exception.getCondition() or ""
if match_exact_condition_and_parameters:
condition_matches = actual_condition == expected
else:
condition_matches = actual_condition == expected or (
expected and actual_condition.startswith(expected + ".")
)
self.assertTrue(
condition_matches,
f"Expected error class was '{expected}' "
f"(match_exact={match_exact_condition_and_parameters}), got '{actual_condition}'.",
)
# Test message parameters
actual_params = exception.getMessageParameters() or {}
is_prefix_match = not match_exact_condition_and_parameters and actual_condition.startswith(
expected + "."
)
if is_prefix_match:
# When matching by main condition only, only require that passed parameters match.
if messageParameters:
for key, value in messageParameters.items():
self.assertIn(
key,
actual_params,
f"Expected message parameter key '{key}' not found in {actual_params}",
)
if matchPVals:
self.assertRegex(
actual_params[key],
value,
f"Parameter '{key}' value '{actual_params[key]}' "
f"does not match pattern '{value}'",
)
else:
self.assertEqual(
actual_params[key],
value,
f"Parameter '{key}': expected '{value}', "
f"got '{actual_params[key]}'",
)
else:
expected_params = messageParameters if messageParameters is not None else {}
if matchPVals:
self.assertEqual(
len(expected_params),
len(actual_params),
"Expected message parameters count does not match actual message "
f"parameters count: {len(expected_params)}, {len(actual_params)}.",
)
for key, value in expected_params.items():
self.assertIn(
key,
actual_params,
f"Expected message parameter key '{key}' was not found "
"in actual message parameters.",
)
self.assertRegex(
actual_params[key],
value,
f"Expected message parameter value '{value}' does not match "
f"actual message parameter value '{actual_params[key]}'.",
)
else:
self.assertEqual(
expected_params,
actual_params,
f"Expected message parameters was '{expected_params}', got '{actual_params}'",
)
# Test query context
if query_context:
expected = query_context_type
actual_contexts = exception.getQueryContext()
for actual_context in actual_contexts:
actual = actual_context.contextType()
self.assertEqual(
expected, actual, f"Expected QueryContext was '{expected}', got '{actual}'"
)
if actual == QueryContextType.DataFrame:
assert (
fragment is not None
), "`fragment` is required when QueryContextType is DataFrame."
expected = fragment
actual = actual_context.fragment()
self.assertEqual(
expected,
actual,
f"Expected PySpark fragment was '{expected}', got '{actual}'",
)
def assertSchemaEqual(
actual: StructType,
expected: StructType,
ignoreNullable: bool = True,
ignoreColumnOrder: bool = False,
ignoreColumnName: bool = False,
):
__tracebackhide__ = True
r"""
A util function to assert equality between DataFrame schemas `actual` and `expected`.
.. versionadded:: 3.5.0
Parameters
----------
actual : StructType
The DataFrame schema that is being compared or tested.
expected : StructType
The expected schema, for comparison with the actual schema.
ignoreNullable : bool, default True
Specifies whether a column's nullable property is included when checking for
schema equality.
When set to `True` (default), the nullable property of the columns being compared
is not taken into account and the columns will be considered equal even if they have
different nullable settings.
When set to `False`, columns are considered equal only if they have the same nullable
setting.
.. versionadded:: 4.0.0
ignoreColumnOrder : bool, default False
Specifies whether to compare columns in the order they appear in the DataFrame or by
column name.
If set to `False` (default), columns are compared in the order they appear in the
DataFrames.
When set to `True`, a column in the expected DataFrame is compared to the column with the
same name in the actual DataFrame.
.. versionadded:: 4.0.0
ignoreColumnName : bool, default False
Specifies whether to fail the initial schema equality check if the column names in the two
DataFrames are different.
When set to `False` (default), column names are checked and the function fails if they are
different.
When set to `True`, the function will succeed even if column names are different.
Column data types are compared for columns in the order they appear in the DataFrames.
.. versionadded:: 4.0.0
Notes
-----
When assertSchemaEqual fails, the error message uses the Python `difflib` library to display
a diff log of the `actual` and `expected` schemas.
Examples
--------
>>> from pyspark.sql.types import StructType, StructField, ArrayType, IntegerType, DoubleType
>>> s1 = StructType([StructField("names", ArrayType(DoubleType(), True), True)])
>>> s2 = StructType([StructField("names", ArrayType(DoubleType(), True), True)])
>>> assertSchemaEqual(s1, s2) # pass, schemas are identical
Different schemas with `ignoreNullable=False` would fail.
>>> s3 = StructType([StructField("names", ArrayType(DoubleType(), True), False)])
>>> assertSchemaEqual(s1, s3, ignoreNullable=False) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
PySparkAssertionError: [DIFFERENT_SCHEMA] Schemas do not match.
--- actual
+++ expected
- StructType([StructField('names', ArrayType(DoubleType(), True), True)])
? ^^^
+ StructType([StructField('names', ArrayType(DoubleType(), True), False)])
? ^^^^
>>> df1 = spark.createDataFrame(data=[(1, 1000), (2, 3000)], schema=["id", "number"])
>>> df2 = spark.createDataFrame(data=[("1", 1000), ("2", 5000)], schema=["id", "amount"])
>>> assertSchemaEqual(df1.schema, df2.schema) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
PySparkAssertionError: [DIFFERENT_SCHEMA] Schemas do not match.
--- actual
+++ expected
- StructType([StructField('id', LongType(), True), StructField('number', LongType(), True)])
? ^^ ^^^^^
+ StructType([StructField('id', StringType(), True), StructField('amount', LongType(), True)])
? ^^^^ ++++ ^
Compare two schemas ignoring the column order.
>>> s1 = StructType(
... [StructField("a", IntegerType(), True), StructField("b", DoubleType(), True)]
... )
>>> s2 = StructType(
... [StructField("b", DoubleType(), True), StructField("a", IntegerType(), True)]
... )
>>> assertSchemaEqual(s1, s2, ignoreColumnOrder=True)
Compare two schemas ignoring the column names.
>>> s1 = StructType(
... [StructField("a", IntegerType(), True), StructField("c", DoubleType(), True)]
... )
>>> s2 = StructType(
... [StructField("b", IntegerType(), True), StructField("d", DoubleType(), True)]
... )
>>> assertSchemaEqual(s1, s2, ignoreColumnName=True)
"""
if not isinstance(actual, StructType):
raise PySparkTypeError(
errorClass="NOT_STRUCT",
messageParameters={"arg_name": "actual", "arg_type": type(actual).__name__},
)
if not isinstance(expected, StructType):
raise PySparkTypeError(
errorClass="NOT_STRUCT",
messageParameters={"arg_name": "expected", "arg_type": type(expected).__name__},
)
def compare_schemas_ignore_nullable(s1: StructType, s2: StructType):
if len(s1) != len(s2):
return False
zipped = zip_longest(s1, s2)
for sf1, sf2 in zipped:
if not compare_structfields_ignore_nullable(sf1, sf2):
return False
return True
def compare_structfields_ignore_nullable(actualSF: StructField, expectedSF: StructField):
if actualSF is None and expectedSF is None:
return True
elif actualSF is None or expectedSF is None:
return False
if actualSF.name != expectedSF.name:
return False
else:
return compare_datatypes_ignore_nullable(actualSF.dataType, expectedSF.dataType)
def compare_datatypes_ignore_nullable(dt1: Any, dt2: Any):
# checks datatype equality, using recursion to ignore nullable
if dt1.typeName() == dt2.typeName():
if dt1.typeName() == "array":
return compare_datatypes_ignore_nullable(dt1.elementType, dt2.elementType)
elif dt1.typeName() == "map":
return compare_datatypes_ignore_nullable(
dt1.keyType, dt2.keyType
) and compare_datatypes_ignore_nullable(dt1.valueType, dt2.valueType)
elif dt1.typeName() == "decimal":
# Fix for SPARK-51062: Compare precision and scale for decimal types
return dt1.precision == dt2.precision and dt1.scale == dt2.scale
elif dt1.typeName() == "struct":
return compare_schemas_ignore_nullable(dt1, dt2)
else:
return True
else:
return False
if ignoreColumnOrder:
actual = StructType(sorted(actual, key=lambda x: x.name))
expected = StructType(sorted(expected, key=lambda x: x.name))
if ignoreColumnName:
actual = StructType(
[StructField(str(i), field.dataType, field.nullable) for i, field in enumerate(actual)]
)
expected = StructType(
[
StructField(str(i), field.dataType, field.nullable)
for i, field in enumerate(expected)
]
)
if (ignoreNullable and not compare_schemas_ignore_nullable(actual, expected)) or (
not ignoreNullable and actual != expected
):
generated_diff = difflib.ndiff(str(actual).splitlines(), str(expected).splitlines())
error_msg = "\n".join(generated_diff)
raise PySparkAssertionError(
errorClass="DIFFERENT_SCHEMA",
messageParameters={"error_msg": error_msg},
)
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import pandas
import pyspark.pandas
def assertDataFrameEqual(
actual: Union[DataFrame, "pandas.DataFrame", "pyspark.pandas.DataFrame", List[Row]],
expected: Union[DataFrame, "pandas.DataFrame", "pyspark.pandas.DataFrame", List[Row]],
checkRowOrder: bool = False,
rtol: float = 1e-5,
atol: float = 1e-8,
ignoreNullable: bool = True,
ignoreColumnOrder: bool = False,
ignoreColumnName: bool = False,
ignoreColumnType: bool = False,
maxErrors: Optional[int] = None,
showOnlyDiff: bool = False,
includeDiffRows=False,
):
__tracebackhide__ = True
r"""
A util function to assert equality between `actual` and `expected`
(DataFrames or lists of Rows), with optional parameters `checkRowOrder`, `rtol`, and `atol`.
Supports Spark, Spark Connect, pandas, and pandas-on-Spark DataFrames.
For more information about pandas-on-Spark DataFrame equality, see the docs for
`assertPandasOnSparkEqual`.
.. versionadded:: 3.5.0
Parameters
----------
actual : DataFrame (Spark, Spark Connect, pandas, or pandas-on-Spark) or list of Rows
The DataFrame that is being compared or tested.
expected : DataFrame (Spark, Spark Connect, pandas, or pandas-on-Spark) or list of Rows
The expected result of the operation, for comparison with the actual result.
checkRowOrder : bool, optional
A flag indicating whether the order of rows should be considered in the comparison.
If set to `False` (default), the row order is not taken into account.
If set to `True`, the order of rows is important and will be checked during comparison.
(See Notes)
rtol : float, optional
The relative tolerance, used in asserting approximate equality for float values in actual
and expected. Set to 1e-5 by default. (See Notes)
atol : float, optional
The absolute tolerance, used in asserting approximate equality for float values in actual
and expected. Set to 1e-8 by default. (See Notes)
ignoreNullable : bool, default True
Specifies whether a column's nullable property is included when checking for
schema equality.
When set to `True` (default), the nullable property of the columns being compared
is not taken into account and the columns will be considered equal even if they have
different nullable settings.
When set to `False`, columns are considered equal only if they have the same nullable
setting.
.. versionadded:: 4.0.0
ignoreColumnOrder : bool, default False
Specifies whether to compare columns in the order they appear in the DataFrame or by
column name.
If set to `False` (default), columns are compared in the order they appear in the
DataFrames.
When set to `True`, a column in the expected DataFrame is compared to the column with the
same name in the actual DataFrame.
.. versionadded:: 4.0.0
ignoreColumnName : bool, default False
Specifies whether to fail the initial schema equality check if the column names in the two
DataFrames are different.
When set to `False` (default), column names are checked and the function fails if they are
different.
When set to `True`, the function will succeed even if column names are different.
Column data types are compared for columns in the order they appear in the DataFrames.
.. versionadded:: 4.0.0
ignoreColumnType : bool, default False
Specifies whether to ignore the data type of the columns when comparing.
When set to `False` (default), column data types are checked and the function fails if they
are different.
When set to `True`, the schema equality check will succeed even if column data types are
different and the function will attempt to compare rows.
.. versionadded:: 4.0.0
maxErrors : bool, optional
The maximum number of row comparison failures to encounter before returning.
When this number of row comparisons have failed, the function returns independent of how
many rows have been compared.
Set to None by default which means compare all rows independent of number of failures.
.. versionadded:: 4.0.0
showOnlyDiff : bool, default False
If set to `True`, the error message will only include rows that are different.
If set to `False` (default), the error message will include all rows
(when there is at least one row that is different).
.. versionadded:: 4.0.0
includeDiffRows: bool, False
If set to `True`, the unequal rows are included in PySparkAssertionError for further
debugging. If set to `False` (default), the unequal rows are not returned as a data set.
.. versionadded:: 4.0.0
Notes
-----
When `assertDataFrameEqual` fails, the error message uses the Python `difflib` library to
display a diff log of each row that differs in `actual` and `expected`.
For `checkRowOrder`, note that PySpark DataFrame ordering is non-deterministic, unless
explicitly sorted.
Note that schema equality is checked only when `expected` is a DataFrame (not a list of Rows).
For DataFrames with float/decimal values, assertDataFrame asserts approximate equality.
Two float/decimal values a and b are approximately equal if the following equation is True:
``absolute(a - b) <= (atol + rtol * absolute(b))``.
`ignoreColumnOrder` cannot be set to `True` if `ignoreColumnNames` is also set to `True`.
`ignoreColumnNames` cannot be set to `True` if `ignoreColumnOrder` is also set to `True`.
Examples
--------
>>> df1 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"])
>>> df2 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"])
>>> assertDataFrameEqual(df1, df2) # pass, DataFrames are identical
>>> df1 = spark.createDataFrame(data=[("1", 0.1), ("2", 3.23)], schema=["id", "amount"])
>>> df2 = spark.createDataFrame(data=[("1", 0.109), ("2", 3.23)], schema=["id", "amount"])
>>> assertDataFrameEqual(df1, df2, rtol=1e-1) # pass, DataFrames are approx equal by rtol
>>> df1 = spark.createDataFrame(data=[(1, 1000), (2, 3000)], schema=["id", "amount"])
>>> list_of_rows = [Row(1, 1000), Row(2, 3000)]
>>> assertDataFrameEqual(df1, list_of_rows) # pass, actual and expected data are equal
>>> import pyspark.pandas as ps # doctest: +SKIP
>>> df1 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]}) # doctest: +SKIP
>>> df2 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]}) # doctest: +SKIP
>>> # pass, pandas-on-Spark DataFrames are equal
>>> assertDataFrameEqual(df1, df2) # doctest: +SKIP
>>> df1 = spark.createDataFrame(
... data=[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], schema=["id", "amount"])
>>> df2 = spark.createDataFrame(
... data=[("1", 1001.00), ("2", 3000.00), ("3", 2003.00)], schema=["id", "amount"])
>>> assertDataFrameEqual(df1, df2) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 66.66667 % )
*** actual ***
! Row(id='1', amount=1000.0)
Row(id='2', amount=3000.0)
! Row(id='3', amount=2000.0)
*** expected ***
! Row(id='1', amount=1001.0)
Row(id='2', amount=3000.0)
! Row(id='3', amount=2003.0)
Example for ignoreNullable
>>> from pyspark.sql.types import StructType, StructField, StringType, LongType
>>> df1_nullable = spark.createDataFrame(
... data=[(1000, "1"), (5000, "2")],
... schema=StructType(
... [StructField("amount", LongType(), True), StructField("id", StringType(), True)]
... )
... )
>>> df2_nullable = spark.createDataFrame(
... data=[(1000, "1"), (5000, "2")],
... schema=StructType(
... [StructField("amount", LongType(), True), StructField("id", StringType(), False)]
... )
... )
>>> assertDataFrameEqual(df1_nullable, df2_nullable, ignoreNullable=True) # pass
>>> assertDataFrameEqual(
... df1_nullable, df2_nullable, ignoreNullable=False
... ) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
PySparkAssertionError: [DIFFERENT_SCHEMA] Schemas do not match.
--- actual
+++ expected
- StructType([StructField('amount', LongType(), True), StructField('id', StringType(), True)])
? ^^^
+ StructType([StructField('amount', LongType(), True), StructField('id', StringType(), False)])
? ^^^^
Example for ignoreColumnOrder
>>> df1_col_order = spark.createDataFrame(
... data=[(1000, "1"), (5000, "2")], schema=["amount", "id"]
... )
>>> df2_col_order = spark.createDataFrame(
... data=[("1", 1000), ("2", 5000)], schema=["id", "amount"]
... )
>>> assertDataFrameEqual(df1_col_order, df2_col_order, ignoreColumnOrder=True)
Example for ignoreColumnName
>>> df1_col_names = spark.createDataFrame(
... data=[(1000, "1"), (5000, "2")], schema=["amount", "identity"]
... )
>>> df2_col_names = spark.createDataFrame(
... data=[(1000, "1"), (5000, "2")], schema=["amount", "id"]
... )
>>> assertDataFrameEqual(df1_col_names, df2_col_names, ignoreColumnName=True)
Example for ignoreColumnType
>>> df1_col_types = spark.createDataFrame(
... data=[(1000, "1"), (5000, "2")], schema=["amount", "id"]
... )
>>> df2_col_types = spark.createDataFrame(
... data=[(1000.0, "1"), (5000.0, "2")], schema=["amount", "id"]
... )
>>> assertDataFrameEqual(df1_col_types, df2_col_types, ignoreColumnType=True)
Example for maxErrors (will only report the first mismatching row)
>>> df1 = spark.createDataFrame([(1, "A"), (2, "B"), (3, "C")])
>>> df2 = spark.createDataFrame([(1, "A"), (2, "X"), (3, "Y")])
>>> assertDataFrameEqual(df1, df2, maxErrors=1) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 33.33333 % )
*** actual ***
Row(_1=1, _2='A')
! Row(_1=2, _2='B')
*** expected ***
Row(_1=1, _2='A')
! Row(_1=2, _2='X')
Example for showOnlyDiff (will only report the mismatching rows)
>>> df1 = spark.createDataFrame([(1, "A"), (2, "B"), (3, "C")])
>>> df2 = spark.createDataFrame([(1, "A"), (2, "X"), (3, "Y")])
>>> assertDataFrameEqual(df1, df2, showOnlyDiff=True) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 66.66667 % )
*** actual ***
! Row(_1=2, _2='B')
! Row(_1=3, _2='C')
*** expected ***
! Row(_1=2, _2='X')
! Row(_1=3, _2='Y')
The `includeDiffRows` parameter can be used to include the rows that did not match
in the PySparkAssertionError. This can be useful for debugging or further analysis.
>>> df1 = spark.createDataFrame(
... data=[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], schema=["id", "amount"])
>>> df2 = spark.createDataFrame(
... data=[("1", 1001.00), ("2", 3000.00), ("3", 2003.00)], schema=["id", "amount"])
>>> try:
... assertDataFrameEqual(df1, df2, includeDiffRows=True)
... except PySparkAssertionError as e:
... spark.createDataFrame(e.data).show() # doctest: +NORMALIZE_WHITESPACE
+-----------+-----------+
| _1| _2|
+-----------+-----------+
|{1, 1000.0}|{1, 1001.0}|
|{3, 2000.0}|{3, 2003.0}|
+-----------+-----------+
"""
if actual is None and expected is None:
return True
elif actual is None:
raise PySparkAssertionError(
errorClass="INVALID_TYPE_DF_EQUALITY_ARG",
messageParameters={
"expected_type": "Union[DataFrame, ps.DataFrame, List[Row]]",
"arg_name": "actual",
"actual_type": None,
},
)
elif expected is None:
raise PySparkAssertionError(
errorClass="INVALID_TYPE_DF_EQUALITY_ARG",
messageParameters={
"expected_type": "Union[DataFrame, ps.DataFrame, List[Row]]",
"arg_name": "expected",
"actual_type": None,
},
)