From 7e1d7592bd5963c4d7cccc7afb8e627d63989e0e Mon Sep 17 00:00:00 2001 From: kaori-seasons Date: Wed, 3 Dec 2025 13:55:05 +0800 Subject: [PATCH 1/2] enhance: support primary key collapse --- .../pypaimon/tests/reader_primary_key_test.py | 10 +-- .../write/writer/key_value_data_writer.py | 69 ++++++++++++++++++- 2 files changed, 73 insertions(+), 6 deletions(-) diff --git a/paimon-python/pypaimon/tests/reader_primary_key_test.py b/paimon-python/pypaimon/tests/reader_primary_key_test.py index b992595fc9b5..60c5cde41b5d 100644 --- a/paimon-python/pypaimon/tests/reader_primary_key_test.py +++ b/paimon-python/pypaimon/tests/reader_primary_key_test.py @@ -139,12 +139,12 @@ def test_pk_multi_write_once_commit(self): read_builder = table.new_read_builder() actual = self._read_test_table(read_builder).sort_by('user_id') - # TODO support pk merge feature when multiple write + # Primary key deduplication keeps the latest record for each primary key expected = pa.Table.from_pydict({ - 'user_id': [1, 2, 2, 3, 4, 5, 7, 8], - 'item_id': [1001, 1002, 1002, 1003, 1004, 1005, 1007, 1008], - 'behavior': ['a', 'b', 'b-new', 'c', None, 'e', 'g', 'h'], - 'dt': ['p1', 'p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2'], + 'user_id': [1, 2, 3, 4, 5, 7, 8], + 'item_id': [1001, 1002, 1003, 1004, 1005, 1007, 1008], + 'behavior': ['a', 'b-new', 'c', None, 'e', 'g', 'h'], + 'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2'], }, schema=self.pa_schema) self.assertEqual(actual, expected) diff --git a/paimon-python/pypaimon/write/writer/key_value_data_writer.py b/paimon-python/pypaimon/write/writer/key_value_data_writer.py index fb929710e8b2..e0cea7dee5f7 100644 --- a/paimon-python/pypaimon/write/writer/key_value_data_writer.py +++ b/paimon-python/pypaimon/write/writer/key_value_data_writer.py @@ -30,8 +30,24 @@ def _process_data(self, data: pa.RecordBatch) -> pa.Table: return pa.Table.from_batches([self._sort_by_primary_key(enhanced_data)]) def _merge_data(self, existing_data: pa.Table, new_data: pa.Table) -> pa.Table: + """Merge existing data with new data and deduplicate by primary key. + + The merge process: + 1. Concatenate existing and new data + 2. Sort by primary key fields and sequence number + 3. Deduplicate by primary key, keeping the record with maximum sequence number + + Args: + existing_data: Previously buffered data + new_data: Newly written data to be merged + + Returns: + Deduplicated and sorted table + """ combined = pa.concat_tables([existing_data, new_data]) - return self._sort_by_primary_key(combined) + sorted_data = self._sort_by_primary_key(combined) + deduplicated_data = self._deduplicate_by_primary_key(sorted_data) + return deduplicated_data def _add_system_fields(self, data: pa.RecordBatch) -> pa.RecordBatch: """Add system fields: _KEY_{pk_key}, _SEQUENCE_NUMBER, _VALUE_KIND.""" @@ -53,7 +69,58 @@ def _add_system_fields(self, data: pa.RecordBatch) -> pa.RecordBatch: return enhanced_table + def _deduplicate_by_primary_key(self, data: pa.RecordBatch) -> pa.RecordBatch: + """Deduplicate data by primary key, keeping the record with maximum sequence number. + + Prerequisite: data is sorted by (primary_keys, _SEQUENCE_NUMBER) + + Algorithm: Since data is sorted by primary key and then by sequence number in ascending + order, for each primary key group, the last occurrence has the maximum sequence number. + We iterate through and track the last index of each primary key, then keep only those rows. + + Args: + data: Sorted record batch with system fields (_KEY_*, _SEQUENCE_NUMBER, _VALUE_KIND) + + Returns: + Deduplicated record batch with only the latest record per primary key + """ + if data.num_rows <= 1: + return data + + # Build primary key column names (prefixed with _KEY_) + pk_columns = [f'_KEY_{pk}' for pk in self.trimmed_primary_key] + + # First pass: find the last index for each primary key + last_index_for_key = {} + for i in range(data.num_rows): + current_key = tuple( + data.column(col)[i].as_py() for col in pk_columns + ) + last_index_for_key[current_key] = i + + # Second pass: collect indices to keep (maintaining original order) + indices_to_keep = [] + for i in range(data.num_rows): + current_key = tuple( + data.column(col)[i].as_py() for col in pk_columns + ) + # Only keep this row if it's the last occurrence of this primary key + if i == last_index_for_key[current_key]: + indices_to_keep.append(i) + + # Extract kept rows using PyArrow's take operation + indices_array = pa.array(indices_to_keep, type=pa.int64()) + return data.take(indices_array) + def _sort_by_primary_key(self, data: pa.RecordBatch) -> pa.RecordBatch: + """Sort data by primary key fields and sequence number. + + Args: + data: Record batch to sort + + Returns: + Sorted record batch + """ sort_keys = [(key, 'ascending') for key in self.trimmed_primary_key] if '_SEQUENCE_NUMBER' in data.column_names: sort_keys.append(('_SEQUENCE_NUMBER', 'ascending')) From dbc27244d9ca599d9f868ca8d1b1509e0ed93e40 Mon Sep 17 00:00:00 2001 From: kaori-seasons Date: Wed, 3 Dec 2025 14:23:06 +0800 Subject: [PATCH 2/2] chore: fix checkstyle --- .../write/writer/key_value_data_writer.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/paimon-python/pypaimon/write/writer/key_value_data_writer.py b/paimon-python/pypaimon/write/writer/key_value_data_writer.py index d6c219e3540a..dae9ceed16bb 100644 --- a/paimon-python/pypaimon/write/writer/key_value_data_writer.py +++ b/paimon-python/pypaimon/write/writer/key_value_data_writer.py @@ -31,16 +31,16 @@ def _process_data(self, data: pa.RecordBatch) -> pa.Table: def _merge_data(self, existing_data: pa.Table, new_data: pa.Table) -> pa.Table: """Merge existing data with new data and deduplicate by primary key. - + The merge process: 1. Concatenate existing and new data 2. Sort by primary key fields and sequence number 3. Deduplicate by primary key, keeping the record with maximum sequence number - + Args: existing_data: Previously buffered data new_data: Newly written data to be merged - + Returns: Deduplicated and sorted table """ @@ -71,25 +71,25 @@ def _add_system_fields(self, data: pa.RecordBatch) -> pa.RecordBatch: def _deduplicate_by_primary_key(self, data: pa.RecordBatch) -> pa.RecordBatch: """Deduplicate data by primary key, keeping the record with maximum sequence number. - + Prerequisite: data is sorted by (primary_keys, _SEQUENCE_NUMBER) - + Algorithm: Since data is sorted by primary key and then by sequence number in ascending order, for each primary key group, the last occurrence has the maximum sequence number. We iterate through and track the last index of each primary key, then keep only those rows. - + Args: data: Sorted record batch with system fields (_KEY_*, _SEQUENCE_NUMBER, _VALUE_KIND) - + Returns: Deduplicated record batch with only the latest record per primary key """ if data.num_rows <= 1: return data - + # Build primary key column names (prefixed with _KEY_) pk_columns = [f'_KEY_{pk}' for pk in self.trimmed_primary_keys] - + # First pass: find the last index for each primary key last_index_for_key = {} for i in range(data.num_rows): @@ -97,7 +97,7 @@ def _deduplicate_by_primary_key(self, data: pa.RecordBatch) -> pa.RecordBatch: data.column(col)[i].as_py() for col in pk_columns ) last_index_for_key[current_key] = i - + # Second pass: collect indices to keep (maintaining original order) indices_to_keep = [] for i in range(data.num_rows): @@ -107,17 +107,17 @@ def _deduplicate_by_primary_key(self, data: pa.RecordBatch) -> pa.RecordBatch: # Only keep this row if it's the last occurrence of this primary key if i == last_index_for_key[current_key]: indices_to_keep.append(i) - + # Extract kept rows using PyArrow's take operation indices_array = pa.array(indices_to_keep, type=pa.int64()) return data.take(indices_array) def _sort_by_primary_key(self, data: pa.RecordBatch) -> pa.RecordBatch: """Sort data by primary key fields and sequence number. - + Args: data: Record batch to sort - + Returns: Sorted record batch """