Skip to content

Commit 21990b0

Browse files
davisptimsaucer
andauthored
feat: Add FFI_TableProviderFactory support (#1396)
* feat: Add FFI_TableProviderFactory support This wraps the new FFI_TableProviderFactory APIs in datafusion-ffi. * Address PR comments * Add support for Python based TableProviderFactory This adds the ability to register Python based TableProviderFactory instances to the SessionContext. * Correction after rebase --------- Co-authored-by: Tim Saucer <timsaucer@gmail.com>
1 parent 9af1681 commit 21990b0

File tree

8 files changed

+291
-5
lines changed

8 files changed

+291
-5
lines changed

crates/core/src/context.rs

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use arrow::pyarrow::FromPyArrow;
2727
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
2828
use datafusion::arrow::pyarrow::PyArrowType;
2929
use datafusion::arrow::record_batch::RecordBatch;
30-
use datafusion::catalog::{CatalogProvider, CatalogProviderList};
30+
use datafusion::catalog::{CatalogProvider, CatalogProviderList, TableProviderFactory};
3131
use datafusion::common::{ScalarValue, TableReference, exec_err};
3232
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
3333
use datafusion::datasource::file_format::parquet::ParquetFormat;
@@ -51,6 +51,7 @@ use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
5151
use datafusion_ffi::catalog_provider_list::FFI_CatalogProviderList;
5252
use datafusion_ffi::execution::FFI_TaskContextProvider;
5353
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
54+
use datafusion_ffi::table_provider_factory::FFI_TableProviderFactory;
5455
use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec;
5556
use datafusion_python_util::{
5657
create_logical_extension_capsule, ffi_logical_codec_from_pycapsule, get_global_ctx,
@@ -81,7 +82,7 @@ use crate::record_batch::PyRecordBatchStream;
8182
use crate::sql::logical::PyLogicalPlan;
8283
use crate::sql::util::replace_placeholders_with_strings;
8384
use crate::store::StorageContexts;
84-
use crate::table::PyTable;
85+
use crate::table::{PyTable, RustWrappedPyTableProviderFactory};
8586
use crate::udaf::PyAggregateUDF;
8687
use crate::udf::PyScalarUDF;
8788
use crate::udtf::PyTableFunction;
@@ -659,6 +660,43 @@ impl PySessionContext {
659660
Ok(())
660661
}
661662

663+
pub fn register_table_factory(
664+
&self,
665+
format: &str,
666+
mut factory: Bound<'_, PyAny>,
667+
) -> PyDataFusionResult<()> {
668+
if factory.hasattr("__datafusion_table_provider_factory__")? {
669+
let py = factory.py();
670+
let codec_capsule = create_logical_extension_capsule(py, self.logical_codec.as_ref())?;
671+
factory = factory
672+
.getattr("__datafusion_table_provider_factory__")?
673+
.call1((codec_capsule,))?;
674+
}
675+
676+
let factory: Arc<dyn TableProviderFactory> =
677+
if let Ok(capsule) = factory.cast::<PyCapsule>().map_err(py_datafusion_err) {
678+
validate_pycapsule(capsule, "datafusion_table_provider_factory")?;
679+
680+
let data: NonNull<FFI_TableProviderFactory> = capsule
681+
.pointer_checked(Some(c_str!("datafusion_table_provider_factory")))?
682+
.cast();
683+
let factory = unsafe { data.as_ref() };
684+
factory.into()
685+
} else {
686+
Arc::new(RustWrappedPyTableProviderFactory::new(
687+
factory.into(),
688+
self.logical_codec.clone(),
689+
))
690+
};
691+
692+
let st = self.ctx.state_ref();
693+
let mut lock = st.write();
694+
lock.table_factories_mut()
695+
.insert(format.to_owned(), factory);
696+
697+
Ok(())
698+
}
699+
662700
pub fn register_catalog_provider_list(
663701
&self,
664702
mut provider: Bound<PyAny>,

crates/core/src/table.rs

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,24 @@ use std::sync::Arc;
2121
use arrow::datatypes::SchemaRef;
2222
use arrow::pyarrow::ToPyArrow;
2323
use async_trait::async_trait;
24-
use datafusion::catalog::Session;
24+
use datafusion::catalog::{Session, TableProviderFactory};
2525
use datafusion::common::Column;
2626
use datafusion::datasource::{TableProvider, TableType};
27-
use datafusion::logical_expr::{Expr, LogicalPlanBuilder, TableProviderFilterPushDown};
27+
use datafusion::logical_expr::{
28+
CreateExternalTable, Expr, LogicalPlanBuilder, TableProviderFilterPushDown,
29+
};
2830
use datafusion::physical_plan::ExecutionPlan;
2931
use datafusion::prelude::DataFrame;
30-
use datafusion_python_util::table_provider_from_pycapsule;
32+
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
33+
use datafusion_python_util::{create_logical_extension_capsule, table_provider_from_pycapsule};
3134
use pyo3::IntoPyObjectExt;
3235
use pyo3::prelude::*;
3336

3437
use crate::context::PySessionContext;
3538
use crate::dataframe::PyDataFrame;
3639
use crate::dataset::Dataset;
40+
use crate::errors;
41+
use crate::expr::create_external_table::PyCreateExternalTable;
3742

3843
/// This struct is used as a common method for all TableProviders,
3944
/// whether they refer to an FFI provider, an internally known
@@ -206,3 +211,51 @@ impl TableProvider for TempViewTable {
206211
Ok(vec![TableProviderFilterPushDown::Exact; filters.len()])
207212
}
208213
}
214+
215+
#[derive(Debug)]
216+
pub(crate) struct RustWrappedPyTableProviderFactory {
217+
pub(crate) table_provider_factory: Py<PyAny>,
218+
pub(crate) codec: Arc<FFI_LogicalExtensionCodec>,
219+
}
220+
221+
impl RustWrappedPyTableProviderFactory {
222+
pub fn new(table_provider_factory: Py<PyAny>, codec: Arc<FFI_LogicalExtensionCodec>) -> Self {
223+
Self {
224+
table_provider_factory,
225+
codec,
226+
}
227+
}
228+
229+
fn create_inner(
230+
&self,
231+
cmd: CreateExternalTable,
232+
codec: Bound<PyAny>,
233+
) -> PyResult<Arc<dyn TableProvider>> {
234+
Python::attach(|py| {
235+
let provider = self.table_provider_factory.bind(py);
236+
let cmd = PyCreateExternalTable::from(cmd);
237+
238+
provider
239+
.call_method1("create", (cmd,))
240+
.and_then(|t| PyTable::new(t, Some(codec)))
241+
.map(|t| t.table())
242+
})
243+
}
244+
}
245+
246+
#[async_trait]
247+
impl TableProviderFactory for RustWrappedPyTableProviderFactory {
248+
async fn create(
249+
&self,
250+
_: &dyn Session,
251+
cmd: &CreateExternalTable,
252+
) -> datafusion::common::Result<Arc<dyn TableProvider>> {
253+
Python::attach(|py| {
254+
let codec = create_logical_extension_capsule(py, self.codec.as_ref())
255+
.map_err(errors::to_datafusion_err)?;
256+
257+
self.create_inner(cmd.clone(), codec.into_any())
258+
.map_err(errors::to_datafusion_err)
259+
})
260+
}
261+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
from datafusion import SessionContext
21+
from datafusion_ffi_example import MyTableProviderFactory
22+
23+
24+
def test_table_provider_factory_ffi() -> None:
25+
ctx = SessionContext()
26+
table = MyTableProviderFactory()
27+
28+
ctx.register_table_factory("MY_FORMAT", table)
29+
30+
# Create a new external table
31+
ctx.sql("""
32+
CREATE EXTERNAL TABLE
33+
foo
34+
STORED AS my_format
35+
LOCATION '';
36+
""").collect()
37+
38+
# Query the pre-populated table
39+
result = ctx.sql("SELECT * FROM foo;").collect()
40+
assert len(result) == 2
41+
assert result[0].num_columns == 2

examples/datafusion-ffi-example/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,23 @@ use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider, MyCatalogP
2222
use crate::scalar_udf::IsNullUDF;
2323
use crate::table_function::MyTableFunction;
2424
use crate::table_provider::MyTableProvider;
25+
use crate::table_provider_factory::MyTableProviderFactory;
2526
use crate::window_udf::MyRankUDF;
2627

2728
pub(crate) mod aggregate_udf;
2829
pub(crate) mod catalog_provider;
2930
pub(crate) mod scalar_udf;
3031
pub(crate) mod table_function;
3132
pub(crate) mod table_provider;
33+
pub(crate) mod table_provider_factory;
3234
pub(crate) mod window_udf;
3335

3436
#[pymodule]
3537
fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> {
3638
pyo3_log::init();
3739

3840
m.add_class::<MyTableProvider>()?;
41+
m.add_class::<MyTableProviderFactory>()?;
3942
m.add_class::<MyTableFunction>()?;
4043
m.add_class::<MyCatalogProvider>()?;
4144
m.add_class::<MyCatalogProviderList>()?;
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use async_trait::async_trait;
21+
use datafusion_catalog::{Session, TableProvider, TableProviderFactory};
22+
use datafusion_common::error::Result as DataFusionResult;
23+
use datafusion_expr::CreateExternalTable;
24+
use datafusion_ffi::table_provider_factory::FFI_TableProviderFactory;
25+
use datafusion_python_util::ffi_logical_codec_from_pycapsule;
26+
use pyo3::types::PyCapsule;
27+
use pyo3::{Bound, PyAny, PyResult, Python, pyclass, pymethods};
28+
29+
use crate::catalog_provider;
30+
31+
#[derive(Debug)]
32+
pub(crate) struct ExampleTableProviderFactory {}
33+
34+
impl ExampleTableProviderFactory {
35+
fn new() -> Self {
36+
Self {}
37+
}
38+
}
39+
40+
#[async_trait]
41+
impl TableProviderFactory for ExampleTableProviderFactory {
42+
async fn create(
43+
&self,
44+
_state: &dyn Session,
45+
_cmd: &CreateExternalTable,
46+
) -> DataFusionResult<Arc<dyn TableProvider>> {
47+
Ok(catalog_provider::my_table())
48+
}
49+
}
50+
51+
#[pyclass(
52+
name = "MyTableProviderFactory",
53+
module = "datafusion_ffi_example",
54+
subclass
55+
)]
56+
#[derive(Debug)]
57+
pub struct MyTableProviderFactory {
58+
inner: Arc<ExampleTableProviderFactory>,
59+
}
60+
61+
impl Default for MyTableProviderFactory {
62+
fn default() -> Self {
63+
let inner = Arc::new(ExampleTableProviderFactory::new());
64+
Self { inner }
65+
}
66+
}
67+
68+
#[pymethods]
69+
impl MyTableProviderFactory {
70+
#[new]
71+
pub fn new() -> Self {
72+
Self::default()
73+
}
74+
75+
pub fn __datafusion_table_provider_factory__<'py>(
76+
&self,
77+
py: Python<'py>,
78+
codec: Bound<PyAny>,
79+
) -> PyResult<Bound<'py, PyCapsule>> {
80+
let name = cr"datafusion_table_provider_factory".into();
81+
let codec = ffi_logical_codec_from_pycapsule(codec)?;
82+
let factory = Arc::clone(&self.inner) as Arc<dyn TableProviderFactory + Send>;
83+
let factory = FFI_TableProviderFactory::new_with_ffi_codec(factory, None, codec);
84+
85+
PyCapsule::new(py, factory, Some(name))
86+
}
87+
}

