From 0e52290e125dc3459ba816202cc33ec767139982 Mon Sep 17 00:00:00 2001 From: Danielshih Date: Wed, 25 Feb 2026 03:34:08 +0000 Subject: [PATCH 1/4] Implement ColumnValue and RowData types for PostgreSQL logical replication - Introduced `ColumnValue` enum to represent PostgreSQL column data as either `Null`, `Text`, or `Binary`. - Added methods for encoding and decoding `ColumnValue` to/from a binary wire format. - Implemented `RowData` struct to hold ordered pairs of column names and values, supporting efficient serialization and deserialization. - Included hex encoding and decoding utilities for binary data representation. - Added comprehensive tests for both `ColumnValue` and `RowData`, covering serialization, deserialization, and various edge cases. --- .github/workflows/ci.yml | 5 + Cargo.lock | 4 +- Cargo.toml | 12 +- benches/columnvalue_vs_json.rs | 328 +++++++ benches/rowdata_vs_hashmap.rs | 279 ------ examples/basic-streaming/Cargo.lock | 30 +- integration-tests/complex_types.rs | 1226 +++++++++++++++++++++++++ src/column_value.rs | 990 ++++++++++++++++++++ src/lib.rs | 6 +- src/protocol.rs | 86 +- src/stream.rs | 144 +-- src/types.rs | 1310 ++++++++++++++++++++++----- 12 files changed, 3744 insertions(+), 676 deletions(-) create mode 100644 benches/columnvalue_vs_json.rs delete mode 100644 benches/rowdata_vs_hashmap.rs create mode 100644 integration-tests/complex_types.rs create mode 100644 src/column_value.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 899ba56..1f0fced 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -237,6 +237,11 @@ jobs: cargo test --test safe_transaction_consumer -- --ignored --nocapture --test-threads=1 + PGPASSWORD=postgres psql -h localhost -U postgres -d test_walstream \ + -c "SELECT pg_drop_replication_slot(slot_name) FROM pg_replication_slots WHERE slot_type = 'logical';" || true + + cargo test --test complex_types -- --ignored --nocapture --test-threads=1 + publish: name: Publish to crates.io runs-on: ubuntu-latest diff --git a/Cargo.lock b/Cargo.lock index a48cc3e..86bfdef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -126,9 +126,9 @@ checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" [[package]] name = "chrono" -version = "0.4.43" +version = "0.4.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fac4744fb15ae8337dc853fee7fb3f4e48c0fbaa23d0afe49c447b4fab126118" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" dependencies = [ "iana-time-zone", "js-sys", diff --git a/Cargo.toml b/Cargo.toml index 6140314..035b953 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,8 +16,7 @@ categories = ["database", "parsing", "network-programming"] tokio = { version = "1.49.0", features = ["io-util", "net", "time", "macros"] } tokio-util = { version = "0.7.18", features = ["compat"] } serde = { version = "1.0.228", features = ["derive", "rc"] } -serde_json = "1.0.149" -chrono = { version = "0.4.43", features = ["serde"] } +chrono = { version = "0.4.44", features = ["serde"] } bytes = "1.11.1" tracing = "0.1.44" libpq-sys = "0.8" @@ -27,8 +26,9 @@ thiserror = "2.0.18" default = [] [dev-dependencies] -tokio = { version = "1.47.2", features = ["full"] } +tokio = { version = "1.49.0", features = ["full"] } criterion = { version = "0.8.2", features = ["html_reports"] } +serde_json = "1.0.149" [[test]] name = "snapshot_export" @@ -42,6 +42,10 @@ path = "integration-tests/rate_limited_streaming.rs" name = "safe_transaction_consumer" path = "integration-tests/safe_transaction_consumer.rs" +[[test]] +name = "complex_types" +path = "integration-tests/complex_types.rs" + [[bench]] -name = "rowdata_vs_hashmap" +name = "columnvalue_vs_json" harness = false diff --git a/benches/columnvalue_vs_json.rs b/benches/columnvalue_vs_json.rs new file mode 100644 index 0000000..520011b --- /dev/null +++ b/benches/columnvalue_vs_json.rs @@ -0,0 +1,328 @@ +//! Benchmark: JSON serialization (serde_json) vs Binary serialization (ColumnValue encode/decode) +//! +//! Measures ChangeEvent performance across two serialization strategies: +//! +//! 1. **JSON (serde_json)**: `serde_json::to_vec` / `serde_json::from_slice` +//! 2. **Binary (ColumnValue)**: `ChangeEvent::encode` / `ChangeEvent::decode` +//! +//! Benchmark groups: +//! - `construct` — Build event: HashMap vs RowData,ColumnValue> +//! - `serialize` — Encode event to bytes: serde_json vs binary +//! - `deserialize` — Decode bytes back to event: serde_json vs binary +//! - `round_trip` — Full encode → decode cycle +//! - `payload_size` — Output size comparison (printed, not timed) +//! - `pipeline` — Realistic CDC: construct → clone → lookup → serialize +//! +//! Run: +//! cargo bench --bench columnvalue_vs_json + +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use pg_walstream::types::{ChangeEvent, ColumnValue, EventType, Lsn, RowData}; +use serde_json::{self, Value}; +use std::collections::HashMap; +use std::hint::black_box; +use std::sync::Arc; + +/// Old-style event: HashMap (pre-ColumnValue approach). +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct OldChangeEvent { + schema: String, + table: String, + relation_oid: u32, + data: HashMap, + lsn: u64, +} + +/// Build an old-style event using HashMap + serde_json::Value. +fn build_old_event(n_columns: usize) -> OldChangeEvent { + let mut data = HashMap::with_capacity(n_columns); + for i in 0..n_columns { + data.insert(format!("column_{i}"), serde_json::json!(i.to_string())); + } + OldChangeEvent { + schema: "public".to_string(), + table: "users".to_string(), + relation_oid: 16384, + data, + lsn: 0x16B374D848, + } +} + +/// Build a new-style ChangeEvent using RowData + ColumnValue. +/// `shared_names` simulates pre-cached `Arc` column names from RelationInfo. +fn build_new_event(shared_names: &[Arc]) -> ChangeEvent { + let n = shared_names.len(); + let mut row = RowData::with_capacity(n); + for (i, name) in shared_names.iter().enumerate() { + row.push(Arc::clone(name), ColumnValue::text(&i.to_string())); + } + ChangeEvent::insert("public", "users", 16384, row, Lsn::new(0x16B374D848)) +} + +/// Pre-create shared column names (mirrors what RelationInfo holds in production). +fn shared_column_names(n: usize) -> Vec> { + (0..n) + .map(|i| Arc::from(format!("column_{i}").as_str())) + .collect() +} + +const COLUMN_COUNTS: [usize; 4] = [5, 10, 20, 50]; + +// --------------------------------------------------------------------------- +// 1. Construction: HashMap+Value vs RowData+ColumnValue +// --------------------------------------------------------------------------- + +/// Compare event construction cost. +fn bench_construct(c: &mut Criterion) { + let mut group = c.benchmark_group("construct"); + + for n_cols in COLUMN_COUNTS { + let names = shared_column_names(n_cols); + + group.bench_with_input( + BenchmarkId::new("json_hashmap", n_cols), + &n_cols, + |b, &n| { + b.iter(|| black_box(build_old_event(n))); + }, + ); + + group.bench_with_input( + BenchmarkId::new("binary_columnvalue", n_cols), + &names, + |b, names| { + b.iter(|| black_box(build_new_event(names))); + }, + ); + } + + group.finish(); +} + +// --------------------------------------------------------------------------- +// 2. Serialize: serde_json::to_vec vs ChangeEvent::encode +// --------------------------------------------------------------------------- + +/// Compare serialization: JSON vs binary encoding. +fn bench_serialize(c: &mut Criterion) { + let mut group = c.benchmark_group("serialize"); + + for n_cols in COLUMN_COUNTS { + let names = shared_column_names(n_cols); + let new_event = build_new_event(&names); + + // JSON serialize (new ChangeEvent via serde) + group.bench_with_input( + BenchmarkId::new("json_serde", n_cols), + &new_event, + |b, event| { + b.iter(|| black_box(serde_json::to_vec(event).unwrap())); + }, + ); + + // Binary encode (ChangeEvent::encode) + group.bench_with_input( + BenchmarkId::new("binary_encode", n_cols), + &new_event, + |b, event| { + b.iter(|| { + let mut buf = bytes::BytesMut::with_capacity(256); + event.encode(&mut buf); + black_box(buf); + }); + }, + ); + } + + group.finish(); +} + +// --------------------------------------------------------------------------- +// 3. Deserialize: serde_json::from_slice vs ChangeEvent::decode +// --------------------------------------------------------------------------- + +/// Compare deserialization: JSON vs binary decoding. +fn bench_deserialize(c: &mut Criterion) { + let mut group = c.benchmark_group("deserialize"); + + for n_cols in COLUMN_COUNTS { + let names = shared_column_names(n_cols); + let new_event = build_new_event(&names); + + let new_json_bytes = serde_json::to_vec(&new_event).unwrap(); + let mut binary_buf = bytes::BytesMut::with_capacity(256); + new_event.encode(&mut binary_buf); + let binary_bytes = binary_buf.freeze(); + + // JSON deserialize (new ChangeEvent via serde) + group.bench_with_input( + BenchmarkId::new("json_serde", n_cols), + &new_json_bytes, + |b, data| { + b.iter(|| { + black_box(serde_json::from_slice::(data).unwrap()); + }); + }, + ); + + // Binary decode (ChangeEvent::decode) + group.bench_with_input( + BenchmarkId::new("binary_decode", n_cols), + &binary_bytes, + |b, data| { + b.iter(|| { + black_box(ChangeEvent::decode(data).unwrap()); + }); + }, + ); + } + + group.finish(); +} + +// --------------------------------------------------------------------------- +// 4. Round-trip: serialize → deserialize +// --------------------------------------------------------------------------- + +/// Compare full encode → decode round-trip. +fn bench_round_trip(c: &mut Criterion) { + let mut group = c.benchmark_group("round_trip"); + + for n_cols in COLUMN_COUNTS { + let names = shared_column_names(n_cols); + let new_event = build_new_event(&names); + + // JSON round-trip (new ChangeEvent via serde) + group.bench_with_input( + BenchmarkId::new("json_serde", n_cols), + &new_event, + |b, event| { + b.iter(|| { + let json = serde_json::to_vec(event).unwrap(); + let decoded: ChangeEvent = serde_json::from_slice(&json).unwrap(); + black_box(decoded); + }); + }, + ); + + // Binary round-trip (encode → decode) + group.bench_with_input( + BenchmarkId::new("binary_encode_decode", n_cols), + &new_event, + |b, event| { + b.iter(|| { + let mut buf = bytes::BytesMut::with_capacity(256); + event.encode(&mut buf); + let decoded = ChangeEvent::decode(&buf).unwrap(); + black_box(decoded); + }); + }, + ); + } + + group.finish(); +} + +// --------------------------------------------------------------------------- +// 5. Payload size comparison (one-shot, informational) +// --------------------------------------------------------------------------- +fn bench_payload_size(c: &mut Criterion) { + let mut group = c.benchmark_group("payload_size"); + + for n_cols in COLUMN_COUNTS { + let names = shared_column_names(n_cols); + let new_event = build_new_event(&names); + + let mut binary_buf = bytes::BytesMut::with_capacity(256); + new_event.encode(&mut binary_buf); + + // Bench building the payloads so criterion records something + group.bench_with_input( + BenchmarkId::new("json_serde", n_cols), + &new_event, + |b, event| { + b.iter(|| black_box(serde_json::to_vec(event).unwrap().len())); + }, + ); + + group.bench_with_input( + BenchmarkId::new("binary_encode", n_cols), + &new_event, + |b, event| { + b.iter(|| { + let mut buf = bytes::BytesMut::with_capacity(256); + event.encode(&mut buf); + black_box(buf.len()); + }); + }, + ); + } + + group.finish(); +} + +// --------------------------------------------------------------------------- +// 6. Realistic CDC pipeline: construct → clone → lookup → serialize +// --------------------------------------------------------------------------- + +/// End-to-end CDC simulation: construct event, clone it, look up 3 columns, +/// then serialize to the target format. +fn bench_pipeline(c: &mut Criterion) { + let mut group = c.benchmark_group("pipeline"); + + for n_cols in COLUMN_COUNTS { + let names = shared_column_names(n_cols); + + // New path: RowData + ColumnValue → JSON serde + group.bench_with_input( + BenchmarkId::new("json_serde", n_cols), + &names, + |b, names| { + b.iter(|| { + let event = build_new_event(names); + let cloned = event.clone(); + if let EventType::Insert { ref data, .. } = cloned.event_type { + let _ = black_box(data.get("column_0")); + let _ = black_box(data.get("column_1")); + let _ = black_box(data.get("column_2")); + } + let out = serde_json::to_vec(&cloned).unwrap(); + black_box(out); + }); + }, + ); + + // New path: RowData + ColumnValue → binary encode + group.bench_with_input( + BenchmarkId::new("binary_encode", n_cols), + &names, + |b, names| { + b.iter(|| { + let event = build_new_event(names); + let cloned = event.clone(); + if let EventType::Insert { ref data, .. } = cloned.event_type { + let _ = black_box(data.get("column_0")); + let _ = black_box(data.get("column_1")); + let _ = black_box(data.get("column_2")); + } + let mut buf = bytes::BytesMut::with_capacity(256); + cloned.encode(&mut buf); + black_box(buf); + }); + }, + ); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_construct, + bench_serialize, + bench_deserialize, + bench_round_trip, + bench_payload_size, + bench_pipeline, +); +criterion_main!(benches); diff --git a/benches/rowdata_vs_hashmap.rs b/benches/rowdata_vs_hashmap.rs deleted file mode 100644 index 729da34..0000000 --- a/benches/rowdata_vs_hashmap.rs +++ /dev/null @@ -1,279 +0,0 @@ -//! Benchmark: RowData (Arc + Vec) vs HashMap -//! -//! Compares the old HashMap-based approach with the new RowData/Arc approach -//! across realistic CDC workloads: -//! -//! - Event construction (simulates convert_to_change_event hot path) -//! - Column lookup by name -//! - Cloning events (simulates buffering / sending across channels) -//! - Serialization to JSON -//! -//! Run: -//! cargo bench --bench rowdata_vs_hashmap - -use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; -use pg_walstream::types::{ChangeEvent, EventType, Lsn, RowData}; -use serde_json::{self, Value}; -use std::collections::HashMap; -use std::sync::Arc; - -// --------------------------------------------------------------------------- -// Helpers — simulate the "old" HashMap code path -// --------------------------------------------------------------------------- - -/// Old-style event with HashMap and String schema/table. -#[derive(Debug, Clone, serde::Serialize)] -struct OldChangeEvent { - schema: String, - table: String, - relation_oid: u32, - data: HashMap, - lsn: u64, -} - -/// Build an old-style event the way the old code used to. -fn build_old_event(n_columns: usize) -> OldChangeEvent { - let mut data = HashMap::with_capacity(n_columns); - for i in 0..n_columns { - data.insert(format!("column_{i}"), serde_json::json!(i)); - } - OldChangeEvent { - schema: "public".to_string(), - table: "users".to_string(), - relation_oid: 16384, - data, - lsn: 0x16B374D848, - } -} - -/// Build a new-style ChangeEvent using Arc + RowData. -/// `shared_names` simulates the pre-cached Arc column names -/// held in RelationInfo — this is how production code works. -fn build_new_event(shared_names: &[Arc]) -> ChangeEvent { - let n = shared_names.len(); - let mut row = RowData::with_capacity(n); - for (i, name) in shared_names.iter().enumerate() { - row.push(Arc::clone(name), serde_json::json!(i)); - } - ChangeEvent::insert("public", "users", 16384, row, Lsn::new(0x16B374D848)) -} - -/// Pre-create shared column names (what RelationInfo holds in production). -fn shared_column_names(n: usize) -> Vec> { - (0..n) - .map(|i| Arc::from(format!("column_{i}").as_str())) - .collect() -} - -// --------------------------------------------------------------------------- -// Benchmarks -// --------------------------------------------------------------------------- - -fn bench_event_construction(c: &mut Criterion) { - let mut group = c.benchmark_group("event_construction"); - - for n_cols in [5, 10, 20, 50] { - let names = shared_column_names(n_cols); - - group.bench_with_input( - BenchmarkId::new("hashmap_string", n_cols), - &n_cols, - |b, &n| { - b.iter(|| black_box(build_old_event(n))); - }, - ); - - group.bench_with_input( - BenchmarkId::new("rowdata_arc", n_cols), - &names, - |b, names| { - b.iter(|| black_box(build_new_event(names))); - }, - ); - } - - group.finish(); -} - -fn bench_column_lookup(c: &mut Criterion) { - let mut group = c.benchmark_group("column_lookup"); - - for n_cols in [5, 10, 20, 50] { - // Build the data structures once - let old = build_old_event(n_cols); - let names = shared_column_names(n_cols); - let new_event = build_new_event(&names); - let row = match &new_event.event_type { - EventType::Insert { data, .. } => data, - _ => unreachable!(), - }; - - // Lookup the LAST column (worst case for linear scan) - let target = format!("column_{}", n_cols - 1); - - group.bench_with_input( - BenchmarkId::new("hashmap_get", n_cols), - &target, - |b, key| { - b.iter(|| black_box(old.data.get(key))); - }, - ); - - group.bench_with_input( - BenchmarkId::new("rowdata_get", n_cols), - &target, - |b, key| { - b.iter(|| black_box(row.get(key))); - }, - ); - } - - group.finish(); -} - -fn bench_clone(c: &mut Criterion) { - let mut group = c.benchmark_group("event_clone"); - - for n_cols in [5, 10, 20, 50] { - let old = build_old_event(n_cols); - let names = shared_column_names(n_cols); - let new_event = build_new_event(&names); - - group.bench_with_input( - BenchmarkId::new("hashmap_string", n_cols), - &old, - |b, event| { - b.iter(|| black_box(event.clone())); - }, - ); - - group.bench_with_input( - BenchmarkId::new("rowdata_arc", n_cols), - &new_event, - |b, event| { - b.iter(|| black_box(event.clone())); - }, - ); - } - - group.finish(); -} - -fn bench_serialize(c: &mut Criterion) { - let mut group = c.benchmark_group("json_serialize"); - - for n_cols in [5, 10, 20, 50] { - let old = build_old_event(n_cols); - let names = shared_column_names(n_cols); - let new_event = build_new_event(&names); - - group.bench_with_input( - BenchmarkId::new("hashmap_string", n_cols), - &old, - |b, event| { - b.iter(|| black_box(serde_json::to_string(event).unwrap())); - }, - ); - - group.bench_with_input( - BenchmarkId::new("rowdata_arc", n_cols), - &new_event, - |b, event| { - b.iter(|| black_box(serde_json::to_string(event).unwrap())); - }, - ); - } - - group.finish(); -} - -/// Simulate a realistic CDC pipeline: construct event → clone → lookup 3 columns → serialize. -fn bench_full_pipeline(c: &mut Criterion) { - let mut group = c.benchmark_group("full_pipeline"); - - for n_cols in [5, 10, 20, 50] { - let names = shared_column_names(n_cols); - - group.bench_with_input( - BenchmarkId::new("hashmap_string", n_cols), - &n_cols, - |b, &n| { - b.iter(|| { - let event = build_old_event(n); - let cloned = event.clone(); - // Simulate looking up a few columns - let _ = black_box(cloned.data.get("column_0")); - let _ = black_box(cloned.data.get("column_1")); - let _ = black_box(cloned.data.get("column_2")); - let json = serde_json::to_string(&cloned).unwrap(); - black_box(json); - }); - }, - ); - - group.bench_with_input( - BenchmarkId::new("rowdata_arc", n_cols), - &names, - |b, names| { - b.iter(|| { - let event = build_new_event(names); - let cloned = event.clone(); - if let EventType::Insert { ref data, .. } = cloned.event_type { - let _ = black_box(data.get("column_0")); - let _ = black_box(data.get("column_1")); - let _ = black_box(data.get("column_2")); - } - let json = serde_json::to_string(&cloned).unwrap(); - black_box(json); - }); - }, - ); - } - - group.finish(); -} - -/// Simulate high-throughput: construct N events from the same relation -/// (e.g., a burst of INSERT in one transaction). -fn bench_batch_construction(c: &mut Criterion) { - let mut group = c.benchmark_group("batch_100_events"); - - for n_cols in [5, 10, 20, 50] { - let names = shared_column_names(n_cols); - - group.bench_with_input( - BenchmarkId::new("hashmap_string", n_cols), - &n_cols, - |b, &n| { - b.iter(|| { - let events: Vec<_> = (0..100).map(|_| build_old_event(n)).collect(); - black_box(events); - }); - }, - ); - - group.bench_with_input( - BenchmarkId::new("rowdata_arc", n_cols), - &names, - |b, names| { - b.iter(|| { - let events: Vec<_> = (0..100).map(|_| build_new_event(names)).collect(); - black_box(events); - }); - }, - ); - } - - group.finish(); -} - -criterion_group!( - benches, - bench_event_construction, - bench_column_lookup, - bench_clone, - bench_serialize, - bench_full_pipeline, - bench_batch_construction, -); -criterion_main!(benches); diff --git a/examples/basic-streaming/Cargo.lock b/examples/basic-streaming/Cargo.lock index fd9fb8a..bb219f6 100644 --- a/examples/basic-streaming/Cargo.lock +++ b/examples/basic-streaming/Cargo.lock @@ -111,9 +111,9 @@ checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" [[package]] name = "chrono" -version = "0.4.43" +version = "0.4.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fac4744fb15ae8337dc853fee7fb3f4e48c0fbaa23d0afe49c447b4fab126118" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" dependencies = [ "iana-time-zone", "js-sys", @@ -290,12 +290,6 @@ dependencies = [ "cc", ] -[[package]] -name = "itoa" -version = "1.0.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" - [[package]] name = "js-sys" version = "0.3.85" @@ -469,7 +463,6 @@ dependencies = [ "chrono", "libpq-sys", "serde", - "serde_json", "thiserror", "tokio", "tokio-util", @@ -611,19 +604,6 @@ dependencies = [ "syn 2.0.115", ] -[[package]] -name = "serde_json" -version = "1.0.149" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" -dependencies = [ - "itoa", - "memchr", - "serde", - "serde_core", - "zmij", -] - [[package]] name = "sharded-slab" version = "0.1.7" @@ -1131,9 +1111,3 @@ name = "windows_x86_64_msvc" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" - -[[package]] -name = "zmij" -version = "1.0.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/integration-tests/complex_types.rs b/integration-tests/complex_types.rs new file mode 100644 index 0000000..0e75ecf --- /dev/null +++ b/integration-tests/complex_types.rs @@ -0,0 +1,1226 @@ +//! Integration tests for complex PostgreSQL data types through logical replication. +//! +//! These tests verify that the library correctly streams and represents complex +//! PostgreSQL types via the `pgoutput` logical decoding plugin: +//! +//! - **Array types**: `integer[]`, `text[]`, `boolean[]`, `float8[]`, nested arrays +//! - **JSON / JSONB**: objects, arrays, nested structures, special values +//! - **Geometric types**: `point`, `line`, `lseg`, `box`, `path`, `polygon`, `circle` +//! - **Composite / mixed**: rows containing multiple complex types together +//! +//! All complex types arrive as `ColumnValue::Text` because `pgoutput` always +//! emits data in text format (unless the `binary` option is explicitly enabled, +//! which this crate does not set by default). +//! +//! ## Prerequisites +//! +//! Same as `snapshot_export.rs` — requires a live PostgreSQL 14+ instance +//! with `wal_level = logical`. +//! +//! ## Running Locally +//! +//! ```bash +//! export DATABASE_URL="postgresql://postgres:postgres@localhost:5432/test_walstream?replication=database" +//! export DATABASE_URL_REGULAR="postgresql://postgres:postgres@localhost:5432/test_walstream" +//! cargo test --test complex_types -- --ignored --nocapture --test-threads=1 +//! ``` + +use pg_walstream::{ + CancellationToken, ColumnValue, EventType, LogicalReplicationStream, PgReplicationConnection, + ReplicationSlotOptions, ReplicationStreamConfig, RetryConfig, StreamingMode, +}; +use std::time::Duration; + +// ─── Helpers ───────────────────────────────────────────────────────────────── + +fn replication_conn_string() -> String { + std::env::var("DATABASE_URL").unwrap_or_else(|_| { + "postgresql://postgres:postgres@localhost:5432/test_walstream?replication=database" + .to_string() + }) +} + +fn regular_conn_string() -> String { + std::env::var("DATABASE_URL_REGULAR").unwrap_or_else(|_| { + let repl = replication_conn_string(); + repl.replace("?replication=database", "") + .replace("&replication=database", "") + }) +} + +fn drop_slot(slot_name: &str) { + if let Ok(conn) = PgReplicationConnection::connect(&replication_conn_string()) { + let _ = conn.exec(&format!( + "SELECT pg_drop_replication_slot('{slot_name}') \ + WHERE EXISTS (SELECT 1 FROM pg_replication_slots WHERE slot_name = '{slot_name}')" + )); + } +} + +fn complex_config(slot_name: &str, pub_name: &str) -> ReplicationStreamConfig { + ReplicationStreamConfig::new( + slot_name.to_string(), + pub_name.to_string(), + 2, + StreamingMode::On, + Duration::from_secs(10), + Duration::from_secs(30), + Duration::from_secs(60), + RetryConfig::default(), + ) + .with_slot_options(ReplicationSlotOptions { + temporary: true, + ..Default::default() + }) +} + +/// Collect Insert events from a single transaction. +/// +/// Returns the `ColumnValue` data for each Insert event (in arrival order). +/// Automatically cancels after `timeout` seconds. +async fn collect_insert_events( + stream: &mut LogicalReplicationStream, + timeout_secs: u64, + expected_inserts: usize, +) -> Vec> { + let cancel_token = CancellationToken::new(); + let cancel_clone = cancel_token.clone(); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(timeout_secs)).await; + cancel_clone.cancel(); + }); + + let mut inserts: Vec> = Vec::new(); + + loop { + match stream.next_event(&cancel_token).await { + Ok(event) => { + stream + .shared_lsn_feedback + .update_applied_lsn(event.lsn.value()); + + if let EventType::Insert { data, .. } = &event.event_type { + let cols: Vec<(String, ColumnValue)> = data + .iter() + .map(|(name, val)| (name.to_string(), val.clone())) + .collect(); + inserts.push(cols); + + if inserts.len() >= expected_inserts { + break; + } + } + } + Err(pg_walstream::ReplicationError::Cancelled(_)) => break, + Err(e) => panic!("Unexpected stream error: {e}"), + } + } + + inserts +} + +/// Look up a column value by name from a flat `(name, value)` list. +fn find_col<'a>(cols: &'a [(String, ColumnValue)], name: &str) -> &'a ColumnValue { + cols.iter() + .find(|(n, _)| n == name) + .map(|(_, v)| v) + .unwrap_or_else(|| panic!("column '{name}' not found in row")) +} + +// ─── Array Type Tests ──────────────────────────────────────────────────────── + +/// Verify that integer, text, boolean, and float arrays are streamed correctly. +#[tokio::test] +#[ignore = "requires live PostgreSQL with wal_level=logical"] +async fn test_array_types_basic() { + let slot = "it_complex_arr_basic"; + let pub_name = "complex_arr_pub"; + drop_slot(slot); + + let regular = + PgReplicationConnection::connect(®ular_conn_string()).expect("regular connection"); + + let _ = regular.exec( + "CREATE TABLE IF NOT EXISTS complex_arr_test (\ + id SERIAL PRIMARY KEY, \ + int_arr INTEGER[], \ + text_arr TEXT[], \ + bool_arr BOOLEAN[], \ + float_arr FLOAT8[]\ + )", + ); + let _ = regular.exec("TRUNCATE complex_arr_test RESTART IDENTITY"); + let _ = regular.exec(&format!("DROP PUBLICATION IF EXISTS {pub_name}")); + let _ = regular.exec(&format!( + "CREATE PUBLICATION {pub_name} FOR TABLE complex_arr_test" + )); + + let config = complex_config(slot, pub_name); + let mut stream = LogicalReplicationStream::new(&replication_conn_string(), config) + .await + .expect("replication stream"); + stream.start(None).await.expect("start"); + + // Insert a row with various array types + regular + .exec( + "INSERT INTO complex_arr_test (int_arr, text_arr, bool_arr, float_arr) \ + VALUES (\ + '{1,2,3}', \ + '{\"hello\",\"world\",\"pg\"}', \ + '{true,false,true}', \ + '{1.1,2.2,3.3}'\ + )", + ) + .expect("INSERT arrays"); + + let inserts = collect_insert_events(&mut stream, 10, 1).await; + assert_eq!(inserts.len(), 1, "expected 1 insert event"); + + let row = &inserts[0]; + + // pgoutput emits arrays as their text representation + let int_arr = find_col(row, "int_arr"); + assert_eq!( + int_arr.as_str(), + Some("{1,2,3}"), + "integer array text mismatch" + ); + + let text_arr = find_col(row, "text_arr"); + assert_eq!( + text_arr.as_str(), + Some("{hello,world,pg}"), + "text array text mismatch" + ); + + let bool_arr = find_col(row, "bool_arr"); + assert_eq!( + bool_arr.as_str(), + Some("{t,f,t}"), + "boolean array text mismatch" + ); + + let float_arr = find_col(row, "float_arr"); + let float_str = float_arr.as_str().expect("float_arr should be text"); + assert!( + float_str.starts_with('{') && float_str.ends_with('}'), + "float array should be delimited: {float_str}" + ); + + // Verify none of the array columns are null + assert!(!int_arr.is_null()); + assert!(!text_arr.is_null()); + assert!(!bool_arr.is_null()); + assert!(!float_arr.is_null()); + + println!("Array basic test passed: int={int_arr}, text={text_arr}, bool={bool_arr}, float={float_arr}"); +} + +/// Verify multi-dimensional (nested) arrays and arrays with NULL elements. +#[tokio::test] +#[ignore = "requires live PostgreSQL with wal_level=logical"] +async fn test_array_types_nested_and_nulls() { + let slot = "it_complex_arr_nested"; + let pub_name = "complex_arr_nested_pub"; + drop_slot(slot); + + let regular = + PgReplicationConnection::connect(®ular_conn_string()).expect("regular connection"); + + let _ = regular.exec( + "CREATE TABLE IF NOT EXISTS complex_arr_nested (\ + id SERIAL PRIMARY KEY, \ + matrix INTEGER[][], \ + nullable_arr TEXT[]\ + )", + ); + let _ = regular.exec("TRUNCATE complex_arr_nested RESTART IDENTITY"); + let _ = regular.exec(&format!("DROP PUBLICATION IF EXISTS {pub_name}")); + let _ = regular.exec(&format!( + "CREATE PUBLICATION {pub_name} FOR TABLE complex_arr_nested" + )); + + let config = complex_config(slot, pub_name); + let mut stream = LogicalReplicationStream::new(&replication_conn_string(), config) + .await + .expect("replication stream"); + stream.start(None).await.expect("start"); + + // 2D array and array with NULL elements + regular + .exec( + "INSERT INTO complex_arr_nested (matrix, nullable_arr) \ + VALUES (\ + '{{1,2},{3,4}}', \ + '{\"present\",NULL,\"also_present\"}'\ + )", + ) + .expect("INSERT nested arrays"); + + let inserts = collect_insert_events(&mut stream, 10, 1).await; + assert_eq!(inserts.len(), 1, "expected 1 insert event"); + + let row = &inserts[0]; + + let matrix = find_col(row, "matrix"); + assert_eq!( + matrix.as_str(), + Some("{{1,2},{3,4}}"), + "2D array text mismatch" + ); + + let nullable_arr = find_col(row, "nullable_arr"); + let arr_str = nullable_arr.as_str().expect("nullable_arr should be text"); + // PostgreSQL represents NULL elements as `NULL` within the array literal + assert!( + arr_str.contains("NULL"), + "nullable array should contain NULL element: {arr_str}" + ); + + println!("Nested/NULL array test passed: matrix={matrix}, nullable={nullable_arr}"); +} + +/// Verify empty arrays are streamed correctly. +#[tokio::test] +#[ignore = "requires live PostgreSQL with wal_level=logical"] +async fn test_array_types_empty() { + let slot = "it_complex_arr_empty"; + let pub_name = "complex_arr_empty_pub"; + drop_slot(slot); + + let regular = + PgReplicationConnection::connect(®ular_conn_string()).expect("regular connection"); + + let _ = regular.exec( + "CREATE TABLE IF NOT EXISTS complex_arr_empty (\ + id SERIAL PRIMARY KEY, \ + empty_int INTEGER[], \ + empty_text TEXT[]\ + )", + ); + let _ = regular.exec("TRUNCATE complex_arr_empty RESTART IDENTITY"); + let _ = regular.exec(&format!("DROP PUBLICATION IF EXISTS {pub_name}")); + let _ = regular.exec(&format!( + "CREATE PUBLICATION {pub_name} FOR TABLE complex_arr_empty" + )); + + let config = complex_config(slot, pub_name); + let mut stream = LogicalReplicationStream::new(&replication_conn_string(), config) + .await + .expect("replication stream"); + stream.start(None).await.expect("start"); + + regular + .exec( + "INSERT INTO complex_arr_empty (empty_int, empty_text) \ + VALUES ('{}', '{}')", + ) + .expect("INSERT empty arrays"); + + let inserts = collect_insert_events(&mut stream, 10, 1).await; + assert_eq!(inserts.len(), 1, "expected 1 insert event"); + + let row = &inserts[0]; + assert_eq!( + find_col(row, "empty_int").as_str(), + Some("{}"), + "empty integer array" + ); + assert_eq!( + find_col(row, "empty_text").as_str(), + Some("{}"), + "empty text array" + ); + + println!("Empty array test passed"); +} + +// ─── JSON / JSONB Type Tests ───────────────────────────────────────────────── + +/// Verify JSON and JSONB objects, arrays, and nested structures. +#[tokio::test] +#[ignore = "requires live PostgreSQL with wal_level=logical"] +async fn test_json_jsonb_basic() { + let slot = "it_complex_json_basic"; + let pub_name = "complex_json_pub"; + drop_slot(slot); + + let regular = + PgReplicationConnection::connect(®ular_conn_string()).expect("regular connection"); + + let _ = regular.exec( + "CREATE TABLE IF NOT EXISTS complex_json_test (\ + id SERIAL PRIMARY KEY, \ + data_json JSON, \ + data_jsonb JSONB\ + )", + ); + let _ = regular.exec("TRUNCATE complex_json_test RESTART IDENTITY"); + let _ = regular.exec(&format!("DROP PUBLICATION IF EXISTS {pub_name}")); + let _ = regular.exec(&format!( + "CREATE PUBLICATION {pub_name} FOR TABLE complex_json_test" + )); + + let config = complex_config(slot, pub_name); + let mut stream = LogicalReplicationStream::new(&replication_conn_string(), config) + .await + .expect("replication stream"); + stream.start(None).await.expect("start"); + + // A JSON object with various value types + regular + .exec( + r#"INSERT INTO complex_json_test (data_json, data_jsonb) VALUES ( + '{"name": "alice", "age": 30, "active": true, "score": 9.5}', + '{"name": "alice", "age": 30, "active": true, "score": 9.5}' + )"#, + ) + .expect("INSERT json objects"); + + let inserts = collect_insert_events(&mut stream, 10, 1).await; + assert_eq!(inserts.len(), 1, "expected 1 insert event"); + + let row = &inserts[0]; + + // JSON preserves the exact input text + let json_val = find_col(row, "data_json"); + let json_str = json_val.as_str().expect("data_json should be text"); + assert!( + json_str.contains("\"name\""), + "json should contain name key" + ); + assert!( + json_str.contains("\"alice\""), + "json should contain alice value" + ); + assert!(json_str.contains("30"), "json should contain age 30"); + assert!(json_str.contains("true"), "json should contain true"); + + // JSONB normalises key order and whitespace + let jsonb_val = find_col(row, "data_jsonb"); + let jsonb_str = jsonb_val.as_str().expect("data_jsonb should be text"); + assert!( + jsonb_str.contains("\"name\""), + "jsonb should contain name key" + ); + assert!( + jsonb_str.contains("\"alice\""), + "jsonb should contain alice value" + ); + + // Neither should be null + assert!(!json_val.is_null()); + assert!(!jsonb_val.is_null()); + + println!("JSON/JSONB basic test passed: json={json_val}, jsonb={jsonb_val}"); +} + +/// Verify nested JSON structures and JSON arrays. +#[tokio::test] +#[ignore = "requires live PostgreSQL with wal_level=logical"] +async fn test_json_nested_and_arrays() { + let slot = "it_complex_json_nested"; + let pub_name = "complex_json_nested_pub"; + drop_slot(slot); + + let regular = + PgReplicationConnection::connect(®ular_conn_string()).expect("regular connection"); + + let _ = regular.exec( + "CREATE TABLE IF NOT EXISTS complex_json_nested (\ + id SERIAL PRIMARY KEY, \ + nested JSONB, \ + arr JSONB\ + )", + ); + let _ = regular.exec("TRUNCATE complex_json_nested RESTART IDENTITY"); + let _ = regular.exec(&format!("DROP PUBLICATION IF EXISTS {pub_name}")); + let _ = regular.exec(&format!( + "CREATE PUBLICATION {pub_name} FOR TABLE complex_json_nested" + )); + + let config = complex_config(slot, pub_name); + let mut stream = LogicalReplicationStream::new(&replication_conn_string(), config) + .await + .expect("replication stream"); + stream.start(None).await.expect("start"); + + regular + .exec( + r#"INSERT INTO complex_json_nested (nested, arr) VALUES ( + '{"user": {"name": "bob", "address": {"city": "NYC", "zip": "10001"}}, "tags": ["admin", "user"]}', + '[1, "two", null, true, {"key": "val"}]' + )"#, + ) + .expect("INSERT nested json"); + + let inserts = collect_insert_events(&mut stream, 10, 1).await; + assert_eq!(inserts.len(), 1, "expected 1 insert event"); + + let row = &inserts[0]; + + let nested = find_col(row, "nested"); + let nested_str = nested.as_str().expect("nested should be text"); + assert!( + nested_str.contains("\"city\""), + "nested json should contain city: {nested_str}" + ); + assert!( + nested_str.contains("\"NYC\"") || nested_str.contains("\"10001\""), + "nested json should contain address data: {nested_str}" + ); + assert!( + nested_str.contains("\"admin\""), + "nested json should contain tags: {nested_str}" + ); + + let arr = find_col(row, "arr"); + let arr_str = arr.as_str().expect("arr should be text"); + assert!( + arr_str.starts_with('['), + "json array should start with [: {arr_str}" + ); + assert!( + arr_str.contains("null"), + "json array should contain null: {arr_str}" + ); + + println!("Nested JSON/array test passed: nested={nested}, arr={arr}"); +} + +/// Verify JSON NULL (SQL NULL column) vs JSON `null` value. +#[tokio::test] +#[ignore = "requires live PostgreSQL with wal_level=logical"] +async fn test_json_null_handling() { + let slot = "it_complex_json_null"; + let pub_name = "complex_json_null_pub"; + drop_slot(slot); + + let regular = + PgReplicationConnection::connect(®ular_conn_string()).expect("regular connection"); + + let _ = regular.exec( + "CREATE TABLE IF NOT EXISTS complex_json_null (\ + id SERIAL PRIMARY KEY, \ + sql_null JSONB, \ + json_null JSONB, \ + json_empty_obj JSONB, \ + json_empty_arr JSONB\ + )", + ); + let _ = regular.exec("TRUNCATE complex_json_null RESTART IDENTITY"); + let _ = regular.exec(&format!("DROP PUBLICATION IF EXISTS {pub_name}")); + let _ = regular.exec(&format!( + "CREATE PUBLICATION {pub_name} FOR TABLE complex_json_null" + )); + + let config = complex_config(slot, pub_name); + let mut stream = LogicalReplicationStream::new(&replication_conn_string(), config) + .await + .expect("replication stream"); + stream.start(None).await.expect("start"); + + regular + .exec( + "INSERT INTO complex_json_null (sql_null, json_null, json_empty_obj, json_empty_arr) \ + VALUES (NULL, 'null', '{}', '[]')", + ) + .expect("INSERT json nulls"); + + let inserts = collect_insert_events(&mut stream, 10, 1).await; + assert_eq!(inserts.len(), 1, "expected 1 insert event"); + + let row = &inserts[0]; + + // SQL NULL → ColumnValue::Null + let sql_null = find_col(row, "sql_null"); + assert!( + sql_null.is_null(), + "SQL NULL column should be ColumnValue::Null" + ); + + // JSON literal `null` → ColumnValue::Text("null") + let json_null = find_col(row, "json_null"); + assert!( + !json_null.is_null(), + "JSON 'null' is a valid JSONB value, not SQL NULL" + ); + assert_eq!(json_null.as_str(), Some("null"), "JSON null text mismatch"); + + // Empty JSON object + let empty_obj = find_col(row, "json_empty_obj"); + assert_eq!(empty_obj.as_str(), Some("{}"), "empty JSONB object"); + + // Empty JSON array + let empty_arr = find_col(row, "json_empty_arr"); + assert_eq!(empty_arr.as_str(), Some("[]"), "empty JSONB array"); + + println!("JSON null handling test passed"); +} + +// ─── Geometric Type Tests ──────────────────────────────────────────────────── + +/// Verify point, line, lseg, box, circle geometric types. +#[tokio::test] +#[ignore = "requires live PostgreSQL with wal_level=logical"] +async fn test_geometric_types_basic() { + let slot = "it_complex_geo_basic"; + let pub_name = "complex_geo_pub"; + drop_slot(slot); + + let regular = + PgReplicationConnection::connect(®ular_conn_string()).expect("regular connection"); + + let _ = regular.exec( + "CREATE TABLE IF NOT EXISTS complex_geo_test (\ + id SERIAL PRIMARY KEY, \ + pt POINT, \ + ln LINE, \ + seg LSEG, \ + bx BOX, \ + cr CIRCLE\ + )", + ); + let _ = regular.exec("TRUNCATE complex_geo_test RESTART IDENTITY"); + let _ = regular.exec(&format!("DROP PUBLICATION IF EXISTS {pub_name}")); + let _ = regular.exec(&format!( + "CREATE PUBLICATION {pub_name} FOR TABLE complex_geo_test" + )); + + let config = complex_config(slot, pub_name); + let mut stream = LogicalReplicationStream::new(&replication_conn_string(), config) + .await + .expect("replication stream"); + stream.start(None).await.expect("start"); + + regular + .exec( + "INSERT INTO complex_geo_test (pt, ln, seg, bx, cr) VALUES (\ + '(1.5, 2.5)', \ + '{1, -1, 0}', \ + '((0,0),(3,4))', \ + '((3,4),(1,2))', \ + '<(1,2),5>'\ + )", + ) + .expect("INSERT geometric types"); + + let inserts = collect_insert_events(&mut stream, 10, 1).await; + assert_eq!(inserts.len(), 1, "expected 1 insert event"); + + let row = &inserts[0]; + + // Point: (x,y) + let pt = find_col(row, "pt"); + let pt_str = pt.as_str().expect("pt should be text"); + assert!( + pt_str.contains("1.5") && pt_str.contains("2.5"), + "point should contain coordinates: {pt_str}" + ); + + // Line: {A, B, C} + let ln = find_col(row, "ln"); + let ln_str = ln.as_str().expect("ln should be text"); + assert!( + ln_str.starts_with('{') && ln_str.ends_with('}'), + "line should be in {{A,B,C}} format: {ln_str}" + ); + + // Line segment: ((x1,y1),(x2,y2)) + let seg = find_col(row, "seg"); + let seg_str = seg.as_str().expect("seg should be text"); + assert!( + seg_str.contains('(') && seg_str.contains(')'), + "lseg should contain parens: {seg_str}" + ); + + // Box: (x1,y1),(x2,y2) + let bx = find_col(row, "bx"); + let bx_str = bx.as_str().expect("bx should be text"); + assert!( + bx_str.contains('(') && bx_str.contains(')'), + "box should contain parens: {bx_str}" + ); + + // Circle: <(x,y),r> + let cr = find_col(row, "cr"); + let cr_str = cr.as_str().expect("cr should be text"); + assert!( + cr_str.contains('<') && cr_str.contains('>'), + "circle should be in <(x,y),r> format: {cr_str}" + ); + assert!(cr_str.contains('5'), "circle radius should be 5: {cr_str}"); + + println!("Geometric basic test passed: pt={pt}, ln={ln}, seg={seg}, bx={bx}, cr={cr}"); +} + +/// Verify path and polygon geometric types. +#[tokio::test] +#[ignore = "requires live PostgreSQL with wal_level=logical"] +async fn test_geometric_path_polygon() { + let slot = "it_complex_geo_path"; + let pub_name = "complex_geo_path_pub"; + drop_slot(slot); + + let regular = + PgReplicationConnection::connect(®ular_conn_string()).expect("regular connection"); + + let _ = regular.exec( + "CREATE TABLE IF NOT EXISTS complex_geo_path (\ + id SERIAL PRIMARY KEY, \ + open_path PATH, \ + closed_path PATH, \ + poly POLYGON\ + )", + ); + let _ = regular.exec("TRUNCATE complex_geo_path RESTART IDENTITY"); + let _ = regular.exec(&format!("DROP PUBLICATION IF EXISTS {pub_name}")); + let _ = regular.exec(&format!( + "CREATE PUBLICATION {pub_name} FOR TABLE complex_geo_path" + )); + + let config = complex_config(slot, pub_name); + let mut stream = LogicalReplicationStream::new(&replication_conn_string(), config) + .await + .expect("replication stream"); + stream.start(None).await.expect("start"); + + regular + .exec( + "INSERT INTO complex_geo_path (open_path, closed_path, poly) VALUES (\ + '[(0,0),(1,1),(2,0)]', \ + '((0,0),(1,1),(2,0))', \ + '((0,0),(4,0),(4,3),(0,3))'\ + )", + ) + .expect("INSERT path/polygon"); + + let inserts = collect_insert_events(&mut stream, 10, 1).await; + assert_eq!(inserts.len(), 1, "expected 1 insert event"); + + let row = &inserts[0]; + + // Open path: [(p1),(p2),...] + let open_path = find_col(row, "open_path"); + let open_str = open_path.as_str().expect("open_path should be text"); + assert!( + open_str.starts_with('['), + "open path should start with [: {open_str}" + ); + + // Closed path: ((p1),(p2),...) + let closed_path = find_col(row, "closed_path"); + let closed_str = closed_path.as_str().expect("closed_path should be text"); + assert!( + closed_str.starts_with('('), + "closed path should start with (: {closed_str}" + ); + + // Polygon: ((p1),(p2),...) + let poly = find_col(row, "poly"); + let poly_str = poly.as_str().expect("poly should be text"); + assert!( + poly_str.starts_with('('), + "polygon should start with (: {poly_str}" + ); + assert!( + poly_str.contains("4,3") || poly_str.contains("4, 3"), + "polygon should contain vertex (4,3): {poly_str}" + ); + + println!("Path/polygon test passed: open={open_path}, closed={closed_path}, poly={poly}"); +} + +// ─── Mixed Complex Types ───────────────────────────────────────────────────── + +/// Verify a single row containing arrays, JSON, and geometric types together. +#[tokio::test] +#[ignore = "requires live PostgreSQL with wal_level=logical"] +async fn test_mixed_complex_types_insert() { + let slot = "it_complex_mixed"; + let pub_name = "complex_mixed_pub"; + drop_slot(slot); + + let regular = + PgReplicationConnection::connect(®ular_conn_string()).expect("regular connection"); + + let _ = regular.exec( + "CREATE TABLE IF NOT EXISTS complex_mixed (\ + id SERIAL PRIMARY KEY, \ + tags TEXT[], \ + metadata JSONB, \ + location POINT, \ + scores FLOAT8[]\ + )", + ); + let _ = regular.exec("TRUNCATE complex_mixed RESTART IDENTITY"); + let _ = regular.exec(&format!("DROP PUBLICATION IF EXISTS {pub_name}")); + let _ = regular.exec(&format!( + "CREATE PUBLICATION {pub_name} FOR TABLE complex_mixed" + )); + + let config = complex_config(slot, pub_name); + let mut stream = LogicalReplicationStream::new(&replication_conn_string(), config) + .await + .expect("replication stream"); + stream.start(None).await.expect("start"); + + regular + .exec( + r#"INSERT INTO complex_mixed (tags, metadata, location, scores) VALUES ( + '{"rust","postgres","cdc"}', + '{"version": 1, "features": ["streaming", "arrays"], "config": {"timeout": 30}}', + '(40.7128, -74.0060)', + '{98.5, 87.3, 95.1}' + )"#, + ) + .expect("INSERT mixed complex types"); + + let inserts = collect_insert_events(&mut stream, 10, 1).await; + assert_eq!(inserts.len(), 1, "expected 1 insert event"); + + let row = &inserts[0]; + + // Array + let tags = find_col(row, "tags"); + let tags_str = tags.as_str().expect("tags should be text"); + assert!( + tags_str.contains("rust"), + "tags should contain 'rust': {tags_str}" + ); + assert!( + tags_str.contains("postgres"), + "tags should contain 'postgres': {tags_str}" + ); + + // JSONB + let metadata = find_col(row, "metadata"); + let meta_str = metadata.as_str().expect("metadata should be text"); + assert!( + meta_str.contains("\"version\""), + "metadata should contain version: {meta_str}" + ); + assert!( + meta_str.contains("\"streaming\""), + "metadata should contain streaming feature: {meta_str}" + ); + assert!( + meta_str.contains("\"timeout\""), + "metadata should contain config.timeout: {meta_str}" + ); + + // Point + let location = find_col(row, "location"); + let loc_str = location.as_str().expect("location should be text"); + assert!( + loc_str.contains("40.7128"), + "location should contain latitude: {loc_str}" + ); + + // Float array + let scores = find_col(row, "scores"); + let scores_str = scores.as_str().expect("scores should be text"); + assert!( + scores_str.starts_with('{') && scores_str.ends_with('}'), + "scores should be array: {scores_str}" + ); + + println!( + "Mixed complex types test passed: tags={tags}, meta={metadata}, loc={location}, scores={scores}" + ); +} + +/// Verify UPDATE events correctly stream complex type data (old + new). +#[tokio::test] +#[ignore = "requires live PostgreSQL with wal_level=logical"] +async fn test_mixed_complex_types_update() { + let slot = "it_complex_update"; + let pub_name = "complex_update_pub"; + drop_slot(slot); + + let regular = + PgReplicationConnection::connect(®ular_conn_string()).expect("regular connection"); + + let _ = regular.exec( + "CREATE TABLE IF NOT EXISTS complex_update_test (\ + id SERIAL PRIMARY KEY, \ + data JSONB, \ + items TEXT[]\ + )", + ); + let _ = regular.exec("TRUNCATE complex_update_test RESTART IDENTITY"); + let _ = regular.exec("ALTER TABLE complex_update_test REPLICA IDENTITY FULL"); + let _ = regular.exec(&format!("DROP PUBLICATION IF EXISTS {pub_name}")); + let _ = regular.exec(&format!( + "CREATE PUBLICATION {pub_name} FOR TABLE complex_update_test" + )); + + let config = complex_config(slot, pub_name); + let mut stream = LogicalReplicationStream::new(&replication_conn_string(), config) + .await + .expect("replication stream"); + stream.start(None).await.expect("start"); + + // Insert initial row, then update it + regular + .exec( + r#"INSERT INTO complex_update_test (data, items) VALUES ( + '{"status": "draft", "count": 0}', + '{"alpha","beta"}' + )"#, + ) + .expect("INSERT initial"); + + regular + .exec( + r#"UPDATE complex_update_test SET + data = '{"status": "published", "count": 42}', + items = '{"alpha","beta","gamma"}' + WHERE id = 1"#, + ) + .expect("UPDATE complex types"); + + let cancel_token = CancellationToken::new(); + let cancel_clone = cancel_token.clone(); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(10)).await; + cancel_clone.cancel(); + }); + + let mut saw_update = false; + let mut commit_count = 0u32; + + loop { + match stream.next_event(&cancel_token).await { + Ok(event) => { + stream + .shared_lsn_feedback + .update_applied_lsn(event.lsn.value()); + + if let EventType::Update { + old_data, new_data, .. + } = &event.event_type + { + saw_update = true; + + // Verify new data has updated values + let new_data_col = new_data.get("data").expect("new_data should have 'data'"); + let new_str = new_data_col.as_str().expect("new data should be text"); + assert!( + new_str.contains("\"published\""), + "new data should contain published: {new_str}" + ); + assert!( + new_str.contains("42"), + "new data should contain count 42: {new_str}" + ); + + let new_items = new_data.get("items").expect("new_data should have 'items'"); + let items_str = new_items.as_str().expect("items should be text"); + assert!( + items_str.contains("gamma"), + "new items should contain gamma: {items_str}" + ); + + // With REPLICA IDENTITY FULL, old_data should be present + if let Some(old) = old_data { + let old_data_col = old.get("data").expect("old_data should have 'data'"); + let old_str = old_data_col.as_str().expect("old data should be text"); + assert!( + old_str.contains("\"draft\""), + "old data should contain draft: {old_str}" + ); + } + } + + if matches!(event.event_type, EventType::Commit { .. }) { + commit_count += 1; + // Wait for the update transaction (2nd commit) + if commit_count >= 2 { + break; + } + } + } + Err(pg_walstream::ReplicationError::Cancelled(_)) => break, + Err(e) => panic!("Unexpected error: {e}"), + } + } + + assert!( + saw_update, + "expected an Update event with complex type data" + ); + println!("Complex type UPDATE test passed"); +} + +/// Verify DELETE events correctly stream complex type data (old row). +#[tokio::test] +#[ignore = "requires live PostgreSQL with wal_level=logical"] +async fn test_mixed_complex_types_delete() { + let slot = "it_complex_delete"; + let pub_name = "complex_delete_pub"; + drop_slot(slot); + + let regular = + PgReplicationConnection::connect(®ular_conn_string()).expect("regular connection"); + + let _ = regular.exec( + "CREATE TABLE IF NOT EXISTS complex_delete_test (\ + id SERIAL PRIMARY KEY, \ + config JSONB, \ + coords POINT\ + )", + ); + let _ = regular.exec("TRUNCATE complex_delete_test RESTART IDENTITY"); + let _ = regular.exec("ALTER TABLE complex_delete_test REPLICA IDENTITY FULL"); + let _ = regular.exec(&format!("DROP PUBLICATION IF EXISTS {pub_name}")); + let _ = regular.exec(&format!( + "CREATE PUBLICATION {pub_name} FOR TABLE complex_delete_test" + )); + + let config = complex_config(slot, pub_name); + let mut stream = LogicalReplicationStream::new(&replication_conn_string(), config) + .await + .expect("replication stream"); + stream.start(None).await.expect("start"); + + // Insert then delete + regular + .exec( + r#"INSERT INTO complex_delete_test (config, coords) VALUES ( + '{"key": "to_delete", "nested": {"a": 1}}', + '(10.0, 20.0)' + )"#, + ) + .expect("INSERT for delete"); + + regular + .exec("DELETE FROM complex_delete_test WHERE id = 1") + .expect("DELETE complex row"); + + let cancel_token = CancellationToken::new(); + let cancel_clone = cancel_token.clone(); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(10)).await; + cancel_clone.cancel(); + }); + + let mut saw_delete = false; + let mut commit_count = 0u32; + + loop { + match stream.next_event(&cancel_token).await { + Ok(event) => { + stream + .shared_lsn_feedback + .update_applied_lsn(event.lsn.value()); + + if let EventType::Delete { old_data, .. } = &event.event_type { + saw_delete = true; + + // With REPLICA IDENTITY FULL, old_data contains the deleted row + let config_col = old_data.get("config").expect("old_data should have config"); + let config_str = config_col.as_str().expect("config should be text"); + assert!( + config_str.contains("\"to_delete\""), + "deleted row config should contain to_delete: {config_str}" + ); + + let coords_col = old_data.get("coords").expect("old_data should have coords"); + let coords_str = coords_col.as_str().expect("coords should be text"); + assert!( + coords_str.contains("10") && coords_str.contains("20"), + "deleted row coords should contain (10,20): {coords_str}" + ); + } + + if matches!(event.event_type, EventType::Commit { .. }) { + commit_count += 1; + if commit_count >= 2 { + break; + } + } + } + Err(pg_walstream::ReplicationError::Cancelled(_)) => break, + Err(e) => panic!("Unexpected error: {e}"), + } + } + + assert!(saw_delete, "expected a Delete event with complex type data"); + println!("Complex type DELETE test passed"); +} + +/// Verify multiple rows with complex types in a single transaction. +#[tokio::test] +#[ignore = "requires live PostgreSQL with wal_level=logical"] +async fn test_batch_insert_complex_types() { + let slot = "it_complex_batch"; + let pub_name = "complex_batch_pub"; + drop_slot(slot); + + let regular = + PgReplicationConnection::connect(®ular_conn_string()).expect("regular connection"); + + let _ = regular.exec( + "CREATE TABLE IF NOT EXISTS complex_batch (\ + id SERIAL PRIMARY KEY, \ + label TEXT, \ + tags TEXT[], \ + info JSONB, \ + pos POINT\ + )", + ); + let _ = regular.exec("TRUNCATE complex_batch RESTART IDENTITY"); + let _ = regular.exec(&format!("DROP PUBLICATION IF EXISTS {pub_name}")); + let _ = regular.exec(&format!( + "CREATE PUBLICATION {pub_name} FOR TABLE complex_batch" + )); + + let config = complex_config(slot, pub_name); + let mut stream = LogicalReplicationStream::new(&replication_conn_string(), config) + .await + .expect("replication stream"); + stream.start(None).await.expect("start"); + + // Batch insert — single transaction with 3 rows + regular + .exec( + r#"INSERT INTO complex_batch (label, tags, info, pos) VALUES + ('row1', '{"a","b"}', '{"seq": 1}', '(0,0)'), + ('row2', '{"c","d"}', '{"seq": 2}', '(1,1)'), + ('row3', '{"e","f","g"}', '{"seq": 3, "extra": true}', '(2,2)') + "#, + ) + .expect("batch INSERT"); + + let inserts = collect_insert_events(&mut stream, 10, 3).await; + assert_eq!(inserts.len(), 3, "expected 3 insert events from batch"); + + // Verify each row + for (i, row) in inserts.iter().enumerate() { + let label = find_col(row, "label"); + assert!(!label.is_null(), "row {i} label should not be null"); + + let tags = find_col(row, "tags"); + let tags_str = tags.as_str().expect("tags should be text"); + assert!( + tags_str.starts_with('{') && tags_str.ends_with('}'), + "row {i} tags should be array: {tags_str}" + ); + + let info = find_col(row, "info"); + let info_str = info.as_str().expect("info should be text"); + assert!( + info_str.contains("\"seq\""), + "row {i} info should contain seq: {info_str}" + ); + + let pos = find_col(row, "pos"); + let pos_str = pos.as_str().expect("pos should be text"); + assert!( + pos_str.contains('(') && pos_str.contains(')'), + "row {i} pos should be point: {pos_str}" + ); + } + + // Third row should have extra tag + let row3_tags = find_col(&inserts[2], "tags"); + let row3_str = row3_tags.as_str().unwrap(); + assert!( + row3_str.contains('g'), + "row3 tags should contain 'g': {row3_str}" + ); + + // Third row should have extra JSON field + let row3_info = find_col(&inserts[2], "info"); + let row3_info_str = row3_info.as_str().unwrap(); + assert!( + row3_info_str.contains("\"extra\""), + "row3 info should contain extra: {row3_info_str}" + ); + + println!( + "Batch insert complex types test passed ({} rows)", + inserts.len() + ); +} + +/// Verify JSONB special numeric values: large integers, floats, negative numbers. +#[tokio::test] +#[ignore = "requires live PostgreSQL with wal_level=logical"] +async fn test_json_special_values() { + let slot = "it_complex_json_special"; + let pub_name = "complex_json_special_pub"; + drop_slot(slot); + + let regular = + PgReplicationConnection::connect(®ular_conn_string()).expect("regular connection"); + + let _ = regular.exec( + "CREATE TABLE IF NOT EXISTS complex_json_special (\ + id SERIAL PRIMARY KEY, \ + data JSONB\ + )", + ); + let _ = regular.exec("TRUNCATE complex_json_special RESTART IDENTITY"); + let _ = regular.exec(&format!("DROP PUBLICATION IF EXISTS {pub_name}")); + let _ = regular.exec(&format!( + "CREATE PUBLICATION {pub_name} FOR TABLE complex_json_special" + )); + + let config = complex_config(slot, pub_name); + let mut stream = LogicalReplicationStream::new(&replication_conn_string(), config) + .await + .expect("replication stream"); + stream.start(None).await.expect("start"); + + // Insert rows with edge-case JSON values + regular + .exec( + r#"INSERT INTO complex_json_special (data) VALUES + ('{"big": 9999999999999999}'), + ('{"neg": -42, "zero": 0}'), + ('{"float": 3.14159265358979}'), + ('{"unicode": "café ☕ 日本語"}'), + ('{"escaped": "line1\nline2\ttab"}')"#, + ) + .expect("INSERT special json values"); + + let inserts = collect_insert_events(&mut stream, 10, 5).await; + assert_eq!(inserts.len(), 5, "expected 5 insert events"); + + // Large integer + let big = find_col(&inserts[0], "data"); + let big_str = big.as_str().unwrap(); + assert!( + big_str.contains("9999999999999999"), + "should preserve large int: {big_str}" + ); + + // Negative and zero + let neg = find_col(&inserts[1], "data"); + let neg_str = neg.as_str().unwrap(); + assert!( + neg_str.contains("-42"), + "should contain negative: {neg_str}" + ); + + // Float precision + let float_val = find_col(&inserts[2], "data"); + let float_str = float_val.as_str().unwrap(); + assert!( + float_str.contains("3.14"), + "should contain pi prefix: {float_str}" + ); + + // Unicode + let unicode = find_col(&inserts[3], "data"); + let unicode_str = unicode.as_str().unwrap(); + assert!( + unicode_str.contains("café"), + "should contain unicode text: {unicode_str}" + ); + + println!("JSON special values test passed"); +} diff --git a/src/column_value.rs b/src/column_value.rs new file mode 100644 index 0000000..3131e3b --- /dev/null +++ b/src/column_value.rs @@ -0,0 +1,990 @@ +//! Column value types for PostgreSQL logical replication +//! +//! This module provides [`ColumnValue`] and [`RowData`] — the core data types +//! used to represent column-level data from PostgreSQL's logical replication +//! protocol. Both types use zero-copy [`bytes::Bytes`] internally and support +//! a compact binary wire format for efficient serialisation. + +use crate::buffer::BufferReader; +use crate::error::{ReplicationError, Result}; +use bytes::{Bytes, BytesMut}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +/// Encode a byte slice as lowercase hex string. +pub(crate) fn hex_encode(bytes: &[u8]) -> String { + const LUT: &[u8; 16] = b"0123456789abcdef"; + let mut out = String::with_capacity(bytes.len() * 2); + for &byte in bytes { + out.push(LUT[(byte >> 4) as usize] as char); + out.push(LUT[(byte & 0x0f) as usize] as char); + } + out +} + +/// Decode a hex string to bytes. Returns `Err` on invalid hex. +fn hex_decode(hex: &str) -> std::result::Result, &'static str> { + if !hex.len().is_multiple_of(2) { + return Err("odd hex length"); + } + let mut out = Vec::with_capacity(hex.len() / 2); + let bytes = hex.as_bytes(); + for chunk in bytes.chunks_exact(2) { + let high = hex_nibble(chunk[0]).ok_or("invalid hex char")?; + let low = hex_nibble(chunk[1]).ok_or("invalid hex char")?; + out.push((high << 4) | low); + } + Ok(out) +} + +#[inline] +fn hex_nibble(b: u8) -> Option { + match b { + b'0'..=b'9' => Some(b - b'0'), + b'a'..=b'f' => Some(b - b'a' + 10), + b'A'..=b'F' => Some(b - b'A' + 10), + _ => None, + } +} + +// --------------------------------------------------------------------------- +// ColumnValue +// --------------------------------------------------------------------------- + +/// PostgreSQL's logical replication protocol sends column data as either text (UTF-8 encoded) or binary format. This enum preserves the raw representation with zero-copy semantics using [`bytes::Bytes`], avoiding unnecessary parsing and allocation. +/// +/// # Wire Format (binary encode/decode) +/// +/// | Tag byte | Meaning | +/// |----------|-----------------------------------| +/// | `0x00` | `Null` | +/// | `0x01` | `Text` — followed by u32-len + data | +/// | `0x02` | `Binary` — followed by u32-len + data | +/// +/// # Serde +/// +/// When serialised with [`serde`], `Text` values emit a JSON string, +/// `Binary` values emit a hex-prefixed string (`"\\xdeadbeef"`), +/// and `Null` emits JSON `null`. +/// +/// # Example +/// +/// ``` +/// use pg_walstream::ColumnValue; +/// use bytes::Bytes; +/// +/// let v = ColumnValue::text("hello"); +/// assert_eq!(v.as_str(), Some("hello")); +/// assert!(!v.is_null()); +/// +/// let n = ColumnValue::Null; +/// assert!(n.is_null()); +/// ``` +#[derive(Debug, Clone)] +pub enum ColumnValue { + /// SQL NULL + Null, + /// Text value from PostgreSQL (UTF-8 encoded string), Uses [`Bytes`] for zero-copy from the protocol buffer. + Text(Bytes), + /// Binary data from PostgreSQL (bytea or binary-mode columns), Stored as raw bytes. + Binary(Bytes), +} + +impl ColumnValue { + /// Wire-format tag bytes + const TAG_NULL: u8 = 0x00; + const TAG_TEXT: u8 = 0x01; + const TAG_BINARY: u8 = 0x02; + + /// Create a `Text` value from a string slice (copies into `Bytes`). + #[inline] + pub fn text(s: &str) -> Self { + Self::Text(Bytes::copy_from_slice(s.as_bytes())) + } + + /// Create a `Text` value from existing `Bytes` (zero-copy). + #[inline] + pub fn text_bytes(b: Bytes) -> Self { + Self::Text(b) + } + + /// Create a `Binary` value from existing `Bytes` (zero-copy). + #[inline] + pub fn binary_bytes(b: Bytes) -> Self { + Self::Binary(b) + } + + /// Returns `true` if this is a `Null` value. + #[inline] + pub fn is_null(&self) -> bool { + matches!(self, Self::Null) + } + + /// Get the text content as `&str`. + /// + /// Returns `Some` for `Text` values that are valid UTF-8, `None` otherwise. + #[inline] + pub fn as_str(&self) -> Option<&str> { + match self { + Self::Text(b) => std::str::from_utf8(b).ok(), + _ => None, + } + } + + /// Get raw bytes regardless of variant. + /// + /// Returns an empty slice for `Null`. + #[inline] + pub fn as_bytes(&self) -> &[u8] { + match self { + Self::Text(b) | Self::Binary(b) => b, + Self::Null => &[], + } + } + + /// Encode this value into a byte buffer. + /// + /// Format: `[1-byte tag]` then for non-null `[4-byte big-endian length][data]`. + #[inline] + pub fn encode(&self, buf: &mut BytesMut) { + match self { + Self::Null => buf.extend_from_slice(&[Self::TAG_NULL]), + Self::Text(b) => { + buf.extend_from_slice(&[Self::TAG_TEXT]); + buf.extend_from_slice(&(b.len() as u32).to_be_bytes()); + buf.extend_from_slice(b); + } + Self::Binary(b) => { + buf.extend_from_slice(&[Self::TAG_BINARY]); + buf.extend_from_slice(&(b.len() as u32).to_be_bytes()); + buf.extend_from_slice(b); + } + } + } + + /// Decode a value from a [`BufferReader`]. + /// + /// Returns an error if the buffer is too short or contains an unknown tag. + #[inline] + pub fn decode(reader: &mut BufferReader) -> Result { + let tag = reader.read_u8()?; + match tag { + Self::TAG_NULL => Ok(Self::Null), + Self::TAG_TEXT => { + let len = reader.read_u32()? as usize; + let data = reader.read_bytes_buf(len)?; + Ok(Self::Text(data)) + } + Self::TAG_BINARY => { + let len = reader.read_u32()? as usize; + let data = reader.read_bytes_buf(len)?; + Ok(Self::Binary(data)) + } + _ => Err(ReplicationError::protocol(format!( + "Unknown ColumnValue tag: 0x{tag:02x}" + ))), + } + } +} + +impl std::fmt::Display for ColumnValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Null => write!(f, "NULL"), + Self::Text(b) => match std::str::from_utf8(b) { + Ok(s) => write!(f, "{s}"), + Err(_) => write!(f, "", b.len()), + }, + Self::Binary(b) => { + write!(f, "\\x")?; + for byte in b.iter() { + write!(f, "{byte:02x}")?; + } + Ok(()) + } + } + } +} + +impl PartialEq for ColumnValue { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Null, Self::Null) => true, + (Self::Text(a), Self::Text(b)) => a == b, + (Self::Binary(a), Self::Binary(b)) => a == b, + _ => false, + } + } +} + +impl Eq for ColumnValue {} + +/// Allows `column_value == "some_str"` for text comparisons. +impl PartialEq for ColumnValue { + fn eq(&self, other: &str) -> bool { + match self { + Self::Text(b) => b.as_ref() == other.as_bytes(), + _ => false, + } + } +} + +/// Allows `column_value == "some_str"` via `&&str`. +impl PartialEq<&str> for ColumnValue { + fn eq(&self, other: &&str) -> bool { + self == *other + } +} + +impl Serialize for ColumnValue { + fn serialize( + &self, + serializer: S, + ) -> std::result::Result { + match self { + Self::Null => serializer.serialize_none(), + Self::Text(b) => match std::str::from_utf8(b) { + Ok(s) => serializer.serialize_str(s), + Err(_) => { + // Fall back to hex for non-UTF-8 text + let hex = hex_encode(b); + serializer.serialize_str(&format!("\\x{hex}")) + } + }, + Self::Binary(b) => { + let hex = hex_encode(b); + serializer.serialize_str(&format!("\\x{hex}")) + } + } + } +} + +impl<'de> Deserialize<'de> for ColumnValue { + fn deserialize>( + deserializer: D, + ) -> std::result::Result { + struct Visitor; + + impl<'de> serde::de::Visitor<'de> for Visitor { + type Value = ColumnValue; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str("a string, null, or hex-encoded binary") + } + + fn visit_none(self) -> std::result::Result { + Ok(ColumnValue::Null) + } + + fn visit_unit(self) -> std::result::Result { + Ok(ColumnValue::Null) + } + + fn visit_some>( + self, + deserializer: D, + ) -> std::result::Result { + // Recurse into the inner value + deserializer.deserialize_any(self) + } + + fn visit_str( + self, + v: &str, + ) -> std::result::Result { + if let Some(hex) = v.strip_prefix("\\x") { + // Decode hex to binary + match hex_decode(hex) { + Ok(bytes) => Ok(ColumnValue::Binary(Bytes::from(bytes))), + Err(e) => Err(E::custom(format!("invalid hex string: {e}"))), + } + } else { + Ok(ColumnValue::Text(Bytes::copy_from_slice(v.as_bytes()))) + } + } + } + + deserializer.deserialize_option(Visitor) + } +} + +/// Ordered row data: a list of `(column_name, value)` pairs. +/// +/// Column names are `Arc` — zero-cost clones from relation metadata. +/// Values are [`ColumnValue`] — a lightweight enum holding raw `Bytes` +/// from the PostgreSQL wire protocol without intermediate parsing. +/// +/// Serialises as a JSON object `{"col": value, …}` for wire-format compatibility. +/// +/// # Example +/// +/// ``` +/// use pg_walstream::{RowData, ColumnValue}; +/// use std::sync::Arc; +/// +/// let mut row = RowData::with_capacity(2); +/// row.push(Arc::from("id"), ColumnValue::text("1")); +/// row.push(Arc::from("name"), ColumnValue::text("Alice")); +/// +/// assert_eq!(row.len(), 2); +/// assert_eq!(row.get("id").and_then(|v| v.as_str()), Some("1")); +/// ``` +#[derive(Debug, Clone, Eq)] +pub struct RowData { + columns: Vec<(Arc, ColumnValue)>, +} + +impl RowData { + /// Create an empty `RowData`. + #[inline] + pub fn new() -> Self { + Self { + columns: Vec::new(), + } + } + + /// Create an empty `RowData` with pre-allocated capacity. + #[inline] + pub fn with_capacity(cap: usize) -> Self { + Self { + columns: Vec::with_capacity(cap), + } + } + + /// Append a column. + #[inline] + pub fn push(&mut self, name: Arc, value: ColumnValue) { + self.columns.push((name, value)); + } + + /// Number of columns. + #[inline] + pub fn len(&self) -> usize { + self.columns.len() + } + + /// Returns `true` when there are no columns. + #[inline] + pub fn is_empty(&self) -> bool { + self.columns.is_empty() + } + + /// Look up a value by column name (linear scan — fast for typical column counts). + #[inline] + pub fn get(&self, name: &str) -> Option<&ColumnValue> { + self.columns + .iter() + .find(|(k, _)| k.as_ref() == name) + .map(|(_, v)| v) + } + + /// Iterate over `(name, value)` pairs. + #[inline] + pub fn iter(&self) -> impl Iterator, &ColumnValue)> { + self.columns.iter().map(|(k, v)| (k, v)) + } + + /// Construct from `(&str, ColumnValue)` pairs — handy for tests and literals. + #[inline] + pub fn from_pairs(pairs: Vec<(&str, ColumnValue)>) -> Self { + Self { + columns: pairs.into_iter().map(|(k, v)| (Arc::from(k), v)).collect(), + } + } + + // ---- binary wire format ---- + + /// Encode this `RowData` into a byte buffer. + /// + /// Format: `[2-byte column count]` then for each column: `[2-byte name length][name bytes][ColumnValue encoding]`. + pub fn encode(&self, buf: &mut BytesMut) { + buf.extend_from_slice(&(self.columns.len() as u16).to_be_bytes()); + for (name, value) in &self.columns { + let name_bytes = name.as_bytes(); + buf.extend_from_slice(&(name_bytes.len() as u16).to_be_bytes()); + buf.extend_from_slice(name_bytes); + value.encode(buf); + } + } + + /// Decode a `RowData` from a [`BufferReader`]. + pub fn decode(reader: &mut BufferReader) -> Result { + let count = reader.read_u16()? as usize; + let mut columns = Vec::with_capacity(count); + for _ in 0..count { + let name_len = reader.read_u16()? as usize; + let name_bytes = reader.read_bytes(name_len)?; + let name = std::str::from_utf8(&name_bytes) + .map_err(|e| ReplicationError::protocol(format!("Invalid column name: {e}")))?; + let name = Arc::from(name); + let value = ColumnValue::decode(reader)?; + columns.push((name, value)); + } + Ok(Self { columns }) + } +} + +impl Default for RowData { + fn default() -> Self { + Self::new() + } +} + +// Order-sensitive equality (columns must match in the same order). +impl PartialEq for RowData { + fn eq(&self, other: &Self) -> bool { + self.columns == other.columns + } +} + +impl Serialize for RowData { + fn serialize( + &self, + serializer: S, + ) -> std::result::Result { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(self.columns.len()))?; + for (k, v) in &self.columns { + map.serialize_entry(k.as_ref(), v)?; + } + map.end() + } +} + +impl<'de> Deserialize<'de> for RowData { + fn deserialize>( + deserializer: D, + ) -> std::result::Result { + struct Visitor; + + impl<'de> serde::de::Visitor<'de> for Visitor { + type Value = RowData; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str("a map of column names to values") + } + + fn visit_map>( + self, + mut map: M, + ) -> std::result::Result { + let mut cols = Vec::with_capacity(map.size_hint().unwrap_or(0)); + while let Some((k, v)) = map.next_entry::()? { + cols.push((Arc::from(k), v)); + } + Ok(RowData { columns: cols }) + } + } + + deserializer.deserialize_map(Visitor) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_rowdata_default() { + let row = RowData::default(); + assert!(row.is_empty()); + assert_eq!(row.len(), 0); + } + + #[test] + fn test_rowdata_deserialize_invalid_type() { + // Feeding a non-object type triggers the `expecting()` method. + let err = serde_json::from_str::("42").unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("a map"), + "Error should reference expecting(), got: {msg}" + ); + } + + #[test] + fn test_rowdata_deserialize_string_gives_error() { + let err = serde_json::from_str::("\"hello\"").unwrap_err(); + assert!(err.to_string().contains("a map")); + } + + #[test] + fn test_rowdata_deserialize_array_gives_error() { + let err = serde_json::from_str::("[1, 2, 3]").unwrap_err(); + assert!(err.to_string().contains("a map")); + } + + #[test] + fn test_column_value_text() { + let v = ColumnValue::text("hello"); + assert_eq!(v.as_str(), Some("hello")); + assert!(!v.is_null()); + assert_eq!(v.as_bytes(), b"hello"); + } + + #[test] + fn test_column_value_null() { + let v = ColumnValue::Null; + assert!(v.is_null()); + assert_eq!(v.as_str(), None); + assert_eq!(v.as_bytes(), &[] as &[u8]); + } + + #[test] + fn test_column_value_binary() { + let v = ColumnValue::binary_bytes(Bytes::from_static(&[0xde, 0xad])); + assert!(!v.is_null()); + assert_eq!(v.as_str(), None); + assert_eq!(v.as_bytes(), &[0xde, 0xad]); + } + + #[test] + fn test_column_value_display() { + assert_eq!(format!("{}", ColumnValue::Null), "NULL"); + assert_eq!(format!("{}", ColumnValue::text("hi")), "hi"); + assert_eq!( + format!( + "{}", + ColumnValue::binary_bytes(Bytes::from_static(&[0xca, 0xfe])) + ), + "\\xcafe" + ); + } + + #[test] + fn test_column_value_equality() { + assert_eq!(ColumnValue::Null, ColumnValue::Null); + assert_eq!(ColumnValue::text("a"), ColumnValue::text("a")); + assert_ne!(ColumnValue::text("a"), ColumnValue::text("b")); + assert_ne!(ColumnValue::text("a"), ColumnValue::Null); + // Cross-variant never equal + assert_ne!( + ColumnValue::text("a"), + ColumnValue::binary_bytes(Bytes::copy_from_slice(b"a")) + ); + } + + #[test] + fn test_column_value_partial_eq_str() { + let v = ColumnValue::text("hello"); + assert!(v == *"hello"); + assert!(v != *"world"); + assert!(ColumnValue::Null != *"hello"); + } + + #[test] + fn test_column_value_serde_round_trip() { + // Text + let v = ColumnValue::text("hello"); + let json = serde_json::to_string(&v).unwrap(); + let back: ColumnValue = serde_json::from_str(&json).unwrap(); + assert_eq!(v, back); + + // Null + let v = ColumnValue::Null; + let json = serde_json::to_string(&v).unwrap(); + let back: ColumnValue = serde_json::from_str(&json).unwrap(); + assert_eq!(v, back); + + // Binary + let v = ColumnValue::binary_bytes(Bytes::from_static(&[0xde, 0xad])); + let json = serde_json::to_string(&v).unwrap(); + let back: ColumnValue = serde_json::from_str(&json).unwrap(); + assert_eq!(v, back); + } + + #[test] + fn test_column_value_encode_decode_round_trip() { + use crate::buffer::BufferReader; + + let values = vec![ + ColumnValue::Null, + ColumnValue::text("hello world"), + ColumnValue::binary_bytes(Bytes::from_static(&[0x01, 0x02, 0x03])), + ]; + + let mut buf = BytesMut::new(); + for v in &values { + v.encode(&mut buf); + } + + let frozen = buf.freeze(); + let mut reader = BufferReader::new(&frozen); + + for expected in &values { + let decoded = ColumnValue::decode(&mut reader).unwrap(); + assert_eq!(&decoded, expected); + } + } + + #[test] + fn test_rowdata_encode_decode_round_trip() { + use crate::buffer::BufferReader; + + let row = RowData::from_pairs(vec![ + ("id", ColumnValue::text("42")), + ("name", ColumnValue::text("Alice")), + ( + "data", + ColumnValue::binary_bytes(Bytes::from_static(&[0xff])), + ), + ("empty", ColumnValue::Null), + ]); + + let mut buf = BytesMut::new(); + row.encode(&mut buf); + + let frozen = buf.freeze(); + let mut reader = BufferReader::new(&frozen); + let decoded = RowData::decode(&mut reader).unwrap(); + + assert_eq!(row, decoded); + } + + #[test] + fn test_rowdata_operations() { + let mut row = RowData::with_capacity(3); + assert!(row.is_empty()); + assert_eq!(row.len(), 0); + + row.push(Arc::from("id"), ColumnValue::text("1")); + row.push(Arc::from("name"), ColumnValue::text("Alice")); + + assert!(!row.is_empty()); + assert_eq!(row.len(), 2); + assert_eq!(row.get("id").unwrap(), "1"); + assert_eq!(row.get("name").unwrap(), "Alice"); + assert!(row.get("missing").is_none()); + + let pairs: Vec<_> = row.iter().collect(); + assert_eq!(pairs.len(), 2); + } + + #[test] + fn test_rowdata_serde_round_trip() { + let row = RowData::from_pairs(vec![ + ("id", ColumnValue::text("1")), + ("name", ColumnValue::text("Alice")), + ]); + let json = serde_json::to_string(&row).unwrap(); + let back: RowData = serde_json::from_str(&json).unwrap(); + assert_eq!(row.len(), back.len()); + + // Values should match (order may differ in JSON map round-trip) + assert_eq!(back.get("id").and_then(|v| v.as_str()), Some("1")); + assert_eq!(back.get("name").and_then(|v| v.as_str()), Some("Alice")); + } + + #[test] + fn test_hex_encode() { + assert_eq!(hex_encode(&[0x00, 0x01, 0x02]), "000102"); + assert_eq!(hex_encode(&[0xff, 0xfe, 0xfd]), "fffefd"); + assert_eq!(hex_encode(&[]), ""); + assert_eq!(hex_encode(&[0xde, 0xad, 0xbe, 0xef]), "deadbeef"); + } + + #[test] + fn test_hex_decode() { + assert_eq!(hex_decode("000102").unwrap(), vec![0x00, 0x01, 0x02]); + assert_eq!( + hex_decode("deadbeef").unwrap(), + vec![0xde, 0xad, 0xbe, 0xef] + ); + assert_eq!(hex_decode("").unwrap(), Vec::::new()); + assert!(hex_decode("0").is_err()); // odd length + assert!(hex_decode("zz").is_err()); // invalid chars + } + + // --- Additional coverage tests --- + + #[test] + fn test_hex_nibble_uppercase() { + // Exercise the b'A'..=b'F' branch in hex_nibble + assert_eq!( + hex_decode("DEADBEEF").unwrap(), + vec![0xde, 0xad, 0xbe, 0xef] + ); + assert_eq!(hex_decode("FF00").unwrap(), vec![0xff, 0x00]); + // Mixed case + assert_eq!( + hex_decode("aAbBcCdDeEfF").unwrap(), + vec![0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff] + ); + } + + #[test] + fn test_hex_decode_invalid_second_char() { + // The second nibble of a pair is invalid — exercises the low nibble error path + assert!(hex_decode("0z").is_err()); + assert!(hex_decode("a!").is_err()); + } + + #[test] + fn test_column_value_decode_unknown_tag() { + use crate::buffer::BufferReader; + + let data = [0xFF]; // Unknown tag byte + let mut reader = BufferReader::new(&data); + let result = ColumnValue::decode(&mut reader); + assert!(result.is_err()); + let err_msg = format!("{}", result.unwrap_err()); + assert!( + err_msg.contains("Unknown ColumnValue tag"), + "got: {err_msg}" + ); + } + + #[test] + fn test_column_value_display_invalid_utf8() { + // Text variant with invalid UTF-8 bytes exercises the Err(_) Display branch + let v = ColumnValue::Text(Bytes::from_static(&[0xff, 0xfe, 0xfd])); + let displayed = format!("{v}"); + assert!(displayed.contains("invalid utf-8"), "got: {displayed}"); + assert!(displayed.contains("3 bytes"), "got: {displayed}"); + } + + #[test] + fn test_column_value_partial_eq_ref_str() { + // Exercises the PartialEq<&str> impl (via &&str) + let v = ColumnValue::text("hello"); + assert!(v == "hello"); + assert!(v != "world"); + + let null = ColumnValue::Null; + assert!(null != "hello"); + + let binary = ColumnValue::binary_bytes(Bytes::from_static(b"hello")); + assert!(binary != "hello"); + } + + #[test] + fn test_column_value_serialize_non_utf8_text() { + // Text with invalid UTF-8 should fall back to hex encoding + let v = ColumnValue::Text(Bytes::from_static(&[0xff, 0xfe])); + let json = serde_json::to_string(&v).unwrap(); + assert_eq!(json, r#""\\xfffe""#); + + // Round-trip: it deserializes back as Binary (due to \x prefix) + let back: ColumnValue = serde_json::from_str(&json).unwrap(); + assert_eq!(back.as_bytes(), &[0xff, 0xfe]); + } + + #[test] + fn test_column_value_deserialize_invalid_hex() { + // \x prefix followed by invalid hex chars triggers the Err path in visit_str + let json = r#""\\xZZZZ""#; + let result = serde_json::from_str::(json); + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!(err_msg.contains("invalid hex string"), "got: {err_msg}"); + } + + #[test] + fn test_column_value_deserialize_expecting() { + // Feeding an unexpected type (integer) should trigger the expecting() method + let result = serde_json::from_str::("42"); + assert!(result.is_err()); + } + + #[test] + fn test_column_value_as_str_binary() { + // Binary variant returns None from as_str() + let v = ColumnValue::binary_bytes(Bytes::from_static(b"data")); + assert_eq!(v.as_str(), None); + } + + #[test] + fn test_column_value_as_str_invalid_utf8() { + // Text with invalid UTF-8 returns None from as_str() + let v = ColumnValue::Text(Bytes::from_static(&[0xff, 0xfe])); + assert_eq!(v.as_str(), None); + } + + #[test] + fn test_column_value_is_null_all_variants() { + assert!(ColumnValue::Null.is_null()); + assert!(!ColumnValue::text("x").is_null()); + assert!(!ColumnValue::binary_bytes(Bytes::from_static(b"x")).is_null()); + } + + #[test] + fn test_column_value_encode_decode_empty_text() { + use crate::buffer::BufferReader; + + let v = ColumnValue::text(""); + let mut buf = BytesMut::new(); + v.encode(&mut buf); + + let frozen = buf.freeze(); + let mut reader = BufferReader::new(&frozen); + let decoded = ColumnValue::decode(&mut reader).unwrap(); + assert_eq!(decoded, v); + assert_eq!(decoded.as_str(), Some("")); + } + + #[test] + fn test_column_value_encode_decode_empty_binary() { + use crate::buffer::BufferReader; + + let v = ColumnValue::binary_bytes(Bytes::new()); + let mut buf = BytesMut::new(); + v.encode(&mut buf); + + let frozen = buf.freeze(); + let mut reader = BufferReader::new(&frozen); + let decoded = ColumnValue::decode(&mut reader).unwrap(); + assert_eq!(decoded, v); + assert_eq!(decoded.as_bytes(), &[] as &[u8]); + } + + #[test] + fn test_column_value_clone() { + let original = ColumnValue::text("cloned"); + let cloned = original.clone(); + assert_eq!(original, cloned); + + let original = ColumnValue::binary_bytes(Bytes::from_static(&[1, 2, 3])); + let cloned = original.clone(); + assert_eq!(original, cloned); + + let original = ColumnValue::Null; + let cloned = original.clone(); + assert_eq!(original, cloned); + } + + #[test] + fn test_column_value_debug() { + // Exercises the derive(Debug) impl + let v = ColumnValue::text("debug_test"); + let debug = format!("{v:?}"); + assert!(debug.contains("Text"), "got: {debug}"); + + let v = ColumnValue::Null; + let debug = format!("{v:?}"); + assert!(debug.contains("Null"), "got: {debug}"); + + let v = ColumnValue::binary_bytes(Bytes::from_static(&[0xab])); + let debug = format!("{v:?}"); + assert!(debug.contains("Binary"), "got: {debug}"); + } + + #[test] + fn test_rowdata_from_pairs_empty() { + let row = RowData::from_pairs(vec![]); + assert!(row.is_empty()); + assert_eq!(row.len(), 0); + assert!(row.get("anything").is_none()); + } + + #[test] + fn test_rowdata_iter_with_values() { + let row = RowData::from_pairs(vec![ + ("a", ColumnValue::text("1")), + ("b", ColumnValue::Null), + ("c", ColumnValue::binary_bytes(Bytes::from_static(&[0xff]))), + ]); + let items: Vec<_> = row.iter().collect(); + assert_eq!(items.len(), 3); + assert_eq!(items[0].0.as_ref(), "a"); + assert_eq!(items[1].0.as_ref(), "b"); + assert!(items[1].1.is_null()); + assert_eq!(items[2].0.as_ref(), "c"); + } + + #[test] + fn test_rowdata_equality_order_sensitive() { + let row1 = RowData::from_pairs(vec![ + ("a", ColumnValue::text("1")), + ("b", ColumnValue::text("2")), + ]); + let row2 = RowData::from_pairs(vec![ + ("b", ColumnValue::text("2")), + ("a", ColumnValue::text("1")), + ]); + // Different order → not equal + assert_ne!(row1, row2); + + // Same order → equal + let row3 = RowData::from_pairs(vec![ + ("a", ColumnValue::text("1")), + ("b", ColumnValue::text("2")), + ]); + assert_eq!(row1, row3); + } + + #[test] + fn test_rowdata_encode_decode_empty() { + use crate::buffer::BufferReader; + + let row = RowData::new(); + let mut buf = BytesMut::new(); + row.encode(&mut buf); + + let frozen = buf.freeze(); + let mut reader = BufferReader::new(&frozen); + let decoded = RowData::decode(&mut reader).unwrap(); + assert_eq!(decoded, row); + assert!(decoded.is_empty()); + } + + #[test] + fn test_rowdata_encode_decode_with_null_values() { + use crate::buffer::BufferReader; + + let row = RowData::from_pairs(vec![ + ("id", ColumnValue::text("1")), + ("description", ColumnValue::Null), + ( + "data", + ColumnValue::binary_bytes(Bytes::from_static(&[0x01, 0x02])), + ), + ("empty_text", ColumnValue::text("")), + ]); + + let mut buf = BytesMut::new(); + row.encode(&mut buf); + + let frozen = buf.freeze(); + let mut reader = BufferReader::new(&frozen); + let decoded = RowData::decode(&mut reader).unwrap(); + assert_eq!(decoded, row); + } + + #[test] + fn test_rowdata_serde_with_null_and_binary() { + let row = RowData::from_pairs(vec![ + ("name", ColumnValue::text("Alice")), + ("middle", ColumnValue::Null), + ( + "blob", + ColumnValue::binary_bytes(Bytes::from_static(&[0xca, 0xfe])), + ), + ]); + let json = serde_json::to_string(&row).unwrap(); + let back: RowData = serde_json::from_str(&json).unwrap(); + assert_eq!(back.len(), row.len()); + assert_eq!(back.get("name").and_then(|v| v.as_str()), Some("Alice")); + assert!(back.get("middle").map(|v| v.is_null()).unwrap_or(false)); + assert_eq!( + back.get("blob").map(|v| v.as_bytes()), + Some(&[0xca, 0xfe][..]) + ); + } + + #[test] + fn test_rowdata_debug() { + let row = RowData::from_pairs(vec![("x", ColumnValue::text("y"))]); + let debug = format!("{row:?}"); + assert!(debug.contains("RowData"), "got: {debug}"); + } + + #[test] + fn test_rowdata_clone() { + let row = RowData::from_pairs(vec![ + ("id", ColumnValue::text("1")), + ("val", ColumnValue::Null), + ]); + let cloned = row.clone(); + assert_eq!(row, cloned); + } +} diff --git a/src/lib.rs b/src/lib.rs index bbe7adf..e92a11e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -95,6 +95,7 @@ // Core modules pub mod buffer; +pub mod column_value; pub mod error; pub mod types; @@ -113,11 +114,13 @@ pub use buffer::{BufferReader, BufferWriter}; pub use error::{ReplicationError, Result}; pub use lsn::SharedLsnFeedback; +// Re-export column value types +pub use column_value::{ColumnValue, RowData}; + // Re-export type aliases and utilities pub use types::{ // Utility functions format_lsn, - format_postgres_timestamp, parse_lsn, postgres_timestamp_to_chrono, system_time_to_postgres_timestamp, @@ -129,7 +132,6 @@ pub use types::{ Oid, ReplicaIdentity, ReplicationSlotOptions, - RowData, SlotType, TimestampTz, // Type aliases matching PostgreSQL types diff --git a/src/protocol.rs b/src/protocol.rs index 95925d9..acbc904 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -6,9 +6,10 @@ //! - use crate::buffer::{BufferReader, BufferWriter}; +use crate::column_value::{ColumnValue, RowData}; use crate::error::{ReplicationError, Result}; use crate::types::{ - format_lsn, system_time_to_postgres_timestamp, Oid, RowData, TimestampTz, XLogRecPtr, Xid, + format_lsn, system_time_to_postgres_timestamp, Oid, TimestampTz, XLogRecPtr, Xid, }; use bytes::Bytes; use serde::{Deserialize, Serialize}; @@ -278,23 +279,22 @@ impl TupleData { self.columns.len() } - /// Convert to a RowData with column names from the relation + /// Convert to a [`RowData`] with column names from the relation. + /// + /// Text columns are stored as [`ColumnValue::Text`] with zero-copy `Bytes`, + /// binary columns as [`ColumnValue::Binary`], and null / unknown as + /// [`ColumnValue::Null`]. Unchanged TOAST columns are skipped. pub fn to_row_data(&self, relation: &RelationInfo) -> RowData { let mut data = RowData::with_capacity(self.columns.len()); for (i, col_data) in self.columns.iter().enumerate() { if let Some(column_info) = relation.get_column_by_index(i) { let value = match col_data.data_type { - 'n' => serde_json::Value::Null, - 't' | 'b' => { - // Use as_str() which returns Cow for zero-copy when possible - match col_data.as_str() { - Some(s) => serde_json::Value::String(s.into_owned()), - None => serde_json::Value::Null, - } - } + 'n' => ColumnValue::Null, + 't' => ColumnValue::text_bytes(col_data.raw_bytes()), + 'b' => ColumnValue::binary_bytes(col_data.raw_bytes()), 'u' => continue, // Skip unchanged TOAST values - _ => serde_json::Value::Null, + _ => ColumnValue::Null, }; data.push(Arc::clone(&column_info.name), value); } @@ -302,15 +302,6 @@ impl TupleData { data } - - /// Convert to a HashMap with column names as keys (legacy convenience method) - #[deprecated( - since = "0.4.0", - note = "use `to_row_data` instead for better performance" - )] - pub fn to_hash_map(&self, relation: &RelationInfo) -> HashMap { - self.to_row_data(relation).into_hash_map() - } } /// Data for a single column @@ -418,6 +409,12 @@ impl ColumnData { &self.data } + /// Get the underlying `Bytes` handle (cheap ref-counted clone). + #[inline(always)] + pub fn raw_bytes(&self) -> bytes::Bytes { + self.data.clone() + } + #[inline(always)] pub fn into_bytes(self) -> bytes::Bytes { self.data @@ -1368,6 +1365,7 @@ pub fn build_hot_standby_feedback_message( #[cfg(test)] mod tests { use super::*; + use crate::column_value::ColumnValue; #[test] fn test_column_data_creation() { @@ -1808,10 +1806,10 @@ mod tests { ColumnData::text(b"Alice".to_vec()), ]); - let map = tuple.to_hash_map(&relation); - assert_eq!(map.len(), 2); - assert_eq!(map.get("id").unwrap(), "42"); - assert_eq!(map.get("name").unwrap(), "Alice"); + let row = tuple.to_row_data(&relation); + assert_eq!(row.len(), 2); + assert_eq!(row.get("id").unwrap(), "42"); + assert_eq!(row.get("name").unwrap(), "Alice"); } #[test] @@ -2031,9 +2029,9 @@ mod tests { let tuple = TupleData::new(vec![ColumnData::text(b"42".to_vec()), ColumnData::null()]); - let map = tuple.to_hash_map(&relation); - assert_eq!(map.get("id").unwrap(), "42"); - assert_eq!(map.get("name").unwrap(), &serde_json::Value::Null); + let row = tuple.to_row_data(&relation); + assert_eq!(row.get("id").unwrap(), "42"); + assert_eq!(row.get("name").unwrap(), &ColumnValue::Null); } #[test] @@ -2049,10 +2047,10 @@ mod tests { ColumnData::unchanged(), // unchanged TOAST should be skipped ]); - let map = tuple.to_hash_map(&relation); - assert_eq!(map.len(), 1); - assert_eq!(map.get("id").unwrap(), "42"); - assert!(!map.contains_key("name")); + let row = tuple.to_row_data(&relation); + assert_eq!(row.len(), 1); + assert_eq!(row.get("id").unwrap(), "42"); + assert!(row.get("name").is_none()); } #[test] @@ -2064,15 +2062,14 @@ mod tests { let tuple = TupleData::new(vec![ColumnData::binary(b"binary data".to_vec())]); - let map = tuple.to_hash_map(&relation); - // Binary valid UTF-8 gets converted via as_str() - let val = map.get("data").unwrap(); - assert!(val.is_string()); + let row = tuple.to_row_data(&relation); + let val = row.get("data").unwrap(); + assert!(matches!(val, ColumnValue::Binary(_))); } #[test] fn test_tuple_data_to_hash_map_text_empty_data() { - // Text column with empty data - as_str() returns None for empty data + // Text column with empty data let columns = vec![ColumnInfo::new(0, "col".to_string(), 25, -1)]; let relation = RelationInfo::new(1, "public".to_string(), "t".to_string(), b'd', columns); @@ -2081,9 +2078,10 @@ mod tests { data: bytes::Bytes::new(), }; let tuple = TupleData::new(vec![col]); - let map = tuple.to_hash_map(&relation); - // Empty text data: as_str() returns None, so Null - assert!(map.get("col").unwrap().is_null()); + let row = tuple.to_row_data(&relation); + // Empty text data yields Text(empty Bytes) + let val = row.get("col").unwrap(); + assert!(matches!(val, ColumnValue::Text(_))); } #[test] @@ -2097,8 +2095,8 @@ mod tests { data: bytes::Bytes::from_static(&[1, 2, 3]), }; let tuple = TupleData::new(vec![col]); - let map = tuple.to_hash_map(&relation); - assert!(map.get("col").unwrap().is_null()); + let row = tuple.to_row_data(&relation); + assert!(row.get("col").unwrap().is_null()); } #[test] @@ -2112,9 +2110,9 @@ mod tests { ColumnData::text(b"val1".to_vec()), ColumnData::text(b"val2".to_vec()), ]); - let map = tuple.to_hash_map(&relation); - assert_eq!(map.len(), 1); // Only the first column maps - assert_eq!(map.get("col1").unwrap(), "val1"); + let row = tuple.to_row_data(&relation); + assert_eq!(row.len(), 1); // Only the first column maps + assert_eq!(row.get("col1").unwrap(), "val1"); } #[test] diff --git a/src/stream.rs b/src/stream.rs index 8e6dfa5..e879cb7 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -13,10 +13,11 @@ //! you can easily wrap it using `futures::stream::unfold`. See the EventStream //! documentation for an example. +use crate::column_value::{ColumnValue, RowData}; use crate::error::{ReplicationError, Result}; use crate::lsn::SharedLsnFeedback; use crate::types::{ - ChangeEvent, EventType, Lsn, ReplicaIdentity, ReplicationSlotOptions, RowData, SlotType, + ChangeEvent, EventType, Lsn, ReplicaIdentity, ReplicationSlotOptions, SlotType, }; use crate::{ format_lsn, parse_keepalive_message, postgres_timestamp_to_chrono, BufferReader, @@ -1954,7 +1955,11 @@ async fn timeout_or_error( } } -/// Convert tuple data to a RowData for ChangeEvent +/// Convert tuple data to a [`RowData`] for [`ChangeEvent`]. +/// +/// Text columns become [`ColumnValue::Text`] (zero-copy), binary columns +/// become [`ColumnValue::Binary`], and null/unknown become [`ColumnValue::Null`]. +/// Unchanged TOAST columns are skipped. #[inline] fn tuple_to_data(tuple: &TupleData, relation: &RelationInfo) -> Result { let mut data = RowData::with_capacity(tuple.columns.len()); @@ -1965,16 +1970,13 @@ fn tuple_to_data(tuple: &TupleData, relation: &RelationInfo) -> Result } if let Some(column_info) = relation.get_column_by_index(i) { let value = if column_data.is_null() { - serde_json::Value::Null + ColumnValue::Null } else if column_data.is_text() { - let text = column_data.as_str().unwrap_or_default(); - serde_json::Value::String(text.into_owned()) + ColumnValue::text_bytes(column_data.raw_bytes()) } else if column_data.is_binary() { - // For binary data, convert to hex string - let hex_string = hex_encode(column_data.as_bytes()); - serde_json::Value::String(format!("\\x{hex_string}")) + ColumnValue::binary_bytes(column_data.raw_bytes()) } else { - serde_json::Value::Null + ColumnValue::Null }; data.push(Arc::clone(&column_info.name), value); @@ -1984,23 +1986,11 @@ fn tuple_to_data(tuple: &TupleData, relation: &RelationInfo) -> Result Ok(data) } -// Simple hex encoding implementation to avoid adding another dependency -fn hex_encode(bytes: &[u8]) -> String { - const LUT: &[u8; 16] = b"0123456789abcdef"; - let mut out = String::with_capacity(bytes.len() * 2); - - for &byte in bytes { - out.push(LUT[(byte >> 4) as usize] as char); - out.push(LUT[(byte & 0x0f) as usize] as char); - } - - out -} - #[cfg(test)] mod tests { use super::*; - use crate::types::{parse_lsn, ReplicaIdentity, RowData}; + use crate::column_value::{ColumnValue, RowData}; + use crate::types::{parse_lsn, ReplicaIdentity}; /// Helper function to create a test configuration fn create_test_config() -> ReplicationStreamConfig { @@ -2221,8 +2211,8 @@ mod tests { #[test] fn test_change_event_insert_creation() { let data = RowData::from_pairs(vec![ - ("id", serde_json::json!(1)), - ("name", serde_json::json!("Alice")), + ("id", ColumnValue::text("1")), + ("name", ColumnValue::text("Alice")), ]); let event = ChangeEvent::insert("public", "users", 16384, data.clone(), Lsn::new(1000)); @@ -2248,13 +2238,13 @@ mod tests { #[test] fn test_change_event_update_creation() { let old_data = RowData::from_pairs(vec![ - ("id", serde_json::json!(1)), - ("name", serde_json::json!("Alice")), + ("id", ColumnValue::text("1")), + ("name", ColumnValue::text("Alice")), ]); let new_data = RowData::from_pairs(vec![ - ("id", serde_json::json!(1)), - ("name", serde_json::json!("Bob")), + ("id", ColumnValue::text("1")), + ("name", ColumnValue::text("Bob")), ]); let event = ChangeEvent::update( @@ -2292,7 +2282,7 @@ mod tests { #[test] fn test_change_event_delete_creation() { - let old_data = RowData::from_pairs(vec![("id", serde_json::json!(1))]); + let old_data = RowData::from_pairs(vec![("id", ColumnValue::text("1"))]); let event = ChangeEvent::delete( "public", @@ -2371,14 +2361,6 @@ mod tests { assert_eq!(data.get("col1").unwrap(), ""); } - #[test] - fn test_hex_encoding() { - assert_eq!(hex_encode(&[0x00, 0x01, 0x02]), "000102"); - assert_eq!(hex_encode(&[0xff, 0xfe, 0xfd]), "fffefd"); - assert_eq!(hex_encode(&[]), ""); - assert_eq!(hex_encode(&[0x12, 0x34, 0x56, 0x78]), "12345678"); - } - #[test] fn test_cancellation_token_basic() { use tokio_util::sync::CancellationToken; @@ -2797,8 +2779,8 @@ mod tests { #[test] fn test_change_event_with_null_values() { let data = RowData::from_pairs(vec![ - ("id", serde_json::json!(1)), - ("nullable_field", serde_json::json!(null)), + ("id", ColumnValue::text("1")), + ("nullable_field", ColumnValue::Null), ]); let event = ChangeEvent::insert( @@ -2811,11 +2793,8 @@ mod tests { match event.event_type { EventType::Insert { data, .. } => { - assert_eq!(data.get("id").unwrap(), &serde_json::json!(1)); - assert_eq!( - data.get("nullable_field").unwrap(), - &serde_json::json!(null) - ); + assert_eq!(data.get("id").unwrap(), &ColumnValue::text("1")); + assert_eq!(data.get("nullable_field").unwrap(), &ColumnValue::Null); } _ => panic!("Expected Insert event"), } @@ -2997,26 +2976,6 @@ mod tests { assert!(d4 <= Duration::from_millis(500)); } - #[test] - fn test_hex_encode_various_inputs() { - // Empty - assert_eq!(hex_encode(&[]), ""); - - // Single byte - assert_eq!(hex_encode(&[0x00]), "00"); - assert_eq!(hex_encode(&[0xff]), "ff"); - - // Multiple bytes - assert_eq!(hex_encode(&[0x12, 0x34]), "1234"); - assert_eq!(hex_encode(&[0xde, 0xad, 0xbe, 0xef]), "deadbeef"); - - // All zeros - assert_eq!(hex_encode(&[0x00, 0x00, 0x00]), "000000"); - - // All ones - assert_eq!(hex_encode(&[0xff, 0xff, 0xff]), "ffffff"); - } - #[tokio::test] async fn test_cancellation_token_async_cancel() { let token = CancellationToken::new(); @@ -3108,9 +3067,9 @@ mod tests { fn test_full_change_event_lifecycle() { // Test complete lifecycle: insert -> update -> delete let insert_data = RowData::from_pairs(vec![ - ("id", serde_json::json!(1)), - ("name", serde_json::json!("Alice")), - ("email", serde_json::json!("alice@example.com")), + ("id", ColumnValue::text("1")), + ("name", ColumnValue::text("Alice")), + ("email", ColumnValue::text("alice@example.com")), ]); let insert_event = ChangeEvent::insert( @@ -3125,9 +3084,9 @@ mod tests { // Update event let update_data = RowData::from_pairs(vec![ - ("id", serde_json::json!(1)), - ("name", serde_json::json!("Alice")), - ("email", serde_json::json!("alice.new@example.com")), + ("id", ColumnValue::text("1")), + ("name", ColumnValue::text("Alice")), + ("email", ColumnValue::text("alice.new@example.com")), ]); let update_event = ChangeEvent::update( @@ -3144,7 +3103,7 @@ mod tests { assert_eq!(update_event.lsn.value(), 2000); // Delete event - let delete_key = RowData::from_pairs(vec![("id", serde_json::json!(1))]); + let delete_key = RowData::from_pairs(vec![("id", ColumnValue::text("1"))]); let delete_event = ChangeEvent::delete( "public", @@ -3283,8 +3242,8 @@ mod tests { // Create multiple event types for same relation let insert_data = RowData::from_pairs(vec![ - ("order_id", serde_json::json!(100)), - ("amount", serde_json::json!(99.99)), + ("order_id", ColumnValue::text("100")), + ("amount", ColumnValue::text("99.99")), ]); let insert = ChangeEvent::insert( @@ -3296,8 +3255,8 @@ mod tests { ); let update_data = RowData::from_pairs(vec![ - ("order_id", serde_json::json!(100)), - ("amount", serde_json::json!(89.99)), + ("order_id", ColumnValue::text("100")), + ("amount", ColumnValue::text("89.99")), ]); let update = ChangeEvent::update( @@ -3490,8 +3449,8 @@ mod tests { #[test] fn test_event_with_unicode_data() { let data = RowData::from_pairs(vec![ - ("name", serde_json::json!("Alice 中文 émoji 😀")), - ("description", serde_json::json!("Test with ñ, ü, ö")), + ("name", ColumnValue::text("Alice 中文 émoji 😀")), + ("description", ColumnValue::text("Test with ñ, ü, ö")), ]); let event = ChangeEvent::insert("public", "users", 12345, data.clone(), Lsn::new(1000)); @@ -3502,11 +3461,11 @@ mod tests { } => { assert_eq!( event_data.get("name").unwrap(), - &serde_json::json!("Alice 中文 émoji 😀") + &ColumnValue::text("Alice 中文 émoji 😀") ); assert_eq!( event_data.get("description").unwrap(), - &serde_json::json!("Test with ñ, ü, ö") + &ColumnValue::text("Test with ñ, ü, ö") ); } _ => panic!("Expected insert"), @@ -3556,17 +3515,6 @@ mod tests { assert!(grandchild.is_cancelled()); } - #[test] - fn test_hex_encoding_edge_cases() { - // Test with repeating patterns - assert_eq!(hex_encode(&[0xaa, 0xaa, 0xaa]), "aaaaaa"); - assert_eq!(hex_encode(&[0x55, 0x55, 0x55]), "555555"); - - // Test with single bits - assert_eq!(hex_encode(&[0x01, 0x02, 0x04, 0x08]), "01020408"); - assert_eq!(hex_encode(&[0x10, 0x20, 0x40, 0x80]), "10204080"); - } - #[test] fn test_streaming_mode_as_str_off() { assert_eq!(StreamingMode::Off.as_str(), "off"); @@ -3698,8 +3646,9 @@ mod tests { let tuple = TupleData::new(vec![ColumnData::binary(vec![0xDE, 0xAD, 0xBE, 0xEF])]); let data = tuple_to_data(&tuple, &relation).unwrap(); - let val = data.get("binary_col").unwrap().as_str().unwrap(); - assert_eq!(val, "\\xdeadbeef"); + let val = data.get("binary_col").unwrap(); + assert!(matches!(val, ColumnValue::Binary(_))); + assert_eq!(val.as_bytes(), &[0xDE, 0xAD, 0xBE, 0xEF]); } #[test] @@ -3712,7 +3661,7 @@ mod tests { let tuple = TupleData::new(vec![ColumnData::null()]); let data = tuple_to_data(&tuple, &relation).unwrap(); - assert_eq!(data.get("nullable").unwrap(), &serde_json::Value::Null); + assert_eq!(data.get("nullable").unwrap(), &ColumnValue::Null); } #[test] @@ -4938,13 +4887,6 @@ mod tests { assert!(result.is_ok()); } - #[test] - fn test_hex_encode() { - assert_eq!(super::hex_encode(&[]), ""); - assert_eq!(super::hex_encode(&[0xde, 0xad, 0xbe, 0xef]), "deadbeef"); - assert_eq!(super::hex_encode(&[0x00, 0xff]), "00ff"); - } - #[test] fn test_build_options_origin_any() { let config = ReplicationStreamConfig::new( diff --git a/src/types.rs b/src/types.rs index 3acb981..5901820 100644 --- a/src/types.rs +++ b/src/types.rs @@ -3,11 +3,14 @@ //! This module provides type aliases for PostgreSQL types and utility functions //! for working with LSN (Log Sequence Numbers) and timestamps. +use crate::buffer::BufferReader; use crate::error::{ReplicationError, Result}; +use crate::protocol::message_types; +use bytes::BytesMut; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use std::time::{SystemTime, UNIX_EPOCH}; // PostgreSQL constants /// Seconds from Unix epoch (1970-01-01) to PostgreSQL epoch (2000-01-01) @@ -131,17 +134,6 @@ pub fn system_time_to_postgres_timestamp(time: SystemTime) -> TimestampTz { unix_micros - PG_EPOCH_OFFSET_SECS * 1_000_000 } -/// Convert PostgreSQL timestamp to formatted string -pub fn format_postgres_timestamp(timestamp: TimestampTz) -> String { - let unix_micros = timestamp + PG_EPOCH_OFFSET_SECS * 1_000_000; - let unix_secs = unix_micros / 1_000_000; - - match SystemTime::UNIX_EPOCH.checked_add(Duration::from_secs(unix_secs as u64)) { - Some(_) => format!("timestamp={unix_secs}"), - None => "invalid timestamp".to_string(), - } -} - /// Convert PostgreSQL timestamp (microseconds since 2000-01-01) into `chrono::DateTime`. pub fn postgres_timestamp_to_chrono(ts: i64) -> chrono::DateTime { use chrono::{TimeZone, Utc}; @@ -471,149 +463,12 @@ impl From for u64 { } } -/// Ordered row data: a list of `(column_name, value)` pairs. -/// -/// Replaces `HashMap` for CDC event payloads. -/// Column names are `Arc` — zero-cost clones from relation metadata. -/// Serialises as a JSON object `{"col": value, …}` for wire-format compatibility. -/// -/// # Example -/// -/// ``` -/// use pg_walstream::RowData; -/// use std::sync::Arc; -/// -/// let mut row = RowData::with_capacity(2); -/// row.push(Arc::from("id"), serde_json::json!(1)); -/// row.push(Arc::from("name"), serde_json::json!("Alice")); -/// -/// assert_eq!(row.len(), 2); -/// assert_eq!(row.get("id"), Some(&serde_json::json!(1))); -/// ``` -#[derive(Debug, Clone, Eq)] -pub struct RowData { - columns: Vec<(Arc, serde_json::Value)>, -} - -// --- core API --- - -impl RowData { - /// Create an empty `RowData`. - #[inline] - pub fn new() -> Self { - Self { - columns: Vec::new(), - } - } - - /// Create an empty `RowData` with pre-allocated capacity. - #[inline] - pub fn with_capacity(cap: usize) -> Self { - Self { - columns: Vec::with_capacity(cap), - } - } - - /// Append a column. - #[inline] - pub fn push(&mut self, name: Arc, value: serde_json::Value) { - self.columns.push((name, value)); - } - - /// Number of columns. - #[inline] - pub fn len(&self) -> usize { - self.columns.len() - } - - /// Returns `true` when there are no columns. - #[inline] - pub fn is_empty(&self) -> bool { - self.columns.is_empty() - } - - /// Look up a value by column name (linear scan — fast for typical column counts). - #[inline] - pub fn get(&self, name: &str) -> Option<&serde_json::Value> { - self.columns - .iter() - .find(|(k, _)| k.as_ref() == name) - .map(|(_, v)| v) - } - - /// Convert to a `HashMap` (allocates — prefer `get` / `iter` for lookups). - pub fn into_hash_map(self) -> HashMap { - self.columns - .into_iter() - .map(|(k, v)| (k.to_string(), v)) - .collect() - } - - /// Construct from `(&str, Value)` pairs — handy for tests and literals. - #[inline] - pub fn from_pairs(pairs: Vec<(&str, serde_json::Value)>) -> Self { - Self { - columns: pairs.into_iter().map(|(k, v)| (Arc::from(k), v)).collect(), - } - } -} - -impl Default for RowData { - fn default() -> Self { - Self::new() - } -} - -// Order-sensitive equality (columns must match in the same order). -impl PartialEq for RowData { - fn eq(&self, other: &Self) -> bool { - self.columns == other.columns - } -} -// --- serde: serialise as JSON object --- - -impl Serialize for RowData { - fn serialize( - &self, - serializer: S, - ) -> std::result::Result { - use serde::ser::SerializeMap; - let mut map = serializer.serialize_map(Some(self.columns.len()))?; - for (k, v) in &self.columns { - map.serialize_entry(k.as_ref(), v)?; - } - map.end() - } -} - -impl<'de> Deserialize<'de> for RowData { - fn deserialize>( - deserializer: D, - ) -> std::result::Result { - struct Visitor; +// Re-export ColumnValue and RowData from their dedicated module. +pub use crate::column_value::{ColumnValue, RowData}; - impl<'de> serde::de::Visitor<'de> for Visitor { - type Value = RowData; - - fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.write_str("a JSON object (map of column names to values)") - } - - fn visit_map>( - self, - mut map: M, - ) -> std::result::Result { - let mut cols = Vec::with_capacity(map.size_hint().unwrap_or(0)); - while let Some((k, v)) = map.next_entry::()? { - cols.push((Arc::from(k.as_str()), v)); - } - Ok(RowData { columns: cols }) - } - } - - deserializer.deserialize_map(Visitor) - } -} +// NOTE: The old ColumnValue enum, RowData struct, hex helpers, and their +// serde / binary-encode impls now live in `src/column_value.rs`. +// They are re-exported above so downstream code is unaffected. /// Represents the type of change event from PostgreSQL logical replication #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -688,8 +543,8 @@ pub struct ChangeEvent { /// LSN (Log Sequence Number) position pub lsn: Lsn, - /// Additional metadata - pub metadata: Option>, + /// Additional user-defined metadata (key-value string pairs). + pub metadata: Option>, } impl ChangeEvent { @@ -706,12 +561,12 @@ impl ChangeEvent { /// # Example /// /// ``` - /// use pg_walstream::{ChangeEvent, Lsn, RowData}; + /// use pg_walstream::{ChangeEvent, ColumnValue, Lsn, RowData}; /// use std::sync::Arc; /// /// let data = RowData::from_pairs(vec![ - /// ("id", serde_json::json!(1)), - /// ("name", serde_json::json!("Alice")), + /// ("id", ColumnValue::text("1")), + /// ("name", ColumnValue::text("Alice")), /// ]); /// /// let event = ChangeEvent::insert( @@ -757,17 +612,17 @@ impl ChangeEvent { /// # Example /// /// ``` - /// use pg_walstream::{ChangeEvent, ReplicaIdentity, Lsn, RowData}; + /// use pg_walstream::{ChangeEvent, ColumnValue, ReplicaIdentity, Lsn, RowData}; /// use std::sync::Arc; /// /// let old_data = RowData::from_pairs(vec![ - /// ("id", serde_json::json!(1)), - /// ("name", serde_json::json!("Alice")), + /// ("id", ColumnValue::text("1")), + /// ("name", ColumnValue::text("Alice")), /// ]); /// /// let new_data = RowData::from_pairs(vec![ - /// ("id", serde_json::json!(1)), - /// ("name", serde_json::json!("Bob")), + /// ("id", ColumnValue::text("1")), + /// ("name", ColumnValue::text("Bob")), /// ]); /// /// let event = ChangeEvent::update( @@ -822,12 +677,12 @@ impl ChangeEvent { /// # Example /// /// ``` - /// use pg_walstream::{ChangeEvent, ReplicaIdentity, Lsn, RowData}; + /// use pg_walstream::{ChangeEvent, ColumnValue, ReplicaIdentity, Lsn, RowData}; /// use std::sync::Arc; /// /// let old_data = RowData::from_pairs(vec![ - /// ("id", serde_json::json!(1)), - /// ("name", serde_json::json!("Alice")), + /// ("id", ColumnValue::text("1")), + /// ("name", ColumnValue::text("Alice")), /// ]); /// /// let event = ChangeEvent::delete( @@ -982,10 +837,7 @@ impl ChangeEvent { } /// Set metadata for this event - pub fn with_metadata( - mut self, - metadata: std::collections::HashMap, - ) -> Self { + pub fn with_metadata(mut self, metadata: HashMap) -> Self { self.metadata = Some(metadata); self } @@ -1021,11 +873,407 @@ impl ChangeEvent { _ => "other", } } + + // ---- binary wire format (encode / decode) ---- + /// For efficient transport over the network, we provide a compact binary encoding of `ChangeEvent`. + /// + /// The format is: + /// ```text + /// [1 byte event tag] + /// [8 bytes LSN (big-endian u64)] + /// [1 byte has_metadata flag] + /// [if has_metadata: u16 entry count, then for each: u16+name, u16+value] + /// [variable event-specific payload] + /// ``` + /// + /// This is significantly faster than JSON for both encoding and decoding, + /// and produces a smaller payload. + pub fn encode(&self, buf: &mut BytesMut) { + // LSN + buf.extend_from_slice(&self.lsn.0.to_be_bytes()); + + // Metadata + match &self.metadata { + None => buf.extend_from_slice(&[0u8]), + Some(m) => { + buf.extend_from_slice(&[1u8]); + buf.extend_from_slice(&(m.len() as u16).to_be_bytes()); + for (k, v) in m { + let kb = k.as_bytes(); + buf.extend_from_slice(&(kb.len() as u16).to_be_bytes()); + buf.extend_from_slice(kb); + let vb = v.as_bytes(); + buf.extend_from_slice(&(vb.len() as u16).to_be_bytes()); + buf.extend_from_slice(vb); + } + } + } + + // Event payload + match &self.event_type { + EventType::Insert { + schema, + table, + relation_oid, + data, + } => { + buf.extend_from_slice(&[message_types::INSERT]); + encode_arc_str(buf, schema); + encode_arc_str(buf, table); + buf.extend_from_slice(&relation_oid.to_be_bytes()); + data.encode(buf); + } + EventType::Update { + schema, + table, + relation_oid, + old_data, + new_data, + replica_identity, + key_columns, + } => { + buf.extend_from_slice(&[message_types::UPDATE]); + encode_arc_str(buf, schema); + encode_arc_str(buf, table); + buf.extend_from_slice(&relation_oid.to_be_bytes()); + // old_data: present flag + data + match old_data { + None => buf.extend_from_slice(&[0u8]), + Some(d) => { + buf.extend_from_slice(&[1u8]); + d.encode(buf); + } + } + new_data.encode(buf); + buf.extend_from_slice(&[replica_identity.to_byte()]); + buf.extend_from_slice(&(key_columns.len() as u16).to_be_bytes()); + for kc in key_columns { + encode_arc_str(buf, kc); + } + } + EventType::Delete { + schema, + table, + relation_oid, + old_data, + replica_identity, + key_columns, + } => { + buf.extend_from_slice(&[message_types::DELETE]); + encode_arc_str(buf, schema); + encode_arc_str(buf, table); + buf.extend_from_slice(&relation_oid.to_be_bytes()); + old_data.encode(buf); + buf.extend_from_slice(&[replica_identity.to_byte()]); + buf.extend_from_slice(&(key_columns.len() as u16).to_be_bytes()); + for kc in key_columns { + encode_arc_str(buf, kc); + } + } + EventType::Truncate(tables) => { + buf.extend_from_slice(&[message_types::TRUNCATE]); + buf.extend_from_slice(&(tables.len() as u16).to_be_bytes()); + for t in tables { + encode_arc_str(buf, t); + } + } + EventType::Begin { + transaction_id, + final_lsn, + commit_timestamp, + } => { + buf.extend_from_slice(&[message_types::BEGIN]); + buf.extend_from_slice(&transaction_id.to_be_bytes()); + buf.extend_from_slice(&final_lsn.0.to_be_bytes()); + buf.extend_from_slice(&commit_timestamp.timestamp_micros().to_be_bytes()); + } + EventType::Commit { + commit_timestamp, + commit_lsn, + end_lsn, + } => { + buf.extend_from_slice(&[message_types::COMMIT]); + buf.extend_from_slice(&commit_timestamp.timestamp_micros().to_be_bytes()); + buf.extend_from_slice(&commit_lsn.0.to_be_bytes()); + buf.extend_from_slice(&end_lsn.0.to_be_bytes()); + } + EventType::StreamStart { + transaction_id, + first_segment, + } => { + buf.extend_from_slice(&[message_types::STREAM_START]); + buf.extend_from_slice(&transaction_id.to_be_bytes()); + buf.extend_from_slice(&[u8::from(*first_segment)]); + } + EventType::StreamStop => { + buf.extend_from_slice(&[message_types::STREAM_STOP]); + } + EventType::StreamCommit { + transaction_id, + commit_lsn, + end_lsn, + commit_timestamp, + } => { + buf.extend_from_slice(&[message_types::STREAM_COMMIT]); + buf.extend_from_slice(&transaction_id.to_be_bytes()); + buf.extend_from_slice(&commit_lsn.0.to_be_bytes()); + buf.extend_from_slice(&end_lsn.0.to_be_bytes()); + buf.extend_from_slice(&commit_timestamp.timestamp_micros().to_be_bytes()); + } + EventType::StreamAbort { + transaction_id, + subtransaction_xid, + abort_lsn, + abort_timestamp, + } => { + buf.extend_from_slice(&[message_types::STREAM_ABORT]); + buf.extend_from_slice(&transaction_id.to_be_bytes()); + buf.extend_from_slice(&subtransaction_xid.to_be_bytes()); + match abort_lsn { + None => buf.extend_from_slice(&[0u8]), + Some(l) => { + buf.extend_from_slice(&[1u8]); + buf.extend_from_slice(&l.0.to_be_bytes()); + } + } + match abort_timestamp { + None => buf.extend_from_slice(&[0u8]), + Some(ts) => { + buf.extend_from_slice(&[1u8]); + buf.extend_from_slice(&ts.timestamp_micros().to_be_bytes()); + } + } + } + EventType::Relation => buf.extend_from_slice(&[message_types::RELATION]), + EventType::Type => buf.extend_from_slice(&[message_types::TYPE]), + EventType::Origin => buf.extend_from_slice(&[message_types::ORIGIN]), + EventType::Message => buf.extend_from_slice(&[message_types::MESSAGE]), + } + } + + /// Decode a `ChangeEvent` from binary data produced by [`encode`](Self::encode). + pub fn decode(data: &[u8]) -> Result { + let mut reader = BufferReader::new(data); + + // LSN + let lsn = Lsn(reader.read_u64()?); + + // Metadata + let has_meta = reader.read_u8()?; + let metadata = if has_meta != 0 { + let count = reader.read_u16()? as usize; + let mut m = HashMap::with_capacity(count); + for _ in 0..count { + let k = decode_string(&mut reader)?; + let v = decode_string(&mut reader)?; + m.insert(k, v); + } + Some(m) + } else { + None + }; + + // Event tag + let tag = reader.read_u8()?; + let event_type = match tag { + message_types::INSERT => { + let schema = decode_arc_str(&mut reader)?; + let table = decode_arc_str(&mut reader)?; + let relation_oid = reader.read_u32()?; + let data = RowData::decode(&mut reader)?; + EventType::Insert { + schema, + table, + relation_oid, + data, + } + } + message_types::UPDATE => { + let schema = decode_arc_str(&mut reader)?; + let table = decode_arc_str(&mut reader)?; + let relation_oid = reader.read_u32()?; + let has_old = reader.read_u8()?; + let old_data = if has_old != 0 { + Some(RowData::decode(&mut reader)?) + } else { + None + }; + let new_data = RowData::decode(&mut reader)?; + let ri_byte = reader.read_u8()?; + let replica_identity = ReplicaIdentity::from_byte(ri_byte).ok_or_else(|| { + ReplicationError::protocol(format!( + "Unknown replica identity byte: 0x{ri_byte:02x}" + )) + })?; + let kc_count = reader.read_u16()? as usize; + let mut key_columns = Vec::with_capacity(kc_count); + for _ in 0..kc_count { + key_columns.push(decode_arc_str(&mut reader)?); + } + EventType::Update { + schema, + table, + relation_oid, + old_data, + new_data, + replica_identity, + key_columns, + } + } + message_types::DELETE => { + let schema = decode_arc_str(&mut reader)?; + let table = decode_arc_str(&mut reader)?; + let relation_oid = reader.read_u32()?; + let old_data = RowData::decode(&mut reader)?; + let ri_byte = reader.read_u8()?; + let replica_identity = ReplicaIdentity::from_byte(ri_byte).ok_or_else(|| { + ReplicationError::protocol(format!( + "Unknown replica identity byte: 0x{ri_byte:02x}" + )) + })?; + let kc_count = reader.read_u16()? as usize; + let mut key_columns = Vec::with_capacity(kc_count); + for _ in 0..kc_count { + key_columns.push(decode_arc_str(&mut reader)?); + } + EventType::Delete { + schema, + table, + relation_oid, + old_data, + replica_identity, + key_columns, + } + } + message_types::TRUNCATE => { + let count = reader.read_u16()? as usize; + let mut tables = Vec::with_capacity(count); + for _ in 0..count { + tables.push(decode_arc_str(&mut reader)?); + } + EventType::Truncate(tables) + } + message_types::BEGIN => { + let transaction_id = reader.read_u32()?; + let final_lsn = Lsn(reader.read_u64()?); + let ts_micros = reader.read_i64()?; + let commit_timestamp = micros_to_chrono(ts_micros); + EventType::Begin { + transaction_id, + final_lsn, + commit_timestamp, + } + } + message_types::COMMIT => { + let ts_micros = reader.read_i64()?; + let commit_timestamp = micros_to_chrono(ts_micros); + let commit_lsn = Lsn(reader.read_u64()?); + let end_lsn = Lsn(reader.read_u64()?); + EventType::Commit { + commit_timestamp, + commit_lsn, + end_lsn, + } + } + message_types::STREAM_START => { + let transaction_id = reader.read_u32()?; + let first_segment = reader.read_u8()? != 0; + EventType::StreamStart { + transaction_id, + first_segment, + } + } + message_types::STREAM_STOP => EventType::StreamStop, + message_types::STREAM_COMMIT => { + let transaction_id = reader.read_u32()?; + let commit_lsn = Lsn(reader.read_u64()?); + let end_lsn = Lsn(reader.read_u64()?); + let ts_micros = reader.read_i64()?; + let commit_timestamp = micros_to_chrono(ts_micros); + EventType::StreamCommit { + transaction_id, + commit_lsn, + end_lsn, + commit_timestamp, + } + } + message_types::STREAM_ABORT => { + let transaction_id = reader.read_u32()?; + let subtransaction_xid = reader.read_u32()?; + let has_lsn = reader.read_u8()?; + let abort_lsn = if has_lsn != 0 { + Some(Lsn(reader.read_u64()?)) + } else { + None + }; + let has_ts = reader.read_u8()?; + let abort_timestamp = if has_ts != 0 { + Some(micros_to_chrono(reader.read_i64()?)) + } else { + None + }; + EventType::StreamAbort { + transaction_id, + subtransaction_xid, + abort_lsn, + abort_timestamp, + } + } + message_types::RELATION => EventType::Relation, + message_types::TYPE => EventType::Type, + message_types::ORIGIN => EventType::Origin, + message_types::MESSAGE => EventType::Message, + _ => { + return Err(ReplicationError::protocol(format!( + "Unknown ChangeEvent tag: 0x{tag:02x}" + ))); + } + }; + + Ok(Self { + event_type, + lsn, + metadata, + }) + } +} + +#[inline] +fn encode_arc_str(buf: &mut BytesMut, s: &Arc) { + let b = s.as_bytes(); + buf.extend_from_slice(&(b.len() as u16).to_be_bytes()); + buf.extend_from_slice(b); +} + +fn decode_arc_str(reader: &mut BufferReader) -> Result> { + let len = reader.read_u16()? as usize; + let bytes = reader.read_bytes(len)?; + let s = std::str::from_utf8(&bytes) + .map_err(|e| ReplicationError::protocol(format!("Invalid UTF-8 in wire format: {e}")))?; + Ok(Arc::from(s)) +} + +fn decode_string(reader: &mut BufferReader) -> Result { + let len = reader.read_u16()? as usize; + let bytes = reader.read_bytes(len)?; + String::from_utf8(bytes) + .map_err(|e| ReplicationError::protocol(format!("Invalid UTF-8 in wire format: {e}"))) +} + +/// Convert Unix timestamp microseconds to chrono DateTime. +fn micros_to_chrono(micros: i64) -> chrono::DateTime { + use chrono::{TimeZone, Utc}; + let secs = micros.div_euclid(1_000_000); + let subsec_nanos = (micros.rem_euclid(1_000_000) as u32) * 1000; + Utc.timestamp_opt(secs, subsec_nanos) + .single() + .unwrap_or_else(|| Utc.timestamp_opt(0, 0).unwrap()) } #[cfg(test)] mod tests { use super::*; + use bytes::Bytes; use chrono::{TimeZone, Utc}; use std::time::{SystemTime, UNIX_EPOCH}; @@ -1091,14 +1339,6 @@ mod tests { assert!(diff < 2, "Round trip difference too large: {diff}"); } - #[test] - fn test_format_postgres_timestamp() { - let ts = 0; // PostgreSQL epoch - let formatted = format_postgres_timestamp(ts); - assert!(formatted.contains("timestamp=")); - assert!(formatted.contains("946684800")); // Unix timestamp for 2000-01-01 - } - #[test] fn test_replica_identity_from_byte() { assert_eq!( @@ -1181,8 +1421,8 @@ mod tests { #[test] fn test_change_event_insert() { let data = RowData::from_pairs(vec![ - ("id", serde_json::json!(1)), - ("name", serde_json::json!("test")), + ("id", ColumnValue::text("1")), + ("name", ColumnValue::text("test")), ]); let event = ChangeEvent::insert("public", "users", 12345, data, Lsn::new(0x16B374D848)); @@ -1243,6 +1483,7 @@ mod tests { } #[test] + #[allow(clippy::clone_on_copy)] fn test_slot_type_clone_copy_eq() { let s1 = SlotType::Logical; let s2 = s1; // Copy @@ -1254,10 +1495,10 @@ mod tests { #[test] fn test_change_event_update() { - let old_data = RowData::from_pairs(vec![("id", serde_json::json!(1))]); + let old_data = RowData::from_pairs(vec![("id", ColumnValue::text("1"))]); let new_data = RowData::from_pairs(vec![ - ("id", serde_json::json!(1)), - ("name", serde_json::json!("updated")), + ("id", ColumnValue::text("1")), + ("name", ColumnValue::text("updated")), ]); let event = ChangeEvent::update( @@ -1296,7 +1537,7 @@ mod tests { #[test] fn test_change_event_delete() { - let old_data = RowData::from_pairs(vec![("id", serde_json::json!(1))]); + let old_data = RowData::from_pairs(vec![("id", ColumnValue::text("1"))]); let event = ChangeEvent::delete( "public", @@ -1412,15 +1653,15 @@ mod tests { let event = ChangeEvent::insert("public", "test", 1, data, Lsn::new(100)); assert!(event.metadata.is_none()); - let mut metadata = std::collections::HashMap::new(); - metadata.insert("source".to_string(), serde_json::json!("test")); - metadata.insert("version".to_string(), serde_json::json!(2)); + let mut metadata = HashMap::new(); + metadata.insert("source".to_string(), "test".to_string()); + metadata.insert("version".to_string(), "2".to_string()); let event = event.with_metadata(metadata.clone()); assert!(event.metadata.is_some()); let m = event.metadata.unwrap(); - assert_eq!(m.get("source").unwrap(), &serde_json::json!("test")); - assert_eq!(m.get("version").unwrap(), &serde_json::json!(2)); + assert_eq!(m.get("source").unwrap(), "test"); + assert_eq!(m.get("version").unwrap(), "2"); } #[test] @@ -1561,14 +1802,6 @@ mod tests { assert_eq!(type_event.event_type_str(), "other"); } - #[test] - fn test_format_postgres_timestamp_invalid() { - // Very large negative timestamp - let result = format_postgres_timestamp(i64::MIN / 2); - // Either we get a valid formatted timestamp or "invalid timestamp" - assert!(!result.is_empty()); - } - #[test] fn test_lsn_serialize_deserialize() { let lsn = Lsn::new(0x16B374D848); @@ -1593,7 +1826,7 @@ mod tests { #[test] fn test_change_event_serialize_deserialize() { - let data = RowData::from_pairs(vec![("id", serde_json::json!(42))]); + let data = RowData::from_pairs(vec![("id", ColumnValue::text("42"))]); let event = ChangeEvent::insert("public", "test", 12345, data, Lsn::new(1000)); @@ -1616,35 +1849,680 @@ mod tests { assert_eq!(&*padded, "hello world"); } - // ---- RowData::default coverage ---- + // --- Encode / Decode round-trip tests --- + + /// Helper: encode a ChangeEvent, then decode it and assert equality. + fn assert_encode_decode_round_trip(event: &ChangeEvent) { + let mut buf = BytesMut::new(); + event.encode(&mut buf); + let decoded = ChangeEvent::decode(&buf).expect("decode failed"); + assert_eq!(decoded.lsn, event.lsn); + assert_eq!(decoded.event_type, event.event_type); + // Compare metadata (HashMap doesn't impl Eq but we can check the content) + assert_eq!( + decoded.metadata.is_some(), + event.metadata.is_some(), + "metadata presence mismatch" + ); + if let (Some(a), Some(b)) = (&decoded.metadata, &event.metadata) { + assert_eq!(a.len(), b.len()); + for (k, v) in b { + assert_eq!(a.get(k), Some(v)); + } + } + } #[test] - fn test_rowdata_default() { - let row = RowData::default(); - assert!(row.is_empty()); - assert_eq!(row.len(), 0); + fn test_encode_decode_insert() { + let data = RowData::from_pairs(vec![ + ("id", ColumnValue::text("42")), + ("name", ColumnValue::text("Alice")), + ("bio", ColumnValue::Null), + ]); + let event = ChangeEvent::insert("public", "users", 12345, data, Lsn::new(0x100)); + assert_encode_decode_round_trip(&event); } #[test] - fn test_rowdata_deserialize_invalid_type() { - // Feeding a non-object type triggers the `expecting()` method. - let err = serde_json::from_str::("42").unwrap_err(); - let msg = err.to_string(); + fn test_encode_decode_insert_with_metadata() { + let data = RowData::from_pairs(vec![("x", ColumnValue::text("1"))]); + let mut meta = HashMap::new(); + meta.insert("source".to_string(), "unit-test".to_string()); + meta.insert("version".to_string(), "3".to_string()); + let event = + ChangeEvent::insert("myschema", "mytable", 99, data, Lsn::new(500)).with_metadata(meta); + assert_encode_decode_round_trip(&event); + } + + #[test] + fn test_encode_decode_update_with_old_data() { + let old = RowData::from_pairs(vec![ + ("id", ColumnValue::text("1")), + ("val", ColumnValue::text("old")), + ]); + let new = RowData::from_pairs(vec![ + ("id", ColumnValue::text("1")), + ("val", ColumnValue::text("new")), + ]); + let event = ChangeEvent::update( + "public", + "items", + 555, + Some(old), + new, + ReplicaIdentity::Full, + vec![Arc::from("id")], + Lsn::new(2000), + ); + assert_encode_decode_round_trip(&event); + } + + #[test] + fn test_encode_decode_update_without_old_data() { + let new = RowData::from_pairs(vec![ + ("id", ColumnValue::text("1")), + ("val", ColumnValue::text("updated")), + ]); + let event = ChangeEvent::update( + "public", + "items", + 555, + None, + new, + ReplicaIdentity::Default, + vec![Arc::from("id"), Arc::from("tenant")], + Lsn::new(2100), + ); + assert_encode_decode_round_trip(&event); + } + + #[test] + fn test_encode_decode_delete() { + let old = RowData::from_pairs(vec![ + ("id", ColumnValue::text("99")), + ("name", ColumnValue::text("deleted")), + ]); + let event = ChangeEvent::delete( + "public", + "users", + 777, + old, + ReplicaIdentity::Index, + vec![Arc::from("id")], + Lsn::new(3000), + ); + assert_encode_decode_round_trip(&event); + } + + #[test] + fn test_encode_decode_truncate() { + let tables = vec![Arc::from("public.a"), Arc::from("public.b")]; + let event = ChangeEvent::truncate(tables, Lsn::new(4000)); + assert_encode_decode_round_trip(&event); + } + + #[test] + fn test_encode_decode_truncate_empty() { + let event = ChangeEvent::truncate(vec![], Lsn::new(4100)); + assert_encode_decode_round_trip(&event); + } + + #[test] + fn test_encode_decode_begin() { + let ts = Utc.with_ymd_and_hms(2024, 6, 15, 12, 30, 45).unwrap(); + let event = ChangeEvent::begin(12345, Lsn::new(5000), ts, Lsn::new(4900)); + assert_encode_decode_round_trip(&event); + } + + #[test] + fn test_encode_decode_commit() { + let ts = Utc.with_ymd_and_hms(2024, 6, 15, 12, 31, 0).unwrap(); + let event = ChangeEvent::commit(ts, Lsn::new(6000), Lsn::new(5900), Lsn::new(6100)); + assert_encode_decode_round_trip(&event); + } + + #[test] + fn test_encode_decode_stream_start() { + let event = ChangeEvent { + event_type: EventType::StreamStart { + transaction_id: 42, + first_segment: true, + }, + lsn: Lsn::new(7000), + metadata: None, + }; + assert_encode_decode_round_trip(&event); + + // second segment + let event2 = ChangeEvent { + event_type: EventType::StreamStart { + transaction_id: 42, + first_segment: false, + }, + lsn: Lsn::new(7001), + metadata: None, + }; + assert_encode_decode_round_trip(&event2); + } + + #[test] + fn test_encode_decode_stream_stop() { + let event = ChangeEvent { + event_type: EventType::StreamStop, + lsn: Lsn::new(7500), + metadata: None, + }; + assert_encode_decode_round_trip(&event); + } + + #[test] + fn test_encode_decode_stream_commit() { + let ts = Utc.with_ymd_and_hms(2024, 8, 1, 0, 0, 0).unwrap(); + let event = ChangeEvent { + event_type: EventType::StreamCommit { + transaction_id: 99, + commit_lsn: Lsn::new(8000), + end_lsn: Lsn::new(8100), + commit_timestamp: ts, + }, + lsn: Lsn::new(7900), + metadata: None, + }; + assert_encode_decode_round_trip(&event); + } + + #[test] + fn test_encode_decode_stream_abort_with_all_fields() { + let ts = Utc.with_ymd_and_hms(2024, 9, 1, 12, 0, 0).unwrap(); + let event = ChangeEvent { + event_type: EventType::StreamAbort { + transaction_id: 50, + subtransaction_xid: 51, + abort_lsn: Some(Lsn::new(9000)), + abort_timestamp: Some(ts), + }, + lsn: Lsn::new(8900), + metadata: None, + }; + assert_encode_decode_round_trip(&event); + } + + #[test] + fn test_encode_decode_stream_abort_without_optional_fields() { + let event = ChangeEvent { + event_type: EventType::StreamAbort { + transaction_id: 50, + subtransaction_xid: 0, + abort_lsn: None, + abort_timestamp: None, + }, + lsn: Lsn::new(8950), + metadata: None, + }; + assert_encode_decode_round_trip(&event); + } + + #[test] + fn test_encode_decode_relation() { + let event = ChangeEvent::relation(Lsn::new(10000)); + assert_encode_decode_round_trip(&event); + } + + #[test] + fn test_encode_decode_type() { + let event = ChangeEvent::type_event(Lsn::new(10100)); + assert_encode_decode_round_trip(&event); + } + + #[test] + fn test_encode_decode_origin() { + let event = ChangeEvent::origin(Lsn::new(10200)); + assert_encode_decode_round_trip(&event); + } + + #[test] + fn test_encode_decode_message() { + let event = ChangeEvent::message(Lsn::new(10300)); + assert_encode_decode_round_trip(&event); + } + + #[test] + fn test_decode_unknown_event_tag() { + // Build a minimal buffer with valid LSN, no-metadata, then unknown tag + let mut buf = BytesMut::new(); + buf.extend_from_slice(&100u64.to_be_bytes()); // LSN + buf.extend_from_slice(&[0u8]); // no metadata + buf.extend_from_slice(&[0xFE]); // unknown event tag + let result = ChangeEvent::decode(&buf); + assert!(result.is_err()); + let err_msg = format!("{}", result.unwrap_err()); assert!( - msg.contains("a JSON object"), - "Error should reference expecting(), got: {msg}" + err_msg.contains("Unknown ChangeEvent tag"), + "got: {err_msg}" ); } #[test] - fn test_rowdata_deserialize_string_gives_error() { - let err = serde_json::from_str::("\"hello\"").unwrap_err(); - assert!(err.to_string().contains("a JSON object")); + fn test_encode_decode_insert_with_binary_data() { + let data = RowData::from_pairs(vec![ + ("id", ColumnValue::text("1")), + ( + "blob", + ColumnValue::binary_bytes(Bytes::from_static(&[0xde, 0xad, 0xbe, 0xef])), + ), + ("empty", ColumnValue::Null), + ]); + let event = ChangeEvent::insert("public", "blobs", 999, data, Lsn::new(11000)); + assert_encode_decode_round_trip(&event); } #[test] - fn test_rowdata_deserialize_array_gives_error() { - let err = serde_json::from_str::("[1, 2, 3]").unwrap_err(); - assert!(err.to_string().contains("a JSON object")); + fn test_encode_decode_update_all_replica_identities() { + for ri in [ + ReplicaIdentity::Default, + ReplicaIdentity::Nothing, + ReplicaIdentity::Full, + ReplicaIdentity::Index, + ] { + let new = RowData::from_pairs(vec![("x", ColumnValue::text("1"))]); + let event = + ChangeEvent::update("s", "t", 1, None, new, ri.clone(), vec![], Lsn::new(12000)); + assert_encode_decode_round_trip(&event); + } + } + + #[test] + fn test_encode_decode_delete_all_replica_identities() { + for ri in [ + ReplicaIdentity::Default, + ReplicaIdentity::Nothing, + ReplicaIdentity::Full, + ReplicaIdentity::Index, + ] { + let old = RowData::from_pairs(vec![("k", ColumnValue::text("v"))]); + let event = ChangeEvent::delete( + "s", + "t", + 1, + old, + ri.clone(), + vec![Arc::from("k")], + Lsn::new(13000), + ); + assert_encode_decode_round_trip(&event); + } + } + + #[test] + fn test_encode_decode_metadata_empty_hashmap() { + let data = RowData::from_pairs(vec![("a", ColumnValue::text("b"))]); + let event = + ChangeEvent::insert("s", "t", 1, data, Lsn::new(14000)).with_metadata(HashMap::new()); + assert_encode_decode_round_trip(&event); + } + + #[test] + fn test_micros_to_chrono_zero() { + // Zero micros = Unix epoch + let dt = micros_to_chrono(0); + assert_eq!(dt, Utc.with_ymd_and_hms(1970, 1, 1, 0, 0, 0).unwrap()); + } + + #[test] + fn test_micros_to_chrono_negative() { + // Negative micros = before Unix epoch + let dt = micros_to_chrono(-1_000_000); + assert_eq!(dt, Utc.with_ymd_and_hms(1969, 12, 31, 23, 59, 59).unwrap()); + } + + #[test] + fn test_micros_to_chrono_with_subsecond() { + let dt = micros_to_chrono(1_500_000); // 1.5 seconds + assert_eq!(dt.timestamp(), 1); + assert_eq!(dt.timestamp_subsec_micros(), 500_000); + } + + #[test] + fn test_change_event_serde_round_trip_update() { + let old = RowData::from_pairs(vec![("id", ColumnValue::text("1"))]); + let new = RowData::from_pairs(vec![ + ("id", ColumnValue::text("1")), + ("v", ColumnValue::text("updated")), + ]); + let event = ChangeEvent::update( + "public", + "t", + 1, + Some(old), + new, + ReplicaIdentity::Full, + vec![Arc::from("id")], + Lsn::new(100), + ); + let json = serde_json::to_string(&event).unwrap(); + let back: ChangeEvent = serde_json::from_str(&json).unwrap(); + assert_eq!(back.event_type, event.event_type); + } + + #[test] + fn test_change_event_serde_round_trip_delete() { + let old = RowData::from_pairs(vec![("id", ColumnValue::text("1"))]); + let event = ChangeEvent::delete( + "public", + "t", + 1, + old, + ReplicaIdentity::Index, + vec![Arc::from("id")], + Lsn::new(200), + ); + let json = serde_json::to_string(&event).unwrap(); + let back: ChangeEvent = serde_json::from_str(&json).unwrap(); + assert_eq!(back.event_type, event.event_type); + } + + #[test] + fn test_change_event_serde_round_trip_begin() { + let ts = Utc.with_ymd_and_hms(2024, 3, 15, 8, 0, 0).unwrap(); + let event = ChangeEvent::begin(1, Lsn::new(300), ts, Lsn::new(300)); + let json = serde_json::to_string(&event).unwrap(); + let back: ChangeEvent = serde_json::from_str(&json).unwrap(); + assert_eq!(back.event_type, event.event_type); + } + + #[test] + fn test_change_event_serde_round_trip_commit() { + let ts = Utc.with_ymd_and_hms(2024, 3, 15, 8, 0, 0).unwrap(); + let event = ChangeEvent::commit(ts, Lsn::new(400), Lsn::new(400), Lsn::new(410)); + let json = serde_json::to_string(&event).unwrap(); + let back: ChangeEvent = serde_json::from_str(&json).unwrap(); + assert_eq!(back.event_type, event.event_type); + } + + #[test] + fn test_change_event_serde_round_trip_truncate() { + let event = ChangeEvent::truncate(vec![Arc::from("t1"), Arc::from("t2")], Lsn::new(500)); + let json = serde_json::to_string(&event).unwrap(); + let back: ChangeEvent = serde_json::from_str(&json).unwrap(); + assert_eq!(back.event_type, event.event_type); + } + + #[test] + fn test_change_event_serde_round_trip_streaming_events() { + let ts = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); + + // StreamStart + let event = ChangeEvent { + event_type: EventType::StreamStart { + transaction_id: 10, + first_segment: true, + }, + lsn: Lsn::new(600), + metadata: None, + }; + let json = serde_json::to_string(&event).unwrap(); + let back: ChangeEvent = serde_json::from_str(&json).unwrap(); + assert_eq!(back.event_type, event.event_type); + + // StreamStop + let event = ChangeEvent { + event_type: EventType::StreamStop, + lsn: Lsn::new(700), + metadata: None, + }; + let json = serde_json::to_string(&event).unwrap(); + let back: ChangeEvent = serde_json::from_str(&json).unwrap(); + assert_eq!(back.event_type, event.event_type); + + // StreamCommit + let event = ChangeEvent { + event_type: EventType::StreamCommit { + transaction_id: 10, + commit_lsn: Lsn::new(800), + end_lsn: Lsn::new(810), + commit_timestamp: ts, + }, + lsn: Lsn::new(800), + metadata: None, + }; + let json = serde_json::to_string(&event).unwrap(); + let back: ChangeEvent = serde_json::from_str(&json).unwrap(); + assert_eq!(back.event_type, event.event_type); + + // StreamAbort with optional fields + let event = ChangeEvent { + event_type: EventType::StreamAbort { + transaction_id: 10, + subtransaction_xid: 11, + abort_lsn: Some(Lsn::new(900)), + abort_timestamp: Some(ts), + }, + lsn: Lsn::new(900), + metadata: None, + }; + let json = serde_json::to_string(&event).unwrap(); + let back: ChangeEvent = serde_json::from_str(&json).unwrap(); + assert_eq!(back.event_type, event.event_type); + + // StreamAbort without optional fields + let event = ChangeEvent { + event_type: EventType::StreamAbort { + transaction_id: 10, + subtransaction_xid: 0, + abort_lsn: None, + abort_timestamp: None, + }, + lsn: Lsn::new(950), + metadata: None, + }; + let json = serde_json::to_string(&event).unwrap(); + let back: ChangeEvent = serde_json::from_str(&json).unwrap(); + assert_eq!(back.event_type, event.event_type); + } + + #[test] + fn test_change_event_serde_round_trip_with_metadata() { + let data = RowData::from_pairs(vec![("id", ColumnValue::text("1"))]); + let mut meta = HashMap::new(); + meta.insert("key1".to_string(), "val1".to_string()); + meta.insert("key2".to_string(), "val2".to_string()); + let event = ChangeEvent::insert("s", "t", 1, data, Lsn::new(1000)).with_metadata(meta); + let json = serde_json::to_string(&event).unwrap(); + let back: ChangeEvent = serde_json::from_str(&json).unwrap(); + assert!(back.metadata.is_some()); + let m = back.metadata.unwrap(); + assert_eq!(m.get("key1").unwrap(), "val1"); + assert_eq!(m.get("key2").unwrap(), "val2"); + } + + #[test] + fn test_change_event_serde_round_trip_simple_events() { + for event in [ + ChangeEvent::relation(Lsn::new(1100)), + ChangeEvent::type_event(Lsn::new(1200)), + ChangeEvent::origin(Lsn::new(1300)), + ChangeEvent::message(Lsn::new(1400)), + ] { + let json = serde_json::to_string(&event).unwrap(); + let back: ChangeEvent = serde_json::from_str(&json).unwrap(); + assert_eq!( + back.event_type, event.event_type, + "Failed for {:?}", + event.event_type + ); + } + } + + #[test] + fn test_lsn_zero() { + let lsn = Lsn::new(0); + assert_eq!(lsn.value(), 0); + assert_eq!(format!("{lsn}"), "0/0"); + + let parsed: Lsn = "0/0".parse().unwrap(); + assert_eq!(parsed, lsn); + } + + #[test] + fn test_lsn_max() { + let lsn = Lsn::new(u64::MAX); + let formatted = format!("{lsn}"); + let parsed: Lsn = formatted.parse().unwrap(); + assert_eq!(parsed, lsn); + } + + #[test] + fn test_lsn_serde_zero_and_max() { + for val in [0u64, 1, u64::MAX / 2, u64::MAX] { + let lsn = Lsn::new(val); + let json = serde_json::to_string(&lsn).unwrap(); + let back: Lsn = serde_json::from_str(&json).unwrap(); + assert_eq!(lsn, back); + } + } + + #[test] + fn test_lsn_equality_and_hash() { + use std::collections::HashSet; + let a = Lsn::new(100); + let b = Lsn::new(100); + let c = Lsn::new(200); + assert_eq!(a, b); + assert_ne!(a, c); + + // Lsn doesn't impl Hash, but does impl Copy + Eq + let x = a; + assert_eq!(x, a); + let _ = HashSet::::new(); // type check only + } + + #[test] + fn test_replica_identity_debug_clone() { + let ri = ReplicaIdentity::Full; + let cloned = ri.clone(); + assert_eq!(ri, cloned); + let debug = format!("{ri:?}"); + assert!(debug.contains("Full"), "got: {debug}"); + } + + #[test] + fn test_replica_identity_round_trip_byte() { + for byte in [b'd', b'n', b'f', b'i'] { + let ri = ReplicaIdentity::from_byte(byte).unwrap(); + assert_eq!(ri.to_byte(), byte); + } + } + + #[test] + fn test_base_backup_options_default() { + let opts = BaseBackupOptions::default(); + assert!(opts.label.is_none()); + assert!(opts.target.is_none()); + assert!(!opts.progress); + assert!(!opts.wal); + assert!(!opts.wait); + assert!(opts.compression.is_none()); + assert!(opts.max_rate.is_none()); + assert!(!opts.tablespace_map); + assert!(!opts.verify_checksums); + assert!(opts.manifest.is_none()); + assert!(!opts.incremental); + } + + #[test] + fn test_replication_slot_options_default() { + let opts = ReplicationSlotOptions::default(); + assert!(!opts.temporary); + assert!(!opts.two_phase); + assert!(!opts.reserve_wal); + assert!(opts.snapshot.is_none()); + assert!(!opts.failover); + } + + #[test] + fn test_change_event_clone() { + let data = RowData::from_pairs(vec![("id", ColumnValue::text("1"))]); + let event = ChangeEvent::insert("s", "t", 1, data, Lsn::new(100)); + let cloned = event.clone(); + assert_eq!(cloned.lsn, event.lsn); + assert_eq!(cloned.event_type, event.event_type); + } + + #[test] + fn test_change_event_debug() { + let data = RowData::from_pairs(vec![("id", ColumnValue::text("1"))]); + let event = ChangeEvent::insert("s", "t", 1, data, Lsn::new(100)); + let debug = format!("{event:?}"); + assert!(debug.contains("Insert"), "got: {debug}"); + assert!(debug.contains("Lsn"), "got: {debug}"); + } + + #[test] + fn test_event_type_debug_all_variants() { + let data = RowData::new(); + let ts = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); + + let variants: Vec = vec![ + EventType::Insert { + schema: Arc::from("s"), + table: Arc::from("t"), + relation_oid: 1, + data: data.clone(), + }, + EventType::Update { + schema: Arc::from("s"), + table: Arc::from("t"), + relation_oid: 1, + old_data: None, + new_data: data.clone(), + replica_identity: ReplicaIdentity::Default, + key_columns: vec![], + }, + EventType::Delete { + schema: Arc::from("s"), + table: Arc::from("t"), + relation_oid: 1, + old_data: data.clone(), + replica_identity: ReplicaIdentity::Default, + key_columns: vec![], + }, + EventType::Truncate(vec![]), + EventType::Begin { + transaction_id: 1, + final_lsn: Lsn::new(1), + commit_timestamp: ts, + }, + EventType::Commit { + commit_timestamp: ts, + commit_lsn: Lsn::new(1), + end_lsn: Lsn::new(2), + }, + EventType::StreamStart { + transaction_id: 1, + first_segment: true, + }, + EventType::StreamStop, + EventType::StreamCommit { + transaction_id: 1, + commit_lsn: Lsn::new(1), + end_lsn: Lsn::new(2), + commit_timestamp: ts, + }, + EventType::StreamAbort { + transaction_id: 1, + subtransaction_xid: 0, + abort_lsn: None, + abort_timestamp: None, + }, + EventType::Relation, + EventType::Type, + EventType::Origin, + EventType::Message, + ]; + + for v in &variants { + let debug = format!("{v:?}"); + assert!(!debug.is_empty()); + } } } From ecbd5bdf004b994371f74c2598f60ae0d1f0b1d1 Mon Sep 17 00:00:00 2001 From: Danielshih Date: Wed, 25 Feb 2026 09:13:18 +0000 Subject: [PATCH 2/4] modified README and benchmarks to replace serde_json with ColumnValue for performance enhancement --- README.md | 10 +++++--- benches/columnvalue_vs_json.rs | 42 +--------------------------------- src/column_value.rs | 2 +- 3 files changed, 9 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index e1c0e16..f3af19b 100644 --- a/README.md +++ b/README.md @@ -137,11 +137,13 @@ async fn main() -> Result<(), Box> { ### Working with Event Data -Events carry row data as `RowData` — an ordered list of `(Arc, Value)` pairs. +Events carry row data as [`RowData`] — an ordered list of `(Arc, ColumnValue)` pairs. +[`ColumnValue`] is a lightweight enum (`Null | Text(Bytes) | Binary(Bytes)`) that preserves +the raw PostgreSQL wire representation with zero-copy semantics. Schema, table, and column names are `Arc` (reference-counted, zero-cost cloning): ```rust -use pg_walstream::{EventType, RowData}; +use pg_walstream::{EventType, RowData, ColumnValue}; // Pattern match on event types match &event.event_type { @@ -367,7 +369,9 @@ The library supports all PostgreSQL logical replication message types: - **Zero-Copy**: Uses `bytes::Bytes` for efficient buffer management - **Arc-shared column metadata**: Column names, schema, and table names use `Arc` — cloning is a single atomic increment instead of a heap allocation per event -- **RowData (ordered Vec)**: Row payloads use `RowData` (a `Vec<(Arc, Value)>`) instead of `HashMap`, eliminating per-event hashing overhead and extra allocations +- **RowData (ordered Vec)**: Row payloads use `RowData` (a `Vec<(Arc, ColumnValue)>`) instead of `HashMap`, eliminating per-event hashing overhead and extra allocations +- **ColumnValue (Null | Text | Binary)**: Preserves the raw PostgreSQL wire representation without intermediate JSON parsing or allocation. Each variant holds zero-copy `Bytes` +- **Binary Wire Format**: `ChangeEvent::encode` / `ChangeEvent::decode` provide a compact binary serialization that is significantly faster and smaller than `serde_json`, ideal for inter-process or network transport - **Atomic Operations**: Thread-safe LSN tracking with minimal overhead - **Connection Pooling**: Reusable connection with automatic retry - **Streaming Support**: Handle large transactions without memory issues diff --git a/benches/columnvalue_vs_json.rs b/benches/columnvalue_vs_json.rs index 520011b..067d77a 100644 --- a/benches/columnvalue_vs_json.rs +++ b/benches/columnvalue_vs_json.rs @@ -10,7 +10,6 @@ //! - `serialize` — Encode event to bytes: serde_json vs binary //! - `deserialize` — Decode bytes back to event: serde_json vs binary //! - `round_trip` — Full encode → decode cycle -//! - `payload_size` — Output size comparison (printed, not timed) //! - `pipeline` — Realistic CDC: construct → clone → lookup → serialize //! //! Run: @@ -224,45 +223,7 @@ fn bench_round_trip(c: &mut Criterion) { } // --------------------------------------------------------------------------- -// 5. Payload size comparison (one-shot, informational) -// --------------------------------------------------------------------------- -fn bench_payload_size(c: &mut Criterion) { - let mut group = c.benchmark_group("payload_size"); - - for n_cols in COLUMN_COUNTS { - let names = shared_column_names(n_cols); - let new_event = build_new_event(&names); - - let mut binary_buf = bytes::BytesMut::with_capacity(256); - new_event.encode(&mut binary_buf); - - // Bench building the payloads so criterion records something - group.bench_with_input( - BenchmarkId::new("json_serde", n_cols), - &new_event, - |b, event| { - b.iter(|| black_box(serde_json::to_vec(event).unwrap().len())); - }, - ); - - group.bench_with_input( - BenchmarkId::new("binary_encode", n_cols), - &new_event, - |b, event| { - b.iter(|| { - let mut buf = bytes::BytesMut::with_capacity(256); - event.encode(&mut buf); - black_box(buf.len()); - }); - }, - ); - } - - group.finish(); -} - -// --------------------------------------------------------------------------- -// 6. Realistic CDC pipeline: construct → clone → lookup → serialize +// 5. Realistic CDC pipeline: construct → clone → lookup → serialize // --------------------------------------------------------------------------- /// End-to-end CDC simulation: construct event, clone it, look up 3 columns, @@ -322,7 +283,6 @@ criterion_group!( bench_serialize, bench_deserialize, bench_round_trip, - bench_payload_size, bench_pipeline, ); criterion_main!(benches); diff --git a/src/column_value.rs b/src/column_value.rs index 3131e3b..af9b280 100644 --- a/src/column_value.rs +++ b/src/column_value.rs @@ -24,7 +24,7 @@ pub(crate) fn hex_encode(bytes: &[u8]) -> String { /// Decode a hex string to bytes. Returns `Err` on invalid hex. fn hex_decode(hex: &str) -> std::result::Result, &'static str> { - if !hex.len().is_multiple_of(2) { + if hex.len() % 2 != 0 { return Err("odd hex length"); } let mut out = Vec::with_capacity(hex.len() / 2); From 9c5b8ca554c86f961eff92a025cd05b49e5000ce Mon Sep 17 00:00:00 2001 From: Danielshih Date: Wed, 25 Feb 2026 10:27:03 +0000 Subject: [PATCH 3/4] non-UTF-8 text and improve hex handling --- src/column_value.rs | 139 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 115 insertions(+), 24 deletions(-) diff --git a/src/column_value.rs b/src/column_value.rs index af9b280..cb236d4 100644 --- a/src/column_value.rs +++ b/src/column_value.rs @@ -10,6 +10,7 @@ use crate::error::{ReplicationError, Result}; use bytes::{Bytes, BytesMut}; use serde::{Deserialize, Serialize}; use std::sync::Arc; +use serde::ser::SerializeMap; /// Encode a byte slice as lowercase hex string. pub(crate) fn hex_encode(bytes: &[u8]) -> String { @@ -61,12 +62,12 @@ fn hex_nibble(b: u8) -> Option { /// | `0x01` | `Text` — followed by u32-len + data | /// | `0x02` | `Binary` — followed by u32-len + data | /// -/// # Serde -/// /// When serialised with [`serde`], `Text` values emit a JSON string, -/// `Binary` values emit a hex-prefixed string (`"\\xdeadbeef"`), +/// `Binary` values emit a tagged JSON object `{"$binary": "deadbeef"}`, /// and `Null` emits JSON `null`. /// +/// The tagged-object format is unambiguous: a `Text` value whose content happens to look like hex will always round-trip correctly. +/// /// # Example /// /// ``` @@ -246,14 +247,16 @@ impl Serialize for ColumnValue { Self::Text(b) => match std::str::from_utf8(b) { Ok(s) => serializer.serialize_str(s), Err(_) => { - // Fall back to hex for non-UTF-8 text - let hex = hex_encode(b); - serializer.serialize_str(&format!("\\x{hex}")) + // Non-UTF-8 text cannot be represented as a JSON string, Emit the tagged binary form so the bytes survive round-trip. + let mut map = serializer.serialize_map(Some(1))?; + map.serialize_entry("$binary", &hex_encode(b))?; + map.end() } }, Self::Binary(b) => { - let hex = hex_encode(b); - serializer.serialize_str(&format!("\\x{hex}")) + let mut map = serializer.serialize_map(Some(1))?; + map.serialize_entry("$binary", &hex_encode(b))?; + map.end() } } } @@ -269,7 +272,7 @@ impl<'de> Deserialize<'de> for ColumnValue { type Value = ColumnValue; fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.write_str("a string, null, or hex-encoded binary") + f.write_str(r#"a string, null, or {"$binary": "hex..."}"#) } fn visit_none(self) -> std::result::Result { @@ -284,7 +287,6 @@ impl<'de> Deserialize<'de> for ColumnValue { self, deserializer: D, ) -> std::result::Result { - // Recurse into the inner value deserializer.deserialize_any(self) } @@ -292,15 +294,26 @@ impl<'de> Deserialize<'de> for ColumnValue { self, v: &str, ) -> std::result::Result { - if let Some(hex) = v.strip_prefix("\\x") { - // Decode hex to binary - match hex_decode(hex) { - Ok(bytes) => Ok(ColumnValue::Binary(Bytes::from(bytes))), - Err(e) => Err(E::custom(format!("invalid hex string: {e}"))), - } - } else { - Ok(ColumnValue::Text(Bytes::copy_from_slice(v.as_bytes()))) + Ok(ColumnValue::Text(Bytes::copy_from_slice(v.as_bytes()))) + } + + fn visit_map>( + self, + mut map: M, + ) -> std::result::Result { + use serde::de::Error; + let key: String = map + .next_key()? + .ok_or_else(|| M::Error::custom("expected \"$binary\" key in tagged object"))?; + if key != "$binary" { + return Err(M::Error::custom(format!( + r#"unknown key "{key}", expected "$binary""# + ))); } + let hex: String = map.next_value()?; + let bytes = hex_decode(&hex) + .map_err(|e| M::Error::custom(format!("invalid hex in $binary: {e}")))?; + Ok(ColumnValue::Binary(Bytes::from(bytes))) } } @@ -759,24 +772,24 @@ mod tests { #[test] fn test_column_value_serialize_non_utf8_text() { - // Text with invalid UTF-8 should fall back to hex encoding + // Text with invalid UTF-8 falls back to tagged binary object let v = ColumnValue::Text(Bytes::from_static(&[0xff, 0xfe])); let json = serde_json::to_string(&v).unwrap(); - assert_eq!(json, r#""\\xfffe""#); + assert_eq!(json, r#"{"$binary":"fffe"}"#); - // Round-trip: it deserializes back as Binary (due to \x prefix) + // Round-trip: deserializes back as Binary (raw bytes preserved) let back: ColumnValue = serde_json::from_str(&json).unwrap(); assert_eq!(back.as_bytes(), &[0xff, 0xfe]); } #[test] fn test_column_value_deserialize_invalid_hex() { - // \x prefix followed by invalid hex chars triggers the Err path in visit_str - let json = r#""\\xZZZZ""#; + // $binary with invalid hex chars triggers an error + let json = r#"{"$binary":"ZZZZ"}"#; let result = serde_json::from_str::(json); assert!(result.is_err()); let err_msg = result.unwrap_err().to_string(); - assert!(err_msg.contains("invalid hex string"), "got: {err_msg}"); + assert!(err_msg.contains("invalid hex"), "got: {err_msg}"); } #[test] @@ -971,6 +984,84 @@ mod tests { ); } + #[test] + fn test_text_starting_with_backslash_x_round_trips_as_text() { + // Text that happens to start with literal `\x` followed by valid hex must survive a JSON round-trip as Text, not be silently reinterpreted as Binary. + let original = ColumnValue::text(r"\x4142"); + let json = serde_json::to_string(&original).unwrap(); + let back: ColumnValue = serde_json::from_str(&json).unwrap(); + + // The variant must stay Text, not become Binary + assert_eq!( + back.as_str(), + Some(r"\x4142"), + "Text was corrupted into Binary on JSON round-trip" + ); + assert_eq!(original, back); + } + + #[test] + fn test_text_with_hex_prefix_and_odd_length_round_trips() { + // Odd-length hex after `\x` — still valid text content + let original = ColumnValue::text(r"\xABC"); + let json = serde_json::to_string(&original).unwrap(); + let back: ColumnValue = serde_json::from_str(&json).unwrap(); + assert_eq!(back.as_str(), Some(r"\xABC")); + assert_eq!(original, back); + } + + #[test] + fn test_binary_round_trips_unambiguously() { + // Binary values must round-trip as Binary, not collide with Text + let original = ColumnValue::binary_bytes(Bytes::from_static(&[0x41, 0x42])); + let json = serde_json::to_string(&original).unwrap(); + let back: ColumnValue = serde_json::from_str(&json).unwrap(); + assert_eq!(back.as_bytes(), &[0x41, 0x42]); + assert_eq!(original, back); + } + + #[test] + fn test_binary_and_text_do_not_collide_in_json() { + // The _serialized_ forms of Binary([0x41, 0x42]) and Text(r"\x4142") + // must be different JSON values so they decode to the correct variant. + let binary = ColumnValue::binary_bytes(Bytes::from_static(&[0x41, 0x42])); + let text = ColumnValue::text(r"\x4142"); + + let binary_json = serde_json::to_string(&binary).unwrap(); + let text_json = serde_json::to_string(&text).unwrap(); + + assert_ne!( + binary_json, text_json, + "Binary and Text produce identical JSON — deserialization will be ambiguous" + ); + } + + #[test] + fn test_rowdata_with_hex_like_text_round_trips() { + // End-to-end: a RowData containing a text column that looks like hex must survive JSON round-trip without corruption. + let row = RowData::from_pairs(vec![ + ("hash", ColumnValue::text(r"\xdeadbeef")), + ( + "blob", + ColumnValue::binary_bytes(Bytes::from_static(&[0xca, 0xfe])), + ), + ("name", ColumnValue::text("Alice")), + ]); + let json = serde_json::to_string(&row).unwrap(); + let back: RowData = serde_json::from_str(&json).unwrap(); + + assert_eq!( + back.get("hash").and_then(|v| v.as_str()), + Some(r"\xdeadbeef"), + "Text column 'hash' was corrupted to Binary" + ); + assert_eq!( + back.get("blob").map(|v| v.as_bytes()), + Some(&[0xca, 0xfe][..]) + ); + assert_eq!(back.get("name").and_then(|v| v.as_str()), Some("Alice")); + } + #[test] fn test_rowdata_debug() { let row = RowData::from_pairs(vec![("x", ColumnValue::text("y"))]); From ebf1292ca07799709095ea3cfb91f1b19fc4bcc0 Mon Sep 17 00:00:00 2001 From: Danielshih Date: Wed, 25 Feb 2026 10:39:36 +0000 Subject: [PATCH 4/4] Refactor hex encoding and decoding functions for clarity and performance --- src/column_value.rs | 2 +- src/types.rs | 36 +++++++++++++++++++++++++++--------- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/src/column_value.rs b/src/column_value.rs index cb236d4..d276cec 100644 --- a/src/column_value.rs +++ b/src/column_value.rs @@ -8,9 +8,9 @@ use crate::buffer::BufferReader; use crate::error::{ReplicationError, Result}; use bytes::{Bytes, BytesMut}; +use serde::ser::SerializeMap; use serde::{Deserialize, Serialize}; use std::sync::Arc; -use serde::ser::SerializeMap; /// Encode a byte slice as lowercase hex string. pub(crate) fn hex_encode(bytes: &[u8]) -> String { diff --git a/src/types.rs b/src/types.rs index 5901820..bdd8448 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1157,7 +1157,7 @@ impl ChangeEvent { let transaction_id = reader.read_u32()?; let final_lsn = Lsn(reader.read_u64()?); let ts_micros = reader.read_i64()?; - let commit_timestamp = micros_to_chrono(ts_micros); + let commit_timestamp = micros_to_chrono(ts_micros)?; EventType::Begin { transaction_id, final_lsn, @@ -1166,7 +1166,7 @@ impl ChangeEvent { } message_types::COMMIT => { let ts_micros = reader.read_i64()?; - let commit_timestamp = micros_to_chrono(ts_micros); + let commit_timestamp = micros_to_chrono(ts_micros)?; let commit_lsn = Lsn(reader.read_u64()?); let end_lsn = Lsn(reader.read_u64()?); EventType::Commit { @@ -1189,7 +1189,7 @@ impl ChangeEvent { let commit_lsn = Lsn(reader.read_u64()?); let end_lsn = Lsn(reader.read_u64()?); let ts_micros = reader.read_i64()?; - let commit_timestamp = micros_to_chrono(ts_micros); + let commit_timestamp = micros_to_chrono(ts_micros)?; EventType::StreamCommit { transaction_id, commit_lsn, @@ -1208,7 +1208,7 @@ impl ChangeEvent { }; let has_ts = reader.read_u8()?; let abort_timestamp = if has_ts != 0 { - Some(micros_to_chrono(reader.read_i64()?)) + Some(micros_to_chrono(reader.read_i64()?)?) } else { None }; @@ -1261,13 +1261,19 @@ fn decode_string(reader: &mut BufferReader) -> Result { } /// Convert Unix timestamp microseconds to chrono DateTime. -fn micros_to_chrono(micros: i64) -> chrono::DateTime { +/// +/// Returns an error if the value is outside the representable range +/// (e.g. from a corrupted binary message) instead of silently falling +/// back to the Unix epoch. +fn micros_to_chrono(micros: i64) -> Result> { use chrono::{TimeZone, Utc}; let secs = micros.div_euclid(1_000_000); let subsec_nanos = (micros.rem_euclid(1_000_000) as u32) * 1000; Utc.timestamp_opt(secs, subsec_nanos) .single() - .unwrap_or_else(|| Utc.timestamp_opt(0, 0).unwrap()) + .ok_or_else(|| { + ReplicationError::protocol(format!("timestamp {micros} µs out of representable range")) + }) } #[cfg(test)] @@ -2164,24 +2170,36 @@ mod tests { #[test] fn test_micros_to_chrono_zero() { // Zero micros = Unix epoch - let dt = micros_to_chrono(0); + let dt = micros_to_chrono(0).unwrap(); assert_eq!(dt, Utc.with_ymd_and_hms(1970, 1, 1, 0, 0, 0).unwrap()); } #[test] fn test_micros_to_chrono_negative() { // Negative micros = before Unix epoch - let dt = micros_to_chrono(-1_000_000); + let dt = micros_to_chrono(-1_000_000).unwrap(); assert_eq!(dt, Utc.with_ymd_and_hms(1969, 12, 31, 23, 59, 59).unwrap()); } #[test] fn test_micros_to_chrono_with_subsecond() { - let dt = micros_to_chrono(1_500_000); // 1.5 seconds + let dt = micros_to_chrono(1_500_000).unwrap(); // 1.5 seconds assert_eq!(dt.timestamp(), 1); assert_eq!(dt.timestamp_subsec_micros(), 500_000); } + #[test] + fn test_micros_to_chrono_out_of_range() { + // i64::MAX µs is far beyond what chrono can represent — must return Err + let result = micros_to_chrono(i64::MAX); + assert!(result.is_err(), "expected Err for i64::MAX, got {result:?}"); + let msg = result.unwrap_err().to_string(); + assert!( + msg.contains("out of representable range"), + "unexpected error message: {msg}" + ); + } + #[test] fn test_change_event_serde_round_trip_update() { let old = RowData::from_pairs(vec![("id", ColumnValue::text("1"))]);