Skip to content

Commit 7e6d090

Browse files
committed
inject tokio runtime
1 parent 625f8fa commit 7e6d090

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

src/dataframe.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -714,12 +714,6 @@ impl Shard {
714714
&CODEC
715715
}
716716
}
717-
#[pymethods]
718-
impl Shard {
719-
pub fn stream(&self) -> PyResult<PyRecordBatchStream> {
720-
shard_stream(self.serialized_plan.as_ref())
721-
}
722-
}
723717

724718
#[pyclass(get_all)]
725719
#[derive(Debug, Clone)]
@@ -795,7 +789,7 @@ async fn split_physical_plan(df: &DataFrame, num_shards: usize) -> Result<Distri
795789
}
796790

797791
#[pyfunction]
798-
pub fn shard_stream(serialized_shard_plan: &[u8]) -> PyResult<PyRecordBatchStream> {
792+
pub fn shard_stream(serialized_shard_plan: &[u8], py: Python) -> PyResult<PyRecordBatchStream> {
799793
deltalake::ensure_initialized();
800794
let registry = MemoryFunctionRegistry::default();
801795
let runtime = RuntimeEnvBuilder::new().build()?;
@@ -805,8 +799,13 @@ pub fn shard_stream(serialized_shard_plan: &[u8]) -> PyResult<PyRecordBatchStrea
805799
.map_err(py_datafusion_err)?;
806800
let plan = node.try_into_physical_plan(&registry, &runtime, &codec)?;
807801
println!("Shard plan: {}", displayable(plan.as_ref()).one_line());
808-
let ctx = TaskContext::default();
809-
execute_stream(plan, Arc::new(ctx)).map(PyRecordBatchStream::new).map_err(py_datafusion_err)
802+
let stream_with_runtime = get_tokio_runtime().0.spawn(async move {
803+
execute_stream(plan, Arc::new(TaskContext::default()))
804+
});
805+
wait_for_future(py, stream_with_runtime)
806+
.map_err(py_datafusion_err)?
807+
.map(PyRecordBatchStream::new)
808+
.map_err(py_datafusion_err)
810809
}
811810

812811
/// Print DataFrame

0 commit comments

Comments
 (0)