Skip to content

Commit 805183b

Browse files
authored
feat: enable list of paths for read_csv (#824)
1 parent b2982ec commit 805183b

File tree

3 files changed

+28
-10
lines changed

3 files changed

+28
-10
lines changed

python/datafusion/context.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,7 @@ def read_json(
883883

884884
def read_csv(
885885
self,
886-
path: str | pathlib.Path,
886+
path: str | pathlib.Path | list[str] | list[pathlib.Path],
887887
schema: pyarrow.Schema | None = None,
888888
has_header: bool = True,
889889
delimiter: str = ",",
@@ -914,9 +914,12 @@ def read_csv(
914914
"""
915915
if table_partition_cols is None:
916916
table_partition_cols = []
917+
918+
path = [str(p) for p in path] if isinstance(path, list) else str(path)
919+
917920
return DataFrame(
918921
self.ctx.read_csv(
919-
str(path),
922+
path,
920923
schema,
921924
has_header,
922925
delimiter,

python/datafusion/tests/test_context.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,22 @@ def test_read_csv(ctx):
484484
csv_df.select(column("c1")).show()
485485

486486

487+
def test_read_csv_list(ctx):
488+
csv_df = ctx.read_csv(path=["testing/data/csv/aggregate_test_100.csv"])
489+
expected = csv_df.count() * 2
490+
491+
double_csv_df = ctx.read_csv(
492+
path=[
493+
"testing/data/csv/aggregate_test_100.csv",
494+
"testing/data/csv/aggregate_test_100.csv",
495+
]
496+
)
497+
actual = double_csv_df.count()
498+
499+
double_csv_df.select(column("c1")).show()
500+
assert actual == expected
501+
502+
487503
def test_read_csv_compressed(ctx, tmp_path):
488504
test_data_path = "testing/data/csv/aggregate_test_100.csv"
489505

src/context.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ impl PySessionContext {
805805
file_compression_type=None))]
806806
pub fn read_csv(
807807
&self,
808-
path: PathBuf,
808+
path: &Bound<'_, PyAny>,
809809
schema: Option<PyArrowType<Schema>>,
810810
has_header: bool,
811811
delimiter: &str,
@@ -815,10 +815,6 @@ impl PySessionContext {
815815
file_compression_type: Option<String>,
816816
py: Python,
817817
) -> PyResult<PyDataFrame> {
818-
let path = path
819-
.to_str()
820-
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
821-
822818
let delimiter = delimiter.as_bytes();
823819
if delimiter.len() != 1 {
824820
return Err(PyValueError::new_err(
@@ -833,13 +829,16 @@ impl PySessionContext {
833829
.file_extension(file_extension)
834830
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
835831
.file_compression_type(parse_file_compression_type(file_compression_type)?);
832+
options.schema = schema.as_ref().map(|x| &x.0);
836833

837-
if let Some(py_schema) = schema {
838-
options.schema = Some(&py_schema.0);
839-
let result = self.ctx.read_csv(path, options);
834+
if path.is_instance_of::<PyList>() {
835+
let paths = path.extract::<Vec<String>>()?;
836+
let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
837+
let result = self.ctx.read_csv(paths, options);
840838
let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?);
841839
Ok(df)
842840
} else {
841+
let path = path.extract::<String>()?;
843842
let result = self.ctx.read_csv(path, options);
844843
let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?);
845844
Ok(df)

0 commit comments

Comments
 (0)