diff --git a/pyiceberg/table/inspect.py b/pyiceberg/table/inspect.py index bfe2fffa56..5da343ccb6 100644 --- a/pyiceberg/table/inspect.py +++ b/pyiceberg/table/inspect.py @@ -285,7 +285,9 @@ def partitions( ] ) - partition_record = self.tbl.metadata.specs_struct() + snapshot = self._get_snapshot(snapshot_id) + spec_ids = {manifest.partition_spec_id for manifest in snapshot.manifests(self.tbl.io)} + partition_record = self.tbl.metadata.specs_struct(spec_ids=spec_ids) has_partitions = len(partition_record.fields) > 0 if has_partitions: @@ -299,8 +301,6 @@ def partitions( table_schema = pa.unify_schemas([partitions_schema, table_schema]) - snapshot = self._get_snapshot(snapshot_id) - scan = DataScan( table_metadata=self.tbl.metadata, io=self.tbl.io, diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 8ae930375a..8a55f77b11 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -18,6 +18,7 @@ import datetime import uuid +from collections.abc import Iterable from copy import copy from typing import Annotated, Any, Literal @@ -262,18 +263,23 @@ def specs(self) -> dict[int, PartitionSpec]: """Return a dict the partition specs this table.""" return {spec.spec_id: spec for spec in self.partition_specs} - def specs_struct(self) -> StructType: - """Produce a struct of all the combined PartitionSpecs. + def specs_struct(self, spec_ids: Iterable[int] | None = None) -> StructType: + """Produce a struct of the combined PartitionSpecs. The partition fields should be optional: Partition fields may be added later, in which case not all files would have the result field, and it may be null. - :return: A StructType that represents all the combined PartitionSpecs of the table + Args: + spec_ids: Optional iterable of spec IDs to include. When not provided, + all table specs are used. + + :return: A StructType that represents the combined PartitionSpecs of the table """ specs = self.specs() + selected_specs = specs.values() if spec_ids is None else [specs[spec_id] for spec_id in spec_ids if spec_id in specs] # Collect all the fields - struct_fields = {field.field_id: field for spec in specs.values() for field in spec.fields} + struct_fields = {field.field_id: field for spec in selected_specs for field in spec.fields} schema = self.schema() diff --git a/tests/integration/test_inspect_table.py b/tests/integration/test_inspect_table.py index 4add18cf3f..ea0cca9bc5 100644 --- a/tests/integration/test_inspect_table.py +++ b/tests/integration/test_inspect_table.py @@ -18,6 +18,7 @@ import math from datetime import date, datetime +from typing import Any import pyarrow as pa import pytest @@ -208,9 +209,18 @@ def _inspect_files_asserts(df: pa.Table, spark_df: DataFrame) -> None: def _check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> None: lhs = df.to_pandas().sort_values("last_updated_at") rhs = spark_df.toPandas().sort_values("last_updated_at") + + def _normalize_partition(d: dict[str, Any]) -> dict[str, Any]: + return {k: v for k, v in d.items() if v is not None} + for column in df.column_names: for left, right in zip(lhs[column].to_list(), rhs[column].to_list(), strict=True): - assert left == right, f"Difference in column {column}: {left} != {right}" + if column == "partition": + assert _normalize_partition(left) == _normalize_partition(right), ( + f"Difference in column {column}: {left} != {right}" + ) + else: + assert left == right, f"Difference in column {column}: {left} != {right}" @pytest.mark.integration diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 869e60f4aa..677e16dab3 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -2588,6 +2588,36 @@ def test_inspect_partition_for_nested_field(catalog: InMemoryCatalog) -> None: assert {part["part"] for part in partitions} == {"data-a", "data-b"} +def test_inspect_partitions_respects_partition_evolution(catalog: InMemoryCatalog) -> None: + schema = Schema( + NestedField(id=1, name="dt", field_type=DateType(), required=False), + NestedField(id=2, name="category", field_type=StringType(), required=False), + ) + spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="dt")) + catalog.create_namespace("default") + table = catalog.create_table( + "default.test_inspect_partitions_respects_partition_evolution", schema=schema, partition_spec=spec + ) + + old_spec_id = table.spec().spec_id + old_data = [{"dt": date(2025, 1, 1), "category": "old"}] + table.append(pa.Table.from_pylist(old_data, schema=table.schema().as_arrow())) + + table.update_spec().add_identity("category").commit() + new_spec_id = table.spec().spec_id + assert new_spec_id != old_spec_id + + partitions_table = table.inspect.partitions() + partitions = partitions_table["partition"].to_pylist() + assert all("category" not in partition for partition in partitions) + + new_data = [{"dt": date(2025, 1, 2), "category": "new"}] + table.append(pa.Table.from_pylist(new_data, schema=table.schema().as_arrow())) + + partitions_table = table.inspect.partitions() + assert set(partitions_table["spec_id"].to_pylist()) == {old_spec_id, new_spec_id} + + def test_identity_partition_on_multi_columns() -> None: test_pa_schema = pa.schema([("born_year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) test_schema = Schema(