1818use std:: ffi:: CString ;
1919use std:: sync:: Arc ;
2020
21+ use crate :: errors:: py_datafusion_err;
22+ use crate :: expr:: sort_expr:: to_sort_expressions;
23+ use crate :: physical_plan:: PyExecutionPlan ;
24+ use crate :: record_batch:: PyRecordBatchStream ;
25+ use crate :: sql:: logical:: PyLogicalPlan ;
26+ use crate :: utils:: { get_tokio_runtime, validate_pycapsule, wait_for_future} ;
27+ use crate :: {
28+ errors:: DataFusionError ,
29+ expr:: { sort_expr:: PySortExpr , PyExpr } ,
30+ } ;
2131use arrow:: array:: { new_null_array, RecordBatch , RecordBatchIterator , RecordBatchReader } ;
2232use arrow:: compute:: can_cast_types;
2333use arrow:: error:: ArrowError ;
@@ -31,12 +41,10 @@ use datafusion::common::stats::Precision;
3141use datafusion:: common:: { DFSchema , UnnestOptions } ;
3242use datafusion:: config:: { ConfigOptions , CsvOptions , TableParquetOptions } ;
3343use datafusion:: dataframe:: { DataFrame , DataFrameWriteOptions } ;
34- use datafusion:: datasource:: physical_plan:: { FileScanConfig , ParquetExec } ;
35- use datafusion:: datasource:: physical_plan:: parquet:: ParquetExecBuilder ;
3644use datafusion:: execution:: runtime_env:: RuntimeEnvBuilder ;
3745use datafusion:: execution:: { SendableRecordBatchStream , TaskContext } ;
3846use datafusion:: parquet:: basic:: { BrotliLevel , Compression , GzipLevel , ZstdLevel } ;
39- use datafusion:: physical_plan:: { displayable , execute_stream , ExecutionPlan } ;
47+ use datafusion:: physical_plan:: { execute_stream , execute_stream_partitioned , ExecutionPlan } ;
4048use datafusion:: prelude:: * ;
4149use datafusion_expr:: registry:: MemoryFunctionRegistry ;
4250use datafusion_proto:: physical_plan:: { AsExecutionPlan , PhysicalExtensionCodec } ;
@@ -48,16 +56,6 @@ use pyo3::prelude::*;
4856use pyo3:: pybacked:: PyBackedStr ;
4957use pyo3:: types:: { PyCapsule , PyTuple , PyTupleMethods } ;
5058use tokio:: task:: JoinHandle ;
51- use crate :: errors:: py_datafusion_err;
52- use crate :: expr:: sort_expr:: to_sort_expressions;
53- use crate :: physical_plan:: PyExecutionPlan ;
54- use crate :: record_batch:: PyRecordBatchStream ;
55- use crate :: sql:: logical:: PyLogicalPlan ;
56- use crate :: utils:: { get_tokio_runtime, validate_pycapsule, wait_for_future} ;
57- use crate :: {
58- errors:: DataFusionError ,
59- expr:: { sort_expr:: PySortExpr , PyExpr } ,
60- } ;
6159use crate :: common:: df_schema:: PyDFSchema ;
6260
6361/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
@@ -661,145 +659,70 @@ impl PyDataFrame {
661659 Ok ( wait_for_future ( py, self . df . as_ref ( ) . clone ( ) . count ( ) ) ?)
662660 }
663661
664- fn distributed_plan ( & self , num_shards : usize , py : Python < ' _ > ) -> PyResult < DistributedPlan > {
665- let distributed_plan = wait_for_future ( py , split_physical_plan ( & self . df , num_shards ) )
666- . map_err ( py_datafusion_err) ?;
667- Ok ( distributed_plan )
662+ fn distributed_plan ( & self , parallelism : usize , py : Python < ' _ > ) -> PyResult < DistributedPlan > {
663+ let future_plan = self . df . as_ref ( ) . clone ( ) . create_physical_plan ( ) ;
664+ let physical_plan = wait_for_future ( py , future_plan ) . map_err ( py_datafusion_err) ?;
665+ DistributedPlan :: try_new ( physical_plan , parallelism ) . map_err ( py_datafusion_err )
668666 }
669667
670668}
671669
672670#[ pyclass( get_all) ]
673671#[ derive( Debug , Clone ) ]
674- pub struct Statistics {
672+ pub struct DistributedPlan {
673+ physical_plan : Vec < u8 > ,
674+ schema : PyDFSchema ,
675+ partitions : usize ,
675676 num_bytes : Option < usize > ,
676677 num_rows : Option < usize > ,
677678}
678679
679- impl Statistics {
680- fn new ( plan : & dyn ExecutionPlan ) -> Self {
680+ fn codec ( ) -> & ' static dyn PhysicalExtensionCodec {
681+ static CODEC : DeltaPhysicalCodec = DeltaPhysicalCodec { } ;
682+ & CODEC
683+ }
684+
685+ impl DistributedPlan {
686+ fn try_new ( plan : Arc < dyn ExecutionPlan > , parallelism : usize ) -> Result < Self , DataFusionError > {
681687 fn extract ( prec : Precision < usize > ) -> Option < usize > {
682688 match prec {
683- Precision :: Exact ( n) | Precision :: Inexact ( n ) => Some ( n) ,
684- Precision :: Absent => None ,
689+ Precision :: Exact ( n) => Some ( n) ,
690+ _ => None ,
685691 }
686692 }
687- if let Ok ( stats) = plan. statistics ( ) {
693+ let ( num_bytes , num_rows ) = if let Ok ( stats) = plan. statistics ( ) {
688694 let num_bytes = extract ( stats. total_byte_size ) ;
689695 let num_rows = extract ( stats. num_rows ) ;
690- Statistics { num_bytes, num_rows}
696+ ( num_bytes, num_rows)
691697 } else {
692- Statistics { num_bytes : None , num_rows : None }
693- }
694- }
695- }
696-
697- #[ pyclass( get_all) ]
698- #[ derive( Debug , Clone ) ]
699- pub struct Shard {
700- stats : Statistics ,
701- serialized_plan : Vec < u8 > ,
702- }
698+ ( None , None )
699+ } ;
703700
704- impl Shard {
705- pub fn try_new ( plan : & Arc < dyn ExecutionPlan > ) -> Result < Self , DataFusionError > {
706- let stats = Statistics :: new ( plan. as_ref ( ) ) ;
707- let serialized_plan = PhysicalPlanNode :: try_from_physical_plan ( plan. clone ( ) , Self :: codec ( ) ) ?
701+ let schema = DFSchema :: try_from ( plan. schema ( ) )
702+ . map ( PyDFSchema :: from)
703+ . map_err ( py_datafusion_err) ?;
704+ let plan = plan. repartitioned ( parallelism, & ConfigOptions :: default ( ) )
705+ . map_err ( py_datafusion_err) ?
706+ . unwrap_or ( plan) ;
707+ let partitions = plan. properties ( ) . partitioning . partition_count ( ) ;
708+ let physical_plan = PhysicalPlanNode :: try_from_physical_plan ( plan, codec ( ) ) ?
708709 . encode_to_vec ( ) ;
709- Ok ( Self { stats, serialized_plan } )
710- }
711-
712- fn codec ( ) -> & ' static dyn PhysicalExtensionCodec {
713- static CODEC : DeltaPhysicalCodec = DeltaPhysicalCodec { } ;
714- & CODEC
710+ Ok ( Self { physical_plan, schema, partitions, num_bytes, num_rows } )
715711 }
716- }
717712
718- #[ pyclass( get_all) ]
719- #[ derive( Debug , Clone ) ]
720- pub struct DistributedPlan {
721- shards : Vec < Shard > ,
722- schema : PyDFSchema ,
723- stats : Statistics ,
724- }
725-
726- async fn split_physical_plan ( df : & DataFrame , num_shards : usize ) -> Result < DistributedPlan , DataFusionError > {
727- fn split ( plan : & Arc < dyn ExecutionPlan > , num_shards : usize ) -> Vec < Arc < dyn ExecutionPlan > > {
728- if let Some ( parquet) = plan. as_any ( ) . downcast_ref :: < ParquetExec > ( ) {
729- let parquet = if let Ok ( Some ( repartitioned) ) = parquet. repartitioned ( num_shards, & ConfigOptions :: default ( ) ) {
730- repartitioned. as_any ( ) . downcast_ref :: < ParquetExec > ( )
731- . expect ( "repartitioned parquet is no longer parquet" )
732- . clone ( )
733- } else { // repartition failed
734- parquet. clone ( )
735- } ;
736- let config = parquet. base_config ( ) ;
737- config
738- . file_groups
739- . iter ( )
740- . map ( |shard| {
741- FileScanConfig {
742- object_store_url : config. object_store_url . clone ( ) ,
743- file_schema : config. file_schema . clone ( ) ,
744- file_groups : shard. iter ( ) . map ( |file| vec ! [ file. to_owned( ) ] ) . collect ( ) , // one partition per file
745- statistics : config. statistics . clone ( ) ,
746- projection : config. projection . clone ( ) ,
747- projection_deep : config. projection_deep . clone ( ) ,
748- limit : config. limit ,
749- table_partition_cols : config. table_partition_cols . clone ( ) ,
750- output_ordering : config. output_ordering . clone ( ) ,
751- }
752- } )
753- . map ( |config| {
754- let mut builder = ParquetExecBuilder :: new ( config)
755- . with_table_parquet_options ( parquet. table_parquet_options ( ) . clone ( ) ) ;
756- if let Some ( predicate) = parquet. predicate ( ) {
757- builder = builder. with_predicate ( predicate. clone ( ) ) ;
758- }
759- builder. build_arc ( )
760- } )
761- . map ( |shard| shard as Arc < dyn ExecutionPlan > )
762- . collect ( )
763- } else if plan. children ( ) . len ( ) == 0 { // TODO: split leaf nodes other than parquet?
764- vec ! [ plan. clone( ) ]
765- } else if plan. children ( ) . len ( ) == 1 {
766- plan. children ( ) . into_iter ( )
767- . flat_map ( |child| {
768- split ( child, num_shards)
769- . into_iter ( )
770- . map ( |shard| plan. clone ( ) . with_new_children ( vec ! [ shard] ) )
771- } )
772- . collect :: < Result < Vec < _ > , _ > > ( )
773- . expect ( "Unable to split plan" )
774- } else {
775- panic ! (
776- "Only leaf or single-child plans are supported, found {}" ,
777- displayable( plan. as_ref( ) ) . one_line( )
778- )
779- }
780- }
781- let plan = df. clone ( ) . create_physical_plan ( ) . await ?;
782- let shards = split ( & plan, num_shards)
783- . iter ( )
784- . map ( Shard :: try_new)
785- . collect :: < Result < Vec < _ > , _ > > ( ) ?;
786- let schema = DFSchema :: try_from ( plan. schema ( ) . as_ref ( ) . to_owned ( ) ) ?. into ( ) ;
787- let stats = Statistics :: new ( plan. as_ref ( ) ) ;
788- Ok ( DistributedPlan { shards, schema, stats } )
789713}
790714
791715#[ pyfunction]
792- pub fn shard_stream ( serialized_shard_plan : & [ u8 ] , py : Python ) -> PyResult < PyRecordBatchStream > {
716+ pub fn partition_stream ( serialized_plan : & [ u8 ] , partition : usize , py : Python ) -> PyResult < PyRecordBatchStream > {
793717 deltalake:: ensure_initialized ( ) ;
794718 let registry = MemoryFunctionRegistry :: default ( ) ;
795719 let runtime = RuntimeEnvBuilder :: new ( ) . build ( ) ?;
796- let codec = DeltaPhysicalCodec { } ;
797- let node = PhysicalPlanNode :: decode ( serialized_shard_plan)
720+ let node = PhysicalPlanNode :: decode ( serialized_plan)
798721 . map_err ( |e| datafusion:: error:: DataFusionError :: External ( Box :: new ( e) ) )
799722 . map_err ( py_datafusion_err) ?;
800- let plan = node. try_into_physical_plan ( & registry, & runtime, & codec) ?;
723+ let plan = node. try_into_physical_plan ( & registry, & runtime, codec ( ) ) ?;
801724 let stream_with_runtime = get_tokio_runtime ( ) . 0 . spawn ( async move {
802- execute_stream ( plan, Arc :: new ( TaskContext :: default ( ) ) )
725+ plan. execute ( partition , Arc :: default ( ) )
803726 } ) ;
804727 wait_for_future ( py, stream_with_runtime)
805728 . map_err ( py_datafusion_err) ?
0 commit comments