Skip to content

Commit 755f00f

Browse files
committed
keep the original unsplit plan
1 parent fcabbb2 commit 755f00f

File tree

2 files changed

+46
-124
lines changed

2 files changed

+46
-124
lines changed

src/dataframe.rs

Lines changed: 45 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@
1818
use std::ffi::CString;
1919
use std::sync::Arc;
2020

21+
use crate::errors::py_datafusion_err;
22+
use crate::expr::sort_expr::to_sort_expressions;
23+
use crate::physical_plan::PyExecutionPlan;
24+
use crate::record_batch::PyRecordBatchStream;
25+
use crate::sql::logical::PyLogicalPlan;
26+
use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
27+
use crate::{
28+
errors::DataFusionError,
29+
expr::{sort_expr::PySortExpr, PyExpr},
30+
};
2131
use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader};
2232
use arrow::compute::can_cast_types;
2333
use arrow::error::ArrowError;
@@ -31,12 +41,10 @@ use datafusion::common::stats::Precision;
3141
use datafusion::common::{DFSchema, UnnestOptions};
3242
use datafusion::config::{ConfigOptions, CsvOptions, TableParquetOptions};
3343
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
34-
use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec};
35-
use datafusion::datasource::physical_plan::parquet::ParquetExecBuilder;
3644
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
3745
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
3846
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
39-
use datafusion::physical_plan::{displayable, execute_stream, ExecutionPlan};
47+
use datafusion::physical_plan::{execute_stream, execute_stream_partitioned, ExecutionPlan};
4048
use datafusion::prelude::*;
4149
use datafusion_expr::registry::MemoryFunctionRegistry;
4250
use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec};
@@ -48,16 +56,6 @@ use pyo3::prelude::*;
4856
use pyo3::pybacked::PyBackedStr;
4957
use pyo3::types::{PyCapsule, PyTuple, PyTupleMethods};
5058
use tokio::task::JoinHandle;
51-
use crate::errors::py_datafusion_err;
52-
use crate::expr::sort_expr::to_sort_expressions;
53-
use crate::physical_plan::PyExecutionPlan;
54-
use crate::record_batch::PyRecordBatchStream;
55-
use crate::sql::logical::PyLogicalPlan;
56-
use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
57-
use crate::{
58-
errors::DataFusionError,
59-
expr::{sort_expr::PySortExpr, PyExpr},
60-
};
6159
use crate::common::df_schema::PyDFSchema;
6260

6361
/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
@@ -661,145 +659,70 @@ impl PyDataFrame {
661659
Ok(wait_for_future(py, self.df.as_ref().clone().count())?)
662660
}
663661

664-
fn distributed_plan(&self, num_shards: usize, py: Python<'_>) -> PyResult<DistributedPlan> {
665-
let distributed_plan = wait_for_future(py, split_physical_plan(&self.df, num_shards))
666-
.map_err(py_datafusion_err)?;
667-
Ok(distributed_plan)
662+
fn distributed_plan(&self, parallelism: usize, py: Python<'_>) -> PyResult<DistributedPlan> {
663+
let future_plan = self.df.as_ref().clone().create_physical_plan();
664+
let physical_plan = wait_for_future(py, future_plan).map_err(py_datafusion_err)?;
665+
DistributedPlan::try_new(physical_plan, parallelism).map_err(py_datafusion_err)
668666
}
669667

670668
}
671669

