Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ See [key-mapping.md](design-docs/2025-12-key-mapping.md) for motivation
* Added `keys` parameter to `InputSpec` and `ComputeInput` to support
joining tables with different key names
* Added `DataField` accessor for `InputSpec.keys`
* Added `OutputSpec` and `ComputeOutput.keys` to explicitly map transform keys to output table primary keys
* Fixed batch transform cleanup for aliased output keys and incomplete transform keys
* Updated `tqdm-loggable` to avoid Python 3.12 deprecation warnings

### Python3.9 support is deprecated

Expand Down
18 changes: 13 additions & 5 deletions datapipe/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,17 @@ def primary_schema(self) -> MetaSchema:
return self.dt.primary_schema


@dataclass
class ComputeOutput:
dt: DataTable

# If provided, this dict tells how transform keys map to output primary keys.
#
# Example: {"post_id": "id"} means that cleanup for transform key post_id
# should be applied to output primary key id.
keys: dict[str, str] | None = None


class ComputeStep:
"""
Шаг вычислений в графе вычислений.
Expand All @@ -152,16 +163,13 @@ class ComputeStep:
def __init__(
self,
name: str,
input_dts: Sequence[ComputeInput | DataTable],
input_dts: Sequence[ComputeInput],
output_dts: list[DataTable],
labels: Labels | None = None,
executor_config: ExecutorConfig | None = None,
) -> None:
self._name = name
# Нормализация input_dts: автоматически оборачиваем DataTable в ComputeInput
self.input_dts = [
inp if isinstance(inp, ComputeInput) else ComputeInput(dt=inp, join_type="full") for inp in input_dts
]
self.input_dts = list(input_dts)
self.output_dts = output_dts
self._labels = labels
self.executor_config = executor_config
Expand Down
2 changes: 2 additions & 0 deletions datapipe/meta/sql_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,8 @@ def _make_agg_of_agg(

prev_ctes.append(cte)

sql = sql.where(sa.and_(*[key.isnot(None) for key in coalesce_keys]))

sql = sql.group_by(*coalesce_keys)

return sql.cte(name=f"all__{agg_col}")
Expand Down
85 changes: 61 additions & 24 deletions datapipe/step/batch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
Sequence,
)

import pandas as pd
from opentelemetry import trace
from tqdm_loggable.auto import tqdm

from datapipe.compute import (
Catalog,
ComputeInput,
ComputeOutput,
ComputeStep,
PipelineStep,
StepStatus,
Expand All @@ -32,6 +34,8 @@
InputSpec,
Labels,
PipelineInput,
PipelineOutput,
OutputSpec,
Required,
TableOrName,
TransformResult,
Expand Down Expand Up @@ -67,8 +71,8 @@ def __init__(
self,
ds: DataStore,
name: str,
input_dts: Sequence[ComputeInput | DataTable],
output_dts: list[DataTable],
input_dts: Sequence[ComputeInput],
output_dts: Sequence[ComputeOutput],
transform_keys: list[str] | None = None,
chunk_size: int = 1000,
labels: Labels | None = None,
Expand All @@ -77,26 +81,19 @@ def __init__(
order_by: list[str] | None = None,
order: Literal["asc", "desc"] = "asc",
) -> None:
# Support both old API (List[DataTable]) and new API (List[ComputeInput])
# Convert to new API format
compute_input_dts: list[ComputeInput] = []
for inp in input_dts:
if isinstance(inp, ComputeInput):
# New API: ComputeInput with .dt attribute
compute_input_dts.append(inp)
else:
# Old API: DataTable passed directly - convert to new API
compute_input_dts.append(ComputeInput(dt=inp, join_type="full"))
compute_input_dts = list(input_dts)
compute_output_dts = list(output_dts)

ComputeStep.__init__(
self,
name=name,
input_dts=compute_input_dts,
output_dts=output_dts,
output_dts=[out.dt for out in compute_output_dts],
labels=labels,
executor_config=executor_config,
)

self.output_specs = compute_output_dts
self.chunk_size = chunk_size

# Force transform_keys to be a list, otherwise Pandas will not be happy
Expand All @@ -118,6 +115,31 @@ def __init__(
self.order_by = order_by
self.order = order

@staticmethod
def _transform_idx_to_output_idx(
idx: IndexDF,
output_spec: ComputeOutput,
) -> IndexDF | None:
res_dt = output_spec.dt
output_to_transform_keys = {
output_key: transform_key for transform_key, output_key in (output_spec.keys or {}).items()
}
columns: dict[str, Any] = {}

for pk in res_dt.primary_keys:
if pk in idx.columns:
columns[pk] = idx[pk]
continue

transform_key = output_to_transform_keys.get(pk)
if transform_key is not None and transform_key in idx.columns:
columns[pk] = idx[transform_key]

if not columns:
return None

return IndexDF(pd.DataFrame(columns))

def _apply_filters_to_run_config(self, run_config: RunConfig | None = None) -> RunConfig | None:
if self.filters is None:
return run_config
Expand Down Expand Up @@ -198,12 +220,13 @@ def store_batch_result(
assert len(self.output_dts) == 1
output_dfs = [output_dfs]

for k, res_dt in enumerate(self.output_dts):
for k, output_spec in enumerate(self.output_specs):
res_dt = output_spec.dt
# Берем k-ое значение функции для k-ой таблички
# Добавляем результат в результирующие чанки
change_idx = res_dt.store_chunk(
data_df=output_dfs[k],
processed_idx=idx,
processed_idx=self._transform_idx_to_output_idx(idx, output_spec),
now=process_ts,
run_config=run_config,
)
Expand All @@ -212,8 +235,13 @@ def store_batch_result(

else:
with tracer.start_as_current_span("delete missing data from output"):
for k, res_dt in enumerate(self.output_dts):
del_idx = res_dt.meta.get_existing_idx(idx)
for k, output_spec in enumerate(self.output_specs):
res_dt = output_spec.dt
processed_idx = self._transform_idx_to_output_idx(idx, output_spec)
if processed_idx is None:
continue

del_idx = res_dt.meta.get_existing_idx(processed_idx)

res_dt.delete_by_idx(del_idx, run_config=run_config)

Expand Down Expand Up @@ -425,7 +453,7 @@ def build_compute(self, ds: DataStore, catalog: Catalog) -> list[ComputeStep]:
name=f"{self.func.__name__}",
func=self.func,
input_dts=[ComputeInput(dt=inp, join_type="full") for inp in input_dts],
output_dts=output_dts,
output_dts=[ComputeOutput(dt=out) for out in output_dts],
kwargs=self.kwargs,
transform_keys=self.transform_keys,
chunk_size=self.chunk_size,
Expand All @@ -440,8 +468,8 @@ def __init__(
ds: DataStore,
name: str,
func: DatatableBatchTransformFunc,
input_dts: list[ComputeInput],
output_dts: list[DataTable],
input_dts: Sequence[ComputeInput],
output_dts: Sequence[ComputeOutput],
kwargs: dict | None = None,
transform_keys: list[str] | None = None,
chunk_size: int = 1000,
Expand Down Expand Up @@ -479,7 +507,7 @@ def process_batch_dts(
class BatchTransform(PipelineStep):
func: BatchTransformFunc
inputs: list[PipelineInput]
outputs: list[TableOrName]
outputs: list[PipelineOutput]
chunk_size: int = 1000
kwargs: dict[str, Any] | None = None
transform_keys: list[str] | None = None
Expand All @@ -505,9 +533,18 @@ def pipeline_input_to_compute_input(self, ds: DataStore, catalog: Catalog, input
else:
return ComputeInput(dt=catalog.get_datatable(ds, input), join_type="full")

def pipeline_output_to_compute_output(self, ds: DataStore, catalog: Catalog, output: PipelineOutput) -> ComputeOutput:
if isinstance(output, OutputSpec):
return ComputeOutput(
dt=catalog.get_datatable(ds, output.table),
keys=output.keys,
)

return ComputeOutput(dt=catalog.get_datatable(ds, output))

def build_compute(self, ds: DataStore, catalog: Catalog) -> list[ComputeStep]:
input_dts = [self.pipeline_input_to_compute_input(ds, catalog, input) for input in self.inputs]
output_dts = [catalog.get_datatable(ds, name) for name in self.outputs]
output_dts = [self.pipeline_output_to_compute_output(ds, catalog, output) for output in self.outputs]

return [
BatchTransformStep(
Expand All @@ -534,8 +571,8 @@ def __init__(
ds: DataStore,
name: str,
func: BatchTransformFunc,
input_dts: list[ComputeInput],
output_dts: list[DataTable],
input_dts: Sequence[ComputeInput],
output_dts: Sequence[ComputeOutput],
kwargs: dict[str, Any] | None = None,
transform_keys: list[str] | None = None,
chunk_size: int = 1000,
Expand Down
14 changes: 14 additions & 0 deletions datapipe/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,20 @@ class Required(InputSpec):
PipelineInput = TableOrName | InputSpec


@dataclass
class OutputSpec:
table: TableOrName

# If provided, this dict tells how transform keys map to output primary keys.
#
# Example: {"post_id": "id"} means that cleanup for transform key post_id
# should be applied to output primary key id.
keys: dict[str, str] | None = None


PipelineOutput = TableOrName | OutputSpec


@dataclass
class ChangeList:
changes: dict[str, IndexDF] = field(default_factory=lambda: cast(dict[str, IndexDF], {}))
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies = [
"PyYAML>=5.3.1",
"cityhash>=0.4.2,<0.5",
"Pillow>=10.0.0,<12",
"tqdm-loggable>=0.2,<0.3",
"tqdm-loggable>=0.4.1,<0.5",
"traceback-with-variables>=2.0.4,<3",
"opentelemetry-api>=1.8.0,<2",
"opentelemetry-sdk>=1.8.0,<2",
Expand Down
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys

os.environ["SQLALCHEMY_WARN_20"] = "1"

Expand Down Expand Up @@ -48,7 +49,8 @@ def assert_df_equal(a: pd.DataFrame, b: pd.DataFrame) -> bool:
@pytest.fixture
def dbconn():
if os.environ.get("TEST_DB_ENV") == "sqlite":
DBCONNSTR = "sqlite+pysqlite3:///:memory:"
sqlite_driver = "sqlite" if sys.platform == "darwin" else "sqlite+pysqlite3"
DBCONNSTR = f"{sqlite_driver}:///:memory:"
DB_TEST_SCHEMA = None
else:
pg_host = os.getenv("POSTGRES_HOST", "localhost")
Expand Down
6 changes: 3 additions & 3 deletions tests/test_batch_transform_scheduling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pandas as pd
from sqlalchemy import Column, Integer

from datapipe.compute import ComputeInput
from datapipe.compute import ComputeInput, ComputeOutput
from datapipe.datatable import DataStore
from datapipe.step.batch_transform import BatchTransformStep
from datapipe.store.database import TableStoreDB
Expand Down Expand Up @@ -40,7 +40,7 @@ def id_func(df):
input_dts=[
ComputeInput(dt=tbl1, join_type="full"),
],
output_dts=[tbl2],
output_dts=[ComputeOutput(dt=tbl2)],
)

count, idx_gen = step.get_full_process_ids(ds)
Expand Down Expand Up @@ -150,7 +150,7 @@ def test_aux_input(dbconn) -> None:
ComputeInput(dt=tbl_items2, join_type="full"),
ComputeInput(dt=tbl_aux, join_type="full"),
],
output_dts=[tbl_out],
output_dts=[ComputeOutput(dt=tbl_out)],
transform_keys=["id"],
)

Expand Down
Loading
Loading