python/datafusion/catalog.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from datafusion import DataFrame, SessionContext
3131
from datafusion.context import TableProviderExportable
32+
from datafusion.expr import CreateExternalTable
3233

3334
try:
3435
from warnings import deprecated # Python 3.13+
@@ -243,6 +244,24 @@ def kind(self) -> str:
243244
return self._inner.kind
244245

245246

247+
class TableProviderFactory(ABC):
248+
"""Abstract class for defining a Python based Table Provider Factory."""
249+
250+
@abstractmethod
251+
def create(self, cmd: CreateExternalTable) -> Table:
252+
"""Create a table using the :class:`CreateExternalTable`."""
253+
...
254+
255+
256+
class TableProviderFactoryExportable(Protocol):
257+
"""Type hint for object that has __datafusion_table_provider_factory__ PyCapsule.
258+
259+
https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableProviderFactory.html
260+
"""
261+
262+
def __datafusion_table_provider_factory__(self, session: Any) -> object: ...
263+
264+
246265
class CatalogProviderList(ABC):
247266
"""Abstract class for defining a Python based Catalog Provider List."""
248267

python/datafusion/context.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
CatalogProviderExportable,
3838
CatalogProviderList,
3939
CatalogProviderListExportable,
40+
TableProviderFactory,
41+
TableProviderFactoryExportable,
4042
)
4143
from datafusion.dataframe import DataFrame
4244
from datafusion.expr import sort_list_to_raw_sort_list
@@ -830,6 +832,22 @@ def deregister_table(self, name: str) -> None:
830832
"""Remove a table from the session."""
831833
self.ctx.deregister_table(name)
832834

835+
def register_table_factory(
836+
self,
837+
format: str,
838+
factory: TableProviderFactory | TableProviderFactoryExportable,
839+
) -> None:
840+
"""Register a :py:class:`~datafusion.TableProviderFactoryExportable`.
841+
842+
The registered factory can be referenced from SQL DDL statements executed
843+
against this context.
844+
845+
Args:
846+
format: The value to be used in `STORED AS ${format}` clause.
847+
factory: A PyCapsule that implements :class:`TableProviderFactoryExportable`
848+
"""
849+
self.ctx.register_table_factory(format, factory)
850+
833851
def catalog_names(self) -> set[str]:
834852
"""Returns the list of catalogs in this context."""
835853
return self.ctx.catalog_names()

0 commit comments

Comments
 (0)