diff --git a/datapipe/meta/sql_meta.py b/datapipe/meta/sql_meta.py index 767dca39..780400b2 100644 --- a/datapipe/meta/sql_meta.py +++ b/datapipe/meta/sql_meta.py @@ -150,12 +150,25 @@ def get_metadata(self, idx: Optional[IndexDF] = None, include_deleted: bool = Fa with self.dbconn.con.begin() as con: if idx is None: sql = self._build_metadata_query(sql, idx, include_deleted) - return cast(MetadataDF, pd.read_sql_query(sql, con=con)) + return cast( + MetadataDF, + pd.read_sql_query( + sql, + con=con, + dtype_backend="pyarrow", + ), + ) for chunk_idx in self._chunk_idx_df(idx): chunk_sql = self._build_metadata_query(sql, chunk_idx, include_deleted) - res.append(pd.read_sql_query(chunk_sql, con=con)) + res.append( + pd.read_sql_query( + chunk_sql, + con=con, + dtype_backend="pyarrow", + ) + ) if len(res) > 0: return cast(MetadataDF, pd.concat(res)) @@ -223,6 +236,7 @@ def get_existing_idx(self, idx: Optional[IndexDF] = None) -> IndexDF: res_df: DataDF = pd.read_sql_query( sql, con=con, + dtype_backend="pyarrow", ) return data_to_index(res_df, self.primary_keys) @@ -358,7 +372,14 @@ def get_stale_idx( with self.dbconn.con.begin() as con: return cast( Iterator[IndexDF], - list(pd.read_sql_query(sql, con=con, chunksize=1000)), + list( + pd.read_sql_query( + sql, + con=con, + chunksize=1000, + dtype_backend="pyarrow", + ) + ), ) def get_changed_rows_count_after_timestamp( diff --git a/datapipe/step/batch_transform.py b/datapipe/step/batch_transform.py index a231ff36..4e15e7a4 100644 --- a/datapipe/step/batch_transform.py +++ b/datapipe/step/batch_transform.py @@ -177,6 +177,8 @@ def _apply_filters_to_run_config(self, run_config: Optional[RunConfig] = None) - filters = self.filters elif isinstance(self.filters, Callable): # type: ignore filters = self.filters() + else: + filters = {} if run_config is None: return RunConfig(filters=filters) @@ -260,7 +262,12 @@ def get_full_process_ids( def alter_res_df(): with ds.meta_dbconn.con.begin() as con: - for df in pd.read_sql_query(u1, con=con, chunksize=chunk_size): + for df in pd.read_sql_query( + u1, + con=con, + chunksize=chunk_size, + dtype_backend="pyarrow", + ): df = df[self.transform_keys] for k, v in extra_filters.items(): @@ -299,6 +306,7 @@ def get_change_list_process_ids( table_changes_df = pd.read_sql_query( sql, con=con, + dtype_backend="pyarrow", ) table_changes_df = table_changes_df[self.transform_keys] diff --git a/datapipe/store/database.py b/datapipe/store/database.py index f521fe14..5f5edf46 100644 --- a/datapipe/store/database.py +++ b/datapipe/store/database.py @@ -272,7 +272,11 @@ def read_rows(self, idx: Optional[IndexDF] = None) -> pd.DataFrame: with self.dbconn.con.begin() as con: for chunk_idx in self._chunk_idx_df(idx): chunk_sql = sql_apply_idx_filter_to_table(sql, self.data_table, self.primary_keys, chunk_idx) - chunk_df = pd.read_sql_query(chunk_sql, con=con) + chunk_df = pd.read_sql_query( + chunk_sql, + con=con, + dtype_backend="pyarrow", + ) res.append(chunk_df) @@ -280,7 +284,11 @@ def read_rows(self, idx: Optional[IndexDF] = None) -> pd.DataFrame: else: with self.dbconn.con.begin() as con: - return pd.read_sql_query(sql, con=con) + return pd.read_sql_query( + sql, + con=con, + dtype_backend="pyarrow", + ) def read_rows_meta_pseudo_df( self, @@ -296,4 +304,5 @@ def read_rows_meta_pseudo_df( sql, con=con, chunksize=chunksize, + dtype_backend="pyarrow", ) diff --git a/datapipe/store/tests/abstract.py b/datapipe/store/tests/abstract.py index d5a64d76..03ca9b33 100644 --- a/datapipe/store/tests/abstract.py +++ b/datapipe/store/tests/abstract.py @@ -6,7 +6,7 @@ import cloudpickle import pandas as pd import pytest -from sqlalchemy import Column, String +from sqlalchemy import Column, Integer, String from datapipe.run_config import RunConfig from datapipe.store.table_store import TableStore @@ -49,6 +49,51 @@ def test_get_schema( assert store.get_schema() == schema + def test_multiple_keys_with_none(self, store_maker: TableStoreMaker) -> None: + data_df = pd.DataFrame( + { + "id1": [1, 2, 3], + "id2": ["a", "b", "c"], + "value": [10, 20, 30], + } + ) + schema: list[Column] = [ + Column("id1", Integer, primary_key=True), + Column("id2", String(100), primary_key=True), + Column("value", Integer), + ] + + store = store_maker(schema) + store.insert_rows(data_df) + + assert_df_equal( + store.read_rows( + data_to_index( + pd.DataFrame.from_dict({"id1": [2, -1], "id2": ["b", None]}).convert_dtypes( + dtype_backend="pyarrow" + ), + ["id1", "id2"], + ) + ), + pd.DataFrame.from_dict({"id1": [2], "id2": ["b"], "value": [20]}), + index_cols=["id1", "id2"], + ) + + assert_df_equal( + store.read_rows( + data_to_index( + pd.DataFrame.from_dict({"id1": [2, None], "id2": ["b", "z"]}).convert_dtypes( + dtype_backend="pyarrow" + ), + ["id1", "id2"], + ) + ), + pd.DataFrame.from_dict({"id1": [2], "id2": ["b"], "value": [20]}), + index_cols=["id1", "id2"], + ) + + assert_ts_contains(store, data_df) + @pytest.mark.parametrize("data_df,schema", DATA_PARAMS) def test_write_read_rows( self, diff --git a/examples/datatable_batch_transform/app.py b/examples/datatable_batch_transform/app.py index c8134ce0..1b50dfd9 100644 --- a/examples/datatable_batch_transform/app.py +++ b/examples/datatable_batch_transform/app.py @@ -80,6 +80,7 @@ def count_tbl( return pd.read_sql_query( sql, con=con, + dtype_backend="pyarrow", )