diff --git a/datapipe/store/pandas.py b/datapipe/store/pandas.py index 228e49d8..8b45e715 100644 --- a/datapipe/store/pandas.py +++ b/datapipe/store/pandas.py @@ -17,7 +17,19 @@ def load_file(self) -> pd.DataFrame | None: if of.fs.exists(of.path): dtypes = sql_schema_to_dtype(self.primary_schema) - df = pd.read_excel(of.open(), engine="openpyxl", dtype=dtypes) + plain_dtypes = {k: v for k, v in dtypes.items() if v not in _DATETIME_PYTHON_TYPES} + + df = pd.read_excel(of.open(), engine="openpyxl", dtype=plain_dtypes) + + for col, py_type in dtypes.items(): + if col not in df.columns or py_type not in _DATETIME_PYTHON_TYPES: + continue + if py_type == datetime.datetime: + df[col] = pd.to_datetime(df[col]) + elif py_type == datetime.date: + df[col] = pd.to_datetime(df[col]).dt.date + elif py_type == datetime.time: + df[col] = pd.to_datetime(df[col]).dt.time return df else: diff --git a/tests/test_table_store_json_line.py b/tests/test_table_store_json_line.py index 2dbb6021..c6c63119 100644 --- a/tests/test_table_store_json_line.py +++ b/tests/test_table_store_json_line.py @@ -5,20 +5,20 @@ import pandas as pd import pytest from sqlalchemy import ( + BigInteger, + Boolean, Column, + Date, + DateTime, + Float, + Integer, + Numeric, + SmallInteger, String, Text, + Time, Unicode, UnicodeText, - Integer, - BigInteger, - SmallInteger, - Float, - Numeric, - Boolean, - DateTime, - Date, - Time, ) from datapipe.compute import Catalog, Pipeline, Table, build_compute, run_steps @@ -27,6 +27,23 @@ from datapipe.step.update_external_table import UpdateExternalTable from datapipe.store.pandas import TableStoreExcel, TableStoreJsonLine from datapipe.store.tests.abstract import AbstractBaseStoreTests +from datapipe.tests.util import assert_df_equal + +DTYPE_SCHEMA = [ + Column("dtype_String", String), + Column("dtype_Text", Text), + Column("dtype_Unicode", Unicode), + Column("dtype_UnicodeText", UnicodeText), + Column("dtype_Integer", Integer), + Column("dtype_BigInteger", BigInteger), + Column("dtype_SmallInteger", SmallInteger), + Column("dtype_Float", Float), + Column("dtype_Numeric", Numeric), + Column("dtype_Boolean", Boolean), + Column("dtype_DateTime", DateTime), + Column("dtype_Date", Date), + Column("dtype_Time", Time), +] def test_table_store_json_line_reading(tmp_dir): @@ -36,8 +53,9 @@ def test_table_store_json_line_reading(tmp_dir): store = TableStoreJsonLine(filename=test_fname) df = store.load_file() - assert all(df.reset_index(drop=False)["id"].values == test_df["id"].values) - assert all(df["record"].values == test_df["record"].values) + + assert df is not None + assert_df_equal(df, test_df) def make_file1(file): @@ -97,23 +115,20 @@ def test_table_store_json_line_with_deleting(dbconn, tmp_dir): assert len(catalog.get_datatable(ds, "transfomed_data").get_data()) == 2 -def test_dtype_mapping(tmp_dir): - schema = [ - Column("dtype_String", String), - Column("dtype_Text", Text), - Column("dtype_Unicode", Unicode), - Column("dtype_UnicodeText", UnicodeText), - Column("dtype_Integer", Integer), - Column("dtype_BigInteger", BigInteger), - Column("dtype_SmallInteger", SmallInteger), - Column("dtype_Float", Float), - Column("dtype_Numeric", Numeric), - Column("dtype_Boolean", Boolean), - Column("dtype_DateTime", DateTime), - Column("dtype_Date", Date), - Column("dtype_Time", Time), - ] - store = TableStoreJsonLine(filename=tmp_dir / "dtypes.json", primary_schema=schema) +@pytest.mark.parametrize( + "store_factory", + [ + pytest.param( + lambda tmp_dir: TableStoreExcel(filename=tmp_dir / "dtypes.xlsx", primary_schema=DTYPE_SCHEMA), id="excel" + ), + pytest.param( + lambda tmp_dir: TableStoreJsonLine(filename=tmp_dir / "dtypes.json", primary_schema=DTYPE_SCHEMA), + id="json_line", + ), + ], +) +def test_dtype_mapping(tmp_dir, store_factory): + store = store_factory(tmp_dir) now = datetime.datetime(2024, 1, 15, 12, 0, 0) df = pd.DataFrame( @@ -153,35 +168,19 @@ def test_dtype_mapping(tmp_dir): def test_table_store_json_line_with_dtype_mapping(dbconn, tmp_dir): - schema = [ - Column("dtype_String", String), - Column("dtype_Text", Text), - Column("dtype_Unicode", Unicode), - Column("dtype_UnicodeText", UnicodeText), - Column("dtype_Integer", Integer), - Column("dtype_BigInteger", BigInteger), - Column("dtype_SmallInteger", SmallInteger), - Column("dtype_Float", Float), - Column("dtype_Numeric", Numeric), - Column("dtype_Boolean", Boolean), - Column("dtype_DateTime", DateTime), - Column("dtype_Date", Date), - Column("dtype_Time", Time), - ] - ds = DataStore(dbconn, create_meta_table=True) catalog = Catalog( { "input_dtypes": Table( store=TableStoreJsonLine( filename=tmp_dir / "dtypes.json", - primary_schema=schema, + primary_schema=DTYPE_SCHEMA, ), ), "transfomed_dtypes": Table( store=TableStoreJsonLine( filename=tmp_dir / "dtypes_transformed.json", - primary_schema=schema, + primary_schema=DTYPE_SCHEMA, ), ), }