Skip to content

Commit adbcae3

Browse files
authored
[DataFrame] - Add union and union_distinct bindings for DataFrame (#35)
* fix: conflicting * fix: python linter * fix: flake8 W503 isssue * fix: test error
1 parent 2cf6abd commit adbcae3

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

datafusion/tests/test_dataframe.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
from datafusion import DataFrame, SessionContext, column, literal, udf
2323

2424

25+
@pytest.fixture
26+
def ctx():
27+
return SessionContext()
28+
29+
2530
@pytest.fixture
2631
def df():
2732
ctx = SessionContext()
@@ -323,3 +328,56 @@ def test_collect_partitioned():
323328
)
324329

325330
assert [[batch]] == ctx.create_dataframe([[batch]]).collect_partitioned()
331+
332+
333+
def test_union(ctx):
334+
batch = pa.RecordBatch.from_arrays(
335+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
336+
names=["a", "b"],
337+
)
338+
df_a = ctx.create_dataframe([[batch]])
339+
340+
batch = pa.RecordBatch.from_arrays(
341+
[pa.array([3, 4, 5]), pa.array([6, 7, 8])],
342+
names=["a", "b"],
343+
)
344+
df_b = ctx.create_dataframe([[batch]])
345+
346+
batch = pa.RecordBatch.from_arrays(
347+
[pa.array([1, 2, 3, 3, 4, 5]), pa.array([4, 5, 6, 6, 7, 8])],
348+
names=["a", "b"],
349+
)
350+
df_c = ctx.create_dataframe([[batch]]).sort(
351+
column("a").sort(ascending=True)
352+
)
353+
354+
df_a_u_b = df_a.union(df_b).sort(column("a").sort(ascending=True))
355+
356+
assert df_c.collect() == df_a_u_b.collect()
357+
358+
359+
def test_union_distinct(ctx):
360+
batch = pa.RecordBatch.from_arrays(
361+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
362+
names=["a", "b"],
363+
)
364+
df_a = ctx.create_dataframe([[batch]])
365+
366+
batch = pa.RecordBatch.from_arrays(
367+
[pa.array([3, 4, 5]), pa.array([6, 7, 8])],
368+
names=["a", "b"],
369+
)
370+
df_b = ctx.create_dataframe([[batch]])
371+
372+
batch = pa.RecordBatch.from_arrays(
373+
[pa.array([1, 2, 3, 4, 5]), pa.array([4, 5, 6, 7, 8])],
374+
names=["a", "b"],
375+
)
376+
df_c = ctx.create_dataframe([[batch]]).sort(
377+
column("a").sort(ascending=True)
378+
)
379+
380+
df_a_u_b = df_a.union(df_b, True).sort(column("a").sort(ascending=True))
381+
382+
assert df_c.collect() == df_a_u_b.collect()
383+
assert df_c.collect() == df_a_u_b.collect()

src/dataframe.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,26 @@ impl PyDataFrame {
204204
Ok(Self::new(new_df))
205205
}
206206

207+
/// Calculate the union of two `DataFrame`s, preserving duplicate rows.The
208+
/// two `DataFrame`s must have exactly the same schema
209+
#[args(distinct = false)]
210+
fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyResult<Self> {
211+
let new_df = if distinct {
212+
self.df.union_distinct(py_df.df)?
213+
} else {
214+
self.df.union(py_df.df)?
215+
};
216+
217+
Ok(Self::new(new_df))
218+
}
219+
220+
/// Calculate the distinct union of two `DataFrame`s. The
221+
/// two `DataFrame`s must have exactly the same schema
222+
fn union_distinct(&self, py_df: PyDataFrame) -> PyResult<Self> {
223+
let new_df = self.df.union_distinct(py_df.df)?;
224+
Ok(Self::new(new_df))
225+
}
226+
207227
/// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema
208228
fn intersect(&self, py_df: PyDataFrame) -> PyResult<Self> {
209229
let new_df = self.df.intersect(py_df.df)?;

0 commit comments

Comments
 (0)