diff --git a/paimon-python/pypaimon/tests/reader_primary_key_test.py b/paimon-python/pypaimon/tests/reader_primary_key_test.py index bcbc94bd68e7..75bd5a81da88 100644 --- a/paimon-python/pypaimon/tests/reader_primary_key_test.py +++ b/paimon-python/pypaimon/tests/reader_primary_key_test.py @@ -141,12 +141,12 @@ def test_pk_multi_write_once_commit(self): read_builder = table.new_read_builder() actual = self._read_test_table(read_builder).sort_by('user_id') - # TODO support pk merge feature when multiple write + # Primary key deduplication keeps the latest record for each primary key expected = pa.Table.from_pydict({ - 'user_id': [1, 2, 2, 3, 4, 5, 7, 8], - 'item_id': [1001, 1002, 1002, 1003, 1004, 1005, 1007, 1008], - 'behavior': ['a', 'b', 'b-new', 'c', None, 'e', 'g', 'h'], - 'dt': ['p1', 'p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2'], + 'user_id': [1, 2, 3, 4, 5, 7, 8], + 'item_id': [1001, 1002, 1003, 1004, 1005, 1007, 1008], + 'behavior': ['a', 'b-new', 'c', None, 'e', 'g', 'h'], + 'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2'], }, schema=self.pa_schema) self.assertEqual(actual, expected) diff --git a/paimon-python/pypaimon/write/writer/key_value_data_writer.py b/paimon-python/pypaimon/write/writer/key_value_data_writer.py index 05cad9bca95e..dae9ceed16bb 100644 --- a/paimon-python/pypaimon/write/writer/key_value_data_writer.py +++ b/paimon-python/pypaimon/write/writer/key_value_data_writer.py @@ -30,8 +30,24 @@ def _process_data(self, data: pa.RecordBatch) -> pa.Table: return pa.Table.from_batches([self._sort_by_primary_key(enhanced_data)]) def _merge_data(self, existing_data: pa.Table, new_data: pa.Table) -> pa.Table: + """Merge existing data with new data and deduplicate by primary key. + + The merge process: + 1. Concatenate existing and new data + 2. Sort by primary key fields and sequence number + 3. Deduplicate by primary key, keeping the record with maximum sequence number + + Args: + existing_data: Previously buffered data + new_data: Newly written data to be merged + + Returns: + Deduplicated and sorted table + """ combined = pa.concat_tables([existing_data, new_data]) - return self._sort_by_primary_key(combined) + sorted_data = self._sort_by_primary_key(combined) + deduplicated_data = self._deduplicate_by_primary_key(sorted_data) + return deduplicated_data def _add_system_fields(self, data: pa.RecordBatch) -> pa.RecordBatch: """Add system fields: _KEY_{pk_key}, _SEQUENCE_NUMBER, _VALUE_KIND.""" @@ -53,7 +69,58 @@ def _add_system_fields(self, data: pa.RecordBatch) -> pa.RecordBatch: return enhanced_table + def _deduplicate_by_primary_key(self, data: pa.RecordBatch) -> pa.RecordBatch: + """Deduplicate data by primary key, keeping the record with maximum sequence number. + + Prerequisite: data is sorted by (primary_keys, _SEQUENCE_NUMBER) + + Algorithm: Since data is sorted by primary key and then by sequence number in ascending + order, for each primary key group, the last occurrence has the maximum sequence number. + We iterate through and track the last index of each primary key, then keep only those rows. + + Args: + data: Sorted record batch with system fields (_KEY_*, _SEQUENCE_NUMBER, _VALUE_KIND) + + Returns: + Deduplicated record batch with only the latest record per primary key + """ + if data.num_rows <= 1: + return data + + # Build primary key column names (prefixed with _KEY_) + pk_columns = [f'_KEY_{pk}' for pk in self.trimmed_primary_keys] + + # First pass: find the last index for each primary key + last_index_for_key = {} + for i in range(data.num_rows): + current_key = tuple( + data.column(col)[i].as_py() for col in pk_columns + ) + last_index_for_key[current_key] = i + + # Second pass: collect indices to keep (maintaining original order) + indices_to_keep = [] + for i in range(data.num_rows): + current_key = tuple( + data.column(col)[i].as_py() for col in pk_columns + ) + # Only keep this row if it's the last occurrence of this primary key + if i == last_index_for_key[current_key]: + indices_to_keep.append(i) + + # Extract kept rows using PyArrow's take operation + indices_array = pa.array(indices_to_keep, type=pa.int64()) + return data.take(indices_array) + def _sort_by_primary_key(self, data: pa.RecordBatch) -> pa.RecordBatch: + """Sort data by primary key fields and sequence number. + + Args: + data: Record batch to sort + + Returns: + Sorted record batch + """ sort_keys = [(key, 'ascending') for key in self.trimmed_primary_keys] if '_SEQUENCE_NUMBER' in data.column_names: sort_keys.append(('_SEQUENCE_NUMBER', 'ascending'))