Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions datapipe/meta/sql_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,25 @@ def get_metadata(self, idx: Optional[IndexDF] = None, include_deleted: bool = Fa
with self.dbconn.con.begin() as con:
if idx is None:
sql = self._build_metadata_query(sql, idx, include_deleted)
return cast(MetadataDF, pd.read_sql_query(sql, con=con))
return cast(
MetadataDF,
pd.read_sql_query(
sql,
con=con,
dtype_backend="pyarrow",
),
)

for chunk_idx in self._chunk_idx_df(idx):
chunk_sql = self._build_metadata_query(sql, chunk_idx, include_deleted)

res.append(pd.read_sql_query(chunk_sql, con=con))
res.append(
pd.read_sql_query(
chunk_sql,
con=con,
dtype_backend="pyarrow",
)
)

if len(res) > 0:
return cast(MetadataDF, pd.concat(res))
Expand Down Expand Up @@ -223,6 +236,7 @@ def get_existing_idx(self, idx: Optional[IndexDF] = None) -> IndexDF:
res_df: DataDF = pd.read_sql_query(
sql,
con=con,
dtype_backend="pyarrow",
)

return data_to_index(res_df, self.primary_keys)
Expand Down Expand Up @@ -358,7 +372,14 @@ def get_stale_idx(
with self.dbconn.con.begin() as con:
return cast(
Iterator[IndexDF],
list(pd.read_sql_query(sql, con=con, chunksize=1000)),
list(
pd.read_sql_query(
sql,
con=con,
chunksize=1000,
dtype_backend="pyarrow",
)
),
)

def get_changed_rows_count_after_timestamp(
Expand Down
10 changes: 9 additions & 1 deletion datapipe/step/batch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def _apply_filters_to_run_config(self, run_config: Optional[RunConfig] = None) -
filters = self.filters
elif isinstance(self.filters, Callable): # type: ignore
filters = self.filters()
else:
filters = {}

if run_config is None:
return RunConfig(filters=filters)
Expand Down Expand Up @@ -260,7 +262,12 @@ def get_full_process_ids(

def alter_res_df():
with ds.meta_dbconn.con.begin() as con:
for df in pd.read_sql_query(u1, con=con, chunksize=chunk_size):
for df in pd.read_sql_query(
u1,
con=con,
chunksize=chunk_size,
dtype_backend="pyarrow",
):
df = df[self.transform_keys]

for k, v in extra_filters.items():
Expand Down Expand Up @@ -299,6 +306,7 @@ def get_change_list_process_ids(
table_changes_df = pd.read_sql_query(
sql,
con=con,
dtype_backend="pyarrow",
)
table_changes_df = table_changes_df[self.transform_keys]

Expand Down
13 changes: 11 additions & 2 deletions datapipe/store/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,15 +272,23 @@ def read_rows(self, idx: Optional[IndexDF] = None) -> pd.DataFrame:
with self.dbconn.con.begin() as con:
for chunk_idx in self._chunk_idx_df(idx):
chunk_sql = sql_apply_idx_filter_to_table(sql, self.data_table, self.primary_keys, chunk_idx)
chunk_df = pd.read_sql_query(chunk_sql, con=con)
chunk_df = pd.read_sql_query(
chunk_sql,
con=con,
dtype_backend="pyarrow",
)

res.append(chunk_df)

return pd.concat(res, ignore_index=True)

else:
with self.dbconn.con.begin() as con:
return pd.read_sql_query(sql, con=con)
return pd.read_sql_query(
sql,
con=con,
dtype_backend="pyarrow",
)

def read_rows_meta_pseudo_df(
self,
Expand All @@ -296,4 +304,5 @@ def read_rows_meta_pseudo_df(
sql,
con=con,
chunksize=chunksize,
dtype_backend="pyarrow",
)
47 changes: 46 additions & 1 deletion datapipe/store/tests/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import cloudpickle
import pandas as pd
import pytest
from sqlalchemy import Column, String
from sqlalchemy import Column, Integer, String

from datapipe.run_config import RunConfig
from datapipe.store.table_store import TableStore
Expand Down Expand Up @@ -49,6 +49,51 @@ def test_get_schema(

assert store.get_schema() == schema

def test_multiple_keys_with_none(self, store_maker: TableStoreMaker) -> None:
data_df = pd.DataFrame(
{
"id1": [1, 2, 3],
"id2": ["a", "b", "c"],
"value": [10, 20, 30],
}
)
schema: list[Column] = [
Column("id1", Integer, primary_key=True),
Column("id2", String(100), primary_key=True),
Column("value", Integer),
]

store = store_maker(schema)
store.insert_rows(data_df)

assert_df_equal(
store.read_rows(
data_to_index(
pd.DataFrame.from_dict({"id1": [2, -1], "id2": ["b", None]}).convert_dtypes(
dtype_backend="pyarrow"
),
["id1", "id2"],
)
),
pd.DataFrame.from_dict({"id1": [2], "id2": ["b"], "value": [20]}),
index_cols=["id1", "id2"],
)

assert_df_equal(
store.read_rows(
data_to_index(
pd.DataFrame.from_dict({"id1": [2, None], "id2": ["b", "z"]}).convert_dtypes(
dtype_backend="pyarrow"
),
["id1", "id2"],
)
),
pd.DataFrame.from_dict({"id1": [2], "id2": ["b"], "value": [20]}),
index_cols=["id1", "id2"],
)

assert_ts_contains(store, data_df)

@pytest.mark.parametrize("data_df,schema", DATA_PARAMS)
def test_write_read_rows(
self,
Expand Down
1 change: 1 addition & 0 deletions examples/datatable_batch_transform/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def count_tbl(
return pd.read_sql_query(
sql,
con=con,
dtype_backend="pyarrow",
)


Expand Down