diff --git a/Cargo.lock b/Cargo.lock index 9e293fd2b8..f65f0ac59e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4946,9 +4946,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.2.0" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" [[package]] name = "num-derive" @@ -11003,30 +11003,30 @@ dependencies = [ [[package]] name = "time" -version = "0.3.47" +version = "0.3.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" +checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" dependencies = [ "deranged", "itoa", "num-conv", "powerfmt", - "serde_core", + "serde", "time-core", "time-macros", ] [[package]] name = "time-core" -version = "0.1.8" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" +checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" [[package]] name = "time-macros" -version = "0.2.27" +version = "0.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" +checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" dependencies = [ "num-conv", "time-core", diff --git a/forester-utils/src/address_staging_tree.rs b/forester-utils/src/address_staging_tree.rs index 786ddb6ac0..a6b1aa89bd 100644 --- a/forester-utils/src/address_staging_tree.rs +++ b/forester-utils/src/address_staging_tree.rs @@ -121,7 +121,7 @@ impl AddressStagingTree { low_element_next_values: &[[u8; 32]], low_element_indices: &[u64], low_element_next_indices: &[u64], - low_element_proofs: &[Vec<[u8; 32]>], + low_element_proofs: &[[[u8; 32]; HEIGHT]], leaves_hashchain: [u8; 32], zkp_batch_size: usize, epoch: u64, @@ -145,15 +145,12 @@ impl AddressStagingTree { let inputs = get_batch_address_append_circuit_inputs::( next_index, old_root, - low_element_values.to_vec(), - low_element_next_values.to_vec(), - low_element_indices.iter().map(|v| *v as usize).collect(), - low_element_next_indices - .iter() - .map(|v| *v as usize) - .collect(), - low_element_proofs.to_vec(), - addresses.to_vec(), + low_element_values, + low_element_next_values, + low_element_indices, + low_element_next_indices, + low_element_proofs, + addresses, &mut self.sparse_tree, leaves_hashchain, zkp_batch_size, diff --git a/forester/src/forester_status.rs b/forester/src/forester_status.rs index 80c4539075..d8f958134d 100644 --- a/forester/src/forester_status.rs +++ b/forester/src/forester_status.rs @@ -670,20 +670,22 @@ fn parse_tree_status( let fullness = next_index as f64 / capacity as f64 * 100.0; let (queue_len, queue_cap) = queue_account - .map(|acc| { - unsafe { parse_hash_set_from_bytes::(&acc.data) } - .ok() - .map(|hs| { + .map( + |acc| match unsafe { parse_hash_set_from_bytes::(&acc.data) } { + Ok(hs) => { let len = hs .iter() .filter(|(_, cell)| cell.sequence_number.is_none()) .count() as u64; let cap = hs.get_capacity() as u64; - (len, cap) - }) - .unwrap_or((0, 0)) - }) - .map(|(l, c)| (Some(l), Some(c))) + (Some(len), Some(cap)) + } + Err(error) => { + warn!(?error, "Failed to parse StateV1 queue hash set"); + (None, None) + } + }, + ) .unwrap_or((None, None)); ( @@ -725,20 +727,22 @@ fn parse_tree_status( let fullness = next_index as f64 / capacity as f64 * 100.0; let (queue_len, queue_cap) = queue_account - .map(|acc| { - unsafe { parse_hash_set_from_bytes::(&acc.data) } - .ok() - .map(|hs| { + .map( + |acc| match unsafe { parse_hash_set_from_bytes::(&acc.data) } { + Ok(hs) => { let len = hs .iter() .filter(|(_, cell)| cell.sequence_number.is_none()) .count() as u64; let cap = hs.get_capacity() as u64; - (len, cap) - }) - .unwrap_or((0, 0)) - }) - .map(|(l, c)| (Some(l), Some(c))) + (Some(len), Some(cap)) + } + Err(error) => { + warn!(?error, "Failed to parse AddressV1 queue hash set"); + (None, None) + } + }, + ) .unwrap_or((None, None)); ( diff --git a/forester/src/processor/v2/helpers.rs b/forester/src/processor/v2/helpers.rs index ed135cb6a4..dd79ec9901 100644 --- a/forester/src/processor/v2/helpers.rs +++ b/forester/src/processor/v2/helpers.rs @@ -9,6 +9,7 @@ use light_client::{ indexer::{AddressQueueData, Indexer, QueueElementsV2Options, StateQueueData}, rpc::Rpc, }; +use light_hasher::hash_chain::create_hash_chain_from_slice; use crate::processor::v2::{common::clamp_to_u16, BatchContext}; @@ -22,6 +23,17 @@ pub(crate) fn lock_recover<'a, T>(mutex: &'a Mutex, name: &'static str) -> Mu } } +#[derive(Debug, Clone)] +pub struct AddressBatchSnapshot { + pub addresses: Vec<[u8; 32]>, + pub low_element_values: Vec<[u8; 32]>, + pub low_element_next_values: Vec<[u8; 32]>, + pub low_element_indices: Vec, + pub low_element_next_indices: Vec, + pub low_element_proofs: Vec<[[u8; 32]; HEIGHT]>, + pub leaves_hashchain: [u8; 32], +} + pub async fn fetch_zkp_batch_size(context: &BatchContext) -> crate::Result { let rpc = context.rpc_pool.get_connection().await?; let mut account = rpc @@ -474,20 +486,96 @@ impl StreamingAddressQueue { } } - pub fn get_batch_data(&self, start: usize, end: usize) -> Option { + pub fn get_batch_snapshot( + &self, + start: usize, + end: usize, + hashchain_idx: usize, + ) -> crate::Result>> { let available = self.wait_for_batch(end); - if start >= available { - return None; + if available < end || start >= end { + return Ok(None); } - let actual_end = end.min(available); let data = lock_recover(&self.data, "streaming_address_queue.data"); - Some(BatchDataSlice { - addresses: data.addresses[start..actual_end].to_vec(), - low_element_values: data.low_element_values[start..actual_end].to_vec(), - low_element_next_values: data.low_element_next_values[start..actual_end].to_vec(), - low_element_indices: data.low_element_indices[start..actual_end].to_vec(), - low_element_next_indices: data.low_element_next_indices[start..actual_end].to_vec(), - }) + let Some(addresses) = data.addresses.get(start..end).map(|slice| slice.to_vec()) else { + return Ok(None); + }; + if addresses.is_empty() { + return Ok(None); + } + let expected_len = addresses.len(); + let Some(low_element_values) = data + .low_element_values + .get(start..end) + .map(|slice| slice.to_vec()) + else { + return Ok(None); + }; + let Some(low_element_next_values) = data + .low_element_next_values + .get(start..end) + .map(|slice| slice.to_vec()) + else { + return Ok(None); + }; + let Some(low_element_indices) = data + .low_element_indices + .get(start..end) + .map(|slice| slice.to_vec()) + else { + return Ok(None); + }; + let Some(low_element_next_indices) = data + .low_element_next_indices + .get(start..end) + .map(|slice| slice.to_vec()) + else { + return Ok(None); + }; + if [ + low_element_values.len(), + low_element_next_values.len(), + low_element_indices.len(), + low_element_next_indices.len(), + ] + .iter() + .any(|&len| len != expected_len) + { + return Ok(None); + } + let low_element_proofs = match data.reconstruct_proofs::(start..end) { + Ok(proofs) if proofs.len() == expected_len => proofs, + Ok(_) | Err(_) => return Ok(None), + }; + + let leaves_hashchain = match data.leaves_hash_chains.get(hashchain_idx).copied() { + Some(hashchain) => hashchain, + None => { + tracing::debug!( + "Missing leaves_hash_chain for batch {} (available: {}), deriving from addresses", + hashchain_idx, + data.leaves_hash_chains.len() + ); + create_hash_chain_from_slice(&addresses).map_err(|error| { + anyhow!( + "Failed to derive leaves_hash_chain for batch {} from {} addresses: {}", + hashchain_idx, + addresses.len(), + error + ) + })? + } + }; + + Ok(Some(AddressBatchSnapshot { + low_element_values, + low_element_next_values, + low_element_indices, + low_element_next_indices, + low_element_proofs, + addresses, + leaves_hashchain, + })) } pub fn into_data(self) -> AddressQueueData { @@ -522,6 +610,10 @@ impl StreamingAddressQueue { lock_recover(&self.data, "streaming_address_queue.data").start_index } + pub fn tree_next_insertion_index(&self) -> u64 { + lock_recover(&self.data, "streaming_address_queue.data").tree_next_insertion_index + } + pub fn subtrees(&self) -> Vec<[u8; 32]> { lock_recover(&self.data, "streaming_address_queue.data") .subtrees @@ -553,15 +645,6 @@ impl StreamingAddressQueue { } } -#[derive(Debug, Clone)] -pub struct BatchDataSlice { - pub addresses: Vec<[u8; 32]>, - pub low_element_values: Vec<[u8; 32]>, - pub low_element_next_values: Vec<[u8; 32]>, - pub low_element_indices: Vec, - pub low_element_next_indices: Vec, -} - pub async fn fetch_streaming_address_batches( context: &BatchContext, total_elements: u64, diff --git a/forester/src/processor/v2/processor.rs b/forester/src/processor/v2/processor.rs index 3de6dea860..372a800e0e 100644 --- a/forester/src/processor/v2/processor.rs +++ b/forester/src/processor/v2/processor.rs @@ -132,7 +132,7 @@ where } if self.worker_pool.is_none() { - let job_tx = spawn_proof_workers(&self.context.prover_config); + let job_tx = spawn_proof_workers(&self.context.prover_config)?; self.worker_pool = Some(WorkerPool { job_tx }); } @@ -532,7 +532,7 @@ where ((queue_size / self.zkp_batch_size) as usize).min(self.context.max_batches_per_tree); if self.worker_pool.is_none() { - let job_tx = spawn_proof_workers(&self.context.prover_config); + let job_tx = spawn_proof_workers(&self.context.prover_config)?; self.worker_pool = Some(WorkerPool { job_tx }); } @@ -561,7 +561,7 @@ where let max_batches = max_batches.min(self.context.max_batches_per_tree); if self.worker_pool.is_none() { - let job_tx = spawn_proof_workers(&self.context.prover_config); + let job_tx = spawn_proof_workers(&self.context.prover_config)?; self.worker_pool = Some(WorkerPool { job_tx }); } diff --git a/forester/src/processor/v2/proof_worker.rs b/forester/src/processor/v2/proof_worker.rs index b7afeacf0b..603fa3f19b 100644 --- a/forester/src/processor/v2/proof_worker.rs +++ b/forester/src/processor/v2/proof_worker.rs @@ -132,27 +132,27 @@ struct ProofClients { } impl ProofClients { - fn new(config: &ProverConfig) -> Self { - Self { + fn new(config: &ProverConfig) -> crate::Result { + Ok(Self { append_client: ProofClient::with_config( config.append_url.clone(), config.polling_interval, config.max_wait_time, config.api_key.clone(), - ), + )?, nullify_client: ProofClient::with_config( config.update_url.clone(), config.polling_interval, config.max_wait_time, config.api_key.clone(), - ), + )?, address_append_client: ProofClient::with_config( config.address_append_url.clone(), config.polling_interval, config.max_wait_time, config.api_key.clone(), - ), - } + )?, + }) } fn get_client(&self, input: &ProofInput) -> &ProofClient { @@ -164,11 +164,13 @@ impl ProofClients { } } -pub fn spawn_proof_workers(config: &ProverConfig) -> async_channel::Sender { +pub fn spawn_proof_workers( + config: &ProverConfig, +) -> crate::Result> { let (job_tx, job_rx) = async_channel::bounded::(256); - let clients = Arc::new(ProofClients::new(config)); + let clients = Arc::new(ProofClients::new(config)?); tokio::spawn(async move { run_proof_pipeline(job_rx, clients).await }); - job_tx + Ok(job_tx) } async fn run_proof_pipeline( diff --git a/forester/src/processor/v2/strategy/address.rs b/forester/src/processor/v2/strategy/address.rs index 06e94d5500..51236c389b 100644 --- a/forester/src/processor/v2/strategy/address.rs +++ b/forester/src/processor/v2/strategy/address.rs @@ -14,11 +14,10 @@ use tracing::{debug, info, instrument}; use crate::processor::v2::{ batch_job_builder::BatchJobBuilder, - common::get_leaves_hashchain, errors::V2Error, helpers::{ fetch_address_zkp_batch_size, fetch_onchain_address_root, fetch_streaming_address_batches, - lock_recover, StreamingAddressQueue, + AddressBatchSnapshot, StreamingAddressQueue, }, proof_worker::ProofInput, root_guard::{reconcile_alignment, AlignmentDecision}, @@ -168,7 +167,7 @@ impl TreeStrategy for AddressTreeStrategy { } let initial_root = streaming_queue.initial_root(); - let start_index = streaming_queue.start_index(); + let start_index = streaming_queue.tree_next_insertion_index(); let subtrees_arr: [[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize] = subtrees.try_into().map_err(|v: Vec<[u8; 32]>| { @@ -267,9 +266,23 @@ impl BatchJobBuilder for AddressQueueData { let batch_end = start + zkp_batch_size_usize; - let batch_data = self - .streaming_queue - .get_batch_data(start, batch_end) + let streaming_queue = &self.streaming_queue; + let staging_tree = &mut self.staging_tree; + let hashchain_idx = start / zkp_batch_size_usize; + let AddressBatchSnapshot { + addresses, + low_element_values, + low_element_next_values, + low_element_indices, + low_element_next_indices, + low_element_proofs, + leaves_hashchain, + } = streaming_queue + .get_batch_snapshot::<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>( + start, + batch_end, + hashchain_idx, + )? .ok_or_else(|| { anyhow!( "Batch data not available: start={}, end={}, available={}", @@ -278,31 +291,21 @@ impl BatchJobBuilder for AddressQueueData { self.streaming_queue.available_batches() * zkp_batch_size_usize ) })?; - - let addresses = &batch_data.addresses; let zkp_batch_size_actual = addresses.len(); - - if zkp_batch_size_actual == 0 { - return Err(anyhow!("Empty batch at start={}", start)); - } - - let low_element_values = &batch_data.low_element_values; - let low_element_next_values = &batch_data.low_element_next_values; - let low_element_indices = &batch_data.low_element_indices; - let low_element_next_indices = &batch_data.low_element_next_indices; - - let low_element_proofs: Vec> = { - let data = lock_recover(self.streaming_queue.data.as_ref(), "streaming_queue.data"); - (start..start + zkp_batch_size_actual) - .map(|i| data.reconstruct_proof(i, DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as u8)) - .collect::, _>>()? - }; - - let hashchain_idx = start / zkp_batch_size_usize; - let leaves_hashchain = { - let data = lock_recover(self.streaming_queue.data.as_ref(), "streaming_queue.data"); - get_leaves_hashchain(&data.leaves_hash_chains, hashchain_idx)? - }; + let result = staging_tree + .process_batch( + &addresses, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + leaves_hashchain, + zkp_batch_size_actual, + epoch, + tree, + ) + .map_err(|err| map_address_staging_error(tree, err))?; let tree_batch = tree_next_index / zkp_batch_size_usize; let absolute_index = data_start + start; @@ -318,24 +321,6 @@ impl BatchJobBuilder for AddressQueueData { self.streaming_queue.is_complete() ); - let result = self.staging_tree.process_batch( - addresses, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - &low_element_proofs, - leaves_hashchain, - zkp_batch_size_actual, - epoch, - tree, - ); - - let result = match result { - Ok(r) => r, - Err(err) => return Err(map_address_staging_error(tree, err)), - }; - Ok(Some(( ProofInput::AddressAppend(result.circuit_inputs), result.new_root, diff --git a/program-tests/utils/src/actions/legacy/instructions/transfer2.rs b/program-tests/utils/src/actions/legacy/instructions/transfer2.rs index 1ff92eeda9..b006824302 100644 --- a/program-tests/utils/src/actions/legacy/instructions/transfer2.rs +++ b/program-tests/utils/src/actions/legacy/instructions/transfer2.rs @@ -169,13 +169,23 @@ pub async fn create_generic_transfer2_instruction( payer: Pubkey, should_filter_zero_outputs: bool, ) -> Result { - // // Get a single shared output queue for ALL compress/compress-and-close operations - // // This prevents reordering issues caused by the sort_by_key at the end - // let shared_output_queue = rpc - // .get_random_state_tree_info() - // .unwrap() - // .get_output_pubkey() - // .unwrap(); + // Transfer2 supports a single output queue per instruction. Legacy helpers accept + // per-action queues, but normalize them down to one shared queue for the IX. + let mut explicit_output_queue = None; + for action in &actions { + let candidate = match action { + Transfer2InstructionType::Compress(input) => Some(input.output_queue), + Transfer2InstructionType::CompressAndClose(input) => Some(input.output_queue), + Transfer2InstructionType::Decompress(_) + | Transfer2InstructionType::Transfer(_) + | Transfer2InstructionType::Approve(_) => None, + }; + if let Some(candidate) = candidate { + if explicit_output_queue.is_none() { + explicit_output_queue = Some(candidate); + } + } + } let mut hashes = Vec::new(); actions.iter().for_each(|account| match account { @@ -210,24 +220,16 @@ pub async fn create_generic_transfer2_instruction( .value; let mut packed_tree_accounts = PackedAccounts::default(); - // tree infos must be packed before packing the token input accounts - let packed_tree_infos = rpc_proof_result.pack_tree_infos(&mut packed_tree_accounts); + // Pack only input state tree infos. Grouped transfer2 proofs can span multiple output trees. + let packed_tree_infos = rpc_proof_result.pack_state_tree_infos(&mut packed_tree_accounts); - // We use a single shared output queue for all compress/compress-and-close operations to avoid ordering failures. - let shared_output_queue = if packed_tree_infos.address_trees.is_empty() { - let shared_output_queue = rpc - .get_random_state_tree_info() + let shared_output_queue = explicit_output_queue.unwrap_or_else(|| { + rpc.get_random_state_tree_info() .unwrap() .get_output_pubkey() - .unwrap(); - packed_tree_accounts.insert_or_get(shared_output_queue) - } else { - packed_tree_infos - .state_trees - .as_ref() .unwrap() - .output_tree_index - }; + }); + let shared_output_queue = packed_tree_accounts.insert_or_get(shared_output_queue); let mut inputs_offset = 0; let mut in_lamports = Vec::new(); @@ -242,14 +244,7 @@ pub async fn create_generic_transfer2_instruction( if let Some(ref input_token_account) = input.compressed_token_account { let token_data = input_token_account .iter() - .zip( - packed_tree_infos - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[inputs_offset..] - .iter(), - ) + .zip(packed_tree_infos[inputs_offset..].iter()) .map(|(account, rpc_account)| { if input.to != account.token.owner { return Err(TokenSdkError::InvalidCompressInputOwner); @@ -391,14 +386,7 @@ pub async fn create_generic_transfer2_instruction( let token_data = input .compressed_token_account .iter() - .zip( - packed_tree_infos - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[inputs_offset..] - .iter(), - ) + .zip(packed_tree_infos[inputs_offset..].iter()) .map(|(account, rpc_account)| { pack_input_token_account( account, @@ -460,14 +448,7 @@ pub async fn create_generic_transfer2_instruction( let token_data = input .compressed_token_account .iter() - .zip( - packed_tree_infos - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[inputs_offset..] - .iter(), - ) + .zip(packed_tree_infos[inputs_offset..].iter()) .map(|(account, rpc_account)| { pack_input_token_account( account, @@ -542,14 +523,7 @@ pub async fn create_generic_transfer2_instruction( let token_data = input .compressed_token_account .iter() - .zip( - packed_tree_infos - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[inputs_offset..] - .iter(), - ) + .zip(packed_tree_infos[inputs_offset..].iter()) .map(|(account, rpc_account)| { pack_input_token_account( account, diff --git a/program-tests/utils/src/e2e_test_env.rs b/program-tests/utils/src/e2e_test_env.rs index a16a7925a7..c2c415ec9e 100644 --- a/program-tests/utils/src/e2e_test_env.rs +++ b/program-tests/utils/src/e2e_test_env.rs @@ -73,7 +73,6 @@ use account_compression::{ use anchor_lang::{prelude::AccountMeta, AnchorSerialize, Discriminator}; use create_address_test_program::create_invoke_cpi_instruction; use forester_utils::{ - account_zero_copy::AccountZeroCopy, address_merkle_tree_config::{address_tree_ready_for_rollover, state_tree_ready_for_rollover}, forester_epoch::{Epoch, Forester, TreeAccounts}, utils::airdrop_lamports, @@ -194,6 +193,7 @@ use crate::{ }, test_batch_forester::{perform_batch_append, perform_batch_nullify}, test_forester::{empty_address_queue_test, nullify_compressed_accounts}, + AccountZeroCopy, }; pub struct User { @@ -748,70 +748,67 @@ where .with_address_queue(None, Some(batch.batch_size as u16)); let result = self .indexer - .get_queue_elements(merkle_tree_pubkey.to_bytes(), options, None) + .get_queue_elements( + merkle_tree_pubkey.to_bytes(), + options, + None, + ) .await .unwrap(); - let addresses = result - .value - .address_queue - .map(|aq| aq.addresses) - .unwrap_or_default(); + let address_queue = result.value.address_queue.unwrap(); + let low_element_proofs = address_queue + .reconstruct_all_proofs::<{ + DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize + }>() + .unwrap(); // // local_leaves_hash_chain is only used for a test assertion. // let local_nullifier_hash_chain = create_hash_chain_from_array(&addresses); // assert_eq!(leaves_hash_chain, local_nullifier_hash_chain); - let start_index = merkle_tree.next_index as usize; + let start_index = address_queue.tree_next_insertion_index as usize; assert!( start_index >= 2, "start index should be greater than 2 else tree is not inited" ); let current_root = *merkle_tree.root_history.last().unwrap(); - let mut low_element_values = Vec::new(); - let mut low_element_indices = Vec::new(); - let mut low_element_next_indices = Vec::new(); - let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); - let non_inclusion_proofs = self - .indexer - .get_multiple_new_address_proofs( - merkle_tree_pubkey.to_bytes(), - addresses.clone(), - None, - ) - .await - .unwrap(); - for non_inclusion_proof in &non_inclusion_proofs.value.items { - low_element_values.push(non_inclusion_proof.low_address_value); - low_element_indices - .push(non_inclusion_proof.low_address_index as usize); - low_element_next_indices - .push(non_inclusion_proof.low_address_next_index as usize); - low_element_next_values - .push(non_inclusion_proof.low_address_next_value); - - low_element_proofs - .push(non_inclusion_proof.low_address_proof.to_vec()); - } - - let subtrees = self.indexer - .get_subtrees(merkle_tree_pubkey.to_bytes(), None) - .await - .unwrap(); - let mut sparse_merkle_tree = SparseMerkleTree::::new(<[[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize]>::try_from(subtrees.value.items).unwrap(), start_index); + assert_eq!(address_queue.initial_root, current_root); + let light_client::indexer::AddressQueueData { + addresses, + low_element_values, + low_element_next_values, + low_element_indices, + low_element_next_indices, + subtrees, + .. + } = address_queue; + let mut sparse_merkle_tree = SparseMerkleTree::< + Poseidon, + { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, + >::new( + subtrees.as_slice().try_into().unwrap(), + start_index, + ); - let mut changelog: Vec> = Vec::new(); - let mut indexed_changelog: Vec> = Vec::new(); + let mut changelog: Vec< + ChangelogEntry<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>, + > = Vec::new(); + let mut indexed_changelog: Vec< + IndexedChangelogEntry< + usize, + { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, + >, + > = Vec::new(); let inputs = get_batch_address_append_circuit_inputs::< { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, >( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - addresses, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &addresses, &mut sparse_merkle_tree, leaves_hash_chain, batch.zkp_batch_size as usize, @@ -834,9 +831,13 @@ where if response_result.status().is_success() { let body = response_result.text().await.unwrap(); - let proof_json = deserialize_gnark_proof_json(&body).unwrap(); - let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); - let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c); + let proof_json = deserialize_gnark_proof_json(&body) + .map_err(|error| RpcError::CustomError(error.to_string())) + .unwrap(); + let (proof_a, proof_b, proof_c) = + proof_from_json_struct(proof_json).unwrap(); + let (proof_a, proof_b, proof_c) = + compress_proof(&proof_a, &proof_b, &proof_c).unwrap(); let instruction_data = InstructionDataBatchNullifyInputs { new_root: circuit_inputs_new_root, compressed_proof: CompressedProof { diff --git a/program-tests/utils/src/mock_batched_forester.rs b/program-tests/utils/src/mock_batched_forester.rs index 4458aa03b3..f3ad76cdbe 100644 --- a/program-tests/utils/src/mock_batched_forester.rs +++ b/program-tests/utils/src/mock_batched_forester.rs @@ -132,7 +132,8 @@ impl MockBatchedForester { assert_eq!(computed_new_root, self.merkle_tree.root()); - let proof_result = match ProofClient::local() + let proof_client = ProofClient::local()?; + let proof_result = match proof_client .generate_batch_append_proof(circuit_inputs) .await { @@ -207,9 +208,8 @@ impl MockBatchedForester { batch_size, &[], )?; - let proof_result = ProofClient::local() - .generate_batch_update_proof(inputs) - .await?; + let proof_client = ProofClient::local()?; + let proof_result = proof_client.generate_batch_update_proof(inputs).await?; let new_root = self.merkle_tree.root(); let proof = CompressedProof { a: proof_result.0.proof.a, @@ -260,7 +260,7 @@ impl MockBatchedAddressForester { let mut low_element_indices = Vec::new(); let mut low_element_next_indices = Vec::new(); let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); + let mut low_element_proofs: Vec<[[u8; 32]; HEIGHT]> = Vec::new(); for new_element_value in &new_element_values { let non_inclusion_proof = self .merkle_tree @@ -270,7 +270,18 @@ impl MockBatchedAddressForester { low_element_indices.push(non_inclusion_proof.leaf_index); low_element_next_indices.push(non_inclusion_proof.next_index); low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); - low_element_proofs.push(non_inclusion_proof.merkle_proof.as_slice().to_vec()); + let proof = non_inclusion_proof + .merkle_proof + .as_slice() + .try_into() + .map_err(|_| { + ProverClientError::InvalidProofData(format!( + "invalid low element proof length: expected {}, got {}", + HEIGHT, + non_inclusion_proof.merkle_proof.len() + )) + })?; + low_element_proofs.push(proof); } let subtrees = self.merkle_tree.merkle_tree.get_subtrees(); let mut merkle_tree = match <[[u8; 32]; HEIGHT]>::try_from(subtrees) { @@ -287,12 +298,12 @@ impl MockBatchedAddressForester { let inputs = match get_batch_address_append_circuit_inputs::( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - new_element_values.clone(), + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &new_element_values, &mut merkle_tree, leaves_hashchain, zkp_batch_size as usize, @@ -307,7 +318,8 @@ impl MockBatchedAddressForester { ))); } }; - let proof_result = match ProofClient::local() + let proof_client = ProofClient::local()?; + let proof_result = match proof_client .generate_batch_address_append_proof(inputs) .await { diff --git a/program-tests/utils/src/test_batch_forester.rs b/program-tests/utils/src/test_batch_forester.rs index 8e6909704f..8cec32757f 100644 --- a/program-tests/utils/src/test_batch_forester.rs +++ b/program-tests/utils/src/test_batch_forester.rs @@ -164,7 +164,7 @@ pub async fn create_append_batch_ix_data( bigint_to_be_bytes_array::<32>(&circuit_inputs.new_root.to_biguint().unwrap()).unwrap(), bundle.merkle_tree.root() ); - let proof_client = ProofClient::local(); + let proof_client = ProofClient::local().unwrap(); let inputs_json = BatchAppendInputsJson::from_inputs(&circuit_inputs).to_string(); match proof_client.generate_proof(inputs_json).await { @@ -293,7 +293,7 @@ pub async fn get_batched_nullify_ix_data( &[], ) .unwrap(); - let proof_client = ProofClient::local(); + let proof_client = ProofClient::local().unwrap(); let circuit_inputs_new_root = bigint_to_be_bytes_array::<32>(&inputs.new_root.to_biguint().unwrap()).unwrap(); let inputs_json = update_inputs_string(&inputs); @@ -319,13 +319,13 @@ pub async fn get_batched_nullify_ix_data( }) } -use forester_utils::{ - account_zero_copy::AccountZeroCopy, instructions::create_account::create_account_instruction, -}; +use forester_utils::instructions::create_account::create_account_instruction; use light_client::indexer::{Indexer, QueueElementsV2Options}; use light_program_test::indexer::state_tree::StateMerkleTreeBundle; use light_sparse_merkle_tree::SparseMerkleTree; +use crate::AccountZeroCopy; + pub async fn assert_registry_created_batched_state_merkle_tree( rpc: &mut R, payer_pubkey: Pubkey, @@ -663,50 +663,33 @@ pub async fn create_batch_update_address_tree_instruction_data_with_proof() + .unwrap(); // // local_leaves_hash_chain is only used for a test assertion. // let local_nullifier_hash_chain = create_hash_chain_from_slice(addresses.as_slice()).unwrap(); // assert_eq!(leaves_hash_chain, local_nullifier_hash_chain); - let start_index = merkle_tree.next_index as usize; + let start_index = address_queue.tree_next_insertion_index as usize; assert!( start_index >= 1, "start index should be greater than 2 else tree is not inited" ); let current_root = *merkle_tree.root_history.last().unwrap(); - let mut low_element_values = Vec::new(); - let mut low_element_indices = Vec::new(); - let mut low_element_next_indices = Vec::new(); - let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); - let non_inclusion_proofs = indexer - .get_multiple_new_address_proofs(merkle_tree_pubkey.to_bytes(), addresses.clone(), None) - .await - .unwrap(); - for non_inclusion_proof in &non_inclusion_proofs.value.items { - low_element_values.push(non_inclusion_proof.low_address_value); - low_element_indices.push(non_inclusion_proof.low_address_index as usize); - low_element_next_indices.push(non_inclusion_proof.low_address_next_index as usize); - low_element_next_values.push(non_inclusion_proof.low_address_next_value); - - low_element_proofs.push(non_inclusion_proof.low_address_proof.to_vec()); - } - - let subtrees = indexer - .get_subtrees(merkle_tree_pubkey.to_bytes(), None) - .await - .unwrap(); + assert_eq!(address_queue.initial_root, current_root); + let light_client::indexer::AddressQueueData { + addresses, + low_element_values, + low_element_indices, + low_element_next_indices, + low_element_next_values, + subtrees, + .. + } = address_queue; let mut sparse_merkle_tree = SparseMerkleTree::< Poseidon, { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, - >::new( - <[[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize]>::try_from(subtrees.value.items) - .unwrap(), - start_index, - ); + >::new(subtrees.as_slice().try_into().unwrap(), start_index); let mut changelog: Vec> = Vec::new(); @@ -718,12 +701,12 @@ pub async fn create_batch_update_address_tree_instruction_data_with_proof( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - addresses, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &addresses, &mut sparse_merkle_tree, leaves_hash_chain, batch.zkp_batch_size as usize, @@ -732,7 +715,7 @@ pub async fn create_batch_update_address_tree_instruction_data_with_proof(&inputs.new_root).unwrap(); let inputs_json = to_json(&inputs); diff --git a/prover/client/src/constants.rs b/prover/client/src/constants.rs index 18a5c05a45..151bf87918 100644 --- a/prover/client/src/constants.rs +++ b/prover/client/src/constants.rs @@ -1,4 +1,4 @@ -pub const SERVER_ADDRESS: &str = "http://localhost:3001"; +pub const SERVER_ADDRESS: &str = "http://127.0.0.1:3001"; pub const HEALTH_CHECK: &str = "/health"; pub const PROVE_PATH: &str = "/prove"; diff --git a/prover/client/src/errors.rs b/prover/client/src/errors.rs index 85c1bc8fbe..859cae32b8 100644 --- a/prover/client/src/errors.rs +++ b/prover/client/src/errors.rs @@ -37,6 +37,8 @@ pub enum ProverClientError { #[error("Invalid proof data: {0}")] InvalidProofData(String), + #[error("Integer conversion failed: {0}")] + IntegerConversion(String), #[error("Hashchain mismatch: computed {computed:?} != expected {expected:?} (batch_size={batch_size}, next_index={next_index})")] HashchainMismatch { computed: [u8; 32], diff --git a/prover/client/src/helpers.rs b/prover/client/src/helpers.rs index 6ea223e79f..9a20b8958e 100644 --- a/prover/client/src/helpers.rs +++ b/prover/client/src/helpers.rs @@ -6,6 +6,8 @@ use num_bigint::{BigInt, BigUint}; use num_traits::{Num, ToPrimitive}; use serde::Serialize; +use crate::errors::ProverClientError; + pub fn get_project_root() -> Option { let output = Command::new("git") .args(["rev-parse", "--show-toplevel"]) @@ -47,23 +49,23 @@ pub fn bigint_to_u8_32(n: &BigInt) -> Result<[u8; 32], Box( leaf: [u8; 32], path_elements: &[[u8; 32]; HEIGHT], - path_index: u32, -) -> ([u8; 32], ChangelogEntry) { - let mut changelog_entry = ChangelogEntry::default_with_index(path_index as usize); + path_index: usize, +) -> Result<([u8; 32], ChangelogEntry), ProverClientError> { + let mut changelog_entry = ChangelogEntry::default_with_index(path_index); let mut current_hash = leaf; let mut current_index = path_index; for (level, path_element) in path_elements.iter().enumerate() { changelog_entry.path[level] = Some(current_hash); if current_index.is_multiple_of(2) { - current_hash = Poseidon::hashv(&[¤t_hash, path_element]).unwrap(); + current_hash = Poseidon::hashv(&[¤t_hash, path_element])?; } else { - current_hash = Poseidon::hashv(&[path_element, ¤t_hash]).unwrap(); + current_hash = Poseidon::hashv(&[path_element, ¤t_hash])?; } current_index /= 2; } - (current_hash, changelog_entry) + Ok((current_hash, changelog_entry)) } pub fn big_uint_to_string(big_uint: &BigUint) -> String { diff --git a/prover/client/src/proof.rs b/prover/client/src/proof.rs index c415a4d108..c5b5847815 100644 --- a/prover/client/src/proof.rs +++ b/prover/client/src/proof.rs @@ -12,6 +12,9 @@ use solana_bn254::compression::prelude::{ convert_endianness, }; +pub type CompressedProofBytes = ([u8; 32], [u8; 64], [u8; 32]); +pub type UncompressedProofBytes = ([u8; 64], [u8; 128], [u8; 64]); + #[derive(Debug, Clone, Copy)] pub struct ProofCompressed { pub a: [u8; 32], @@ -66,16 +69,27 @@ pub fn deserialize_gnark_proof_json(json_data: &str) -> serde_json::Result [u8; 32] { - let trimmed_str = hex_str.trim_start_matches("0x"); - let big_int = num_bigint::BigInt::from_str_radix(trimmed_str, 16).unwrap(); - let big_int_bytes = big_int.to_bytes_be().1; - if big_int_bytes.len() < 32 { +pub fn deserialize_hex_string_to_be_bytes(hex_str: &str) -> Result<[u8; 32], ProverClientError> { + let trimmed_str = hex_str + .strip_prefix("0x") + .or_else(|| hex_str.strip_prefix("0X")) + .unwrap_or(hex_str); + let big_uint = num_bigint::BigUint::from_str_radix(trimmed_str, 16) + .map_err(|error| ProverClientError::InvalidHexString(format!("{hex_str}: {error}")))?; + let big_uint_bytes = big_uint.to_bytes_be(); + if big_uint_bytes.len() > 32 { + return Err(ProverClientError::InvalidHexString(format!( + "{hex_str}: exceeds 32 bytes" + ))); + } + if big_uint_bytes.len() < 32 { let mut result = [0u8; 32]; - result[32 - big_int_bytes.len()..].copy_from_slice(&big_int_bytes); - result + result[32 - big_uint_bytes.len()..].copy_from_slice(&big_uint_bytes); + Ok(result) } else { - big_int_bytes.try_into().unwrap() + big_uint_bytes.try_into().map_err(|_| { + ProverClientError::InvalidHexString(format!("{hex_str}: invalid 32-byte encoding")) + }) } } @@ -83,47 +97,79 @@ pub fn compress_proof( proof_a: &[u8; 64], proof_b: &[u8; 128], proof_c: &[u8; 64], -) -> ([u8; 32], [u8; 64], [u8; 32]) { - let proof_a = alt_bn128_g1_compress(proof_a).unwrap(); - let proof_b = alt_bn128_g2_compress(proof_b).unwrap(); - let proof_c = alt_bn128_g1_compress(proof_c).unwrap(); - (proof_a, proof_b, proof_c) +) -> Result { + let proof_a = alt_bn128_g1_compress(proof_a)?; + let proof_b = alt_bn128_g2_compress(proof_b)?; + let proof_c = alt_bn128_g1_compress(proof_c)?; + Ok((proof_a, proof_b, proof_c)) } -pub fn proof_from_json_struct(json: GnarkProofJson) -> ([u8; 64], [u8; 128], [u8; 64]) { - let proof_a_x = deserialize_hex_string_to_be_bytes(&json.ar[0]); - let proof_a_y = deserialize_hex_string_to_be_bytes(&json.ar[1]); - let proof_a: [u8; 64] = [proof_a_x, proof_a_y].concat().try_into().unwrap(); - let proof_a = negate_g1(&proof_a); - let proof_b_x_0 = deserialize_hex_string_to_be_bytes(&json.bs[0][0]); - let proof_b_x_1 = deserialize_hex_string_to_be_bytes(&json.bs[0][1]); - let proof_b_y_0 = deserialize_hex_string_to_be_bytes(&json.bs[1][0]); - let proof_b_y_1 = deserialize_hex_string_to_be_bytes(&json.bs[1][1]); +pub fn proof_from_json_struct( + json: GnarkProofJson, +) -> Result { + let proof_a_x = deserialize_hex_string_to_be_bytes(json.ar.first().ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof A x coordinate".to_string()) + })?)?; + let proof_a_y = deserialize_hex_string_to_be_bytes(json.ar.get(1).ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof A y coordinate".to_string()) + })?)?; + let proof_a: [u8; 64] = [proof_a_x, proof_a_y] + .concat() + .try_into() + .map_err(|_| ProverClientError::InvalidProofData("invalid proof A length".to_string()))?; + let proof_a = negate_g1(&proof_a)?; + let proof_b_x_0 = deserialize_hex_string_to_be_bytes( + json.bs.first().and_then(|row| row.first()).ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof B x0 coordinate".to_string()) + })?, + )?; + let proof_b_x_1 = deserialize_hex_string_to_be_bytes( + json.bs.first().and_then(|row| row.get(1)).ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof B x1 coordinate".to_string()) + })?, + )?; + let proof_b_y_0 = deserialize_hex_string_to_be_bytes( + json.bs.get(1).and_then(|row| row.first()).ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof B y0 coordinate".to_string()) + })?, + )?; + let proof_b_y_1 = + deserialize_hex_string_to_be_bytes(json.bs.get(1).and_then(|row| row.get(1)).ok_or_else( + || ProverClientError::InvalidProofData("missing proof B y1 coordinate".to_string()), + )?)?; let proof_b: [u8; 128] = [proof_b_x_0, proof_b_x_1, proof_b_y_0, proof_b_y_1] .concat() .try_into() - .unwrap(); + .map_err(|_| ProverClientError::InvalidProofData("invalid proof B length".to_string()))?; - let proof_c_x = deserialize_hex_string_to_be_bytes(&json.krs[0]); - let proof_c_y = deserialize_hex_string_to_be_bytes(&json.krs[1]); - let proof_c: [u8; 64] = [proof_c_x, proof_c_y].concat().try_into().unwrap(); - (proof_a, proof_b, proof_c) + let proof_c_x = deserialize_hex_string_to_be_bytes(json.krs.first().ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof C x coordinate".to_string()) + })?)?; + let proof_c_y = deserialize_hex_string_to_be_bytes(json.krs.get(1).ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof C y coordinate".to_string()) + })?)?; + let proof_c: [u8; 64] = [proof_c_x, proof_c_y] + .concat() + .try_into() + .map_err(|_| ProverClientError::InvalidProofData("invalid proof C length".to_string()))?; + Ok((proof_a, proof_b, proof_c)) } -pub fn negate_g1(g1_be: &[u8; 64]) -> [u8; 64] { +pub fn negate_g1(g1_be: &[u8; 64]) -> Result<[u8; 64], ProverClientError> { let g1_le = convert_endianness::<32, 64>(g1_be); - let g1: G1 = G1::deserialize_with_mode(g1_le.as_slice(), Compress::No, Validate::No).unwrap(); + let g1: G1 = G1::deserialize_with_mode(g1_le.as_slice(), Compress::No, Validate::Yes) + .map_err(|error| ProverClientError::InvalidProofData(error.to_string()))?; let g1_neg = g1.neg(); let mut g1_neg_be = [0u8; 64]; g1_neg .x .serialize_with_mode(&mut g1_neg_be[..32], Compress::No) - .unwrap(); + .map_err(|error| ProverClientError::InvalidProofData(error.to_string()))?; g1_neg .y .serialize_with_mode(&mut g1_neg_be[32..], Compress::No) - .unwrap(); + .map_err(|error| ProverClientError::InvalidProofData(error.to_string()))?; let g1_neg_be: [u8; 64] = convert_endianness::<32, 64>(&g1_neg_be); - g1_neg_be + Ok(g1_neg_be) } diff --git a/prover/client/src/proof_client.rs b/prover/client/src/proof_client.rs index 1d557407bd..b3dacf295f 100644 --- a/prover/client/src/proof_client.rs +++ b/prover/client/src/proof_client.rs @@ -6,7 +6,7 @@ use tokio::time::sleep; use tracing::{debug, error, info, trace, warn}; use crate::{ - constants::PROVE_PATH, + constants::{PROVE_PATH, SERVER_ADDRESS}, errors::ProverClientError, proof::{ compress_proof, deserialize_gnark_proof_json, proof_from_json_struct, ProofCompressed, @@ -17,14 +17,13 @@ use crate::{ batch_append::{BatchAppendInputsJson, BatchAppendsCircuitInputs}, batch_update::{update_inputs_string, BatchUpdateCircuitInputs}, }, + prover::build_http_client, }; const MAX_RETRIES: u32 = 10; const BASE_RETRY_DELAY_SECS: u64 = 1; const DEFAULT_POLLING_INTERVAL_MS: u64 = 100; const DEFAULT_MAX_WAIT_TIME_SECS: u64 = 600; -const DEFAULT_LOCAL_SERVER: &str = "http://localhost:3001"; - const INITIAL_POLL_DELAY_SMALL_CIRCUIT_MS: u64 = 200; const INITIAL_POLL_DELAY_LARGE_CIRCUIT_MS: u64 = 200; @@ -68,15 +67,15 @@ pub struct ProofClient { } impl ProofClient { - pub fn local() -> Self { - Self { - client: Client::new(), - server_address: DEFAULT_LOCAL_SERVER.to_string(), + pub fn local() -> Result { + Ok(Self { + client: build_http_client()?, + server_address: SERVER_ADDRESS.to_string(), polling_interval: Duration::from_millis(DEFAULT_POLLING_INTERVAL_MS), max_wait_time: Duration::from_secs(DEFAULT_MAX_WAIT_TIME_SECS), api_key: None, initial_poll_delay: Duration::from_millis(INITIAL_POLL_DELAY_SMALL_CIRCUIT_MS), - } + }) } #[allow(unused)] @@ -85,21 +84,21 @@ impl ProofClient { polling_interval: Duration, max_wait_time: Duration, api_key: Option, - ) -> Self { + ) -> Result { let initial_poll_delay = if api_key.is_some() { Duration::from_millis(INITIAL_POLL_DELAY_LARGE_CIRCUIT_MS) } else { Duration::from_millis(INITIAL_POLL_DELAY_SMALL_CIRCUIT_MS) }; - Self { - client: Client::new(), + Ok(Self { + client: build_http_client()?, server_address, polling_interval, max_wait_time, api_key, initial_poll_delay, - } + }) } #[allow(unused)] @@ -109,15 +108,15 @@ impl ProofClient { max_wait_time: Duration, api_key: Option, initial_poll_delay: Duration, - ) -> Self { - Self { - client: Client::new(), + ) -> Result { + Ok(Self { + client: build_http_client()?, server_address, polling_interval, max_wait_time, api_key, initial_poll_delay, - } + }) } pub async fn submit_proof_async( @@ -655,8 +654,8 @@ impl ProofClient { ProverClientError::ProverServerError(format!("Failed to deserialize proof JSON: {}", e)) })?; - let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); - let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c); + let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json)?; + let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c)?; Ok(ProofResult { proof: ProofCompressed { diff --git a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs index f80e8d49e4..f11b6fb0cb 100644 --- a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{collections::HashMap, fmt::Debug}; use light_hasher::{ bigint::bigint_to_be_bytes_array, @@ -187,21 +187,49 @@ impl BatchAddressAppendInputs { pub fn get_batch_address_append_circuit_inputs( next_index: usize, current_root: [u8; 32], - low_element_values: Vec<[u8; 32]>, - low_element_next_values: Vec<[u8; 32]>, - low_element_indices: Vec, - low_element_next_indices: Vec, - low_element_proofs: Vec>, - new_element_values: Vec<[u8; 32]>, + low_element_values: &[[u8; 32]], + low_element_next_values: &[[u8; 32]], + low_element_indices: &[impl Copy + TryInto + Debug], + low_element_next_indices: &[impl Copy + TryInto + Debug], + low_element_proofs: &[[[u8; 32]; HEIGHT]], + new_element_values: &[[u8; 32]], sparse_merkle_tree: &mut SparseMerkleTree, leaves_hashchain: [u8; 32], zkp_batch_size: usize, changelog: &mut Vec>, indexed_changelog: &mut Vec>, ) -> Result { - let new_element_values = new_element_values[0..zkp_batch_size].to_vec(); + let batch_len = zkp_batch_size; + for (name, len) in [ + ("new_element_values", new_element_values.len()), + ("low_element_values", low_element_values.len()), + ("low_element_next_values", low_element_next_values.len()), + ("low_element_indices", low_element_indices.len()), + ("low_element_next_indices", low_element_next_indices.len()), + ("low_element_proofs", low_element_proofs.len()), + ] { + if len < batch_len { + return Err(ProverClientError::GenericError(format!( + "truncated batch from indexer: {} len {} < required batch size {}", + name, len, batch_len + ))); + } + } - let computed_hashchain = create_hash_chain_from_slice(&new_element_values).map_err(|e| { + let new_element_values = &new_element_values[..batch_len]; + let mut staged_changelog = changelog.clone(); + let mut staged_indexed_changelog = indexed_changelog.clone(); + let mut staged_sparse_merkle_tree = sparse_merkle_tree.clone(); + let initial_changelog_len = staged_changelog.len(); + let mut new_root = [0u8; 32]; + let mut low_element_circuit_merkle_proofs = Vec::with_capacity(batch_len); + let mut new_element_circuit_merkle_proofs = Vec::with_capacity(batch_len); + let mut patched_low_element_next_values = Vec::with_capacity(batch_len); + let mut patched_low_element_next_indices = Vec::with_capacity(batch_len); + let mut patched_low_element_values = Vec::with_capacity(batch_len); + let mut patched_low_element_indices = Vec::with_capacity(batch_len); + + let computed_hashchain = create_hash_chain_from_slice(new_element_values).map_err(|e| { ProverClientError::GenericError(format!("Failed to compute hashchain: {}", e)) })?; if computed_hashchain != leaves_hashchain { @@ -229,42 +257,45 @@ pub fn get_batch_address_append_circuit_inputs( next_index ); - let mut new_root = [0u8; 32]; - let mut low_element_circuit_merkle_proofs = vec![]; - let mut new_element_circuit_merkle_proofs = vec![]; - - let mut patched_low_element_next_values: Vec<[u8; 32]> = Vec::new(); - let mut patched_low_element_next_indices: Vec = Vec::new(); - let mut patched_low_element_values: Vec<[u8; 32]> = Vec::new(); - let mut patched_low_element_indices: Vec = Vec::new(); + let mut patcher = ChangelogProofPatcher::new::(&staged_changelog); - let mut patcher = ChangelogProofPatcher::new::(changelog); - - let is_first_batch = indexed_changelog.is_empty(); + let is_first_batch = staged_indexed_changelog.is_empty(); let mut expected_root_for_low = current_root; - for i in 0..new_element_values.len() { + for i in 0..batch_len { let mut changelog_index = 0; + let low_element_index = low_element_indices[i].try_into().map_err(|_| { + ProverClientError::IntegerConversion(format!( + "low element index {:?} does not fit into usize", + low_element_indices[i] + )) + })?; + let low_element_next_index = low_element_next_indices[i].try_into().map_err(|_| { + ProverClientError::IntegerConversion(format!( + "low element next index {:?} does not fit into usize", + low_element_next_indices[i] + )) + })?; let new_element_index = next_index + i; let mut low_element: IndexedElement = IndexedElement { - index: low_element_indices[i], + index: low_element_index, value: BigUint::from_bytes_be(&low_element_values[i]), - next_index: low_element_next_indices[i], + next_index: low_element_next_index, }; let mut new_element: IndexedElement = IndexedElement { index: new_element_index, value: BigUint::from_bytes_be(&new_element_values[i]), - next_index: low_element_next_indices[i], + next_index: low_element_next_index, }; - let mut low_element_proof = low_element_proofs[i].to_vec(); + let mut low_element_proof = low_element_proofs[i]; let mut low_element_next_value = BigUint::from_bytes_be(&low_element_next_values[i]); patch_indexed_changelogs( 0, &mut changelog_index, - indexed_changelog, + &mut staged_indexed_changelog, &mut low_element, &mut new_element, &mut low_element_next_value, @@ -293,18 +324,10 @@ pub fn get_batch_address_append_circuit_inputs( next_value: bigint_to_be_bytes_array::<32>(&new_element.value)?, index: new_low_element.index, }; + let low_element_changelog_proof = low_element_proof; let intermediate_root = { - let mut low_element_proof_arr: [[u8; 32]; HEIGHT] = low_element_proof - .clone() - .try_into() - .map_err(|v: Vec<[u8; 32]>| { - ProverClientError::ProofPatchFailed(format!( - "low element proof length mismatch: expected {}, got {}", - HEIGHT, - v.len() - )) - })?; + let mut low_element_proof_arr = low_element_changelog_proof; patcher.update_proof::(low_element.index(), &mut low_element_proof_arr); let merkle_proof = low_element_proof_arr; @@ -320,8 +343,8 @@ pub fn get_batch_address_append_circuit_inputs( let (computed_root, _) = compute_root_from_merkle_proof::( old_low_leaf_hash, &merkle_proof, - low_element.index as u32, - ); + low_element.index, + )?; if computed_root != expected_root_for_low { let low_value_bytes = bigint_to_be_bytes_array::<32>(&low_element.value) .map_err(|e| { @@ -361,10 +384,10 @@ pub fn get_batch_address_append_circuit_inputs( compute_root_from_merkle_proof::( new_low_leaf_hash, &merkle_proof, - new_low_element.index as u32, - ); + new_low_element.index, + )?; - patcher.push_changelog_entry::(changelog, changelog_entry); + patcher.push_changelog_entry::(&mut staged_changelog, changelog_entry); low_element_circuit_merkle_proofs.push( merkle_proof .iter() @@ -376,17 +399,11 @@ pub fn get_batch_address_append_circuit_inputs( }; let low_element_changelog_entry = IndexedChangelogEntry { element: new_low_element_raw, - proof: low_element_proof.as_slice()[..HEIGHT] - .try_into() - .map_err(|_| { - ProverClientError::ProofPatchFailed( - "low_element_proof slice conversion failed".to_string(), - ) - })?, - changelog_index: indexed_changelog.len(), //change_log_index, + proof: low_element_changelog_proof, + changelog_index: staged_indexed_changelog.len(), //change_log_index, }; - indexed_changelog.push(low_element_changelog_entry); + staged_indexed_changelog.push(low_element_changelog_entry); { let new_element_next_value = low_element_next_value; @@ -396,10 +413,10 @@ pub fn get_batch_address_append_circuit_inputs( ProverClientError::GenericError(format!("Failed to hash new element: {}", e)) })?; - let sparse_root_before = sparse_merkle_tree.root(); - let sparse_next_idx_before = sparse_merkle_tree.get_next_index(); + let sparse_root_before = staged_sparse_merkle_tree.root(); + let sparse_next_idx_before = staged_sparse_merkle_tree.get_next_index(); - let mut merkle_proof_array = sparse_merkle_tree.append(new_element_leaf_hash); + let mut merkle_proof_array = staged_sparse_merkle_tree.append(new_element_leaf_hash); let current_index = next_index + i; @@ -408,10 +425,10 @@ pub fn get_batch_address_append_circuit_inputs( let (updated_root, changelog_entry) = compute_root_from_merkle_proof( new_element_leaf_hash, &merkle_proof_array, - current_index as u32, - ); + current_index, + )?; - if i == 0 && changelog.len() == 1 { + if i == 0 && staged_changelog.len() == initial_changelog_len + 1 { if sparse_next_idx_before != current_index { return Err(ProverClientError::GenericError(format!( "sparse index mismatch: sparse tree next_index={} but expected current_index={}", @@ -435,8 +452,8 @@ pub fn get_batch_address_append_circuit_inputs( let (root_with_zero, _) = compute_root_from_merkle_proof::( zero_hash, &merkle_proof_array, - current_index as u32, - ); + current_index, + )?; if root_with_zero != intermediate_root { tracing::error!( "ELEMENT {} NEW_PROOF MISMATCH: proof + ZERO = {:?}[..4] but expected \ @@ -470,7 +487,7 @@ pub fn get_batch_address_append_circuit_inputs( new_root = updated_root; - patcher.push_changelog_entry::(changelog, changelog_entry); + patcher.push_changelog_entry::(&mut staged_changelog, changelog_entry); new_element_circuit_merkle_proofs.push( merkle_proof_array .iter() @@ -488,9 +505,9 @@ pub fn get_batch_address_append_circuit_inputs( let new_element_changelog_entry = IndexedChangelogEntry { element: new_element_raw, proof: merkle_proof_array, - changelog_index: indexed_changelog.len(), + changelog_index: staged_indexed_changelog.len(), }; - indexed_changelog.push(new_element_changelog_entry); + staged_indexed_changelog.push(new_element_changelog_entry); } } @@ -526,18 +543,18 @@ pub fn get_batch_address_append_circuit_inputs( patcher.hits, patcher.misses, patcher.overwrites, - changelog.len(), - indexed_changelog.len() + staged_changelog.len(), + staged_indexed_changelog.len() ); - if patcher.hits == 0 && !changelog.is_empty() { + if patcher.hits == 0 && !staged_changelog.is_empty() { tracing::warn!( "Address proof patcher had 0 cache hits despite non-empty changelog (changelog_len={}, indexed_changelog_len={})", - changelog.len(), - indexed_changelog.len() + staged_changelog.len(), + staged_indexed_changelog.len() ); } - Ok(BatchAddressAppendInputs { + let inputs = BatchAddressAppendInputs { batch_size: patched_low_element_values.len(), hashchain_hash: BigUint::from_bytes_be(&leaves_hashchain), low_element_values: patched_low_element_values @@ -557,7 +574,7 @@ pub fn get_batch_address_append_circuit_inputs( .map(|v| BigUint::from_bytes_be(v)) .collect(), low_element_proofs: low_element_circuit_merkle_proofs, - new_element_values: new_element_values[0..] + new_element_values: new_element_values .iter() .map(|v| BigUint::from_bytes_be(v)) .collect(), @@ -567,5 +584,11 @@ pub fn get_batch_address_append_circuit_inputs( public_input_hash: BigUint::from_bytes_be(&public_input_hash), start_index: next_index, tree_height: HEIGHT, - }) + }; + + *changelog = staged_changelog; + *indexed_changelog = staged_indexed_changelog; + *sparse_merkle_tree = staged_sparse_merkle_tree; + + Ok(inputs) } diff --git a/prover/client/src/proof_types/batch_append/proof_inputs.rs b/prover/client/src/proof_types/batch_append/proof_inputs.rs index ef0327ac1d..41a6dcfcd6 100644 --- a/prover/client/src/proof_types/batch_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_append/proof_inputs.rs @@ -187,8 +187,11 @@ pub fn get_batch_append_inputs( }; // Update the root based on the current proof and nullifier - let (updated_root, changelog_entry) = - compute_root_from_merkle_proof(final_leaf, &merkle_proof_array, start_index + i as u32); + let (updated_root, changelog_entry) = compute_root_from_merkle_proof( + final_leaf, + &merkle_proof_array, + start_index as usize + i, + )?; new_root = updated_root; changelog.push(changelog_entry); circuit_merkle_proofs.push( diff --git a/prover/client/src/proof_types/batch_update/proof_inputs.rs b/prover/client/src/proof_types/batch_update/proof_inputs.rs index 2136d01d10..f5467184aa 100644 --- a/prover/client/src/proof_types/batch_update/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_update/proof_inputs.rs @@ -31,8 +31,12 @@ pub struct BatchUpdateCircuitInputs { } impl BatchUpdateCircuitInputs { - pub fn public_inputs_arr(&self) -> [u8; 32] { - bigint_to_u8_32(&self.public_input_hash).unwrap() + pub fn public_inputs_arr(&self) -> Result<[u8; 32], ProverClientError> { + bigint_to_u8_32(&self.public_input_hash).map_err(|error| { + ProverClientError::GenericError(format!( + "failed to serialize batch update public input: {error}" + )) + }) } pub fn new( @@ -112,9 +116,17 @@ impl BatchUpdateCircuitInputs { pub struct BatchUpdateInputs<'a>(pub &'a [BatchUpdateCircuitInputs]); impl BatchUpdateInputs<'_> { - pub fn public_inputs(&self) -> Vec<[u8; 32]> { - // Concatenate all public inputs into a single flat vector - vec![self.0[0].public_inputs_arr()] + pub fn public_inputs(&self) -> Result, ProverClientError> { + if self.0.is_empty() { + return Err(ProverClientError::GenericError( + "batch update inputs cannot be empty".to_string(), + )); + } + + self.0 + .iter() + .map(BatchUpdateCircuitInputs::public_inputs_arr) + .collect() } } @@ -175,7 +187,7 @@ pub fn get_batch_update_inputs( index_bytes[28..].copy_from_slice(&(*index).to_be_bytes()); let nullifier = Poseidon::hashv(&[leaf, &index_bytes, &tx_hashes[i]]).unwrap(); let (root, changelog_entry) = - compute_root_from_merkle_proof(nullifier, &merkle_proof_array, *index); + compute_root_from_merkle_proof(nullifier, &merkle_proof_array, *index as usize)?; new_root = root; changelog.push(changelog_entry); circuit_merkle_proofs.push( diff --git a/prover/client/src/prover.rs b/prover/client/src/prover.rs index 3bf1bab785..49b9f5aceb 100644 --- a/prover/client/src/prover.rs +++ b/prover/client/src/prover.rs @@ -1,72 +1,223 @@ use std::{ - process::Command, + io::{Read, Write}, + net::{TcpStream, ToSocketAddrs}, + process::{Child, Command}, sync::atomic::{AtomicBool, Ordering}, - thread::sleep, time::Duration, }; +use tokio::time::sleep; use tracing::info; use crate::{ constants::{HEALTH_CHECK, SERVER_ADDRESS}, + errors::ProverClientError, helpers::get_project_root, }; static IS_LOADING: AtomicBool = AtomicBool::new(false); +const STARTUP_HEALTH_CHECK_RETRIES: usize = 300; + +fn has_http_ok_status(response: &[u8]) -> bool { + response + .split(|&byte| byte == b'\n') + .next() + .map(|status_line| { + status_line.starts_with(b"HTTP/") + && status_line.windows(5).any(|window| window == b" 200 ") + }) + .unwrap_or(false) +} + +pub(crate) fn build_http_client() -> Result { + reqwest::Client::builder() + .no_proxy() + .build() + .map_err(|error| { + ProverClientError::GenericError(format!("failed to build HTTP client: {error}")) + }) +} + +fn health_check_once(timeout: Duration) -> bool { + let endpoint = SERVER_ADDRESS + .strip_prefix("http://") + .or_else(|| SERVER_ADDRESS.strip_prefix("https://")) + .unwrap_or(SERVER_ADDRESS); + let addr = match endpoint + .to_socket_addrs() + .ok() + .and_then(|mut addrs| addrs.next()) + { + Some(addr) => addr, + None => return false, + }; + + let mut stream = match TcpStream::connect_timeout(&addr, timeout) { + Ok(stream) => stream, + Err(error) => { + tracing::debug!(?error, endpoint, "prover health TCP connect failed"); + return health_check_once_with_curl(timeout); + } + }; + + let _ = stream.set_read_timeout(Some(timeout)); + let _ = stream.set_write_timeout(Some(timeout)); + + let host = endpoint.split(':').next().unwrap_or("127.0.0.1"); + let request = format!( + "GET {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n", + HEALTH_CHECK, host + ); + if let Err(error) = stream.write_all(request.as_bytes()) { + tracing::debug!(?error, "failed to write prover health request"); + return health_check_once_with_curl(timeout); + } + + let mut response = [0_u8; 512]; + let bytes_read = match stream.read(&mut response) { + Ok(bytes_read) => bytes_read, + Err(error) => { + tracing::debug!(?error, "failed to read prover health response"); + return health_check_once_with_curl(timeout); + } + }; + + bytes_read > 0 + && (has_http_ok_status(&response[..bytes_read]) || health_check_once_with_curl(timeout)) +} + +fn health_check_once_with_curl(timeout: Duration) -> bool { + let timeout_secs = timeout.as_secs().max(1).to_string(); + let url = format!("{}{}", SERVER_ADDRESS, HEALTH_CHECK); + match Command::new("curl") + .args(["-sS", "-m", timeout_secs.as_str(), url.as_str()]) + .output() + { + Ok(output) => { + output.status.success() + && String::from_utf8_lossy(&output.stdout).contains("{\"status\":\"ok\"}") + } + Err(error) => { + tracing::debug!(?error, "failed to execute curl prover health check"); + false + } + } +} + +async fn wait_for_prover_health( + retries: usize, + timeout: Duration, + child: &mut Child, +) -> Result<(), String> { + for attempt in 0..retries { + if health_check_once(timeout) { + return Ok(()); + } + + match child.try_wait() { + Ok(Some(status)) => { + return Err(format!( + "prover process exited before health check succeeded with status {status}" + )); + } + Ok(None) => {} + Err(error) => { + return Err(format!("failed to poll prover process status: {error}")); + } + } + + if attempt + 1 < retries { + sleep(timeout).await; + } + } + + Err(format!( + "prover health check failed after {} attempts", + retries + )) +} + +fn monitor_prover_child(mut child: Child) { + std::thread::spawn(move || match child.wait() { + Ok(status) => tracing::debug!(?status, "prover launcher exited"), + Err(error) => tracing::warn!(?error, "failed to wait on prover launcher"), + }); +} pub async fn spawn_prover() { if let Some(_project_root) = get_project_root() { - let prover_path: &str = { + let prover_path = { #[cfg(feature = "devenv")] { - &format!("{}/{}", _project_root.trim(), "cli/test_bin/run") + format!("{}/{}", _project_root.trim(), "cli/test_bin/run") } #[cfg(not(feature = "devenv"))] { println!("Running in production mode, using prover binary"); - "light" + "light".to_string() } }; - if !health_check(10, 1).await && !IS_LOADING.load(Ordering::Relaxed) { - IS_LOADING.store(true, Ordering::Relaxed); + if health_check(10, 1).await { + return; + } + + if IS_LOADING + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) + .is_err() + { + if health_check(STARTUP_HEALTH_CHECK_RETRIES, 1).await { + return; + } + panic!("Failed to start prover, health check failed."); + } - let command = Command::new(prover_path) + let spawn_result = async { + let mut child = Command::new(&prover_path) .arg("start-prover") .spawn() - .expect("Failed to start prover process"); - - let _ = command.wait_with_output(); + .unwrap_or_else(|error| panic!("Failed to start prover process: {error}")); - let health_result = health_check(120, 1).await; - if health_result { - info!("Prover started successfully"); - } else { - panic!("Failed to start prover, health check failed."); + match wait_for_prover_health( + STARTUP_HEALTH_CHECK_RETRIES, + Duration::from_secs(1), + &mut child, + ) + .await + { + Ok(()) => { + monitor_prover_child(child); + info!("Prover started successfully"); + } + Err(error) => { + let _ = child.kill(); + let _ = child.wait(); + panic!("Failed to start prover: {error}"); + } } } + .await; + + IS_LOADING.store(false, Ordering::Release); + spawn_result } else { panic!("Failed to find project root."); }; } pub async fn health_check(retries: usize, timeout: usize) -> bool { - let client = reqwest::Client::new(); - let mut result = false; - for _ in 0..retries { - match client - .get(format!("{}{}", SERVER_ADDRESS, HEALTH_CHECK)) - .send() - .await - { - Ok(_) => { - result = true; - break; - } - Err(_) => { - sleep(Duration::from_secs(timeout as u64)); - } + let timeout = Duration::from_secs(timeout as u64); + let retry_delay = timeout; + + for attempt in 0..retries { + if health_check_once(timeout) { + return true; + } + + if attempt + 1 < retries { + sleep(retry_delay).await; } } - result + + false } diff --git a/prover/client/tests/batch_address_append.rs b/prover/client/tests/batch_address_append.rs index 22f58d5362..7b8ceaa5f9 100644 --- a/prover/client/tests/batch_address_append.rs +++ b/prover/client/tests/batch_address_append.rs @@ -26,46 +26,54 @@ async fn prove_batch_address_append() { spawn_prover().await; // Initialize test data - let mut new_element_values = vec![]; - let zkp_batch_size = 10; - for i in 1..zkp_batch_size + 1 { - new_element_values.push(num_bigint::ToBigUint::to_biguint(&i).unwrap()); - } + let total_batch_size = 10usize; + let warmup_batch_size = 1usize; + let prior_value = 999_u32.to_biguint().unwrap(); + let new_element_values = (1..=total_batch_size) + .map(|i| num_bigint::ToBigUint::to_biguint(&i).unwrap()) + .collect::>(); // Initialize indexing structures - let relayer_merkle_tree = + let mut relayer_merkle_tree = IndexedMerkleTree::::new(DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize, 0) .unwrap(); - let start_index = relayer_merkle_tree.merkle_tree.rightmost_index; - let current_root = relayer_merkle_tree.root(); + let collect_non_inclusion_data = |tree: &IndexedMerkleTree, + values: &[BigUint]| { + let mut low_element_values = Vec::with_capacity(values.len()); + let mut low_element_indices = Vec::with_capacity(values.len()); + let mut low_element_next_indices = Vec::with_capacity(values.len()); + let mut low_element_next_values = Vec::with_capacity(values.len()); + let mut low_element_proofs: Vec<[[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize]> = + Vec::with_capacity(values.len()); - // Prepare proof components - let mut low_element_values = Vec::new(); - let mut low_element_indices = Vec::new(); - let mut low_element_next_indices = Vec::new(); - let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); + for new_element_value in values { + let non_inclusion_proof = tree.get_non_inclusion_proof(new_element_value).unwrap(); - // Generate non-inclusion proofs for each element - for new_element_value in &new_element_values { - let non_inclusion_proof = relayer_merkle_tree - .get_non_inclusion_proof(new_element_value) - .unwrap(); + low_element_values.push(non_inclusion_proof.leaf_lower_range_value); + low_element_indices.push(non_inclusion_proof.leaf_index); + low_element_next_indices.push(non_inclusion_proof.next_index); + low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); + low_element_proofs.push( + non_inclusion_proof + .merkle_proof + .as_slice() + .try_into() + .unwrap(), + ); + } - low_element_values.push(non_inclusion_proof.leaf_lower_range_value); - low_element_indices.push(non_inclusion_proof.leaf_index); - low_element_next_indices.push(non_inclusion_proof.next_index); - low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); - low_element_proofs.push(non_inclusion_proof.merkle_proof.as_slice().to_vec()); - } + ( + low_element_values, + low_element_indices, + low_element_next_indices, + low_element_next_values, + low_element_proofs, + ) + }; - // Convert big integers to byte arrays - let new_element_values = new_element_values - .iter() - .map(|v| bigint_to_be_bytes_array::<32>(v).unwrap()) - .collect::>(); - let hash_chain = create_hash_chain_from_slice(&new_element_values).unwrap(); + let initial_start_index = relayer_merkle_tree.merkle_tree.rightmost_index; + let initial_root = relayer_merkle_tree.root(); let subtrees: [[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize] = relayer_merkle_tree .merkle_tree @@ -75,7 +83,7 @@ async fn prove_batch_address_append() { let mut sparse_merkle_tree = SparseMerkleTree::< Poseidon, { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, - >::new(subtrees, start_index); + >::new(subtrees, initial_start_index); let mut changelog: Vec> = Vec::new(); @@ -83,19 +91,68 @@ async fn prove_batch_address_append() { IndexedChangelogEntry, > = Vec::new(); + let warmup_values = vec![prior_value.clone()]; + let ( + warmup_low_element_values, + warmup_low_element_indices, + warmup_low_element_next_indices, + warmup_low_element_next_values, + warmup_low_element_proofs, + ) = collect_non_inclusion_data(&relayer_merkle_tree, &warmup_values); + let warmup_values = warmup_values + .iter() + .map(|v| bigint_to_be_bytes_array::<32>(v).unwrap()) + .collect::>(); + let warmup_hash_chain = create_hash_chain_from_slice(&warmup_values).unwrap(); + + get_batch_address_append_circuit_inputs::<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>( + initial_start_index, + initial_root, + &warmup_low_element_values, + &warmup_low_element_next_values, + &warmup_low_element_indices, + &warmup_low_element_next_indices, + &warmup_low_element_proofs, + &warmup_values, + &mut sparse_merkle_tree, + warmup_hash_chain, + warmup_batch_size, + &mut changelog, + &mut indexed_changelog, + ) + .unwrap(); + + relayer_merkle_tree.append(&prior_value).unwrap(); + + let remaining_values = &new_element_values[..]; + let ( + low_element_values, + low_element_indices, + low_element_next_indices, + low_element_next_values, + low_element_proofs, + ) = collect_non_inclusion_data(&relayer_merkle_tree, remaining_values); + let new_element_values = remaining_values + .iter() + .map(|v| bigint_to_be_bytes_array::<32>(v).unwrap()) + .collect::>(); + let hash_chain = create_hash_chain_from_slice(&new_element_values).unwrap(); + let start_index = relayer_merkle_tree.merkle_tree.rightmost_index; + let current_root = relayer_merkle_tree.root(); + let inputs = get_batch_address_append_circuit_inputs::<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - new_element_values, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &new_element_values, &mut sparse_merkle_tree, hash_chain, - zkp_batch_size, + total_batch_size, &mut changelog, &mut indexed_changelog, ) diff --git a/sdk-libs/client/src/indexer/photon_indexer.rs b/sdk-libs/client/src/indexer/photon_indexer.rs index 26d16ae235..eb8890d6b7 100644 --- a/sdk-libs/client/src/indexer/photon_indexer.rs +++ b/sdk-libs/client/src/indexer/photon_indexer.rs @@ -2,7 +2,7 @@ use std::{fmt::Debug, time::Duration}; use async_trait::async_trait; use bs58; -use light_sdk_types::constants::STATE_MERKLE_TREE_CANOPY_DEPTH; +use light_sdk_types::constants::{STATE_MERKLE_TREE_CANOPY_DEPTH, STATE_MERKLE_TREE_HEIGHT}; use photon_api::apis::configuration::Configuration; use solana_pubkey::Pubkey; use tracing::{error, trace, warn}; @@ -1142,17 +1142,26 @@ impl Indexer for PhotonIndexer { .value .iter() .map(|x| { - let mut proof_vec = x.proof.clone(); - if proof_vec.len() < STATE_MERKLE_TREE_CANOPY_DEPTH { + let expected_siblings = + STATE_MERKLE_TREE_HEIGHT - STATE_MERKLE_TREE_CANOPY_DEPTH; + let expected_total = STATE_MERKLE_TREE_CANOPY_DEPTH + expected_siblings; + if x.proof.len() != expected_total { return Err(IndexerError::InvalidParameters(format!( - "Merkle proof length ({}) is less than canopy depth ({})", - proof_vec.len(), - STATE_MERKLE_TREE_CANOPY_DEPTH, + "Merkle proof length ({}) does not match expected total proof length ({})", + x.proof.len(), + expected_total, + ))); + } + let proof_len = x.proof.len() - STATE_MERKLE_TREE_CANOPY_DEPTH; + if proof_len != expected_siblings { + return Err(IndexerError::InvalidParameters(format!( + "Merkle proof sibling count ({}) does not match expected sibling count ({})", + proof_len, + expected_siblings, ))); } - proof_vec.truncate(proof_vec.len() - STATE_MERKLE_TREE_CANOPY_DEPTH); - let proof = proof_vec + let proof = x.proof[..proof_len] .iter() .map(|s| Hash::from_base58(s)) .collect::, IndexerError>>() @@ -1682,6 +1691,7 @@ impl Indexer for PhotonIndexer { .map(|h| super::base58::decode_base58_to_fixed_array(&h.0)) .collect::, _>>()?, start_index: aq.start_index, + tree_next_insertion_index: aq.start_index, root_seq: aq.root_seq, }) } else { @@ -1703,15 +1713,13 @@ impl Indexer for PhotonIndexer { async fn get_subtrees( &self, - _merkle_tree_pubkey: [u8; 32], + merkle_tree_pubkey: [u8; 32], _config: Option, ) -> Result>, IndexerError> { - #[cfg(not(feature = "v2"))] - unimplemented!(); - #[cfg(feature = "v2")] - { - todo!(); - } + Err(IndexerError::NotImplemented(format!( + "PhotonIndexer::get_subtrees is not implemented for merkle tree {}", + solana_pubkey::Pubkey::new_from_array(merkle_tree_pubkey) + ))) } } diff --git a/sdk-libs/client/src/indexer/types/proof.rs b/sdk-libs/client/src/indexer/types/proof.rs index 0b45e00986..4a9f396732 100644 --- a/sdk-libs/client/src/indexer/types/proof.rs +++ b/sdk-libs/client/src/indexer/types/proof.rs @@ -189,41 +189,57 @@ pub struct PackedTreeInfos { } impl ValidityProofWithContext { - pub fn pack_tree_infos(&self, packed_accounts: &mut PackedAccounts) -> PackedTreeInfos { - let mut packed_tree_infos = Vec::new(); - let mut address_trees = Vec::new(); - let mut output_tree_index = None; - for account in self.accounts.iter() { - // Pack TreeInfo - let merkle_tree_pubkey_index = packed_accounts.insert_or_get(account.tree_info.tree); - let queue_pubkey_index = packed_accounts.insert_or_get(account.tree_info.queue); - let tree_info_packed = PackedStateTreeInfo { - root_index: account.root_index.root_index, - merkle_tree_pubkey_index, - queue_pubkey_index, + pub fn pack_state_tree_infos( + &self, + packed_accounts: &mut PackedAccounts, + ) -> Vec { + self.accounts + .iter() + .map(|account| PackedStateTreeInfo { + root_index: account.root_index.root_index().unwrap_or_default(), + merkle_tree_pubkey_index: packed_accounts.insert_or_get(account.tree_info.tree), + queue_pubkey_index: packed_accounts.insert_or_get(account.tree_info.queue), leaf_index: account.leaf_index as u32, prove_by_index: account.root_index.proof_by_index(), - }; - packed_tree_infos.push(tree_info_packed); + }) + .collect() + } + pub fn pack_tree_infos( + &self, + packed_accounts: &mut PackedAccounts, + ) -> Result { + let packed_tree_infos = self.pack_state_tree_infos(packed_accounts); + let mut address_trees = Vec::new(); + let mut output_tree_index = None; + for account in self.accounts.iter() { // If a next Merkle tree exists the Merkle tree is full -> use the next Merkle tree for new state. // Else use the current Merkle tree for new state. if let Some(next) = account.tree_info.next_tree_info { // SAFETY: account will always have a state Merkle tree context. // pack_output_tree_index only panics on an address Merkle tree context. - let index = next.pack_output_tree_index(packed_accounts).unwrap(); - if output_tree_index.is_none() { - output_tree_index = Some(index); + let index = next.pack_output_tree_index(packed_accounts)?; + match output_tree_index { + Some(existing) if existing != index => { + return Err(IndexerError::InvalidParameters(format!( + "mixed output tree indices in state proof: {existing} != {index}" + ))); + } + Some(_) => {} + None => output_tree_index = Some(index), } } else { // SAFETY: account will always have a state Merkle tree context. // pack_output_tree_index only panics on an address Merkle tree context. - let index = account - .tree_info - .pack_output_tree_index(packed_accounts) - .unwrap(); - if output_tree_index.is_none() { - output_tree_index = Some(index); + let index = account.tree_info.pack_output_tree_index(packed_accounts)?; + match output_tree_index { + Some(existing) if existing != index => { + return Err(IndexerError::InvalidParameters(format!( + "mixed output tree indices in state proof: {existing} != {index}" + ))); + } + Some(_) => {} + None => output_tree_index = Some(index), } } } @@ -244,13 +260,17 @@ impl ValidityProofWithContext { } else { Some(PackedStateTreeInfos { packed_tree_infos, - output_tree_index: output_tree_index.unwrap(), + output_tree_index: output_tree_index.ok_or_else(|| { + IndexerError::InvalidParameters( + "missing output tree index for non-empty state proof".to_string(), + ) + })?, }) }; - PackedTreeInfos { + Ok(PackedTreeInfos { state_trees: packed_tree_infos, address_trees, - } + }) } pub fn from_api_model( diff --git a/sdk-libs/client/src/indexer/types/queue.rs b/sdk-libs/client/src/indexer/types/queue.rs index 40e7cc0f6e..79fcb45cf1 100644 --- a/sdk-libs/client/src/indexer/types/queue.rs +++ b/sdk-libs/client/src/indexer/types/queue.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use super::super::IndexerError; #[derive(Debug, Clone, PartialEq, Default)] @@ -59,18 +61,22 @@ pub struct AddressQueueData { pub initial_root: [u8; 32], pub leaves_hash_chains: Vec<[u8; 32]>, pub subtrees: Vec<[u8; 32]>, + /// Pagination offset for the returned queue slice. pub start_index: u64, + /// Sparse tree insertion point / next index used to initialize staging trees. + pub tree_next_insertion_index: u64, pub root_seq: u64, } impl AddressQueueData { + const ADDRESS_TREE_HEIGHT: usize = 40; + /// Reconstruct a merkle proof for a given low_element_index from the deduplicated nodes. - /// The tree_height is needed to know how many levels to traverse. - pub fn reconstruct_proof( + pub fn reconstruct_proof( &self, address_idx: usize, - tree_height: u8, - ) -> Result, IndexerError> { + ) -> Result<[[u8; 32]; HEIGHT], IndexerError> { + self.validate_proof_height::()?; let leaf_index = *self.low_element_indices.get(address_idx).ok_or_else(|| { IndexerError::MissingResult { context: "reconstruct_proof".to_string(), @@ -81,10 +87,10 @@ impl AddressQueueData { ), } })?; - let mut proof = Vec::with_capacity(tree_height as usize); + let mut proof = [[0u8; 32]; HEIGHT]; let mut pos = leaf_index; - for level in 0..tree_height { + for (level, proof_element) in proof.iter_mut().enumerate() { let sibling_pos = if pos.is_multiple_of(2) { pos + 1 } else { @@ -114,28 +120,242 @@ impl AddressQueueData { self.node_hashes.len(), ), })?; - proof.push(*hash); + *proof_element = *hash; pos /= 2; } Ok(proof) } + /// Reconstruct a contiguous batch of proofs while reusing a single node lookup table. + pub fn reconstruct_proofs( + &self, + address_range: std::ops::Range, + ) -> Result, IndexerError> { + self.validate_proof_height::()?; + let available = self.proof_count(); + if address_range.start > address_range.end { + return Err(IndexerError::InvalidParameters(format!( + "invalid address proof range {}..{}", + address_range.start, address_range.end + ))); + } + if address_range.end > available { + return Err(IndexerError::InvalidParameters(format!( + "address proof range {}..{} exceeds available proofs {}", + address_range.start, address_range.end, available + ))); + } + let node_lookup = self.build_node_lookup(); + let mut proofs = Vec::with_capacity(address_range.len()); + + for address_idx in address_range { + proofs.push(self.reconstruct_proof_with_lookup::(address_idx, &node_lookup)?); + } + + Ok(proofs) + } + /// Reconstruct all proofs for all addresses - pub fn reconstruct_all_proofs( + pub fn reconstruct_all_proofs( + &self, + ) -> Result, IndexerError> { + self.validate_proof_height::()?; + self.reconstruct_proofs::(0..self.addresses.len()) + } + + fn build_node_lookup(&self) -> HashMap { + let mut lookup = HashMap::with_capacity(self.nodes.len()); + for (idx, node) in self.nodes.iter().copied().enumerate() { + lookup.entry(node).or_insert(idx); + } + lookup + } + + fn proof_count(&self) -> usize { + self.addresses.len().min(self.low_element_indices.len()) + } + + fn reconstruct_proof_with_lookup( &self, - tree_height: u8, - ) -> Result>, IndexerError> { - (0..self.addresses.len()) - .map(|i| self.reconstruct_proof(i, tree_height)) - .collect() + address_idx: usize, + node_lookup: &HashMap, + ) -> Result<[[u8; 32]; HEIGHT], IndexerError> { + self.validate_proof_height::()?; + let leaf_index = *self.low_element_indices.get(address_idx).ok_or_else(|| { + IndexerError::MissingResult { + context: "reconstruct_proof".to_string(), + message: format!( + "address_idx {} out of bounds for low_element_indices (len {})", + address_idx, + self.low_element_indices.len(), + ), + } + })?; + let mut proof = [[0u8; 32]; HEIGHT]; + let mut pos = leaf_index; + + for (level, proof_element) in proof.iter_mut().enumerate() { + let sibling_pos = if pos.is_multiple_of(2) { + pos + 1 + } else { + pos - 1 + }; + let sibling_idx = Self::encode_node_index(level, sibling_pos); + let hash_idx = node_lookup.get(&sibling_idx).copied().ok_or_else(|| { + IndexerError::MissingResult { + context: "reconstruct_proof".to_string(), + message: format!( + "Missing proof node at level {} position {} (encoded: {})", + level, sibling_pos, sibling_idx + ), + } + })?; + let hash = + self.node_hashes + .get(hash_idx) + .ok_or_else(|| IndexerError::MissingResult { + context: "reconstruct_proof".to_string(), + message: format!( + "node_hashes index {} out of bounds (len {})", + hash_idx, + self.node_hashes.len(), + ), + })?; + *proof_element = *hash; + pos /= 2; + } + + Ok(proof) } /// Encode node index: (level << 56) | position #[inline] - fn encode_node_index(level: u8, position: u64) -> u64 { + fn encode_node_index(level: usize, position: u64) -> u64 { ((level as u64) << 56) | position } + + fn validate_proof_height(&self) -> Result<(), IndexerError> { + if HEIGHT == Self::ADDRESS_TREE_HEIGHT { + return Ok(()); + } + + Err(IndexerError::InvalidParameters(format!( + "address queue proofs require HEIGHT={} but got HEIGHT={}", + Self::ADDRESS_TREE_HEIGHT, + HEIGHT + ))) + } +} + +#[cfg(test)] +mod tests { + use std::{collections::BTreeMap, hint::black_box, time::Instant}; + + use super::AddressQueueData; + + fn hash_from_node(node_index: u64) -> [u8; 32] { + let mut hash = [0u8; 32]; + hash[..8].copy_from_slice(&node_index.to_le_bytes()); + hash[8..16].copy_from_slice(&node_index.rotate_left(17).to_le_bytes()); + hash[16..24].copy_from_slice(&node_index.rotate_right(9).to_le_bytes()); + hash[24..32].copy_from_slice(&(node_index ^ 0xA5A5_A5A5_A5A5_A5A5).to_le_bytes()); + hash + } + + fn build_queue_data(num_addresses: usize) -> AddressQueueData { + let low_element_indices = (0..num_addresses) + .map(|i| (i as u64).saturating_mul(2)) + .collect::>(); + let mut nodes = BTreeMap::new(); + + for &leaf_index in &low_element_indices { + let mut pos = leaf_index; + for level in 0..HEIGHT { + let sibling_pos = if pos.is_multiple_of(2) { + pos + 1 + } else { + pos - 1 + }; + let node_index = ((level as u64) << 56) | sibling_pos; + nodes + .entry(node_index) + .or_insert_with(|| hash_from_node(node_index)); + pos /= 2; + } + } + + let (nodes, node_hashes): (Vec<_>, Vec<_>) = nodes.into_iter().unzip(); + + AddressQueueData { + addresses: vec![[0u8; 32]; num_addresses], + low_element_values: vec![[1u8; 32]; num_addresses], + low_element_next_values: vec![[2u8; 32]; num_addresses], + low_element_indices, + low_element_next_indices: (0..num_addresses).map(|i| (i as u64) + 1).collect(), + nodes, + node_hashes, + initial_root: [9u8; 32], + leaves_hash_chains: vec![[3u8; 32]; num_addresses.max(1)], + subtrees: vec![[4u8; 32]; HEIGHT], + start_index: 0, + tree_next_insertion_index: 0, + root_seq: 0, + } + } + + #[test] + fn batched_reconstruction_matches_individual_reconstruction() { + let queue = build_queue_data::<40>(128); + + let expected = (0..queue.addresses.len()) + .map(|i| queue.reconstruct_proof::<40>(i).unwrap()) + .collect::>(); + let actual = queue + .reconstruct_proofs::<40>(0..queue.addresses.len()) + .unwrap(); + + assert_eq!(actual, expected); + } + + #[test] + #[ignore = "profiling helper"] + fn profile_reconstruct_proofs_batch() { + const HEIGHT: usize = 40; + const NUM_ADDRESSES: usize = 2_048; + const ITERS: usize = 25; + + let queue = build_queue_data::(NUM_ADDRESSES); + + let baseline_start = Instant::now(); + for _ in 0..ITERS { + let proofs = (0..queue.addresses.len()) + .map(|i| queue.reconstruct_proof::(i).unwrap()) + .collect::>(); + black_box(proofs); + } + let baseline = baseline_start.elapsed(); + + let batched_start = Instant::now(); + for _ in 0..ITERS { + black_box( + queue + .reconstruct_proofs::(0..queue.addresses.len()) + .unwrap(), + ); + } + let batched = batched_start.elapsed(); + + println!( + "queue reconstruction profile: addresses={}, height={}, iters={}, individual={:?}, batched={:?}, speedup={:.2}x", + NUM_ADDRESSES, + HEIGHT, + ITERS, + baseline, + batched, + baseline.as_secs_f64() / batched.as_secs_f64(), + ); + } } /// V2 Queue Elements Result with deduplicated node data diff --git a/sdk-libs/client/src/interface/initialize_config.rs b/sdk-libs/client/src/interface/initialize_config.rs index 7b5919cdb1..9fbeacfe89 100644 --- a/sdk-libs/client/src/interface/initialize_config.rs +++ b/sdk-libs/client/src/interface/initialize_config.rs @@ -7,6 +7,8 @@ use borsh::{BorshDeserialize as AnchorDeserialize, BorshSerialize as AnchorSeria use solana_instruction::{AccountMeta, Instruction}; use solana_pubkey::Pubkey; +use crate::interface::instructions::INITIALIZE_COMPRESSION_CONFIG_DISCRIMINATOR; + /// Default address tree v2 pubkey. pub const ADDRESS_TREE_V2: Pubkey = solana_pubkey::pubkey!("amt2kaJA14v3urZbZvnc5v2np8jqvc4Z8zDep5wbtzx"); @@ -115,16 +117,14 @@ impl InitializeRentFreeConfig { address_space: self.address_space, }; - // Anchor discriminator for "initialize_compression_config" - // SHA256("global:initialize_compression_config")[..8] - const DISCRIMINATOR: [u8; 8] = [133, 228, 12, 169, 56, 76, 222, 61]; - let serialized_data = instruction_data .try_to_vec() .expect("Failed to serialize instruction data"); - let mut data = Vec::with_capacity(DISCRIMINATOR.len() + serialized_data.len()); - data.extend_from_slice(&DISCRIMINATOR); + let mut data = Vec::with_capacity( + INITIALIZE_COMPRESSION_CONFIG_DISCRIMINATOR.len() + serialized_data.len(), + ); + data.extend_from_slice(&INITIALIZE_COMPRESSION_CONFIG_DISCRIMINATOR); data.extend_from_slice(&serialized_data); let instruction = Instruction { diff --git a/sdk-libs/client/src/interface/instructions.rs b/sdk-libs/client/src/interface/instructions.rs index f6d754b9b1..bb4056ceae 100644 --- a/sdk-libs/client/src/interface/instructions.rs +++ b/sdk-libs/client/src/interface/instructions.rs @@ -8,7 +8,7 @@ use light_account::{ CompressedAccountData, InitializeLightConfigParams, Pack, UpdateLightConfigParams, }; use light_sdk::instruction::{ - account_meta::CompressedAccountMetaNoLamportsNoAddress, PackedAccounts, + account_meta::CompressedAccountMetaNoLamportsNoAddress, PackedAccounts, PackedStateTreeInfo, SystemAccountMetaConfig, ValidityProof, }; use light_token::constants::{ @@ -234,12 +234,7 @@ where let output_queue = get_output_queue(&cold_accounts[0].0.tree_info); let output_state_tree_index = remaining_accounts.insert_or_get(output_queue); - let packed_tree_infos = proof.pack_tree_infos(&mut remaining_accounts); - let tree_infos = &packed_tree_infos - .state_trees - .as_ref() - .ok_or("missing state_trees in packed_tree_infos")? - .packed_tree_infos; + let tree_infos = proof.pack_state_tree_infos(&mut remaining_accounts); let mut accounts = program_account_metas.to_vec(); let mut typed_accounts = Vec::with_capacity(cold_accounts.len()); @@ -247,11 +242,15 @@ where // Process PDAs first, then tokens, to match on-chain split_at(token_accounts_offset). for &i in pda_indices.iter().chain(token_indices.iter()) { let (acc, data) = &cold_accounts[i]; - let _queue_index = remaining_accounts.insert_or_get(acc.tree_info.queue); - let tree_info = tree_infos + let proof_tree_info = tree_infos .get(i) .copied() .ok_or("tree info index out of bounds")?; + let queue_index = remaining_accounts.insert_or_get(acc.tree_info.queue); + let tree_info = PackedStateTreeInfo { + queue_pubkey_index: queue_index, + ..proof_tree_info + }; let packed_data = data.pack(&mut remaining_accounts)?; typed_accounts.push(CompressedAccountData { @@ -309,14 +308,9 @@ pub fn build_compress_accounts_idempotent( let output_queue = get_output_queue(&proof.accounts[0].tree_info); let output_state_tree_index = remaining_accounts.insert_or_get(output_queue); - let packed_tree_infos = proof.pack_tree_infos(&mut remaining_accounts); - let tree_infos = packed_tree_infos - .state_trees - .as_ref() - .ok_or("missing state_trees in packed_tree_infos")?; + let tree_infos = proof.pack_state_tree_infos(&mut remaining_accounts); let cold_metas: Vec<_> = tree_infos - .packed_tree_infos .iter() .map(|tree_info| CompressedAccountMetaNoLamportsNoAddress { tree_info: *tree_info, diff --git a/sdk-libs/client/src/interface/load_accounts.rs b/sdk-libs/client/src/interface/load_accounts.rs index 061ad5074b..01d91364c1 100644 --- a/sdk-libs/client/src/interface/load_accounts.rs +++ b/sdk-libs/client/src/interface/load_accounts.rs @@ -1,5 +1,6 @@ //! Load cold accounts API. +use futures::{stream, StreamExt, TryStreamExt}; use light_account::{derive_rent_sponsor_pda, Pack}; use light_compressed_account::{ compressed_account::PackedMerkleContext, instruction_data::compressed_proof::ValidityProof, @@ -53,6 +54,9 @@ pub enum LoadAccountsError { #[error("Cold PDA at index {index} (pubkey {pubkey}) missing data")] MissingPdaCompressed { index: usize, pubkey: Pubkey }, + #[error("Cold PDA (pubkey {pubkey}) missing data")] + MissingPdaCompressedData { pubkey: Pubkey }, + #[error("Cold ATA at index {index} (pubkey {pubkey}) missing data")] MissingAtaCompressed { index: usize, pubkey: Pubkey }, @@ -67,6 +71,8 @@ pub enum LoadAccountsError { } const MAX_ATAS_PER_IX: usize = 8; +const MAX_PDAS_PER_IX: usize = 8; +const PROOF_FETCH_CONCURRENCY: usize = 8; /// Build load instructions for cold accounts. Returns empty vec if all hot. /// @@ -113,14 +119,23 @@ where }) .collect(); - let pda_hashes = collect_pda_hashes(&cold_pdas)?; + let pda_groups = group_pda_specs(&cold_pdas, MAX_PDAS_PER_IX); + let mut pda_offset = 0usize; + let pda_hashes = pda_groups + .iter() + .map(|group| { + let hashes = collect_pda_hashes(group, pda_offset)?; + pda_offset += group.len(); + Ok::<_, LoadAccountsError>(hashes) + }) + .collect::, _>>()?; let ata_hashes = collect_ata_hashes(&cold_atas)?; let mint_hashes = collect_mint_hashes(&cold_mints)?; let (pda_proofs, ata_proofs, mint_proofs) = futures::join!( - fetch_proofs(&pda_hashes, indexer), + fetch_proof_batches(&pda_hashes, indexer), fetch_proofs_batched(&ata_hashes, MAX_ATAS_PER_IX, indexer), - fetch_proofs(&mint_hashes, indexer), + fetch_individual_proofs(&mint_hashes, indexer), ); let pda_proofs = pda_proofs?; @@ -136,9 +151,9 @@ where // 2. DecompressAccountsIdempotent for all cold PDAs (including token PDAs). // Token PDAs are created on-chain via CPI inside DecompressVariant. - for (spec, proof) in cold_pdas.iter().zip(pda_proofs) { + for (group, proof) in pda_groups.into_iter().zip(pda_proofs) { out.push(build_pda_load( - &[spec], + &group, proof, fee_payer, compression_config, @@ -146,21 +161,23 @@ where } // 3. ATA loads (CreateAssociatedTokenAccount + Transfer2) - requires mint to exist - let ata_chunks: Vec<_> = cold_atas.chunks(MAX_ATAS_PER_IX).collect(); - for (chunk, proof) in ata_chunks.into_iter().zip(ata_proofs) { + for (chunk, proof) in cold_atas.chunks(MAX_ATAS_PER_IX).zip(ata_proofs) { out.extend(build_ata_load(chunk, proof, fee_payer)?); } Ok(out) } -fn collect_pda_hashes(specs: &[&PdaSpec]) -> Result, LoadAccountsError> { +fn collect_pda_hashes( + specs: &[&PdaSpec], + start_index: usize, +) -> Result, LoadAccountsError> { specs .iter() .enumerate() .map(|(i, s)| { s.hash().ok_or(LoadAccountsError::MissingPdaCompressed { - index: i, + index: start_index + i, pubkey: s.address(), }) }) @@ -195,23 +212,83 @@ fn collect_mint_hashes(ifaces: &[&AccountInterface]) -> Result, Lo .collect() } -async fn fetch_proofs( +/// Groups already-ordered PDA specs into contiguous runs of the same program id. +/// +/// This preserves input order rather than globally regrouping by program. Callers that +/// want maximal batching across interleaved program ids should sort before calling. +fn group_pda_specs<'a, V>( + specs: &[&'a PdaSpec], + max_per_group: usize, +) -> Vec>> { + debug_assert!(max_per_group > 0, "max_per_group must be non-zero"); + if specs.is_empty() { + return Vec::new(); + } + + let mut groups = Vec::new(); + let mut current = Vec::with_capacity(max_per_group); + let mut current_program: Option = None; + + for spec in specs { + let program_id = spec.program_id(); + let should_split = current_program + .map(|existing| existing != program_id || current.len() >= max_per_group) + .unwrap_or(false); + + if should_split { + groups.push(current); + current = Vec::with_capacity(max_per_group); + } + + current_program = Some(program_id); + current.push(*spec); + } + + if !current.is_empty() { + groups.push(current); + } + + groups +} + +async fn fetch_individual_proofs( hashes: &[[u8; 32]], indexer: &I, ) -> Result, IndexerError> { if hashes.is_empty() { return Ok(vec![]); } - let mut proofs = Vec::with_capacity(hashes.len()); - for hash in hashes { - proofs.push( + + stream::iter(hashes.iter().copied()) + .map(|hash| async move { indexer - .get_validity_proof(vec![*hash], vec![], None) - .await? - .value, - ); + .get_validity_proof(vec![hash], vec![], None) + .await + .map(|response| response.value) + }) + .buffered(PROOF_FETCH_CONCURRENCY) + .try_collect() + .await +} + +async fn fetch_proof_batches( + hash_batches: &[Vec<[u8; 32]>], + indexer: &I, +) -> Result, IndexerError> { + if hash_batches.is_empty() { + return Ok(vec![]); } - Ok(proofs) + + stream::iter(hash_batches.iter().cloned()) + .map(|hashes| async move { + indexer + .get_validity_proof(hashes, vec![], None) + .await + .map(|response| response.value) + }) + .buffered(PROOF_FETCH_CONCURRENCY) + .try_collect() + .await } async fn fetch_proofs_batched( @@ -222,16 +299,13 @@ async fn fetch_proofs_batched( if hashes.is_empty() { return Ok(vec![]); } - let mut proofs = Vec::with_capacity(hashes.len().div_ceil(batch_size)); - for chunk in hashes.chunks(batch_size) { - proofs.push( - indexer - .get_validity_proof(chunk.to_vec(), vec![], None) - .await? - .value, - ); - } - Ok(proofs) + + let hash_batches = hashes + .chunks(batch_size) + .map(|chunk| chunk.to_vec()) + .collect::>(); + + fetch_proof_batches(&hash_batches, indexer).await } fn build_pda_load( @@ -262,11 +336,16 @@ where let hot_addresses: Vec = specs.iter().map(|s| s.address()).collect(); let cold_accounts: Vec<(CompressedAccount, V)> = specs .iter() - .map(|s| { - let compressed = s.compressed().expect("cold spec must have data").clone(); - (compressed, s.variant.clone()) + .map(|s| -> Result<_, LoadAccountsError> { + let compressed = + s.compressed() + .cloned() + .ok_or(LoadAccountsError::MissingPdaCompressedData { + pubkey: s.address(), + })?; + Ok((compressed, s.variant.clone())) }) - .collect(); + .collect::, _>>()?; let program_id = specs.first().map(|s| s.program_id()).unwrap_or_default(); @@ -345,11 +424,7 @@ fn build_transfer2( fee_payer: Pubkey, ) -> Result { let mut packed = PackedAccounts::default(); - let packed_trees = proof.pack_tree_infos(&mut packed); - let tree_infos = packed_trees - .state_trees - .as_ref() - .ok_or_else(|| LoadAccountsError::BuildInstruction("no state trees".into()))?; + let tree_infos = proof.pack_state_tree_infos(&mut packed); let mut token_accounts = Vec::with_capacity(contexts.len()); let mut tlv_data: Vec> = Vec::with_capacity(contexts.len()); @@ -357,12 +432,12 @@ fn build_transfer2( for (i, ctx) in contexts.iter().enumerate() { let token = &ctx.compressed.token; - let tree = tree_infos.packed_tree_infos.get(i).ok_or( - LoadAccountsError::TreeInfoIndexOutOfBounds { + let tree = tree_infos + .get(i) + .ok_or(LoadAccountsError::TreeInfoIndexOutOfBounds { index: i, - len: tree_infos.packed_tree_infos.len(), - }, - )?; + len: tree_infos.len(), + })?; let owner_idx = packed.insert_or_get_config(ctx.wallet_owner, true, false); let ata_idx = packed.insert_or_get(derive_token_ata(&ctx.wallet_owner, &ctx.mint)); diff --git a/sdk-libs/client/src/interface/pack.rs b/sdk-libs/client/src/interface/pack.rs index 804a48751d..97505adabe 100644 --- a/sdk-libs/client/src/interface/pack.rs +++ b/sdk-libs/client/src/interface/pack.rs @@ -12,6 +12,9 @@ use crate::indexer::{TreeInfo, ValidityProofWithContext}; pub enum PackError { #[error("Failed to add system accounts: {0}")] SystemAccounts(#[from] light_sdk::error::LightSdkError), + + #[error("Failed to pack tree infos: {0}")] + Indexer(#[from] crate::indexer::IndexerError), } /// Packed state tree infos from validity proof. @@ -87,7 +90,7 @@ fn pack_proof_internal( // For mint creation: pack address tree first (index 1), then state tree. let (client_packed_tree_infos, state_tree_index) = if include_state_tree { // Pack tree infos first to ensure address tree is at index 1 - let tree_infos = proof.pack_tree_infos(&mut packed); + let tree_infos = proof.pack_tree_infos(&mut packed)?; // Then add state tree (will be after address tree) let state_tree = output_tree @@ -99,7 +102,7 @@ fn pack_proof_internal( (tree_infos, Some(state_idx)) } else { - let tree_infos = proof.pack_tree_infos(&mut packed); + let tree_infos = proof.pack_tree_infos(&mut packed)?; (tree_infos, None) }; let (remaining_accounts, system_offset, _) = packed.to_account_metas(); diff --git a/sdk-libs/client/src/local_test_validator.rs b/sdk-libs/client/src/local_test_validator.rs index 36ed7c04b3..b27daa6a25 100644 --- a/sdk-libs/client/src/local_test_validator.rs +++ b/sdk-libs/client/src/local_test_validator.rs @@ -1,6 +1,7 @@ -use std::process::{Command, Stdio}; +use std::process::Stdio; use light_prover_client::helpers::get_project_root; +use tokio::process::Command; /// Configuration for an upgradeable program to deploy to the validator. #[derive(Debug, Clone)] @@ -57,25 +58,25 @@ impl Default for LightValidatorConfig { pub async fn spawn_validator(config: LightValidatorConfig) { if let Some(project_root) = get_project_root() { - let path = "cli/test_bin/run test-validator"; - let mut path = format!("{}/{}", project_root.trim(), path); + let command = "cli/test_bin/run test-validator"; + let mut command = format!("{}/{}", project_root.trim(), command); if !config.enable_indexer { - path.push_str(" --skip-indexer"); + command.push_str(" --skip-indexer"); } if let Some(limit_ledger_size) = config.limit_ledger_size { - path.push_str(&format!(" --limit-ledger-size {}", limit_ledger_size)); + command.push_str(&format!(" --limit-ledger-size {}", limit_ledger_size)); } for sbf_program in config.sbf_programs.iter() { - path.push_str(&format!( + command.push_str(&format!( " --sbf-program {} {}", sbf_program.0, sbf_program.1 )); } for upgradeable_program in config.upgradeable_programs.iter() { - path.push_str(&format!( + command.push_str(&format!( " --upgradeable-program {} {} {}", upgradeable_program.program_id, upgradeable_program.program_path, @@ -84,18 +85,18 @@ pub async fn spawn_validator(config: LightValidatorConfig) { } if !config.enable_prover { - path.push_str(" --skip-prover"); + command.push_str(" --skip-prover"); } if config.use_surfpool { - path.push_str(" --use-surfpool"); + command.push_str(" --use-surfpool"); } for arg in config.validator_args.iter() { - path.push_str(&format!(" {}", arg)); + command.push_str(&format!(" {}", arg)); } - println!("Starting validator with command: {}", path); + println!("Starting validator with command: {}", command); if config.use_surfpool { // The CLI starts surfpool, prover, and photon, then exits once all @@ -103,24 +104,25 @@ pub async fn spawn_validator(config: LightValidatorConfig) { // is up before the test proceeds. let mut child = Command::new("sh") .arg("-c") - .arg(path) + .arg(command) .stdin(Stdio::null()) .stdout(Stdio::inherit()) .stderr(Stdio::inherit()) .spawn() .expect("Failed to start server process"); - let status = child.wait().expect("Failed to wait for CLI process"); + let status = child.wait().await.expect("Failed to wait for CLI process"); assert!(status.success(), "CLI exited with error: {}", status); } else { - let child = Command::new("sh") + let _child = Command::new("sh") .arg("-c") - .arg(path) + .arg(command) .stdin(Stdio::null()) .stdout(Stdio::null()) .stderr(Stdio::null()) .spawn() .expect("Failed to start server process"); - std::mem::drop(child); + // Intentionally detaching the spawned child; the caller only waits + // for the validator services to become available. tokio::time::sleep(tokio::time::Duration::from_secs(config.wait_time)).await; } } diff --git a/sdk-libs/client/src/utils.rs b/sdk-libs/client/src/utils.rs index b8f2e05ecb..0055f8dbea 100644 --- a/sdk-libs/client/src/utils.rs +++ b/sdk-libs/client/src/utils.rs @@ -15,8 +15,11 @@ pub fn find_light_bin() -> Option { if !output.status.success() { return None; } - // Convert the output into a string (removing any trailing newline) - let light_path = String::from_utf8_lossy(&output.stdout).trim().to_string(); + let light_path = std::str::from_utf8(&output.stdout) + .ok()? + .trim_end_matches("\r\n") + .trim_end_matches('\n') + .to_string(); // Get the parent directory of the 'light' binary let mut light_bin_path = PathBuf::from(light_path); light_bin_path.pop(); // Remove the 'light' binary itself @@ -30,16 +33,16 @@ pub fn find_light_bin() -> Option { #[cfg(feature = "devenv")] { println!("Use only in light protocol monorepo. Using 'git rev-parse --show-toplevel' to find the location of 'light' binary"); - let light_protocol_toplevel = String::from_utf8_lossy( - &std::process::Command::new("git") - .arg("rev-parse") - .arg("--show-toplevel") - .output() - .expect("Failed to get top-level directory") - .stdout, - ) - .trim() - .to_string(); + let output = std::process::Command::new("git") + .arg("rev-parse") + .arg("--show-toplevel") + .output() + .expect("Failed to get top-level directory"); + let light_protocol_toplevel = std::str::from_utf8(&output.stdout) + .ok()? + .trim_end_matches("\r\n") + .trim_end_matches('\n') + .to_string(); let light_path = PathBuf::from(format!("{}/target/deploy/", light_protocol_toplevel)); Some(light_path) } diff --git a/sdk-libs/program-test/src/indexer/test_indexer.rs b/sdk-libs/program-test/src/indexer/test_indexer.rs index 0b5b0583a3..7618e045f3 100644 --- a/sdk-libs/program-test/src/indexer/test_indexer.rs +++ b/sdk-libs/program-test/src/indexer/test_indexer.rs @@ -95,6 +95,21 @@ use crate::accounts::{ }; use crate::indexer::TestIndexerExtensions; +fn build_compressed_proof(body: &str) -> Result { + let proof_json = deserialize_gnark_proof_json(body) + .map_err(|error| IndexerError::CustomError(error.to_string()))?; + let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json) + .map_err(|error| IndexerError::CustomError(error.to_string()))?; + let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c) + .map_err(|error| IndexerError::CustomError(error.to_string()))?; + + Ok(CompressedProof { + a: proof_a, + b: proof_b, + c: proof_c, + }) +} + #[derive(Debug)] pub struct TestIndexer { pub state_merkle_trees: Vec, @@ -472,8 +487,6 @@ impl Indexer for TestIndexer { let account_data = account.value.ok_or(IndexerError::AccountNotFound)?; state_merkle_tree_pubkeys.push(account_data.tree_info.tree); } - println!("state_merkle_tree_pubkeys {:?}", state_merkle_tree_pubkeys); - println!("hashes {:?}", hashes); let mut proof_inputs = vec![]; let mut indices_to_remove = Vec::new(); @@ -495,14 +508,7 @@ impl Indexer for TestIndexer { .output_queue_elements .iter() .find(|(hash, _)| hash == compressed_account); - println!("queue_element {:?}", queue_element); - if let Some((_, index)) = queue_element { - println!("index {:?}", index); - println!( - "accounts.output_queue_batch_size {:?}", - accounts.output_queue_batch_size - ); if accounts.output_queue_batch_size.is_some() && accounts.leaf_index_in_queue_range(*index as usize)? { @@ -513,12 +519,7 @@ impl Indexer for TestIndexer { hash: *compressed_account, root: [0u8; 32], root_index: RootIndex::new_none(), - leaf_index: accounts - .output_queue_elements - .iter() - .position(|(x, _)| x == compressed_account) - .unwrap() - as u64, + leaf_index: *index, tree_info: light_client::indexer::TreeInfo { cpi_context: Some(accounts.accounts.cpi_context), tree: accounts.accounts.merkle_tree, @@ -727,6 +728,7 @@ impl Indexer for TestIndexer { leaves_hash_chains: Vec::new(), subtrees: address_tree_bundle.get_subtrees(), start_index: start as u64, + tree_next_insertion_index: address_tree_bundle.right_most_index() as u64, root_seq: address_tree_bundle.sequence_number(), }) } else { @@ -2084,6 +2086,107 @@ impl TestIndexer { } } +#[cfg(all(test, feature = "v2"))] +mod tests { + use light_compressed_account::compressed_account::CompressedAccount; + + use super::*; + + fn queued_account( + owner: [u8; 32], + merkle_tree: Pubkey, + queue: Pubkey, + leaf_index: u32, + ) -> CompressedAccountWithMerkleContext { + CompressedAccountWithMerkleContext { + compressed_account: CompressedAccount { + owner: owner.into(), + lamports: 0, + address: None, + data: None, + }, + merkle_context: MerkleContext { + merkle_tree_pubkey: merkle_tree.to_bytes().into(), + queue_pubkey: queue.to_bytes().into(), + leaf_index, + prove_by_index: false, + tree_type: TreeType::StateV2, + }, + } + } + + #[tokio::test] + async fn get_validity_proof_preserves_sparse_queue_leaf_indices() { + let merkle_tree = Pubkey::new_unique(); + let queue = Pubkey::new_unique(); + let sparse_leaf_indices = [5_u32, 1, 0, 4]; + + let compressed_accounts: Vec<_> = sparse_leaf_indices + .iter() + .enumerate() + .map(|(i, &leaf_index)| { + queued_account([i as u8 + 1; 32], merkle_tree, queue, leaf_index) + }) + .collect(); + let hashes: Vec<_> = compressed_accounts + .iter() + .map(|account| account.hash().unwrap()) + .collect(); + + let output_queue_elements = hashes + .iter() + .zip(sparse_leaf_indices.iter()) + .map(|(hash, &leaf_index)| (*hash, leaf_index as u64)) + .collect(); + + let indexer = TestIndexer { + state_merkle_trees: vec![StateMerkleTreeBundle { + rollover_fee: 0, + network_fee: 0, + merkle_tree: Box::new(MerkleTree::::new_with_history( + DEFAULT_BATCH_STATE_TREE_HEIGHT, + 0, + 0, + DEFAULT_BATCH_ROOT_HISTORY_LEN, + )), + accounts: StateMerkleTreeAccounts { + merkle_tree, + nullifier_queue: queue, + cpi_context: Pubkey::new_unique(), + tree_type: TreeType::StateV2, + }, + tree_type: TreeType::StateV2, + output_queue_elements, + input_leaf_indices: vec![], + output_queue_batch_size: Some(500), + num_inserted_batches: 0, + }], + address_merkle_trees: vec![], + payer: Keypair::new(), + governance_authority: Keypair::new(), + group_pda: Pubkey::new_unique(), + compressed_accounts, + nullified_compressed_accounts: vec![], + token_compressed_accounts: vec![], + token_nullified_compressed_accounts: vec![], + events: vec![], + onchain_pubkey_index: HashMap::new(), + }; + + let response = Indexer::get_validity_proof(&indexer, hashes, vec![], None) + .await + .unwrap(); + let leaf_indices: Vec = response + .value + .accounts + .iter() + .map(|account| account.leaf_index) + .collect(); + + assert_eq!(leaf_indices, sparse_leaf_indices.map(u64::from)); + } +} + impl TestIndexer { async fn process_inclusion_proofs( &self, @@ -2345,7 +2448,16 @@ impl TestIndexer { new_addresses.unwrap().len() ))); } - let client = Client::new(); + let client = Client::builder() + .no_proxy() + .connect_timeout(Duration::from_secs(5)) + .timeout(Duration::from_secs(120)) + .build() + .map_err(|error| { + IndexerError::CustomError(format!( + "failed to build prover HTTP client: {error}" + )) + })?; let (account_proof_inputs, address_proof_inputs, json_payload) = match (compressed_accounts, new_addresses) { (Some(accounts), None) => { @@ -2470,6 +2582,7 @@ impl TestIndexer { }; let mut retries = 3; + let mut last_error = "Failed to get proof from server".to_string(); while retries > 0 { let response_result = client .post(format!("{}{}", SERVER_ADDRESS, PROVE_PATH)) @@ -2477,33 +2590,40 @@ impl TestIndexer { .body(json_payload.clone()) .send() .await; - if let Ok(response_result) = response_result { - if response_result.status().is_success() { - let body = response_result.text().await.unwrap(); - let proof_json = deserialize_gnark_proof_json(&body).unwrap(); - let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); - let (proof_a, proof_b, proof_c) = - compress_proof(&proof_a, &proof_b, &proof_c); - return Ok(ValidityProofWithContext { - accounts: account_proof_inputs, - addresses: address_proof_inputs, - proof: CompressedProof { - a: proof_a, - b: proof_b, - c: proof_c, - } - .into(), - }); + match response_result { + Ok(response_result) => { + let status = response_result.status(); + let body = response_result.text().await.map_err(|error| { + IndexerError::CustomError(format!( + "failed to read prover response body: {error}" + )) + })?; + + if status.is_success() { + return Ok(ValidityProofWithContext { + accounts: account_proof_inputs, + addresses: address_proof_inputs, + proof: build_compressed_proof(&body)?.into(), + }); + } + + let body_preview: String = body.chars().take(512).collect(); + last_error = format!( + "prover returned HTTP {status} for validity proof request: {body_preview}" + ); } - } else { - println!("Error: {:#?}", response_result); + Err(error) => { + last_error = + format!("failed to contact prover for validity proof: {error}"); + } + } + + retries -= 1; + if retries > 0 { tokio::time::sleep(Duration::from_secs(5)).await; - retries -= 1; } } - Err(IndexerError::CustomError( - "Failed to get proof from server".to_string(), - )) + Err(IndexerError::CustomError(last_error)) } } } diff --git a/sdk-libs/sdk-types/src/interface/account/token_seeds.rs b/sdk-libs/sdk-types/src/interface/account/token_seeds.rs index f22657590a..2bd0ee7bdc 100644 --- a/sdk-libs/sdk-types/src/interface/account/token_seeds.rs +++ b/sdk-libs/sdk-types/src/interface/account/token_seeds.rs @@ -265,7 +265,7 @@ where fn into_in_token_data( &self, tree_info: &PackedStateTreeInfo, - output_queue_index: u8, + _output_queue_index: u8, ) -> Result { Ok(MultiInputTokenDataWithContext { amount: self.token_data.amount, @@ -277,7 +277,7 @@ where root_index: tree_info.root_index, merkle_context: PackedMerkleContext { merkle_tree_pubkey_index: tree_info.merkle_tree_pubkey_index, - queue_pubkey_index: output_queue_index, + queue_pubkey_index: tree_info.queue_pubkey_index, leaf_index: tree_info.leaf_index, prove_by_index: tree_info.prove_by_index, }, diff --git a/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs b/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs index 3e32ec6ef3..cc7aa4ba1f 100644 --- a/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs +++ b/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs @@ -142,7 +142,10 @@ where let pda_key = pda_account.key(); let address = derive_address(&pda_key, &ctx.light_config.address_space[0], ctx.program_id); - // 10. Build CompressedAccountInfo for CPI + // 10. Build CompressedAccountInfo for CPI. + // Input nullifiers must keep their original queue basis. The later system-program path + // groups nullifiers by queue index, so rewriting mixed PDA+token inputs onto a shared + // output queue drops whole tree/queue pairs from insertion. let input = InAccountInfo { data_hash: input_data_hash, lamports: 0, diff --git a/sdk-libs/sdk-types/src/interface/program/decompression/processor.rs b/sdk-libs/sdk-types/src/interface/program/decompression/processor.rs index 2fc7cfc811..2fc1f037b8 100644 --- a/sdk-libs/sdk-types/src/interface/program/decompression/processor.rs +++ b/sdk-libs/sdk-types/src/interface/program/decompression/processor.rs @@ -126,7 +126,7 @@ pub struct DecompressCtx<'a, AI: AccountInfoTrait + Clone> { #[cfg(feature = "token")] pub in_tlv: Option>>, #[cfg(feature = "token")] - pub token_seeds: Vec>, + pub token_seeds: Vec>>, } // ============================================================================ @@ -296,7 +296,7 @@ pub struct DecompressAccountsBuilt<'a, AI: AccountInfoTrait + Clone> { pub cpi_context: bool, pub in_token_data: Vec, pub in_tlv: Option>>, - pub token_seeds: Vec>, + pub token_seeds: Vec>>, } /// Validates accounts, dispatches all variants, and collects CPI inputs for @@ -649,13 +649,20 @@ where .map_err(|e| LightSdkTypesError::ProgramError(e.into()))?; } else { // At least one regular token account - use invoke_signed with PDA seeds - let signer_seed_refs: Vec<&[u8]> = token_seeds.iter().map(|s| s.as_slice()).collect(); + let signer_seed_storage: Vec> = token_seeds + .iter() + .map(|seed_group| seed_group.iter().map(|seed| seed.as_slice()).collect()) + .collect(); + let signer_seed_refs: Vec<&[&[u8]]> = signer_seed_storage + .iter() + .map(|seed_group| seed_group.as_slice()) + .collect(); AI::invoke_cpi( &LIGHT_TOKEN_PROGRAM_ID, &transfer2_data, &account_metas, remaining_accounts, - &[signer_seed_refs.as_slice()], + signer_seed_refs.as_slice(), ) .map_err(|e| LightSdkTypesError::ProgramError(e.into()))?; } diff --git a/sdk-libs/sdk-types/src/interface/program/decompression/token.rs b/sdk-libs/sdk-types/src/interface/program/decompression/token.rs index 153943e275..2b6fa37a7a 100644 --- a/sdk-libs/sdk-types/src/interface/program/decompression/token.rs +++ b/sdk-libs/sdk-types/src/interface/program/decompression/token.rs @@ -148,8 +148,9 @@ where ) .map_err(|e| LightSdkTypesError::ProgramError(e.into()))?; - // Push seeds for the Transfer2 CPI (needed for invoke_signed) - ctx.token_seeds.extend(seeds.iter().map(|s| s.to_vec())); + // Push one signer seed group per vault PDA for the later Transfer2 CPI. + ctx.token_seeds + .push(seeds.iter().map(|seed| seed.to_vec()).collect()); } // Push token data for the Transfer2 CPI (common for both ATA and regular paths) diff --git a/sdk-tests/csdk-anchor-full-derived-test/tests/basic_test.rs b/sdk-tests/csdk-anchor-full-derived-test/tests/basic_test.rs index 747c75ce32..585a828f03 100644 --- a/sdk-tests/csdk-anchor-full-derived-test/tests/basic_test.rs +++ b/sdk-tests/csdk-anchor-full-derived-test/tests/basic_test.rs @@ -472,13 +472,12 @@ async fn test_create_pdas_and_mint_auto() { .await .expect("create_load_instructions should succeed"); - println!("all_instructions.len() = {:?}", all_instructions); - - // Expected: 1 PDA+Token ix + 2 ATA ixs (1 create_ata + 1 decompress) + 1 mint ix = 4 + // Expected: 1 mint load, 1 grouped PDA/token load, and 2 ATA instructions + // (create ATA + Transfer2 decompression) = 4 total. assert_eq!( all_instructions.len(), - 6, - "Should have 6 instructions: 1 PDA, 1 Token, 2 create_ata, 1 decompress_ata, 1 mint" + 4, + "Should have 4 instructions: 1 mint, 1 grouped PDA/token load, 1 create_ata, 1 ATA Transfer2" ); // Capture rent sponsor balance before decompression diff --git a/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/read_only.rs b/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/read_only.rs index 154f4e2045..3e3fb4934d 100644 --- a/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/read_only.rs +++ b/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/read_only.rs @@ -127,7 +127,9 @@ async fn create_compressed_account( ) .await? .value; - let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_accounts = rpc_result + .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))?; let output_tree_index = rpc .get_random_state_tree_info() @@ -178,6 +180,7 @@ async fn read_sha256_light_system_cpi( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); @@ -231,6 +234,7 @@ async fn read_sha256_lowlevel( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); @@ -289,7 +293,9 @@ async fn create_compressed_account_poseidon( ) .await? .value; - let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_accounts = rpc_result + .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))?; let output_tree_index = rpc .get_random_state_tree_info() @@ -340,6 +346,7 @@ async fn read_poseidon_light_system_cpi( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); @@ -393,6 +400,7 @@ async fn read_poseidon_lowlevel( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); diff --git a/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/test.rs b/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/test.rs index e19d0742de..e5cde869bd 100644 --- a/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/test.rs +++ b/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/test.rs @@ -171,7 +171,9 @@ async fn create_compressed_account( ) .await? .value; - let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_accounts = rpc_result + .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))?; let output_tree_index = rpc .get_random_state_tree_info() @@ -223,6 +225,7 @@ async fn update_compressed_account( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); @@ -277,6 +280,7 @@ async fn close_compressed_account( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); @@ -340,6 +344,7 @@ async fn reinit_closed_account( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); @@ -388,6 +393,7 @@ async fn close_compressed_account_permanent( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); diff --git a/sdk-tests/sdk-native-test/tests/test.rs b/sdk-tests/sdk-native-test/tests/test.rs index 30d792487f..eb81e8cf47 100644 --- a/sdk-tests/sdk-native-test/tests/test.rs +++ b/sdk-tests/sdk-native-test/tests/test.rs @@ -103,7 +103,10 @@ pub async fn create_pda( .value; let output_merkle_tree_index = accounts.insert_or_get(*merkle_tree_pubkey); - let packed_address_tree_info = rpc_result.pack_tree_infos(&mut accounts).address_trees[0]; + let packed_address_tree_info = rpc_result + .pack_tree_infos(&mut accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? + .address_trees[0]; let (accounts, system_accounts_offset, tree_accounts_offset) = accounts.to_account_metas(); let instruction_data = CreatePdaInstructionData { @@ -147,6 +150,7 @@ pub async fn update_pda( let packed_accounts = rpc_result .pack_tree_infos(&mut accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); diff --git a/sdk-tests/sdk-pinocchio-v1-test/tests/test.rs b/sdk-tests/sdk-pinocchio-v1-test/tests/test.rs index 0ae7f5c029..83e205bbf4 100644 --- a/sdk-tests/sdk-pinocchio-v1-test/tests/test.rs +++ b/sdk-tests/sdk-pinocchio-v1-test/tests/test.rs @@ -101,7 +101,10 @@ pub async fn create_pda( .value; let output_merkle_tree_index = accounts.insert_or_get(*merkle_tree_pubkey); - let packed_address_tree_info = rpc_result.pack_tree_infos(&mut accounts).address_trees[0]; + let packed_address_tree_info = rpc_result + .pack_tree_infos(&mut accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? + .address_trees[0]; let (accounts, system_accounts_offset, tree_accounts_offset) = accounts.to_account_metas(); let instruction_data = CreatePdaInstructionData { proof: rpc_result.proof, @@ -145,6 +148,7 @@ pub async fn update_pda( let packed_accounts = rpc_result .pack_tree_infos(&mut accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); diff --git a/sdk-tests/sdk-pinocchio-v2-test/tests/test.rs b/sdk-tests/sdk-pinocchio-v2-test/tests/test.rs index 59a0562c63..510c98b2b5 100644 --- a/sdk-tests/sdk-pinocchio-v2-test/tests/test.rs +++ b/sdk-tests/sdk-pinocchio-v2-test/tests/test.rs @@ -111,7 +111,8 @@ pub async fn create_pda( .value; let output_merkle_tree_index = accounts.insert_or_get(*merkle_tree_pubkey); - let packed_address_tree_info = rpc_result.pack_tree_infos(&mut accounts).address_trees[0]; + let packed_tree_infos = rpc_result.pack_tree_infos(&mut accounts)?; + let packed_address_tree_info = packed_tree_infos.address_trees[0]; let (accounts, system_accounts_offset, tree_accounts_offset) = accounts.to_account_metas(); let instruction_data = CreatePdaInstructionData { proof: rpc_result.proof, @@ -154,7 +155,7 @@ pub async fn update_pda( .value; let packed_accounts = rpc_result - .pack_tree_infos(&mut accounts) + .pack_tree_infos(&mut accounts)? .state_trees .unwrap(); diff --git a/sdk-tests/sdk-token-test/tests/ctoken_pda.rs b/sdk-tests/sdk-token-test/tests/ctoken_pda.rs index 8e2b595285..10a4cc4680 100644 --- a/sdk-tests/sdk-token-test/tests/ctoken_pda.rs +++ b/sdk-tests/sdk-token-test/tests/ctoken_pda.rs @@ -156,7 +156,9 @@ pub async fn create_mint( let config = SystemAccountMetaConfig::new_with_cpi_context(ID, tree_info.cpi_context.unwrap()); packed_accounts.add_system_accounts_v2(config).unwrap(); // packed_accounts.insert_or_get(tree_info.get_output_pubkey()?); - rpc_result.pack_tree_infos(&mut packed_accounts); + rpc_result + .pack_tree_infos(&mut packed_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))?; // Create PDA parameters let pda_amount = 100u64; diff --git a/sdk-tests/sdk-token-test/tests/decompress_full_cpi.rs b/sdk-tests/sdk-token-test/tests/decompress_full_cpi.rs index 5f096af560..04cef5e47a 100644 --- a/sdk-tests/sdk-token-test/tests/decompress_full_cpi.rs +++ b/sdk-tests/sdk-token-test/tests/decompress_full_cpi.rs @@ -8,7 +8,7 @@ use light_compressed_token_sdk::compressed_token::{ create_compressed_mint::find_mint_address, decompress_full::DecompressFullAccounts, }; use light_program_test::{Indexer, LightProgramTest, ProgramTestConfig, Rpc}; -use light_sdk::instruction::PackedAccounts; +use light_sdk::instruction::{PackedAccounts, PackedStateTreeInfo}; use light_test_utils::{ actions::{legacy::instructions::mint_action::NewMint, mint_action_comprehensive}, airdrop_lamports, @@ -34,6 +34,23 @@ struct TestContext { total_compressed_amount: u64, } +fn pack_input_state_tree_infos( + rpc_result: &light_client::indexer::ValidityProofWithContext, + remaining_accounts: &mut PackedAccounts, +) -> Vec { + rpc_result + .accounts + .iter() + .map(|account| PackedStateTreeInfo { + root_index: account.root_index.root_index().unwrap_or_default(), + merkle_tree_pubkey_index: remaining_accounts.insert_or_get(account.tree_info.tree), + queue_pubkey_index: remaining_accounts.insert_or_get(account.tree_info.queue), + leaf_index: account.leaf_index as u32, + prove_by_index: account.root_index.proof_by_index(), + }) + .collect() +} + /// Setup function for decompress_full tests /// Creates compressed tokens (source) and empty decompressed accounts (destination) async fn setup_decompress_full_test(num_inputs: usize) -> (LightProgramTest, TestContext) { @@ -213,7 +230,7 @@ async fn test_decompress_full_cpi() { .unwrap() .value; - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_infos = pack_input_state_tree_infos(&rpc_result, &mut remaining_accounts); let config = DecompressFullAccounts::new(None); remaining_accounts .add_custom_system_accounts(config) @@ -236,12 +253,7 @@ async fn test_decompress_full_cpi() { let indices: Vec<_> = token_data .iter() .zip( - packed_tree_info - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos - .iter(), + packed_tree_infos.iter(), ) .zip(ctx.destination_accounts.iter()) .zip(versions.iter()) @@ -370,7 +382,7 @@ async fn test_decompress_full_cpi_with_context() { .value; // Add tree accounts first, then custom system accounts (no CPI context since params is None) - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_infos = pack_input_state_tree_infos(&rpc_result, &mut remaining_accounts); let config = DecompressFullAccounts::new(None); remaining_accounts .add_custom_system_accounts(config) @@ -393,12 +405,7 @@ async fn test_decompress_full_cpi_with_context() { let indices: Vec<_> = token_data .iter() .zip( - packed_tree_info - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos - .iter(), + packed_tree_infos.iter(), ) .zip(ctx.destination_accounts.iter()) .zip(versions.iter()) diff --git a/sdk-tests/sdk-token-test/tests/pda_ctoken.rs b/sdk-tests/sdk-token-test/tests/pda_ctoken.rs index 91e0f2db9e..0d38a38a21 100644 --- a/sdk-tests/sdk-token-test/tests/pda_ctoken.rs +++ b/sdk-tests/sdk-token-test/tests/pda_ctoken.rs @@ -214,7 +214,9 @@ pub async fn create_mint( let mut packed_accounts = PackedAccounts::default(); let config = SystemAccountMetaConfig::new_with_cpi_context(ID, tree_info.cpi_context.unwrap()); packed_accounts.add_system_accounts_v2(config).unwrap(); - rpc_result.pack_tree_infos(&mut packed_accounts); + rpc_result + .pack_tree_infos(&mut packed_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))?; // Create PDA parameters let pda_amount = 100u64; diff --git a/sdk-tests/sdk-token-test/tests/test.rs b/sdk-tests/sdk-token-test/tests/test.rs index 3c6941881d..26646657ce 100644 --- a/sdk-tests/sdk-token-test/tests/test.rs +++ b/sdk-tests/sdk-token-test/tests/test.rs @@ -367,7 +367,9 @@ async fn transfer_compressed_tokens( .await? .value; - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_info = rpc_result + .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))?; let output_tree_index = packed_tree_info .state_trees .as_ref() @@ -433,7 +435,9 @@ async fn decompress_compressed_tokens( .await? .value; - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_info = rpc_result + .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))?; let output_tree_index = packed_tree_info .state_trees .as_ref() diff --git a/sdk-tests/sdk-token-test/tests/test_4_invocations.rs b/sdk-tests/sdk-token-test/tests/test_4_invocations.rs index 9e70170056..913bdf39a9 100644 --- a/sdk-tests/sdk-token-test/tests/test_4_invocations.rs +++ b/sdk-tests/sdk-token-test/tests/test_4_invocations.rs @@ -9,7 +9,7 @@ use light_compressed_token_sdk::{ use light_program_test::{AddressWithTree, Indexer, LightProgramTest, ProgramTestConfig, Rpc}; use light_sdk::{ address::v1::derive_address, - instruction::{PackedAccounts, SystemAccountMetaConfig}, + instruction::{PackedAccounts, PackedStateTreeInfo, SystemAccountMetaConfig}, }; use light_test_utils::{ spl::{create_mint_helper, create_token_account, mint_spl_tokens}, @@ -22,6 +22,36 @@ use solana_sdk::{ signature::{Keypair, Signature, Signer}, }; +fn pack_input_state_tree_infos( + rpc_result: &light_client::indexer::ValidityProofWithContext, + remaining_accounts: &mut PackedAccounts, +) -> Vec { + rpc_result + .accounts + .iter() + .map(|account| PackedStateTreeInfo { + root_index: account.root_index.root_index().unwrap_or_default(), + merkle_tree_pubkey_index: remaining_accounts.insert_or_get(account.tree_info.tree), + queue_pubkey_index: remaining_accounts.insert_or_get(account.tree_info.queue), + leaf_index: account.leaf_index as u32, + prove_by_index: account.root_index.proof_by_index(), + }) + .collect() +} + +fn pack_selected_output_tree_index( + tree_info: light_client::indexer::TreeInfo, + remaining_accounts: &mut PackedAccounts, +) -> Result> { + tree_info + .next_tree_info + .map(|next| next.pack_output_tree_index(remaining_accounts)) + .unwrap_or_else(|| tree_info.pack_output_tree_index(remaining_accounts)) + .map_err(|error| Box::new( + RpcError::CustomError(format!("Failed to pack output tree index: {error}")) + )) +} + #[ignore = "fix cpi context usage"] #[tokio::test] async fn test_4_invocations() { @@ -389,7 +419,9 @@ async fn create_compressed_escrow_pda( .await? .value; - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_info = rpc_result + .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))?; let new_address_params = packed_tree_info.address_trees[0] .into_new_address_params_assigned_packed(address_seed, Some(0)); @@ -495,29 +527,18 @@ async fn test_four_invokes_instruction( ) .await? .value; - // We need to pack the tree after the cpi context. - remaining_accounts.insert_or_get(rpc_result.accounts[0].tree_info.tree); - - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); - let output_tree_index = packed_tree_info - .state_trees - .as_ref() - .unwrap() - .output_tree_index; + let output_tree_index = pack_selected_output_tree_index( + mint2_token_account.account.tree_info, + &mut remaining_accounts, + ) + .map_err(|error| *error)?; + let packed_tree_infos = pack_input_state_tree_infos(&rpc_result, &mut remaining_accounts); // Create token metas from compressed accounts - each uses its respective tree info index // Index 0: escrow PDA, Index 1: mint2 token account, Index 2: mint3 token account - let mint2_tree_info = packed_tree_info - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[1]; + let mint2_tree_info = packed_tree_infos[1]; - let mint3_tree_info = packed_tree_info - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[2]; + let mint3_tree_info = packed_tree_infos[2]; // Create FourInvokesParams let four_invokes_params = sdk_token_test::FourInvokesParams { @@ -557,11 +578,7 @@ async fn test_four_invokes_instruction( }; // Create PdaParams - escrow PDA uses tree info index 0 - let escrow_tree_info = packed_tree_info - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[0]; + let escrow_tree_info = packed_tree_infos[0]; let pda_params = sdk_token_test::PdaParams { account_meta: light_sdk::instruction::account_meta::CompressedAccountMeta { diff --git a/sdk-tests/sdk-token-test/tests/test_4_transfer2.rs b/sdk-tests/sdk-token-test/tests/test_4_transfer2.rs index d7ef38a08c..8c7b8d422c 100644 --- a/sdk-tests/sdk-token-test/tests/test_4_transfer2.rs +++ b/sdk-tests/sdk-token-test/tests/test_4_transfer2.rs @@ -27,6 +27,36 @@ use solana_sdk::{ signature::{Keypair, Signer}, }; +fn pack_input_state_tree_infos( + rpc_result: &light_client::indexer::ValidityProofWithContext, + remaining_accounts: &mut PackedAccounts, +) -> Vec { + rpc_result + .accounts + .iter() + .map(|account| PackedStateTreeInfo { + root_index: account.root_index.root_index().unwrap_or_default(), + merkle_tree_pubkey_index: remaining_accounts.insert_or_get(account.tree_info.tree), + queue_pubkey_index: remaining_accounts.insert_or_get(account.tree_info.queue), + leaf_index: account.leaf_index as u32, + prove_by_index: account.root_index.proof_by_index(), + }) + .collect() +} + +fn pack_selected_output_tree_index( + tree_info: light_client::indexer::TreeInfo, + remaining_accounts: &mut PackedAccounts, +) -> Result { + tree_info + .next_tree_info + .map(|next| next.pack_output_tree_index(remaining_accounts)) + .unwrap_or_else(|| tree_info.pack_output_tree_index(remaining_accounts)) + .map_err(|error| { + RpcError::CustomError(format!("Failed to pack output tree index: {error}")) + }) +} + #[tokio::test] async fn test_4_transfer2() { // Initialize the test environment @@ -339,7 +369,9 @@ async fn create_compressed_escrow_pda( .await? .value; - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_info = rpc_result + .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))?; let new_address_params = packed_tree_info.address_trees[0] .into_new_address_params_assigned_packed(address_seed, Some(0)); @@ -435,29 +467,17 @@ async fn test_four_transfer2_instruction( ) .await? .value; - // We need to pack the tree after the cpi context. - remaining_accounts.insert_or_get(rpc_result.accounts[0].tree_info.tree); - - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); - let output_tree_index = packed_tree_info - .state_trees - .as_ref() - .unwrap() - .output_tree_index; + let output_tree_index = pack_selected_output_tree_index( + mint2_token_account.account.tree_info, + &mut remaining_accounts, + )?; + let packed_tree_infos = pack_input_state_tree_infos(&rpc_result, &mut remaining_accounts); // Create token metas from compressed accounts - each uses its respective tree info index // Index 0: escrow PDA, Index 1: mint2 token account, Index 2: mint3 token account - let mint2_tree_info = packed_tree_info - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[1]; + let mint2_tree_info = packed_tree_infos[1]; - let mint3_tree_info = packed_tree_info - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[2]; + let mint3_tree_info = packed_tree_infos[2]; // Create FourTransfer2Params let four_transfer2_params = sdk_token_test::process_four_transfer2::FourTransfer2Params { @@ -491,11 +511,7 @@ async fn test_four_transfer2_instruction( }; // Create PdaParams - escrow PDA uses tree info index 0 - let escrow_tree_info = packed_tree_info - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[0]; + let escrow_tree_info = packed_tree_infos[0]; let pda_params = sdk_token_test::PdaParams { account_meta: light_sdk::instruction::account_meta::CompressedAccountMeta { diff --git a/sdk-tests/sdk-token-test/tests/test_deposit.rs b/sdk-tests/sdk-token-test/tests/test_deposit.rs index 9ebcbd8549..dd962abeec 100644 --- a/sdk-tests/sdk-token-test/tests/test_deposit.rs +++ b/sdk-tests/sdk-token-test/tests/test_deposit.rs @@ -10,7 +10,10 @@ use light_compressed_token_sdk::{ use light_program_test::{AddressWithTree, Indexer, LightProgramTest, ProgramTestConfig, Rpc}; use light_sdk::{ address::v1::derive_address, - instruction::{account_meta::CompressedAccountMeta, PackedAccounts, SystemAccountMetaConfig}, + instruction::{ + account_meta::CompressedAccountMeta, PackedAccounts, PackedStateTreeInfo, + SystemAccountMetaConfig, + }, }; use light_test_utils::{ spl::{create_mint_helper, create_token_account, mint_spl_tokens}, @@ -23,6 +26,55 @@ use solana_sdk::{ signature::{Keypair, Signature, Signer}, }; +fn pack_input_state_tree_infos( + rpc_result: &light_client::indexer::ValidityProofWithContext, + remaining_accounts: &mut PackedAccounts, +) -> Vec { + rpc_result + .accounts + .iter() + .map(|account| PackedStateTreeInfo { + root_index: account.root_index.root_index().unwrap_or_default(), + merkle_tree_pubkey_index: remaining_accounts.insert_or_get(account.tree_info.tree), + queue_pubkey_index: remaining_accounts.insert_or_get(account.tree_info.queue), + leaf_index: account.leaf_index as u32, + prove_by_index: account.root_index.proof_by_index(), + }) + .collect() +} + +fn pack_selected_output_tree_context( + tree_info: light_client::indexer::TreeInfo, + remaining_accounts: &mut PackedAccounts, +) -> Result<(u8, u8, u8), RpcError> { + let (tree, queue, output_state_tree_index) = if let Some(next) = tree_info.next_tree_info { + ( + next.tree, + next.queue, + next.pack_output_tree_index(remaining_accounts) + .map_err(|error| { + RpcError::CustomError(format!("Failed to pack output tree index: {error}")) + })?, + ) + } else { + ( + tree_info.tree, + tree_info.queue, + tree_info + .pack_output_tree_index(remaining_accounts) + .map_err(|error| { + RpcError::CustomError(format!("Failed to pack output tree index: {error}")) + })?, + ) + }; + + Ok(( + remaining_accounts.insert_or_get(tree), + remaining_accounts.insert_or_get(queue), + output_state_tree_index, + )) +} + #[ignore = "fix cpi context usage"] #[tokio::test] async fn test_deposit_compressed_account() { @@ -206,7 +258,9 @@ async fn create_deposit_compressed_account( ) .await? .value; - let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_accounts = rpc_result + .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))?; println!("packed_accounts {:?}", packed_accounts.state_trees); // Create token meta from compressed account @@ -302,9 +356,14 @@ async fn update_deposit_compressed_account( "rpc_result.accounts[0].tree_info.queue {:?}", rpc_result.accounts[0].tree_info.queue.to_bytes() ); - // We need to pack the tree after the cpi context. - let index = remaining_accounts.insert_or_get(rpc_result.accounts[0].tree_info.tree); - println!("index {}", index); + let (output_tree_index, output_tree_queue_index, output_state_tree_index) = + pack_selected_output_tree_context( + rpc_result.accounts[0].tree_info, + &mut remaining_accounts, + )?; + println!("output_tree_index {}", output_tree_index); + println!("output_tree_queue_index {}", output_tree_queue_index); + println!("output_state_tree_index {}", output_state_tree_index); // Get mint from the compressed token account let mint = deposit_ctoken_account.token.mint; println!( @@ -318,15 +377,11 @@ async fn update_deposit_compressed_account( // Get validity proof for the compressed token account and new address println!("rpc_result {:?}", rpc_result); - let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts); - println!("packed_accounts {:?}", packed_accounts.state_trees); + let packed_tree_infos = pack_input_state_tree_infos(&rpc_result, &mut remaining_accounts); + println!("packed_tree_infos {:?}", packed_tree_infos); // TODO: investigate why packed_tree_infos seem to be out of order // Create token meta from compressed account - let tree_info = packed_accounts - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[1]; + let tree_info = packed_tree_infos[1]; let depositing_token_metas = vec![TokenAccountMeta { amount: deposit_ctoken_account.token.amount, delegate_index: None, @@ -335,11 +390,7 @@ async fn update_deposit_compressed_account( tlv: None, }]; println!("depositing_token_metas {:?}", depositing_token_metas); - let tree_info = packed_accounts - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[2]; + let tree_info = packed_tree_infos[2]; let escrowed_token_meta = TokenAccountMeta { amount: escrow_ctoken_account.token.amount, delegate_index: None, @@ -354,19 +405,11 @@ async fn update_deposit_compressed_account( let system_accounts_start_offset = system_accounts_start_offset as u8; println!("remaining_accounts {:?}", remaining_accounts); - let tree_info = packed_accounts - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[0]; + let tree_info = packed_tree_infos[0]; let account_meta = CompressedAccountMeta { tree_info, address: escrow_pda.address.unwrap(), - output_state_tree_index: packed_accounts - .state_trees - .as_ref() - .unwrap() - .output_tree_index, + output_state_tree_index, }; let instruction = Instruction { @@ -381,14 +424,8 @@ async fn update_deposit_compressed_account( .concat(), data: sdk_token_test::instruction::UpdateDeposit { proof: rpc_result.proof, - output_tree_index: packed_accounts - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[0] - .merkle_tree_pubkey_index, - output_tree_queue_index: packed_accounts.state_trees.unwrap().packed_tree_infos[0] - .queue_pubkey_index, + output_tree_index, + output_tree_queue_index, system_accounts_start_offset, token_params: sdk_token_test::TokenParams { deposit_amount: amount, diff --git a/sdk-tests/sdk-v1-native-test/tests/test.rs b/sdk-tests/sdk-v1-native-test/tests/test.rs index a93beab599..2e10e61e14 100644 --- a/sdk-tests/sdk-v1-native-test/tests/test.rs +++ b/sdk-tests/sdk-v1-native-test/tests/test.rs @@ -94,7 +94,8 @@ pub async fn create_pda( .value; let output_merkle_tree_index = accounts.insert_or_get(*merkle_tree_pubkey); - let packed_address_tree_info = rpc_result.pack_tree_infos(&mut accounts).address_trees[0]; + let packed_tree_infos = rpc_result.pack_tree_infos(&mut accounts)?; + let packed_address_tree_info = packed_tree_infos.address_trees[0]; let (accounts, system_accounts_offset, tree_accounts_offset) = accounts.to_account_metas(); let instruction_data = CreatePdaInstructionData { @@ -137,7 +138,7 @@ pub async fn update_pda( .value; let packed_accounts = rpc_result - .pack_tree_infos(&mut accounts) + .pack_tree_infos(&mut accounts)? .state_trees .unwrap(); diff --git a/sparse-merkle-tree/src/indexed_changelog.rs b/sparse-merkle-tree/src/indexed_changelog.rs index bbd30e1ee6..7e6a26cff7 100644 --- a/sparse-merkle-tree/src/indexed_changelog.rs +++ b/sparse-merkle-tree/src/indexed_changelog.rs @@ -29,7 +29,7 @@ pub fn patch_indexed_changelogs( low_element: &mut IndexedElement, new_element: &mut IndexedElement, low_element_next_value: &mut BigUint, - low_leaf_proof: &mut Vec<[u8; 32]>, + low_leaf_proof: &mut [[u8; 32]; HEIGHT], ) -> Result<(), SparseMerkleTreeError> { // Tests are in program-tests/merkle-tree/tests/indexed_changelog.rs let next_indexed_changelog_indices: Vec = (*indexed_changelogs) @@ -69,7 +69,7 @@ pub fn patch_indexed_changelogs( // Patch the next value. *low_element_next_value = BigUint::from_bytes_be(&changelog_entry.element.next_value); // Patch the proof. - *low_leaf_proof = changelog_entry.proof.to_vec(); + *low_leaf_proof = changelog_entry.proof; } // If we found a new low element. @@ -82,7 +82,7 @@ pub fn patch_indexed_changelogs( next_index: new_low_element_changelog_entry.element.next_index, }; - *low_leaf_proof = new_low_element_changelog_entry.proof.to_vec(); + *low_leaf_proof = new_low_element_changelog_entry.proof; new_element.next_index = low_element.next_index; if new_low_element_changelog_index == indexed_changelogs.len() - 1 { return Ok(()); diff --git a/sparse-merkle-tree/tests/indexed_changelog.rs b/sparse-merkle-tree/tests/indexed_changelog.rs index 7d37142b46..59efda6fde 100644 --- a/sparse-merkle-tree/tests/indexed_changelog.rs +++ b/sparse-merkle-tree/tests/indexed_changelog.rs @@ -92,7 +92,8 @@ fn test_indexed_changelog() { next_index: low_element_next_indices[i], }; println!("unpatched new_element: {:?}", new_element); - let mut low_element_proof = low_element_proofs[i].to_vec(); + let mut low_element_proof: [[u8; 32]; 8] = + low_element_proofs[i].as_slice().try_into().unwrap(); let mut low_element_next_value = BigUint::from_bytes_be(&low_element_next_values[i]); if i > 0 { @@ -114,7 +115,7 @@ fn test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&new_element.value).unwrap(), index: low_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); indexed_changelog.push(IndexedChangelogEntry { @@ -124,7 +125,7 @@ fn test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&low_element_next_value).unwrap(), index: new_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); println!("patched -------------------"); @@ -206,7 +207,8 @@ fn debug_test_indexed_changelog() { next_index: low_element_next_indices[i], }; println!("unpatched new_element: {:?}", new_element); - let mut low_element_proof = low_element_proofs[i].to_vec(); + let mut low_element_proof: [[u8; 32]; 8] = + low_element_proofs[i].as_slice().try_into().unwrap(); let mut low_element_next_value = BigUint::from_bytes_be(&low_element_next_values[i]); if i > 0 { @@ -228,7 +230,7 @@ fn debug_test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&new_element.value).unwrap(), index: low_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); indexed_changelog.push(IndexedChangelogEntry { @@ -238,7 +240,7 @@ fn debug_test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&low_element_next_value).unwrap(), index: new_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); man_indexed_array.elements[low_element.index()] = low_element.clone();