diff --git a/CHANGELOG.md b/CHANGELOG.md index c6b8fdb5..af2108b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,9 @@ See [key-mapping.md](design-docs/2025-12-key-mapping.md) for motivation * Added `keys` parameter to `InputSpec` and `ComputeInput` to support joining tables with different key names * Added `DataField` accessor for `InputSpec.keys` +* Added `OutputSpec` and `ComputeOutput.keys` to explicitly map transform keys to output table primary keys +* Fixed batch transform cleanup for aliased output keys and incomplete transform keys +* Updated `tqdm-loggable` to avoid Python 3.12 deprecation warnings ### Python3.9 support is deprecated diff --git a/datapipe/compute.py b/datapipe/compute.py index 5cfe83fd..e04387f4 100644 --- a/datapipe/compute.py +++ b/datapipe/compute.py @@ -133,6 +133,17 @@ def primary_schema(self) -> MetaSchema: return self.dt.primary_schema +@dataclass +class ComputeOutput: + dt: DataTable + + # If provided, this dict tells how transform keys map to output primary keys. + # + # Example: {"post_id": "id"} means that cleanup for transform key post_id + # should be applied to output primary key id. + keys: dict[str, str] | None = None + + class ComputeStep: """ Шаг вычислений в графе вычислений. @@ -152,16 +163,13 @@ class ComputeStep: def __init__( self, name: str, - input_dts: Sequence[ComputeInput | DataTable], + input_dts: Sequence[ComputeInput], output_dts: list[DataTable], labels: Labels | None = None, executor_config: ExecutorConfig | None = None, ) -> None: self._name = name - # Нормализация input_dts: автоматически оборачиваем DataTable в ComputeInput - self.input_dts = [ - inp if isinstance(inp, ComputeInput) else ComputeInput(dt=inp, join_type="full") for inp in input_dts - ] + self.input_dts = list(input_dts) self.output_dts = output_dts self._labels = labels self.executor_config = executor_config diff --git a/datapipe/meta/sql_meta.py b/datapipe/meta/sql_meta.py index 6e26f628..d18dca24 100644 --- a/datapipe/meta/sql_meta.py +++ b/datapipe/meta/sql_meta.py @@ -919,6 +919,8 @@ def _make_agg_of_agg( prev_ctes.append(cte) + sql = sql.where(sa.and_(*[key.isnot(None) for key in coalesce_keys])) + sql = sql.group_by(*coalesce_keys) return sql.cte(name=f"all__{agg_col}") diff --git a/datapipe/step/batch_transform.py b/datapipe/step/batch_transform.py index c6ee1912..1fffe571 100644 --- a/datapipe/step/batch_transform.py +++ b/datapipe/step/batch_transform.py @@ -12,12 +12,14 @@ Sequence, ) +import pandas as pd from opentelemetry import trace from tqdm_loggable.auto import tqdm from datapipe.compute import ( Catalog, ComputeInput, + ComputeOutput, ComputeStep, PipelineStep, StepStatus, @@ -32,6 +34,8 @@ InputSpec, Labels, PipelineInput, + PipelineOutput, + OutputSpec, Required, TableOrName, TransformResult, @@ -67,8 +71,8 @@ def __init__( self, ds: DataStore, name: str, - input_dts: Sequence[ComputeInput | DataTable], - output_dts: list[DataTable], + input_dts: Sequence[ComputeInput], + output_dts: Sequence[ComputeOutput], transform_keys: list[str] | None = None, chunk_size: int = 1000, labels: Labels | None = None, @@ -77,26 +81,19 @@ def __init__( order_by: list[str] | None = None, order: Literal["asc", "desc"] = "asc", ) -> None: - # Support both old API (List[DataTable]) and new API (List[ComputeInput]) - # Convert to new API format - compute_input_dts: list[ComputeInput] = [] - for inp in input_dts: - if isinstance(inp, ComputeInput): - # New API: ComputeInput with .dt attribute - compute_input_dts.append(inp) - else: - # Old API: DataTable passed directly - convert to new API - compute_input_dts.append(ComputeInput(dt=inp, join_type="full")) + compute_input_dts = list(input_dts) + compute_output_dts = list(output_dts) ComputeStep.__init__( self, name=name, input_dts=compute_input_dts, - output_dts=output_dts, + output_dts=[out.dt for out in compute_output_dts], labels=labels, executor_config=executor_config, ) + self.output_specs = compute_output_dts self.chunk_size = chunk_size # Force transform_keys to be a list, otherwise Pandas will not be happy @@ -118,6 +115,31 @@ def __init__( self.order_by = order_by self.order = order + @staticmethod + def _transform_idx_to_output_idx( + idx: IndexDF, + output_spec: ComputeOutput, + ) -> IndexDF | None: + res_dt = output_spec.dt + output_to_transform_keys = { + output_key: transform_key for transform_key, output_key in (output_spec.keys or {}).items() + } + columns: dict[str, Any] = {} + + for pk in res_dt.primary_keys: + if pk in idx.columns: + columns[pk] = idx[pk] + continue + + transform_key = output_to_transform_keys.get(pk) + if transform_key is not None and transform_key in idx.columns: + columns[pk] = idx[transform_key] + + if not columns: + return None + + return IndexDF(pd.DataFrame(columns)) + def _apply_filters_to_run_config(self, run_config: RunConfig | None = None) -> RunConfig | None: if self.filters is None: return run_config @@ -198,12 +220,13 @@ def store_batch_result( assert len(self.output_dts) == 1 output_dfs = [output_dfs] - for k, res_dt in enumerate(self.output_dts): + for k, output_spec in enumerate(self.output_specs): + res_dt = output_spec.dt # Берем k-ое значение функции для k-ой таблички # Добавляем результат в результирующие чанки change_idx = res_dt.store_chunk( data_df=output_dfs[k], - processed_idx=idx, + processed_idx=self._transform_idx_to_output_idx(idx, output_spec), now=process_ts, run_config=run_config, ) @@ -212,8 +235,13 @@ def store_batch_result( else: with tracer.start_as_current_span("delete missing data from output"): - for k, res_dt in enumerate(self.output_dts): - del_idx = res_dt.meta.get_existing_idx(idx) + for k, output_spec in enumerate(self.output_specs): + res_dt = output_spec.dt + processed_idx = self._transform_idx_to_output_idx(idx, output_spec) + if processed_idx is None: + continue + + del_idx = res_dt.meta.get_existing_idx(processed_idx) res_dt.delete_by_idx(del_idx, run_config=run_config) @@ -425,7 +453,7 @@ def build_compute(self, ds: DataStore, catalog: Catalog) -> list[ComputeStep]: name=f"{self.func.__name__}", func=self.func, input_dts=[ComputeInput(dt=inp, join_type="full") for inp in input_dts], - output_dts=output_dts, + output_dts=[ComputeOutput(dt=out) for out in output_dts], kwargs=self.kwargs, transform_keys=self.transform_keys, chunk_size=self.chunk_size, @@ -440,8 +468,8 @@ def __init__( ds: DataStore, name: str, func: DatatableBatchTransformFunc, - input_dts: list[ComputeInput], - output_dts: list[DataTable], + input_dts: Sequence[ComputeInput], + output_dts: Sequence[ComputeOutput], kwargs: dict | None = None, transform_keys: list[str] | None = None, chunk_size: int = 1000, @@ -479,7 +507,7 @@ def process_batch_dts( class BatchTransform(PipelineStep): func: BatchTransformFunc inputs: list[PipelineInput] - outputs: list[TableOrName] + outputs: list[PipelineOutput] chunk_size: int = 1000 kwargs: dict[str, Any] | None = None transform_keys: list[str] | None = None @@ -505,9 +533,18 @@ def pipeline_input_to_compute_input(self, ds: DataStore, catalog: Catalog, input else: return ComputeInput(dt=catalog.get_datatable(ds, input), join_type="full") + def pipeline_output_to_compute_output(self, ds: DataStore, catalog: Catalog, output: PipelineOutput) -> ComputeOutput: + if isinstance(output, OutputSpec): + return ComputeOutput( + dt=catalog.get_datatable(ds, output.table), + keys=output.keys, + ) + + return ComputeOutput(dt=catalog.get_datatable(ds, output)) + def build_compute(self, ds: DataStore, catalog: Catalog) -> list[ComputeStep]: input_dts = [self.pipeline_input_to_compute_input(ds, catalog, input) for input in self.inputs] - output_dts = [catalog.get_datatable(ds, name) for name in self.outputs] + output_dts = [self.pipeline_output_to_compute_output(ds, catalog, output) for output in self.outputs] return [ BatchTransformStep( @@ -534,8 +571,8 @@ def __init__( ds: DataStore, name: str, func: BatchTransformFunc, - input_dts: list[ComputeInput], - output_dts: list[DataTable], + input_dts: Sequence[ComputeInput], + output_dts: Sequence[ComputeOutput], kwargs: dict[str, Any] | None = None, transform_keys: list[str] | None = None, chunk_size: int = 1000, diff --git a/datapipe/types.py b/datapipe/types.py index fe553c3a..9f4a17b4 100644 --- a/datapipe/types.py +++ b/datapipe/types.py @@ -86,6 +86,20 @@ class Required(InputSpec): PipelineInput = TableOrName | InputSpec +@dataclass +class OutputSpec: + table: TableOrName + + # If provided, this dict tells how transform keys map to output primary keys. + # + # Example: {"post_id": "id"} means that cleanup for transform key post_id + # should be applied to output primary key id. + keys: dict[str, str] | None = None + + +PipelineOutput = TableOrName | OutputSpec + + @dataclass class ChangeList: changes: dict[str, IndexDF] = field(default_factory=lambda: cast(dict[str, IndexDF], {})) diff --git a/pyproject.toml b/pyproject.toml index 5ac9d490..a9702d55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ "PyYAML>=5.3.1", "cityhash>=0.4.2,<0.5", "Pillow>=10.0.0,<12", - "tqdm-loggable>=0.2,<0.3", + "tqdm-loggable>=0.4.1,<0.5", "traceback-with-variables>=2.0.4,<3", "opentelemetry-api>=1.8.0,<2", "opentelemetry-sdk>=1.8.0,<2", diff --git a/tests/conftest.py b/tests/conftest.py index 6f010d68..806577a3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import os +import sys os.environ["SQLALCHEMY_WARN_20"] = "1" @@ -48,7 +49,8 @@ def assert_df_equal(a: pd.DataFrame, b: pd.DataFrame) -> bool: @pytest.fixture def dbconn(): if os.environ.get("TEST_DB_ENV") == "sqlite": - DBCONNSTR = "sqlite+pysqlite3:///:memory:" + sqlite_driver = "sqlite" if sys.platform == "darwin" else "sqlite+pysqlite3" + DBCONNSTR = f"{sqlite_driver}:///:memory:" DB_TEST_SCHEMA = None else: pg_host = os.getenv("POSTGRES_HOST", "localhost") diff --git a/tests/test_batch_transform_scheduling.py b/tests/test_batch_transform_scheduling.py index 78b5cd34..4ca58950 100644 --- a/tests/test_batch_transform_scheduling.py +++ b/tests/test_batch_transform_scheduling.py @@ -1,7 +1,7 @@ import pandas as pd from sqlalchemy import Column, Integer -from datapipe.compute import ComputeInput +from datapipe.compute import ComputeInput, ComputeOutput from datapipe.datatable import DataStore from datapipe.step.batch_transform import BatchTransformStep from datapipe.store.database import TableStoreDB @@ -40,7 +40,7 @@ def id_func(df): input_dts=[ ComputeInput(dt=tbl1, join_type="full"), ], - output_dts=[tbl2], + output_dts=[ComputeOutput(dt=tbl2)], ) count, idx_gen = step.get_full_process_ids(ds) @@ -150,7 +150,7 @@ def test_aux_input(dbconn) -> None: ComputeInput(dt=tbl_items2, join_type="full"), ComputeInput(dt=tbl_aux, join_type="full"), ], - output_dts=[tbl_out], + output_dts=[ComputeOutput(dt=tbl_out)], transform_keys=["id"], ) diff --git a/tests/test_complex_cross_merge_many_tables.py b/tests/test_complex_cross_merge_many_tables.py index 020e2591..e7e6b2e1 100644 --- a/tests/test_complex_cross_merge_many_tables.py +++ b/tests/test_complex_cross_merge_many_tables.py @@ -12,6 +12,7 @@ from datapipe.step.batch_transform import BatchTransform from datapipe.store.database import TableStoreDB from datapipe.tests.util import assert_datatable_equal +from datapipe.types import InputSpec, OutputSpec TEST_SCHEMA_LEFT = [ pytest.param( @@ -90,8 +91,9 @@ def get_primary_key_to_their_tables(schemas: list[list[Column]], table_names: li primary_keys = [set([x.name for x in schema if x.primary_key]) for schema in schemas] idxs = range(len(schemas)) pairs = itertools.combinations(idxs, 2) - nt = lambda a, b: primary_keys[a].intersection(primary_keys[b]) - table_name1_table_name2_to_intersection_idxs = {(table_names[t[0]], table_names[t[1]]): nt(*t) for t in pairs} + table_name1_table_name2_to_intersection_idxs = { + (table_names[t[0]], table_names[t[1]]): primary_keys[t[0]].intersection(primary_keys[t[1]]) for t in pairs + } return table_name1_table_name2_to_intersection_idxs @@ -245,3 +247,90 @@ def gen_tbl(df): for output_table_name, test_output_df in zip(output_tables_names, test_output_dfs): tbl_output = catalog.get_datatable(ds, output_table_name) assert_datatable_equal(tbl_output, test_output_df) + + +def test_complex_cross_merge_on_many_tables_with_index_aliases(dbconn): + ds = DataStore(dbconn, create_meta_table=True) + + left_schema = [ + Column("id_left", Integer, primary_key=True), + Column("a_left", Integer), + ] + center_schema = [ + Column("id_center", Integer, primary_key=True), + Column("c_center", Integer), + ] + right_schema = [ + Column("id_right", Integer, primary_key=True), + Column("b_right", Integer), + ] + output_schema = [ + Column("id_left", Integer, primary_key=True), + Column("id_center", Integer, primary_key=True), + Column("id_right", Integer, primary_key=True), + Column("a_left", Integer), + Column("c_center", Integer), + Column("b_right", Integer), + ] + + test_df_left = pd.DataFrame({"id_left": [1, 2], "a_left": [10, 20]}) + test_df_center = pd.DataFrame({"id_center": [3, 4], "c_center": [30, 40]}) + test_df_right = pd.DataFrame({"id_right": [5, 6], "b_right": [50, 60]}) + test_df_output = cross_merge_func( + test_df_left, + test_df_center, + test_df_right, + input_intersection_idxs=[], + output_schema_tables=[output_schema], + )[0] + + catalog = Catalog( + { + "table_left": Table(store=TableStoreDB(dbconn, "table_left", left_schema, True)), + "table_center": Table(store=TableStoreDB(dbconn, "table_center", center_schema, True)), + "table_right": Table(store=TableStoreDB(dbconn, "table_right", right_schema, True)), + "table_output": Table(store=TableStoreDB(dbconn, "table_output", output_schema, True)), + } + ) + + def gen_tbl(df): + yield df + + pipeline_case = Pipeline( + [ + BatchGenerate(func=gen_tbl, outputs=["table_left"], kwargs=dict(df=test_df_left)), + BatchGenerate(func=gen_tbl, outputs=["table_center"], kwargs=dict(df=test_df_center)), + BatchGenerate(func=gen_tbl, outputs=["table_right"], kwargs=dict(df=test_df_right)), + BatchTransform( + func=cross_merge_func, + inputs=[ + InputSpec(table="table_left", keys={"left_id": "id_left"}), + InputSpec(table="table_center", keys={"center_id": "id_center"}), + InputSpec(table="table_right", keys={"right_id": "id_right"}), + ], + outputs=[ + OutputSpec( + table="table_output", + keys={ + "left_id": "id_left", + "center_id": "id_center", + "right_id": "id_right", + }, + ) + ], + transform_keys=["left_id", "center_id", "right_id"], + chunk_size=6, + kwargs=dict( + input_intersection_idxs=[], + output_schema_tables=[output_schema], + ), + ), + ] + ) + + run_pipeline(ds, catalog, pipeline_case) + + assert_datatable_equal(catalog.get_datatable(ds, "table_left"), test_df_left) + assert_datatable_equal(catalog.get_datatable(ds, "table_center"), test_df_center) + assert_datatable_equal(catalog.get_datatable(ds, "table_right"), test_df_right) + assert_datatable_equal(catalog.get_datatable(ds, "table_output"), test_df_output) diff --git a/tests/test_complex_cross_merge_two_tables.py b/tests/test_complex_cross_merge_two_tables.py index a5852537..176e8588 100644 --- a/tests/test_complex_cross_merge_two_tables.py +++ b/tests/test_complex_cross_merge_two_tables.py @@ -12,6 +12,7 @@ from datapipe.step.batch_transform import BatchTransform from datapipe.store.database import TableStoreDB from datapipe.tests.util import assert_datatable_equal, assert_df_equal +from datapipe.types import InputSpec, OutputSpec TEST_SCHEMA_LEFT = [ pytest.param( @@ -250,6 +251,66 @@ def gen_tbl(df): assert_datatable_equal(tbl_left_x_right, test_df_left_x_right) +def test_complex_cross_merge_scenary_with_index_aliases(dbconn): + ds = DataStore(dbconn, create_meta_table=True) + + left_schema = [ + Column("id_left", Integer, primary_key=True), + Column("a_left", Integer), + ] + right_schema = [ + Column("id_right", Integer, primary_key=True), + Column("b_right", Integer), + ] + left_x_right_schema = left_schema + right_schema + + test_df_left = pd.DataFrame({"id_left": [1, 2], "a_left": [10, 20]}) + test_df_right = pd.DataFrame({"id_right": [3, 4], "b_right": [30, 40]}) + test_df_left_x_right = cross_merge_func(test_df_left, test_df_right) + + catalog = Catalog( + { + "tbl_left": Table(store=TableStoreDB(dbconn, "tbl_left", left_schema, True)), + "tbl_right": Table(store=TableStoreDB(dbconn, "tbl_right", right_schema, True)), + "tbl_left_x_right": Table(store=TableStoreDB(dbconn, "tbl_left_x_right", left_x_right_schema, True)), + } + ) + + def gen_tbl(df): + yield df + + pipeline_case = Pipeline( + [ + BatchGenerate(func=gen_tbl, outputs=["tbl_left"], kwargs=dict(df=test_df_left)), + BatchGenerate(func=gen_tbl, outputs=["tbl_right"], kwargs=dict(df=test_df_right)), + BatchTransform( + func=cross_merge_func, + inputs=[ + InputSpec(table="tbl_left", keys={"left_id": "id_left"}), + InputSpec(table="tbl_right", keys={"right_id": "id_right"}), + ], + outputs=[ + OutputSpec( + table="tbl_left_x_right", + keys={ + "left_id": "id_left", + "right_id": "id_right", + }, + ) + ], + transform_keys=["left_id", "right_id"], + chunk_size=6, + ), + ] + ) + + run_pipeline(ds, catalog, pipeline_case) + + assert_datatable_equal(catalog.get_datatable(ds, "tbl_left"), test_df_left) + assert_datatable_equal(catalog.get_datatable(ds, "tbl_right"), test_df_right) + assert_datatable_equal(catalog.get_datatable(ds, "tbl_left_x_right"), test_df_left_x_right) + + def reverse_cross_merge_func(df_left_x_right: pd.DataFrame, left_schema: list[Column], right_schema: list[Column]): df_left = df_left_x_right[[x.name for x in left_schema]].drop_duplicates() df_right = df_left_x_right[[x.name for x in right_schema]].drop_duplicates() diff --git a/tests/test_complex_pipeline.py b/tests/test_complex_pipeline.py index daf9333b..4c1f697b 100644 --- a/tests/test_complex_pipeline.py +++ b/tests/test_complex_pipeline.py @@ -11,7 +11,7 @@ from datapipe.step.batch_transform import BatchTransform from datapipe.store.database import TableStoreDB from datapipe.tests.util import assert_datatable_equal, assert_df_equal -from datapipe.types import IndexDF, Required +from datapipe.types import IndexDF, InputSpec, OutputSpec, Required TEST__ITEM = pd.DataFrame( { @@ -157,6 +157,127 @@ def complex_function(df__item, df__pipeline, df__prediction, df__keypoint, idx: assert_datatable_equal(ds.get_table("output"), TEST_RESULT) +def test_complex_pipeline_with_index_aliases(dbconn): + ds = DataStore(dbconn, create_meta_table=True) + catalog = Catalog( + { + "item": Table( + store=TableStoreDB( + dbconn, + "item", + [ + Column("item_id", String, primary_key=True), + Column("item__attribute", String), + ], + True, + ) + ), + "pipeline": Table( + store=TableStoreDB( + dbconn, + "pipeline", + [ + Column("pipeline_id", String, primary_key=True), + Column("pipeline__attribute", String), + ], + True, + ) + ), + "prediction": Table( + store=TableStoreDB( + dbconn, + "prediction", + [ + Column("item_id", String, primary_key=True), + Column("pipeline_id", String, primary_key=True), + Column("keypoint_name", String, primary_key=True), + Column("prediction__attribute", String), + ], + True, + ) + ), + "keypoint": Table( + store=TableStoreDB( + dbconn, + "keypoint", + [ + Column("keypoint_id", Integer, primary_key=True), + Column("keypoint_name", String, primary_key=True), + ], + True, + ) + ), + "output": Table( + store=TableStoreDB( + dbconn, + "output", + [ + Column("item_id", String, primary_key=True), + Column("pipeline_id", String, primary_key=True), + Column("attirbute", String), + ], + True, + ) + ), + } + ) + + def complex_function(df__item, df__pipeline, df__prediction, df__keypoint, idx: IndexDF): + assert idx[idx[["post_item_id", "post_pipeline_id"]].duplicated()].empty + assert len(df__keypoint) == len(TEST__KEYPOINT) + df__output = pd.merge(df__item, df__prediction, on=["item_id"]) + df__output = pd.merge(df__output, df__pipeline, on=["pipeline_id"]) + df__output = pd.merge(df__output, df__keypoint, on=["keypoint_name"]) + df__output = df__output[["item_id", "pipeline_id"]].drop_duplicates() + df__output["attirbute"] = "attribute" + return df__output + + pipeline = Pipeline( + [ + BatchTransform( + func=complex_function, + inputs=[ + InputSpec(table="item", keys={"post_item_id": "item_id"}), + InputSpec(table="pipeline", keys={"post_pipeline_id": "pipeline_id"}), + InputSpec( + table="prediction", + keys={ + "post_item_id": "item_id", + "post_pipeline_id": "pipeline_id", + }, + ), + "keypoint", + ], + outputs=[ + OutputSpec( + table="output", + keys={ + "post_item_id": "item_id", + "post_pipeline_id": "pipeline_id", + }, + ) + ], + transform_keys=["post_item_id", "post_pipeline_id"], + chunk_size=50, + ), + ] + ) + steps = build_compute(ds, catalog, pipeline) + ds.get_table("item").store_chunk(TEST__ITEM) + ds.get_table("pipeline").store_chunk(TEST__PIPELINE) + ds.get_table("prediction").store_chunk(TEST__PREDICTION) + ds.get_table("keypoint").store_chunk(TEST__KEYPOINT) + test_result = complex_function( + TEST__ITEM, + TEST__PIPELINE, + TEST__PREDICTION, + TEST__KEYPOINT, + idx=cast(IndexDF, pd.DataFrame(columns=["post_item_id", "post_pipeline_id"])), + ) + run_steps(ds, steps) + assert_datatable_equal(ds.get_table("output"), test_result) + + TEST__FROZEN_DATASET = pd.DataFrame( { "frozen_dataset_id": [f"frozen_dataset_id{i}" for i in range(2)], diff --git a/tests/test_core_steps1.py b/tests/test_core_steps1.py index e2f39e3d..266ef943 100644 --- a/tests/test_core_steps1.py +++ b/tests/test_core_steps1.py @@ -8,7 +8,7 @@ from sqlalchemy import Column from sqlalchemy.sql.sqltypes import JSON, Integer -from datapipe.compute import ComputeInput +from datapipe.compute import ComputeInput, ComputeOutput from datapipe.datatable import DataStore from datapipe.step.batch_generate import do_batch_generate from datapipe.step.batch_transform import BatchTransformStep @@ -99,7 +99,7 @@ def id_func(df): name="test", func=id_func, input_dts=[ComputeInput(dt=tbl1, join_type="full")], - output_dts=[tbl2], + output_dts=[ComputeOutput(dt=tbl2)], ) step.run_full(ds) @@ -130,7 +130,7 @@ def id_func(df): name="test", func=id_func, input_dts=[ComputeInput(dt=tbl1, join_type="full")], - output_dts=[tbl2], + output_dts=[ComputeOutput(dt=tbl2)], ) step.run_full(ds) @@ -165,7 +165,7 @@ def id_func(df): name="test", func=id_func, input_dts=[ComputeInput(dt=tbl1, join_type="full")], - output_dts=[tbl2], + output_dts=[ComputeOutput(dt=tbl2)], ) step.run_full(ds) @@ -224,7 +224,7 @@ def inc_func(df): name="step_inc", func=inc_func, input_dts=[ComputeInput(dt=tbl, join_type="full")], - output_dts=[tbl1, tbl2, tbl3], + output_dts=[ComputeOutput(dt=tbl1), ComputeOutput(dt=tbl2), ComputeOutput(dt=tbl3)], ) step_inc.run_full(ds) @@ -250,7 +250,7 @@ def inc_func_inv(df): name="step_inc_inv", func=inc_func_inv, input_dts=[ComputeInput(dt=tbl, join_type="full")], - output_dts=[tbl3, tbl2, tbl1], + output_dts=[ComputeOutput(dt=tbl3), ComputeOutput(dt=tbl2), ComputeOutput(dt=tbl1)], ) step_inc_inv.run_full(ds) @@ -306,7 +306,7 @@ def inc_func(df1, df2): ComputeInput(dt=tbl1, join_type="full"), ComputeInput(dt=tbl2, join_type="full"), ], - output_dts=[tbl], + output_dts=[ComputeOutput(dt=tbl)], ) step.run_full(ds) @@ -390,7 +390,7 @@ def inc_func(df): name="test", func=inc_func, input_dts=[ComputeInput(dt=tbl, join_type="full")], - output_dts=[tbl_good, tbl_bad], + output_dts=[ComputeOutput(dt=tbl_good), ComputeOutput(dt=tbl_bad)], ) step.run_full(ds) @@ -438,14 +438,14 @@ def inc_func_pack(df): name="unpack", func=inc_func_unpack, input_dts=[ComputeInput(dt=tbl, join_type="full")], - output_dts=[tbl_rel], + output_dts=[ComputeOutput(dt=tbl_rel)], ) step_pack = BatchTransformStep( ds=ds, name="pack", func=inc_func_pack, input_dts=[ComputeInput(dt=tbl_rel, join_type="full")], - output_dts=[tbl2], + output_dts=[ComputeOutput(dt=tbl2)], ) step_unpack.run_full(ds) @@ -506,14 +506,14 @@ def inc_func_pack(df): name="unpack", func=inc_func_unpack, input_dts=[ComputeInput(dt=tbl, join_type="full")], - output_dts=[tbl_rel], + output_dts=[ComputeOutput(dt=tbl_rel)], ) step_pack = BatchTransformStep( ds=ds, name="pack", func=inc_func_pack, input_dts=[ComputeInput(dt=tbl_rel, join_type="full")], - output_dts=[tbl2], + output_dts=[ComputeOutput(dt=tbl2)], ) step_unpack.run_full(ds) @@ -619,7 +619,7 @@ def inc_func_good(df): name="bad", func=inc_func_bad, input_dts=[ComputeInput(dt=tbl, join_type="full")], - output_dts=[tbl_good], + output_dts=[ComputeOutput(dt=tbl_good)], chunk_size=1, ) step_bad.run_full(ds) @@ -631,7 +631,7 @@ def inc_func_good(df): name="good", func=inc_func_good, input_dts=[ComputeInput(dt=tbl, join_type="full")], - output_dts=[tbl_good], + output_dts=[ComputeOutput(dt=tbl_good)], chunk_size=CHUNKSIZE, ) step_good.run_full(ds) diff --git a/tests/test_core_steps2.py b/tests/test_core_steps2.py index dfa589a8..024d1722 100644 --- a/tests/test_core_steps2.py +++ b/tests/test_core_steps2.py @@ -9,7 +9,7 @@ from sqlalchemy import Column, String from sqlalchemy.sql.sqltypes import Integer -from datapipe.compute import ComputeInput +from datapipe.compute import ComputeInput, ComputeOutput from datapipe.datatable import DataStore from datapipe.run_config import RunConfig from datapipe.step.batch_generate import do_batch_generate @@ -91,7 +91,7 @@ def test_batch_transform(dbconn): name="test", func=lambda df: df, input_dts=[ComputeInput(dt=tbl1, join_type="full")], - output_dts=[tbl2], + output_dts=[ComputeOutput(dt=tbl2)], ) step.run_full(ds) @@ -125,7 +125,7 @@ def test_batch_transform_with_filter(dbconn): name="test", func=lambda df: df, input_dts=[ComputeInput(dt=tbl1, join_type="full")], - output_dts=[tbl2], + output_dts=[ComputeOutput(dt=tbl2)], ) step.run_full( ds, @@ -151,7 +151,7 @@ def test_batch_transform_with_filter_not_in_transform_index(dbconn): name="test", func=lambda df: df[["item_id", "a"]], input_dts=[ComputeInput(dt=tbl1, join_type="full")], - output_dts=[tbl2], + output_dts=[ComputeOutput(dt=tbl2)], ) step.run_full( @@ -191,7 +191,7 @@ def update_df(df1: pd.DataFrame, df2: pd.DataFrame): ComputeInput(dt=tbl1, join_type="full"), ComputeInput(dt=tbl2, join_type="full"), ], - output_dts=[tbl2], + output_dts=[ComputeOutput(dt=tbl2)], ) step.run_full(ds) @@ -246,7 +246,7 @@ def transform_func(df, context=context): name="step1", func=transform_func, input_dts=[ComputeInput(dt=tbl1, join_type="full")], - output_dts=[tbl2], + output_dts=[ComputeOutput(dt=tbl2)], ) tbl1.store_chunk( @@ -305,7 +305,7 @@ def func(df): name="test", func=func, input_dts=[ComputeInput(dt=tbl1, join_type="full")], - output_dts=[tbl2], + output_dts=[ComputeOutput(dt=tbl2)], ) change_list = ChangeList() @@ -354,7 +354,7 @@ def update_df(products: pd.DataFrame, items: pd.DataFrame): ComputeInput(dt=products, join_type="full"), ComputeInput(dt=items, join_type="full"), ], - output_dts=[items2], + output_dts=[ComputeOutput(dt=items2)], ) step.run_full(ds) diff --git a/tests/test_image_pipeline.py b/tests/test_image_pipeline.py index 0707f3fe..d855e783 100644 --- a/tests/test_image_pipeline.py +++ b/tests/test_image_pipeline.py @@ -7,6 +7,7 @@ from datapipe.compute import ( Catalog, ComputeInput, + ComputeOutput, Pipeline, Table, build_compute, @@ -66,7 +67,7 @@ def test_image_datatables(dbconn, tmp_dir): name="resize_images", func=resize_images, input_dts=[ComputeInput(dt=tbl1, join_type="full")], - output_dts=[tbl2], + output_dts=[ComputeOutput(dt=tbl2)], ) step.run_full(ds) diff --git a/tests/test_meta_transform_keys.py b/tests/test_meta_transform_keys.py index 6b25ca10..a270253f 100644 --- a/tests/test_meta_transform_keys.py +++ b/tests/test_meta_transform_keys.py @@ -3,12 +3,12 @@ import pandas as pd from sqlalchemy import Column, String -from datapipe.compute import ComputeInput +from datapipe.compute import Catalog, ComputeInput, ComputeOutput, Table from datapipe.datatable import DataStore -from datapipe.step.batch_transform import BatchTransformStep +from datapipe.step.batch_transform import BatchTransform, BatchTransformStep from datapipe.store.database import DBConn, TableStoreDB from datapipe.tests.util import assert_datatable_equal -from datapipe.types import DataField +from datapipe.types import DataField, InputSpec, OutputSpec def test_transform_keys(dbconn: DBConn): @@ -93,7 +93,7 @@ def transform_func(posts_df, profiles_df): func=transform_func, input_dts=[ ComputeInput( - dt=posts, + dt=posts, # [id, user_id, content] join_type="full", keys={ "post_id": "id", @@ -101,14 +101,27 @@ def transform_func(posts_df, profiles_df): }, ), ComputeInput( - dt=profiles, + dt=profiles, # [id, username] join_type="inner", keys={ "user_id": "id", }, ), ], - output_dts=[output_dt], + # post, profiles -> output_dt + + # 1 post [id, user_id, content] -> mapping [post_id, user_id, content] + # 2 profiles [id, username] -> mapping [user_id, username] + + # output_dt 1x2 = [post_id, user_id, content, username] -> get output [post_id, username] -> mapping [id, user_name] + output_dts=[ + ComputeOutput( + dt=output_dt, # [id, user_id, content, username] + keys={ + "post_id": "id", + } + ), + ], transform_keys=["post_id", "user_id"], ) @@ -128,37 +141,1089 @@ def transform_func(posts_df, profiles_df): ), ) - # 8. Добавим новые данные и проверим инкрементальную обработку + # 8. Изменение lookup-таблицы должно пересчитать все связанные posts. time.sleep(0.01) # Небольшая задержка для различения timestamp'ов process_ts2 = time.time() + profiles.store_chunk(pd.DataFrame([{"id": "1", "username": "alice-updated"}]), now=process_ts2) + step.run_full(ds) + + assert_datatable_equal( + output_dt, + pd.DataFrame( + [ + {"id": "1", "user_id": "1", "content": "Post 1", "username": "alice-updated"}, + {"id": "2", "user_id": "1", "content": "Post 2", "username": "alice-updated"}, + {"id": "3", "user_id": "2", "content": "Post 3", "username": "bob"}, + ] + ), + ) + + # 11. Удаление lookup-записи должно удалить все output rows для связанных posts. + time.sleep(0.01) # Небольшая задержка для различения timestamp'ов + profiles.delete_by_idx(pd.DataFrame([{"id": "1"}]), now=time.time()) + step.run_full(ds) + + assert_datatable_equal( + output_dt, + pd.DataFrame( + [ + {"id": "3", "user_id": "2", "content": "Post 3", "username": "bob"}, + ] + ), + ) + + # 9. Добавим новые данные и проверим инкрементальную обработку + time.sleep(0.01) # Небольшая задержка для различения timestamp'ов + process_ts3 = time.time() + # Добавляем 1 новый пост new_posts_df = pd.DataFrame( [ {"id": "4", "user_id": "1", "content": "New Post 4"}, ] ) - posts.store_chunk(new_posts_df, now=process_ts2) + posts.store_chunk(new_posts_df, now=process_ts3) - # Добавляем 1 новый профиль + # Добавляем 1 новый профиль без связанных posts. Он не должен создать partial transform task. new_profiles_df = pd.DataFrame( [ {"id": "3", "username": "charlie"}, ] ) - profiles.store_chunk(new_profiles_df, now=process_ts2) + profiles.store_chunk(new_profiles_df, now=process_ts3) - # 9. Запускаем инкрементальную обработку + # 10. Запускаем инкрементальную обработку step.run_full(ds) assert_datatable_equal( output_dt, pd.DataFrame( [ - {"id": "1", "user_id": "1", "content": "Post 1", "username": "alice"}, - {"id": "2", "user_id": "1", "content": "Post 2", "username": "alice"}, {"id": "3", "user_id": "2", "content": "Post 3", "username": "bob"}, - {"id": "4", "user_id": "1", "content": "New Post 4", "username": "alice"}, + ] + ), + ) + + +def test_transform_output_spec_keys(dbconn: DBConn): + ds = DataStore(dbconn, create_meta_table=True) + + catalog = Catalog( + { + "posts": Table( + store=TableStoreDB( + dbconn, + "posts", + [ + Column("id", String, primary_key=True), + Column("user_id", String), + Column("content", String), + ], + create_table=True, + ) + ), + "profiles": Table( + store=TableStoreDB( + dbconn, + "profiles", + [ + Column("id", String, primary_key=True), + Column("username", String), + ], + create_table=True, + ) + ), + "posts_with_username": Table( + store=TableStoreDB( + dbconn, + "posts_with_username", + [ + Column("id", String, primary_key=True), + Column("user_id", String), + Column("content", String), + Column("username", String), + ], + create_table=True, + ) + ), + } + ) + + def transform_func(posts_df, profiles_df): + result = posts_df.merge(profiles_df, left_on="user_id", right_on="id", suffixes=("", "_profile")) + return result[["id", "user_id", "content", "username"]] + + step = BatchTransform( + func=transform_func, + inputs=[ + InputSpec( + table="posts", + keys={ + "post_id": "id", + "user_id": DataField("user_id"), + }, + ), + InputSpec( + table="profiles", + keys={ + "user_id": "id", + }, + ), + ], + outputs=[OutputSpec(table="posts_with_username", keys={"post_id": "id"})], + transform_keys=["post_id", "user_id"], + ).build_compute(ds, catalog)[0] + + assert isinstance(step, BatchTransformStep) + assert len(step.output_specs) == 1 + assert step.output_specs[0].keys == {"post_id": "id"} + + +def test_transform_keys_with_input_table_as_output(dbconn: DBConn): + ds = DataStore(dbconn, create_meta_table=True) + + users = ds.create_table( + "users", + TableStoreDB( + dbconn, + "users", + [ + Column("id", String, primary_key=True), + Column("name", String), + ], + create_table=True, + ), + ) + scores = ds.create_table( + "scores", + TableStoreDB( + dbconn, + "scores", + [ + Column("id", String, primary_key=True), + Column("score", String), + Column("user_name", String), + ], + create_table=True, + ), + ) + + process_ts = time.time() + users.store_chunk( + pd.DataFrame( + [ + {"id": "u1", "name": "Alice"}, + {"id": "u2", "name": "Bob"}, + ] + ), + now=process_ts, + ) + scores.store_chunk( + pd.DataFrame( + [ + {"id": "u1", "score": "10", "user_name": ""}, + {"id": "u2", "score": "20", "user_name": ""}, + ] + ), + now=process_ts, + ) + + def transform_func(users_df, scores_df): + df = scores_df[["id", "score"]].merge(users_df, on="id") + return df.rename(columns={"name": "user_name"})[["id", "score", "user_name"]] + + step = BatchTransformStep( + ds=ds, + name="test_input_table_as_output", + func=transform_func, + input_dts=[ + ComputeInput( + dt=users, + join_type="full", + keys={ + "user_id": "id", + }, + ), + ComputeInput( + dt=scores, + join_type="inner", + keys={ + "user_id": "id", + }, + ), + ], + output_dts=[ + ComputeOutput( + dt=scores, + keys={ + "user_id": "id", + }, + ) + ], + transform_keys=["user_id"], + ) + + step.run_full(ds) + + assert_datatable_equal( + scores, + pd.DataFrame( + [ + {"id": "u1", "score": "10", "user_name": "Alice"}, + {"id": "u2", "score": "20", "user_name": "Bob"}, + ] + ), + ) + + time.sleep(0.01) + users.store_chunk(pd.DataFrame([{"id": "u1", "name": "Alice Updated"}]), now=time.time()) + step.run_full(ds) + + assert_datatable_equal( + scores, + pd.DataFrame( + [ + {"id": "u1", "score": "10", "user_name": "Alice Updated"}, + {"id": "u2", "score": "20", "user_name": "Bob"}, + ] + ), + ) + + +def test_batch_transform_outputs_with_different_key_mappings(dbconn: DBConn): + ds = DataStore(dbconn, create_meta_table=True) + + catalog = Catalog( + { + "posts": Table( + store=TableStoreDB( + dbconn, + "posts", + [ + Column("id", String, primary_key=True), + Column("user_id", String), + Column("content", String), + ], + create_table=True, + ) + ), + "profiles": Table( + store=TableStoreDB( + dbconn, + "profiles", + [ + Column("id", String, primary_key=True), + Column("username", String), + ], + create_table=True, + ) + ), + "post_cards": Table( + store=TableStoreDB( + dbconn, + "post_cards", + [ + Column("id", String, primary_key=True), + Column("username", String), + Column("content", String), + ], + create_table=True, + ) + ), + "user_cards": Table( + store=TableStoreDB( + dbconn, + "user_cards", + [ + Column("id", String, primary_key=True), + Column("username", String), + Column("posts_count", String), + ], + create_table=True, + ) + ), + } + ) + + posts = catalog.get_datatable(ds, "posts") + profiles = catalog.get_datatable(ds, "profiles") + post_cards = catalog.get_datatable(ds, "post_cards") + user_cards = catalog.get_datatable(ds, "user_cards") + + process_ts = time.time() + posts.store_chunk( + pd.DataFrame( + [ + {"id": "p1", "user_id": "u1", "content": "Post 1"}, + {"id": "p2", "user_id": "u1", "content": "Post 2"}, + {"id": "p3", "user_id": "u2", "content": "Post 3"}, + ] + ), + now=process_ts, + ) + profiles.store_chunk( + pd.DataFrame( + [ + {"id": "u1", "username": "alice"}, + {"id": "u2", "username": "bob"}, + ] + ), + now=process_ts, + ) + + def transform_func(posts_df, profiles_df): + df = posts_df.merge(profiles_df, left_on="user_id", right_on="id", suffixes=("_post", "_profile")) + post_cards_df = df.rename(columns={"id_post": "id"})[["id", "username", "content"]] + user_cards_df = ( + df.groupby(["id_profile", "username"], as_index=False) + .agg(posts_count=("id_post", "count")) + .rename(columns={"id_profile": "id"}) + ) + user_cards_df["posts_count"] = user_cards_df["posts_count"].astype(str) + return post_cards_df, user_cards_df + + step = BatchTransform( + func=transform_func, + inputs=[ + InputSpec( + table="posts", + keys={ + "post_id": "id", + "user_id": DataField("user_id"), + }, + ), + InputSpec( + table="profiles", + keys={ + "user_id": "id", + }, + ), + ], + outputs=[ + OutputSpec(table="post_cards", keys={"post_id": "id"}), + OutputSpec(table="user_cards", keys={"user_id": "id"}), + ], + transform_keys=["post_id", "user_id"], + ).build_compute(ds, catalog)[0] + + step.run_full(ds) + + assert_datatable_equal( + post_cards, + pd.DataFrame( + [ + {"id": "p1", "username": "alice", "content": "Post 1"}, + {"id": "p2", "username": "alice", "content": "Post 2"}, + {"id": "p3", "username": "bob", "content": "Post 3"}, + ] + ), + ) + assert_datatable_equal( + user_cards, + pd.DataFrame( + [ + {"id": "u1", "username": "alice", "posts_count": "2"}, + {"id": "u2", "username": "bob", "posts_count": "1"}, + ] + ), + ) + + time.sleep(0.01) + profiles.delete_by_idx(pd.DataFrame([{"id": "u1"}]), now=time.time()) + step.run_full(ds) + + assert_datatable_equal( + post_cards, + pd.DataFrame( + [ + {"id": "p3", "username": "bob", "content": "Post 3"}, + ] + ), + ) + assert_datatable_equal( + user_cards, + pd.DataFrame( + [ + {"id": "u2", "username": "bob", "posts_count": "1"}, + ] + ), + ) + + +def test_transform_keys_with_composite_aliases_and_multiple_outputs(dbconn: DBConn): + ds = DataStore(dbconn, create_meta_table=True) + + events = ds.create_table( + "events", + TableStoreDB( + dbconn, + "events", + [ + Column("event_id", String, primary_key=True), + Column("tenant_id", String, primary_key=True), + Column("user_id", String), + Column("payload", String), + ], + create_table=True, + ), + ) + tenants = ds.create_table( + "tenants", + TableStoreDB( + dbconn, + "tenants", + [ + Column("id", String, primary_key=True), + Column("tenant_name", String), + ], + create_table=True, + ), + ) + enriched_events = ds.create_table( + "enriched_events", + TableStoreDB( + dbconn, + "enriched_events", + [ + Column("event_pk", String, primary_key=True), + Column("tenant_pk", String, primary_key=True), + Column("user_id", String), + Column("payload", String), + Column("tenant_name", String), + ], + create_table=True, + ), + ) + event_summaries = ds.create_table( + "event_summaries", + TableStoreDB( + dbconn, + "event_summaries", + [ + Column("summary_event_id", String, primary_key=True), + Column("summary_tenant_id", String, primary_key=True), + Column("summary", String), + ], + create_table=True, + ), + ) + + process_ts = time.time() + events.store_chunk( + pd.DataFrame( + [ + {"event_id": "e1", "tenant_id": "t1", "user_id": "u1", "payload": "payload-1"}, + {"event_id": "e2", "tenant_id": "t1", "user_id": "u2", "payload": "payload-2"}, + {"event_id": "e3", "tenant_id": "t2", "user_id": "u3", "payload": "payload-3"}, + ] + ), + now=process_ts, + ) + tenants.store_chunk( + pd.DataFrame( + [ + {"id": "t1", "tenant_name": "tenant-one"}, + {"id": "t2", "tenant_name": "tenant-two"}, + ] + ), + now=process_ts, + ) + + def transform_func(events_df, tenants_df): + df = events_df.merge(tenants_df, left_on="tenant_id", right_on="id") + enriched_df = df.rename(columns={"event_id": "event_pk", "tenant_id": "tenant_pk"})[ + ["event_pk", "tenant_pk", "user_id", "payload", "tenant_name"] + ] + summary_df = df.rename(columns={"event_id": "summary_event_id", "tenant_id": "summary_tenant_id"})[ + ["summary_event_id", "summary_tenant_id"] + ].copy() + summary_df["summary"] = [ + f"{payload}@{tenant_name}" for payload, tenant_name in zip(df["payload"], df["tenant_name"]) + ] + return enriched_df, summary_df + + step = BatchTransformStep( + ds=ds, + name="test_composite_aliases", + func=transform_func, + input_dts=[ + ComputeInput( + dt=events, + join_type="full", + keys={ + "task_event_id": "event_id", + "task_tenant_id": "tenant_id", + }, + ), + ComputeInput( + dt=tenants, + join_type="inner", + keys={ + "task_tenant_id": "id", + }, + ), + ], + output_dts=[ + ComputeOutput( + dt=enriched_events, + keys={ + "task_event_id": "event_pk", + "task_tenant_id": "tenant_pk", + }, + ), + ComputeOutput( + dt=event_summaries, + keys={ + "task_event_id": "summary_event_id", + "task_tenant_id": "summary_tenant_id", + }, + ), + ], + transform_keys=["task_event_id", "task_tenant_id"], + ) + + step.run_full(ds) + + assert_datatable_equal( + enriched_events, + pd.DataFrame( + [ + { + "event_pk": "e1", + "tenant_pk": "t1", + "user_id": "u1", + "payload": "payload-1", + "tenant_name": "tenant-one", + }, + { + "event_pk": "e2", + "tenant_pk": "t1", + "user_id": "u2", + "payload": "payload-2", + "tenant_name": "tenant-one", + }, + { + "event_pk": "e3", + "tenant_pk": "t2", + "user_id": "u3", + "payload": "payload-3", + "tenant_name": "tenant-two", + }, + ] + ), + ) + assert_datatable_equal( + event_summaries, + pd.DataFrame( + [ + {"summary_event_id": "e1", "summary_tenant_id": "t1", "summary": "payload-1@tenant-one"}, + {"summary_event_id": "e2", "summary_tenant_id": "t1", "summary": "payload-2@tenant-one"}, + {"summary_event_id": "e3", "summary_tenant_id": "t2", "summary": "payload-3@tenant-two"}, + ] + ), + ) + + time.sleep(0.01) + tenants.store_chunk(pd.DataFrame([{"id": "t1", "tenant_name": "tenant-one-updated"}]), now=time.time()) + step.run_full(ds) + + assert_datatable_equal( + enriched_events, + pd.DataFrame( + [ + { + "event_pk": "e1", + "tenant_pk": "t1", + "user_id": "u1", + "payload": "payload-1", + "tenant_name": "tenant-one-updated", + }, + { + "event_pk": "e2", + "tenant_pk": "t1", + "user_id": "u2", + "payload": "payload-2", + "tenant_name": "tenant-one-updated", + }, + { + "event_pk": "e3", + "tenant_pk": "t2", + "user_id": "u3", + "payload": "payload-3", + "tenant_name": "tenant-two", + }, + ] + ), + ) + assert_datatable_equal( + event_summaries, + pd.DataFrame( + [ + {"summary_event_id": "e1", "summary_tenant_id": "t1", "summary": "payload-1@tenant-one-updated"}, + {"summary_event_id": "e2", "summary_tenant_id": "t1", "summary": "payload-2@tenant-one-updated"}, + {"summary_event_id": "e3", "summary_tenant_id": "t2", "summary": "payload-3@tenant-two"}, + ] + ), + ) + + time.sleep(0.01) + tenants.delete_by_idx(pd.DataFrame([{"id": "t1"}]), now=time.time()) + step.run_full(ds) + + assert_datatable_equal( + enriched_events, + pd.DataFrame( + [ + { + "event_pk": "e3", + "tenant_pk": "t2", + "user_id": "u3", + "payload": "payload-3", + "tenant_name": "tenant-two", + }, + ] + ), + ) + assert_datatable_equal( + event_summaries, + pd.DataFrame( + [ + {"summary_event_id": "e3", "summary_tenant_id": "t2", "summary": "payload-3@tenant-two"}, + ] + ), + ) + + +def test_transform_keys_with_same_column_names_and_different_aliases(dbconn: DBConn): + ds = DataStore(dbconn, create_meta_table=True) + + users = ds.create_table( + "users", + TableStoreDB( + dbconn, + "users", + [ + Column("id", String, primary_key=True), + Column("name", String), + Column("team_id", String), + Column("role_id", String), + ], + create_table=True, + ), + ) + teams = ds.create_table( + "teams", + TableStoreDB( + dbconn, + "teams", + [ + Column("id", String, primary_key=True), + Column("name", String), + ], + create_table=True, + ), + ) + roles = ds.create_table( + "roles", + TableStoreDB( + dbconn, + "roles", + [ + Column("id", String, primary_key=True), + Column("name", String), + ], + create_table=True, + ), + ) + memberships = ds.create_table( + "memberships", + TableStoreDB( + dbconn, + "memberships", + [ + Column("member_id", String, primary_key=True), + Column("member_team_id", String, primary_key=True), + Column("member_role_id", String, primary_key=True), + Column("user_name", String), + Column("team_name", String), + Column("role_name", String), + ], + create_table=True, + ), + ) + + process_ts = time.time() + users.store_chunk( + pd.DataFrame( + [ + {"id": "u1", "name": "Alice", "team_id": "t1", "role_id": "r1"}, + {"id": "u2", "name": "Bob", "team_id": "t1", "role_id": "r2"}, + {"id": "u3", "name": "Eve", "team_id": "t2", "role_id": "r1"}, + ] + ), + now=process_ts, + ) + teams.store_chunk( + pd.DataFrame( + [ + {"id": "t1", "name": "Core"}, + {"id": "t2", "name": "ML"}, + ] + ), + now=process_ts, + ) + roles.store_chunk( + pd.DataFrame( + [ + {"id": "r1", "name": "Admin"}, + {"id": "r2", "name": "Reviewer"}, + ] + ), + now=process_ts, + ) + + def transform_func(users_df, teams_df, roles_df): + df = users_df.merge(teams_df, left_on="team_id", right_on="id", suffixes=("_user", "_team")) + df = df.merge(roles_df, left_on="role_id", right_on="id") + return pd.DataFrame( + { + "member_id": df["id_user"], + "member_team_id": df["team_id"], + "member_role_id": df["role_id"], + "user_name": df["name_user"], + "team_name": df["name_team"], + "role_name": df["name"], + } + ) + + step = BatchTransformStep( + ds=ds, + name="test_same_column_names_aliases", + func=transform_func, + input_dts=[ + ComputeInput( + dt=users, + join_type="full", + keys={ + "user_id": "id", + "team_id": DataField("team_id"), + "role_id": DataField("role_id"), + }, + ), + ComputeInput( + dt=teams, + join_type="inner", + keys={ + "team_id": "id", + }, + ), + ComputeInput( + dt=roles, + join_type="inner", + keys={ + "role_id": "id", + }, + ), + ], + output_dts=[ + ComputeOutput( + dt=memberships, + keys={ + "user_id": "member_id", + "team_id": "member_team_id", + "role_id": "member_role_id", + }, + ), + ], + transform_keys=["user_id", "team_id", "role_id"], + ) + + step.run_full(ds) + + assert_datatable_equal( + memberships, + pd.DataFrame( + [ + { + "member_id": "u1", + "member_team_id": "t1", + "member_role_id": "r1", + "user_name": "Alice", + "team_name": "Core", + "role_name": "Admin", + }, + { + "member_id": "u2", + "member_team_id": "t1", + "member_role_id": "r2", + "user_name": "Bob", + "team_name": "Core", + "role_name": "Reviewer", + }, + { + "member_id": "u3", + "member_team_id": "t2", + "member_role_id": "r1", + "user_name": "Eve", + "team_name": "ML", + "role_name": "Admin", + }, + ] + ), + ) + + time.sleep(0.01) + teams.store_chunk(pd.DataFrame([{"id": "t1", "name": "Core Platform"}]), now=time.time()) + step.run_full(ds) + + assert_datatable_equal( + memberships, + pd.DataFrame( + [ + { + "member_id": "u1", + "member_team_id": "t1", + "member_role_id": "r1", + "user_name": "Alice", + "team_name": "Core Platform", + "role_name": "Admin", + }, + { + "member_id": "u2", + "member_team_id": "t1", + "member_role_id": "r2", + "user_name": "Bob", + "team_name": "Core Platform", + "role_name": "Reviewer", + }, + { + "member_id": "u3", + "member_team_id": "t2", + "member_role_id": "r1", + "user_name": "Eve", + "team_name": "ML", + "role_name": "Admin", + }, + ] + ), + ) + + time.sleep(0.01) + roles.delete_by_idx(pd.DataFrame([{"id": "r1"}]), now=time.time()) + step.run_full(ds) + + assert_datatable_equal( + memberships, + pd.DataFrame( + [ + { + "member_id": "u2", + "member_team_id": "t1", + "member_role_id": "r2", + "user_name": "Bob", + "team_name": "Core Platform", + "role_name": "Reviewer", + }, + ] + ), + ) + + +def test_transform_keys_with_same_output_column_names_and_different_aliases(dbconn: DBConn): + ds = DataStore(dbconn, create_meta_table=True) + + users = ds.create_table( + "users", + TableStoreDB( + dbconn, + "users", + [ + Column("id", String, primary_key=True), + Column("name", String), + Column("role_id", String), + ], + create_table=True, + ), + ) + roles = ds.create_table( + "roles", + TableStoreDB( + dbconn, + "roles", + [ + Column("id", String, primary_key=True), + Column("name", String), + ], + create_table=True, + ), + ) + user_cards = ds.create_table( + "user_cards", + TableStoreDB( + dbconn, + "user_cards", + [ + Column("id", String, primary_key=True), + Column("user_name", String), + Column("role_name", String), + ], + create_table=True, + ), + ) + role_cards = ds.create_table( + "role_cards", + TableStoreDB( + dbconn, + "role_cards", + [ + Column("id", String, primary_key=True), + Column("role_name", String), + Column("users_count", String), + ], + create_table=True, + ), + ) + + process_ts = time.time() + users.store_chunk( + pd.DataFrame( + [ + {"id": "u1", "name": "Alice", "role_id": "r1"}, + {"id": "u2", "name": "Bob", "role_id": "r2"}, + {"id": "u3", "name": "Eve", "role_id": "r1"}, + ] + ), + now=process_ts, + ) + roles.store_chunk( + pd.DataFrame( + [ + {"id": "r1", "name": "Admin"}, + {"id": "r2", "name": "Reviewer"}, + ] + ), + now=process_ts, + ) + + def transform_func(users_df, roles_df): + df = users_df.merge(roles_df, left_on="role_id", right_on="id", suffixes=("_user", "_role")) + user_cards_df = pd.DataFrame( + { + "id": df["id_user"], + "user_name": df["name_user"], + "role_name": df["name_role"], + } + ) + role_cards_df = ( + df.groupby(["id_role", "name_role"], as_index=False) + .agg(users_count=("id_user", "count")) + .rename(columns={"id_role": "id", "name_role": "role_name"}) + ) + role_cards_df["users_count"] = role_cards_df["users_count"].astype(str) + return user_cards_df, role_cards_df + + step = BatchTransformStep( + ds=ds, + name="test_same_output_columns_aliases", + func=transform_func, + input_dts=[ + ComputeInput( + dt=users, + join_type="full", + keys={ + "user_id": "id", + "role_id": DataField("role_id"), + }, + ), + ComputeInput( + dt=roles, + join_type="inner", + keys={ + "role_id": "id", + }, + ), + ], + output_dts=[ + ComputeOutput( + dt=user_cards, + keys={ + "user_id": "id", + }, + ), + ComputeOutput( + dt=role_cards, + keys={ + "role_id": "id", + }, + ), + ], + transform_keys=["user_id", "role_id"], + ) + + step.run_full(ds) + + assert_datatable_equal( + user_cards, + pd.DataFrame( + [ + {"id": "u1", "user_name": "Alice", "role_name": "Admin"}, + {"id": "u2", "user_name": "Bob", "role_name": "Reviewer"}, + {"id": "u3", "user_name": "Eve", "role_name": "Admin"}, + ] + ), + ) + assert_datatable_equal( + role_cards, + pd.DataFrame( + [ + {"id": "r1", "role_name": "Admin", "users_count": "2"}, + {"id": "r2", "role_name": "Reviewer", "users_count": "1"}, + ] + ), + ) + + time.sleep(0.01) + roles.store_chunk(pd.DataFrame([{"id": "r1", "name": "Administrator"}]), now=time.time()) + step.run_full(ds) + + assert_datatable_equal( + user_cards, + pd.DataFrame( + [ + {"id": "u1", "user_name": "Alice", "role_name": "Administrator"}, + {"id": "u2", "user_name": "Bob", "role_name": "Reviewer"}, + {"id": "u3", "user_name": "Eve", "role_name": "Administrator"}, + ] + ), + ) + assert_datatable_equal( + role_cards, + pd.DataFrame( + [ + {"id": "r1", "role_name": "Administrator", "users_count": "2"}, + {"id": "r2", "role_name": "Reviewer", "users_count": "1"}, + ] + ), + ) + + time.sleep(0.01) + roles.delete_by_idx(pd.DataFrame([{"id": "r1"}]), now=time.time()) + step.run_full(ds) + + assert_datatable_equal( + user_cards, + pd.DataFrame( + [ + {"id": "u2", "user_name": "Bob", "role_name": "Reviewer"}, + ] + ), + ) + assert_datatable_equal( + role_cards, + pd.DataFrame( + [ + {"id": "r2", "role_name": "Reviewer", "users_count": "1"}, ] ), ) diff --git a/tests/test_table_store_filedir.py b/tests/test_table_store_filedir.py index 7c5fe483..10b8b687 100644 --- a/tests/test_table_store_filedir.py +++ b/tests/test_table_store_filedir.py @@ -1,6 +1,7 @@ import base64 import io import json +from pathlib import Path import fsspec import numpy as np @@ -186,7 +187,7 @@ def tmp_dir_with_json_data(tmp_dir): for fn, data in TEST_JSONS.items(): with fsspec.open(f"{tmp_dir}/{fn}.json", "w+") as f: json.dump(data, f) - yield tmp_dir + yield canonical_local_path(tmp_dir) def get_test_df_filepath(test_df, tmp_dir_): @@ -195,6 +196,15 @@ def get_test_df_filepath(test_df, tmp_dir_): return test_df_filepath +def canonical_local_path(path): + path = str(path) + if path.startswith("file://"): + return f"file://{Path(path.removeprefix('file://')).resolve()}" + if "://" in path: + return path + return str(Path(path).resolve()) + + def test_read_json_rows(tmp_dir_with_json_data): ts = TableStoreFiledir(f"{tmp_dir_with_json_data}/{{id}}.json", adapter=JSONFile()) @@ -409,7 +419,7 @@ def tmp_several_dirs_with_json_data(tmp_dir): out.write(f'{{"a": {i}, "b": {j}}}') with fsspec.open(f"{tmp_dir}/folder{i}/{i}.json", "w", auto_mkdir=True) as out: out.write(f'{{"a": {i}, "b": -1}}') - yield tmp_dir + yield canonical_local_path(tmp_dir) TEST_DF_FOLDER0 = pd.DataFrame(