Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyiceberg/expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ def __repr__(self) -> str:
def ref(self) -> BoundReference[L]:
return self

def __hash__(self) -> int:
"""Return hash value of the BoundReference class."""
return hash(str(self))


class UnboundTerm(Term[Any], Unbound[BoundTerm[L]], ABC):
"""Represents an unbound term."""
Expand Down
161 changes: 156 additions & 5 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@

from pyiceberg.conversions import to_bytes
from pyiceberg.exceptions import ResolveError
from pyiceberg.expressions import (
AlwaysTrue,
BooleanExpression,
BoundTerm,
)
from pyiceberg.expressions import AlwaysTrue, BooleanExpression, BoundIsNaN, BoundIsNull, BoundReference, BoundTerm, Not, Or
from pyiceberg.expressions.literals import Literal
from pyiceberg.expressions.visitors import (
BoundBooleanExpressionVisitor,
Expand Down Expand Up @@ -638,10 +634,165 @@ def visit_or(self, left_result: pc.Expression, right_result: pc.Expression) -> p
return left_result | right_result


class _NullNaNUnmentionedTermsCollector(BoundBooleanExpressionVisitor[Any]):
Comment thread
jqin61 marked this conversation as resolved.
Outdated
# BoundTerms which have either is_null or is_not_null appearing at least once in the boolean expr.
is_null_or_not_bound_terms: set[BoundTerm[Any]]
# The remaining BoundTerms appearing in the boolean expr.
null_unmentioned_bound_terms: set[BoundTerm[Any]]
# BoundTerms which have either is_nan or is_not_nan appearing at least once in the boolean expr.
is_nan_or_not_bound_terms: set[BoundTerm[Any]]
# The remaining BoundTerms appearing in the boolean expr.
nan_unmentioned_bound_terms: set[BoundTerm[Any]]

def __init__(self) -> None:
self.is_null_or_not_bound_terms = set()
self.null_unmentioned_bound_terms = set()
self.is_nan_or_not_bound_terms = set()
self.nan_unmentioned_bound_terms = set()
super().__init__()
Comment thread
jqin61 marked this conversation as resolved.
Outdated

def _handle_explicit_is_null_or_not(self, term: BoundTerm[Any]) -> None:
"""Handle the predicate case where either is_null or is_not_null is included."""
if term in self.null_unmentioned_bound_terms:
self.null_unmentioned_bound_terms.remove(term)
self.is_null_or_not_bound_terms.add(term)

def _handle_null_unmentioned(self, term: BoundTerm[Any]) -> None:
"""Handle the predicate case where neither is_null or is_not_null is included."""
if term not in self.is_null_or_not_bound_terms:
self.null_unmentioned_bound_terms.add(term)

def _handle_explicit_is_nan_or_not(self, term: BoundTerm[Any]) -> None:
"""Handle the predicate case where either is_nan or is_not_nan is included."""
if term in self.nan_unmentioned_bound_terms:
self.nan_unmentioned_bound_terms.remove(term)
self.is_nan_or_not_bound_terms.add(term)

def _handle_nan_unmentioned(self, term: BoundTerm[Any]) -> None:
"""Handle the predicate case where neither is_nan or is_not_nan is included."""
if term not in self.is_nan_or_not_bound_terms:
self.nan_unmentioned_bound_terms.add(term)

def visit_in(self, term: BoundTerm[pc.Expression], literals: Set[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_not_in(self, term: BoundTerm[pc.Expression], literals: Set[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_is_nan(self, term: BoundTerm[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_explicit_is_nan_or_not(term)

def visit_not_nan(self, term: BoundTerm[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_explicit_is_nan_or_not(term)

def visit_is_null(self, term: BoundTerm[Any]) -> None:
self._handle_explicit_is_null_or_not(term)
self._handle_nan_unmentioned(term)

def visit_not_null(self, term: BoundTerm[Any]) -> None:
self._handle_explicit_is_null_or_not(term)
self._handle_nan_unmentioned(term)

def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_true(self) -> None:
return

def visit_false(self) -> None:
return

def visit_not(self, child_result: pc.Expression) -> None:
return

def visit_and(self, left_result: pc.Expression, right_result: pc.Expression) -> None:
return

def visit_or(self, left_result: pc.Expression, right_result: pc.Expression) -> None:
Comment thread
jqin61 marked this conversation as resolved.
Outdated
return

def collect(
self,
expr: BooleanExpression,
) -> None:
"""Collect the bound references categorized by having at least one is_null or is_not_null in the expr and the remaining."""
boolean_expression_visit(expr, self)


def expression_to_pyarrow(expr: BooleanExpression) -> pc.Expression:
return boolean_expression_visit(expr, _ConvertToArrowExpression())


def _expression_to_complementary_pyarrow(expr: BooleanExpression) -> pc.Expression:
"""Complementary filter conversion function of expression_to_pyarrow.

Could not use expression_to_pyarrow(Not(expr)) to achieve this complementary effect because ~ in pyarrow.compute.Expression does not handle null.
"""
collector = _NullNaNUnmentionedTermsCollector()
collector.collect(expr)

def _downcast_term_to_reference(bound_terms: Set[BoundTerm[Any]]) -> Set[BoundReference[Any]]:
Comment thread
jqin61 marked this conversation as resolved.
Outdated
"""Handle mypy check for BoundTerm -> BoundReference."""
bound_refs: Set[BoundReference[Any]] = set()
for t in bound_terms:
if not isinstance(t, BoundReference):
raise ValueError("Collected Bound Term that is not reference.")
else:
bound_refs.add(t)
return bound_refs

null_unmentioned_bound_refs: Set[BoundReference[Any]] = _downcast_term_to_reference(collector.null_unmentioned_bound_terms)
nan_unmentioned_bound_refs: Set[BoundReference[Any]] = _downcast_term_to_reference(collector.nan_unmentioned_bound_terms)

# Convert the set of references to a sorted list so that layout of the expression to build is deterministic.
null_unmentioned_bound_refs_sorted: List[BoundReference[Any]] = sorted(
null_unmentioned_bound_refs, key=lambda ref: ref.field.name
)
nan_unmentioned_bound_refs_sorted: List[BoundReference[Any]] = sorted(
nan_unmentioned_bound_refs, key=lambda ref: ref.field.name
)

preserve_expr: BooleanExpression = Not(expr)
for term in null_unmentioned_bound_refs_sorted:
preserve_expr = Or(preserve_expr, BoundIsNull(term=term))
for term in nan_unmentioned_bound_refs_sorted:
preserve_expr = Or(preserve_expr, BoundIsNaN(term=term))
return expression_to_pyarrow(preserve_expr)


@lru_cache
def _get_file_format(file_format: FileFormat, **kwargs: Dict[str, Any]) -> ds.FileFormat:
if file_format == FileFormat.PARQUET:
Expand Down
9 changes: 6 additions & 3 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
And,
BooleanExpression,
EqualTo,
Not,
Or,
Reference,
)
Expand Down Expand Up @@ -576,7 +575,11 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti
delete_filter: A boolean expression to delete rows from a table
snapshot_properties: Custom properties to be added to the snapshot summary
"""
from pyiceberg.io.pyarrow import _dataframe_to_data_files, expression_to_pyarrow, project_table
from pyiceberg.io.pyarrow import (
_dataframe_to_data_files,
_expression_to_complementary_pyarrow,
project_table,
)

if (
self.table_metadata.properties.get(TableProperties.DELETE_MODE, TableProperties.DELETE_MODE_DEFAULT)
Expand All @@ -593,7 +596,7 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti
# Check if there are any files that require an actual rewrite of a data file
if delete_snapshot.rewrites_needed is True:
bound_delete_filter = bind(self._table.schema(), delete_filter, case_sensitive=True)
preserve_row_filter = expression_to_pyarrow(Not(bound_delete_filter))
preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter)

files = self._scan(row_filter=delete_filter).plan_files()

Expand Down
141 changes: 140 additions & 1 deletion tests/integration/test_deletes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pyiceberg.manifest import ManifestEntryStatus
from pyiceberg.schema import Schema
from pyiceberg.table.snapshots import Operation, Summary
from pyiceberg.types import IntegerType, NestedField
from pyiceberg.types import FloatType, IntegerType, NestedField


def run_spark_commands(spark: SparkSession, sqls: List[str]) -> None:
Expand Down Expand Up @@ -105,6 +105,41 @@ def test_partitioned_table_rewrite(spark: SparkSession, session_catalog: RestCat
assert tbl.scan().to_arrow().to_pydict() == {"number_partitioned": [11, 10], "number": [30, 30]}


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_rewrite_partitioned_table_with_null(spark: SparkSession, session_catalog: RestCatalog, format_version: int) -> None:
identifier = "default.table_partitioned_delete"

run_spark_commands(
spark,
[
f"DROP TABLE IF EXISTS {identifier}",
f"""
CREATE TABLE {identifier} (
number_partitioned int,
number int
)
USING iceberg
PARTITIONED BY (number_partitioned)
TBLPROPERTIES('format-version' = {format_version})
""",
f"""
INSERT INTO {identifier} VALUES (10, 20), (10, 30)
""",
f"""
INSERT INTO {identifier} VALUES (11, 20), (11, NULL)
""",
],
)

tbl = session_catalog.load_table(identifier)
tbl.delete(EqualTo("number", 20))

# We don't delete a whole partition, so there is only a overwrite
assert [snapshot.summary.operation.value for snapshot in tbl.snapshots()] == ["append", "append", "overwrite"]
assert tbl.scan().to_arrow().to_pydict() == {"number_partitioned": [11, 10], "number": [None, 30]}


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_partitioned_table_no_match(spark: SparkSession, session_catalog: RestCatalog, format_version: int) -> None:
Expand Down Expand Up @@ -417,3 +452,107 @@ def test_delete_truncate(session_catalog: RestCatalog) -> None:
assert len(entries) == 1

assert entries[0].status == ManifestEntryStatus.DELETED


@pytest.mark.integration
def test_delete_overwrite_table_with_null(session_catalog: RestCatalog) -> None:
arrow_schema = pa.schema([pa.field("ints", pa.int32())])
arrow_tbl = pa.Table.from_pylist(
[{"ints": 1}, {"ints": 2}, {"ints": None}],
schema=arrow_schema,
)

iceberg_schema = Schema(NestedField(1, "ints", IntegerType()))

tbl_identifier = "default.test_delete_overwrite_with_null"

try:
session_catalog.drop_table(tbl_identifier)
except NoSuchTableError:
pass

tbl = session_catalog.create_table(tbl_identifier, iceberg_schema)
tbl.append(arrow_tbl)
Comment thread
jqin61 marked this conversation as resolved.

assert [snapshot.summary.operation for snapshot in tbl.snapshots()] == [Operation.APPEND]

arrow_tbl_overwrite = pa.Table.from_pylist(
[
{"ints": 3},
{"ints": 4},
],
schema=arrow_schema,
)
tbl.overwrite(arrow_tbl_overwrite, "ints == 2") # Should rewrite one file

assert [snapshot.summary.operation for snapshot in tbl.snapshots()] == [
Operation.APPEND,
Operation.OVERWRITE,
Operation.APPEND,
]

assert tbl.scan().to_arrow()["ints"].to_pylist() == [3, 4, 1, None]


@pytest.mark.integration
def test_delete_overwrite_table_with_nan(session_catalog: RestCatalog) -> None:
arrow_schema = pa.schema([pa.field("floats", pa.float32())])

# Create Arrow Table with NaN values
data = [pa.array([1.0, float("nan"), 2.0], type=pa.float32())]
arrow_tbl = pa.Table.from_arrays(
data,
schema=arrow_schema,
)

iceberg_schema = Schema(NestedField(1, "floats", FloatType()))

tbl_identifier = "default.test_delete_overwrite_with_nan"

try:
session_catalog.drop_table(tbl_identifier)
except NoSuchTableError:
pass

tbl = session_catalog.create_table(tbl_identifier, iceberg_schema)
tbl.append(arrow_tbl)

assert [snapshot.summary.operation for snapshot in tbl.snapshots()] == [Operation.APPEND]

arrow_tbl_overwrite = pa.Table.from_pylist(
[
{"floats": 3.0},
{"floats": 4.0},
],
schema=arrow_schema,
)
"""
We want to test the _expression_to_complementary_pyarrow function can generate a correct complimentary filter
for selecting records to remain in the new overwritten file.
Compared with test_delete_overwrite_table_with_null which tests rows with null cells,
nan testing is faced with a more tricky issue:
A filter of (field == value) will not include cells of nan but col != val will.
(Interestingly, neither == or != will include null)

This means if we set the test case as floats == 2.0 (equal predicate as in test_delete_overwrite_table_with_null),
test will pass even without the logic under test
in _NullNaNUnmentionedTermsCollector (a helper of _expression_to_complementary_pyarrow
to handle revert of iceberg expression of is_null/not_null/is_nan/not_nan).
Instead, we test the filter of !=, so that the revert is == which exposes the issue.
"""
tbl.overwrite(arrow_tbl_overwrite, "floats != 2.0") # Should rewrite one file

assert [snapshot.summary.operation for snapshot in tbl.snapshots()] == [
Operation.APPEND,
Operation.OVERWRITE,
Operation.APPEND,
]

result = tbl.scan().to_arrow()["floats"].to_pylist()

from math import isnan

assert any(isnan(e) for e in result)
assert 2.0 in result
assert 3.0 in result
assert 4.0 in result
1 change: 0 additions & 1 deletion tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=protected-access,unused-argument,redefined-outer-name

Comment thread
jqin61 marked this conversation as resolved.
import os
import tempfile
import uuid
Expand Down
Loading