@@ -714,12 +714,6 @@ impl Shard {
714714 & CODEC
715715 }
716716}
717- #[ pymethods]
718- impl Shard {
719- pub fn stream ( & self ) -> PyResult < PyRecordBatchStream > {
720- shard_stream ( self . serialized_plan . as_ref ( ) )
721- }
722- }
723717
724718#[ pyclass( get_all) ]
725719#[ derive( Debug , Clone ) ]
@@ -795,7 +789,7 @@ async fn split_physical_plan(df: &DataFrame, num_shards: usize) -> Result<Distri
795789}
796790
797791#[ pyfunction]
798- pub fn shard_stream ( serialized_shard_plan : & [ u8 ] ) -> PyResult < PyRecordBatchStream > {
792+ pub fn shard_stream ( serialized_shard_plan : & [ u8 ] , py : Python ) -> PyResult < PyRecordBatchStream > {
799793 deltalake:: ensure_initialized ( ) ;
800794 let registry = MemoryFunctionRegistry :: default ( ) ;
801795 let runtime = RuntimeEnvBuilder :: new ( ) . build ( ) ?;
@@ -805,8 +799,13 @@ pub fn shard_stream(serialized_shard_plan: &[u8]) -> PyResult<PyRecordBatchStrea
805799 . map_err ( py_datafusion_err) ?;
806800 let plan = node. try_into_physical_plan ( & registry, & runtime, & codec) ?;
807801 println ! ( "Shard plan: {}" , displayable( plan. as_ref( ) ) . one_line( ) ) ;
808- let ctx = TaskContext :: default ( ) ;
809- execute_stream ( plan, Arc :: new ( ctx) ) . map ( PyRecordBatchStream :: new) . map_err ( py_datafusion_err)
802+ let stream_with_runtime = get_tokio_runtime ( ) . 0 . spawn ( async move {
803+ execute_stream ( plan, Arc :: new ( TaskContext :: default ( ) ) )
804+ } ) ;
805+ wait_for_future ( py, stream_with_runtime)
806+ . map_err ( py_datafusion_err) ?
807+ . map ( PyRecordBatchStream :: new)
808+ . map_err ( py_datafusion_err)
810809}
811810
812811/// Print DataFrame
0 commit comments