Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions paimon-python/pypaimon/tests/reader_primary_key_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,12 @@ def test_pk_multi_write_once_commit(self):

read_builder = table.new_read_builder()
actual = self._read_test_table(read_builder).sort_by('user_id')
# TODO support pk merge feature when multiple write
# Primary key deduplication keeps the latest record for each primary key
expected = pa.Table.from_pydict({
'user_id': [1, 2, 2, 3, 4, 5, 7, 8],
'item_id': [1001, 1002, 1002, 1003, 1004, 1005, 1007, 1008],
'behavior': ['a', 'b', 'b-new', 'c', None, 'e', 'g', 'h'],
'dt': ['p1', 'p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2'],
'user_id': [1, 2, 3, 4, 5, 7, 8],
'item_id': [1001, 1002, 1003, 1004, 1005, 1007, 1008],
'behavior': ['a', 'b-new', 'c', None, 'e', 'g', 'h'],
'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2'],
}, schema=self.pa_schema)
self.assertEqual(actual, expected)

Expand Down
69 changes: 68 additions & 1 deletion paimon-python/pypaimon/write/writer/key_value_data_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,24 @@ def _process_data(self, data: pa.RecordBatch) -> pa.Table:
return pa.Table.from_batches([self._sort_by_primary_key(enhanced_data)])

def _merge_data(self, existing_data: pa.Table, new_data: pa.Table) -> pa.Table:
"""Merge existing data with new data and deduplicate by primary key.

The merge process:
1. Concatenate existing and new data
2. Sort by primary key fields and sequence number
3. Deduplicate by primary key, keeping the record with maximum sequence number

Args:
existing_data: Previously buffered data
new_data: Newly written data to be merged

Returns:
Deduplicated and sorted table
"""
combined = pa.concat_tables([existing_data, new_data])
return self._sort_by_primary_key(combined)
sorted_data = self._sort_by_primary_key(combined)
deduplicated_data = self._deduplicate_by_primary_key(sorted_data)
return deduplicated_data

def _add_system_fields(self, data: pa.RecordBatch) -> pa.RecordBatch:
"""Add system fields: _KEY_{pk_key}, _SEQUENCE_NUMBER, _VALUE_KIND."""
Expand All @@ -53,7 +69,58 @@ def _add_system_fields(self, data: pa.RecordBatch) -> pa.RecordBatch:

return enhanced_table

def _deduplicate_by_primary_key(self, data: pa.RecordBatch) -> pa.RecordBatch:
"""Deduplicate data by primary key, keeping the record with maximum sequence number.

Prerequisite: data is sorted by (primary_keys, _SEQUENCE_NUMBER)

Algorithm: Since data is sorted by primary key and then by sequence number in ascending
order, for each primary key group, the last occurrence has the maximum sequence number.
We iterate through and track the last index of each primary key, then keep only those rows.

Args:
data: Sorted record batch with system fields (_KEY_*, _SEQUENCE_NUMBER, _VALUE_KIND)

Returns:
Deduplicated record batch with only the latest record per primary key
"""
if data.num_rows <= 1:
return data

# Build primary key column names (prefixed with _KEY_)
pk_columns = [f'_KEY_{pk}' for pk in self.trimmed_primary_keys]

# First pass: find the last index for each primary key
last_index_for_key = {}
for i in range(data.num_rows):
current_key = tuple(
data.column(col)[i].as_py() for col in pk_columns
)
last_index_for_key[current_key] = i

# Second pass: collect indices to keep (maintaining original order)
indices_to_keep = []
for i in range(data.num_rows):
current_key = tuple(
data.column(col)[i].as_py() for col in pk_columns
)
# Only keep this row if it's the last occurrence of this primary key
if i == last_index_for_key[current_key]:
indices_to_keep.append(i)

# Extract kept rows using PyArrow's take operation
indices_array = pa.array(indices_to_keep, type=pa.int64())
return data.take(indices_array)

def _sort_by_primary_key(self, data: pa.RecordBatch) -> pa.RecordBatch:
"""Sort data by primary key fields and sequence number.

Args:
data: Record batch to sort

Returns:
Sorted record batch
"""
sort_keys = [(key, 'ascending') for key in self.trimmed_primary_keys]
if '_SEQUENCE_NUMBER' in data.column_names:
sort_keys.append(('_SEQUENCE_NUMBER', 'ascending'))
Expand Down