672670
#[pyclass(get_all)]
673671
#[derive(Debug, Clone)]
674-
pub struct Statistics {
672+
pub struct DistributedPlan {
673+
physical_plan: Vec<u8>,
674+
schema: PyDFSchema,
675+
partitions: usize,
675676
num_bytes: Option<usize>,
676677
num_rows: Option<usize>,
677678
}
678679

679-
impl Statistics {
680-
fn new(plan: &dyn ExecutionPlan) -> Self {
680+
fn codec() -> &'static dyn PhysicalExtensionCodec {
681+
static CODEC: DeltaPhysicalCodec = DeltaPhysicalCodec {};
682+
&CODEC
683+
}
684+
685+
impl DistributedPlan {
686+
fn try_new(plan: Arc<dyn ExecutionPlan>, parallelism: usize) -> Result<Self, DataFusionError> {
681687
fn extract(prec: Precision<usize>) -> Option<usize> {
682688
match prec {
683-
Precision::Exact(n) | Precision::Inexact(n) => Some(n),
684-
Precision::Absent => None,
689+
Precision::Exact(n) => Some(n),
690+
_ => None,
685691
}
686692
}
687-
if let Ok(stats) = plan.statistics() {
693+
let (num_bytes, num_rows) = if let Ok(stats) = plan.statistics() {
688694
let num_bytes = extract(stats.total_byte_size);
689695
let num_rows = extract(stats.num_rows);
690-
Statistics { num_bytes, num_rows}
696+
(num_bytes, num_rows)
691697
} else {
692-
Statistics { num_bytes: None, num_rows: None }
693-
}
694-
}
695-
}
696-
697-
#[pyclass(get_all)]
698-
#[derive(Debug, Clone)]
699-
pub struct Shard {
700-
stats: Statistics,
701-
serialized_plan: Vec<u8>,
702-
}
698+
(None, None)
699+
};
703700

704-
impl Shard {
705-
pub fn try_new(plan: &Arc<dyn ExecutionPlan>) -> Result<Self, DataFusionError> {
706-
let stats = Statistics::new(plan.as_ref());
707-
let serialized_plan = PhysicalPlanNode::try_from_physical_plan(plan.clone(), Self::codec())?
701+
let schema = DFSchema::try_from(plan.schema())
702+
.map(PyDFSchema::from)
703+
.map_err(py_datafusion_err)?;
704+
let plan = plan.repartitioned(parallelism, &ConfigOptions::default())
705+
.map_err(py_datafusion_err)?
706+
.unwrap_or(plan);
707+
let partitions = plan.properties().partitioning.partition_count();
708+
let physical_plan = PhysicalPlanNode::try_from_physical_plan(plan, codec())?
708709
.encode_to_vec();
709-
Ok(Self { stats, serialized_plan })
710-
}
711-
712-
fn codec() -> &'static dyn PhysicalExtensionCodec {
713-
static CODEC: DeltaPhysicalCodec = DeltaPhysicalCodec {};
714-
&CODEC
710+
Ok(Self { physical_plan, schema, partitions, num_bytes, num_rows })
715711
}
716-
}
717712

718-
#[pyclass(get_all)]
719-
#[derive(Debug, Clone)]
720-
pub struct DistributedPlan {
721-
shards: Vec<Shard>,
722-
schema: PyDFSchema,
723-
stats: Statistics,
724-
}
725-
726-
async fn split_physical_plan(df: &DataFrame, num_shards: usize) -> Result<DistributedPlan, DataFusionError> {
727-
fn split(plan: &Arc<dyn ExecutionPlan>, num_shards: usize) -> Vec<Arc<dyn ExecutionPlan>> {
728-
if let Some(parquet) = plan.as_any().downcast_ref::<ParquetExec>() {
729-
let parquet = if let Ok(Some(repartitioned)) = parquet.repartitioned(num_shards, &ConfigOptions::default()) {
730-
repartitioned.as_any().downcast_ref::<ParquetExec>()
731-
.expect("repartitioned parquet is no longer parquet")
732-
.clone()
733-
} else { // repartition failed
734-
parquet.clone()
735-
};
736-
let config = parquet.base_config();
737-
config
738-
.file_groups
739-
.iter()
740-
.map(|shard| {
741-
FileScanConfig {
742-
object_store_url: config.object_store_url.clone(),
743-
file_schema: config.file_schema.clone(),
744-
file_groups: shard.iter().map(|file| vec![file.to_owned()]).collect(), // one partition per file
745-
statistics: config.statistics.clone(),
746-
projection: config.projection.clone(),
747-
projection_deep: config.projection_deep.clone(),
748-
limit: config.limit,
749-
table_partition_cols: config.table_partition_cols.clone(),
750-
output_ordering: config.output_ordering.clone(),
751-
}
752-
})
753-
.map(|config| {
754-
let mut builder = ParquetExecBuilder::new(config)
755-
.with_table_parquet_options(parquet.table_parquet_options().clone());
756-
if let Some(predicate) = parquet.predicate() {
757-
builder = builder.with_predicate(predicate.clone());
758-
}
759-
builder.build_arc()
760-
})
761-
.map(|shard| shard as Arc<dyn ExecutionPlan>)
762-
.collect()
763-
} else if plan.children().len() == 0 { // TODO: split leaf nodes other than parquet?
764-
vec![plan.clone()]
765-
} else if plan.children().len() == 1 {
766-
plan.children().into_iter()
767-
.flat_map(|child| {
768-
split(child, num_shards)
769-
.into_iter()
770-
.map(|shard| plan.clone().with_new_children(vec![shard]))
771-
})
772-
.collect::<Result<Vec<_>, _>>()
773-
.expect("Unable to split plan")
774-
} else {
775-
panic!(
776-
"Only leaf or single-child plans are supported, found {}",
777-
displayable(plan.as_ref()).one_line()
778-
)
779-
}
780-
}
781-
let plan = df.clone().create_physical_plan().await?;
782-
let shards = split(&plan, num_shards)
783-
.iter()
784-
.map(Shard::try_new)
785-
.collect::<Result<Vec<_>, _>>()?;
786-
let schema = DFSchema::try_from(plan.schema().as_ref().to_owned())?.into();
787-
let stats = Statistics::new(plan.as_ref());
788-
Ok(DistributedPlan { shards, schema, stats })
789713
}
790714

791715
#[pyfunction]
792-
pub fn shard_stream(serialized_shard_plan: &[u8], py: Python) -> PyResult<PyRecordBatchStream> {
716+
pub fn partition_stream(serialized_plan: &[u8], partition: usize, py: Python) -> PyResult<PyRecordBatchStream> {
793717
deltalake::ensure_initialized();
794718
let registry = MemoryFunctionRegistry::default();
795719
let runtime = RuntimeEnvBuilder::new().build()?;
796-
let codec = DeltaPhysicalCodec {};
797-
let node = PhysicalPlanNode::decode(serialized_shard_plan)
720+
let node = PhysicalPlanNode::decode(serialized_plan)
798721
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))
799722
.map_err(py_datafusion_err)?;
800-
let plan = node.try_into_physical_plan(&registry, &runtime, &codec)?;
723+
let plan = node.try_into_physical_plan(&registry, &runtime, codec())?;
801724
let stream_with_runtime = get_tokio_runtime().0.spawn(async move {
802-
execute_stream(plan, Arc::new(TaskContext::default()))
725+
plan.execute(partition, Arc::default())
803726
});
804727
wait_for_future(py, stream_with_runtime)
805728
.map_err(py_datafusion_err)?

src/lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
115115
#[cfg(feature = "substrait")]
116116
setup_substrait_module(py, &m)?;
117117

118-
m.add_class::<dataframe::Shard>()?;
119118
m.add_class::<dataframe::DistributedPlan>()?;
120-
m.add_wrapped(wrap_pyfunction!(dataframe::shard_stream))?;
119+
m.add_wrapped(wrap_pyfunction!(dataframe::partition_stream))?;
121120
Ok(())
122121
}
123122

0 commit comments

Comments
 (0)