|
22 | 22 | from datafusion import DataFrame, SessionContext, column, literal, udf
|
23 | 23 |
|
24 | 24 |
|
| 25 | +@pytest.fixture |
| 26 | +def ctx(): |
| 27 | + return SessionContext() |
| 28 | + |
| 29 | + |
25 | 30 | @pytest.fixture
|
26 | 31 | def df():
|
27 | 32 | ctx = SessionContext()
|
@@ -323,3 +328,56 @@ def test_collect_partitioned():
|
323 | 328 | )
|
324 | 329 |
|
325 | 330 | assert [[batch]] == ctx.create_dataframe([[batch]]).collect_partitioned()
|
| 331 | + |
| 332 | + |
| 333 | +def test_union(ctx): |
| 334 | + batch = pa.RecordBatch.from_arrays( |
| 335 | + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| 336 | + names=["a", "b"], |
| 337 | + ) |
| 338 | + df_a = ctx.create_dataframe([[batch]]) |
| 339 | + |
| 340 | + batch = pa.RecordBatch.from_arrays( |
| 341 | + [pa.array([3, 4, 5]), pa.array([6, 7, 8])], |
| 342 | + names=["a", "b"], |
| 343 | + ) |
| 344 | + df_b = ctx.create_dataframe([[batch]]) |
| 345 | + |
| 346 | + batch = pa.RecordBatch.from_arrays( |
| 347 | + [pa.array([1, 2, 3, 3, 4, 5]), pa.array([4, 5, 6, 6, 7, 8])], |
| 348 | + names=["a", "b"], |
| 349 | + ) |
| 350 | + df_c = ctx.create_dataframe([[batch]]).sort( |
| 351 | + column("a").sort(ascending=True) |
| 352 | + ) |
| 353 | + |
| 354 | + df_a_u_b = df_a.union(df_b).sort(column("a").sort(ascending=True)) |
| 355 | + |
| 356 | + assert df_c.collect() == df_a_u_b.collect() |
| 357 | + |
| 358 | + |
| 359 | +def test_union_distinct(ctx): |
| 360 | + batch = pa.RecordBatch.from_arrays( |
| 361 | + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| 362 | + names=["a", "b"], |
| 363 | + ) |
| 364 | + df_a = ctx.create_dataframe([[batch]]) |
| 365 | + |
| 366 | + batch = pa.RecordBatch.from_arrays( |
| 367 | + [pa.array([3, 4, 5]), pa.array([6, 7, 8])], |
| 368 | + names=["a", "b"], |
| 369 | + ) |
| 370 | + df_b = ctx.create_dataframe([[batch]]) |
| 371 | + |
| 372 | + batch = pa.RecordBatch.from_arrays( |
| 373 | + [pa.array([1, 2, 3, 4, 5]), pa.array([4, 5, 6, 7, 8])], |
| 374 | + names=["a", "b"], |
| 375 | + ) |
| 376 | + df_c = ctx.create_dataframe([[batch]]).sort( |
| 377 | + column("a").sort(ascending=True) |
| 378 | + ) |
| 379 | + |
| 380 | + df_a_u_b = df_a.union(df_b, True).sort(column("a").sort(ascending=True)) |
| 381 | + |
| 382 | + assert df_c.collect() == df_a_u_b.collect() |
| 383 | + assert df_c.collect() == df_a_u_b.collect() |
0 commit comments