diff --git a/circuit-std-rs/src/logup.rs b/circuit-std-rs/src/logup.rs index 01123c23..24f04ee2 100644 --- a/circuit-std-rs/src/logup.rs +++ b/circuit-std-rs/src/logup.rs @@ -326,6 +326,41 @@ impl LogUpSingleKeyTable { &alpha, ); + assert_eq_rational(builder, &v_table, &v_query); + } + pub fn final_check_with_query_count>( + &mut self, + builder: &mut B, + query_count: &[Variable], + ) { + if self.table.is_empty() || self.query_keys.is_empty() { + panic!("empty table or empty query"); + } + + let value_len = self.table[0].len(); + + let alpha = builder.get_random_value(); + let randomness = get_column_randomness(builder, value_len); + + let table_combined = combine_columns(builder, &self.table, &randomness); + let mut inputs = vec![builder.constant(self.table.len() as u32)]; + //append table keys + for i in 0..self.table.len() { + inputs.push(self.table[i][0]); + } + //append query keys + inputs.extend(self.query_keys.clone()); + let v_table = logup_poly_val(builder, &table_combined, &query_count, &alpha); + + let query_combined = combine_columns(builder, &self.query_results, &randomness); + let one = builder.constant(1); + let v_query = logup_poly_val( + builder, + &query_combined, + &vec![one; query_combined.len()], + &alpha, + ); + assert_eq_rational(builder, &v_table, &v_query); } } @@ -455,6 +490,26 @@ impl LogUpRangeProofTable { ); assert_eq_rational(builder, &v_table, &v_query); } + + pub fn final_check_with_query_count>( + &mut self, + builder: &mut B, + query_count: &[Variable], + ) { + let alpha = builder.get_random_value(); + let inputs = self.query_keys.clone(); + + let v_table = logup_poly_val(builder, &self.table_keys, &query_count, &alpha); + + let one = builder.constant(1); + let v_query = logup_poly_val( + builder, + &self.query_keys, + &vec![one; self.query_keys.len()], + &alpha, + ); + assert_eq_rational(builder, &v_table, &v_query); + } } pub fn query_count_hint(inputs: &[F], outputs: &mut [F]) -> Result<(), Error> { diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index a95555a4..1bec7e9f 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -577,7 +577,7 @@ impl>> Context { let dm_shapes = self.propagate_and_get_shapes(); - let (mut cg_kernels, cg_proof_templates, cg_commitments_lens) = if let Some(cg) = cg { + let (cg_kernels, cg_proof_templates, cg_commitments_lens) = if let Some(cg) = cg { for (i, kernel) in cg.kernels.iter().enumerate() { assert_eq!(self.kernels.add(kernel), i); } @@ -616,8 +616,9 @@ impl>> Context { .map(get_pad_shape) .collect::>(); let kernel_primitive = self.kernel_primitives.get(kernel_call.kernel_id); - let kernel = if let Some(cg_kernels) = cg_kernels.as_mut() { - cg_kernels.drain(..1).next().unwrap() + let kernel = if cg_kernels.is_some() { + // Get kernel from loaded kernels by kernel_id + self.kernels.get(kernel_call.kernel_id).clone() } else { let mut psi = Vec::new(); for (s, &ib) in pad_shapes_input.iter().zip(kernel_call.is_broadcast.iter()) { @@ -708,8 +709,9 @@ impl>> Context { }); } - if let Some(cg_kernels) = cg_kernels { - assert!(cg_kernels.is_empty()); + if cg_kernels.is_some() { + // No longer checking if cg_kernels is empty since we no longer consume it + // Kernels were already added earlier via self.kernels.add() assert_eq!(cg_proof_templates.unwrap(), self.proof_templates); assert_eq!(cg_commitments_lens.unwrap(), commitments_lens); Ok(None) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs index 7d7fed98..6a559fa1 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs @@ -73,3 +73,15 @@ where wait_async(ClientHttpHelper::request_exit()) } } + +impl ExpanderNoOverSubscribe +where + as ExpanderPCS>>::Commitment: + AsRef< as ExpanderPCS>>::Commitment>, +{ + /// Lightweight prove that doesn't require computation_graph or prover_setup. + /// Use this after setup() to allow releasing those large data structures before proving. + pub fn prove_lightweight(device_memories: Vec>>) { + client_send_witness_and_prove::(device_memories); + } +} diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index bc980372..cd5c894b 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -4,6 +4,8 @@ use gkr_engine::{ BN254ConfigXN, ExpanderDualVarChallenge, FieldEngine, GKREngine, MPIConfig, MPIEngine, Transcript, }; +use std::collections::HashMap; +use std::fs; use crate::{ frontend::{Config, SIMDField}, @@ -43,6 +45,23 @@ pub fn mpi_prove_no_oversubscribe_impl( where ::FieldConfig: FieldEngine, { + // Check for schedule file and use scheduler if available + if std::path::Path::new("schedule.txt").exists() { + let my_rank = global_mpi_config.world_rank(); + if my_rank == 0 { + eprintln!("⚡ Schedule file detected, using scheduled execution"); + } + return mpi_prove_no_oversubscribe_with_schedule::( + global_mpi_config, + "schedule.txt", + Some("task_mapping.txt"), + prover_setup, + computation_graph, + values, + n_bytes_profiler, + ); + } + let commit_timer = Timer::new("Commit to all input", global_mpi_config.is_root()); let (commitments, states) = if global_mpi_config.is_root() { let (commitments, states) = values @@ -58,6 +77,7 @@ where ), }) .unzip::<_, _, Vec<_>, Vec<_>>(); + (Some(commitments), Some(states)) } else { (None, None) @@ -162,7 +182,6 @@ where true => { if global_mpi_config.is_root() { let mut proofs = proofs.into_iter().map(|p| p.unwrap()).collect::>(); - let pcs_opening_timer = Timer::new("Batch PCS Opening for all kernels", true); let pcs_batch_opening = open_defered_pcs::( prover_setup, @@ -421,6 +440,7 @@ where expander_circuit.layers[0].input_vals = input_vals; expander_circuit.fill_rnd_coefs(transcript); + expander_circuit.evaluate(); #[cfg(feature = "zkcuda_profile")] @@ -436,6 +456,7 @@ where let (claimed_v, challenge) = gkr::gkr_prove(expander_circuit, prover_scratch, transcript, mpi_config); + assert_eq!(claimed_v, FBasic::ChallengeField::from(0u32)); let n_simd_vars_basic = FBasic::SimdCircuitField::PACK_SIZE.ilog2() as usize; @@ -451,3 +472,961 @@ where }, } } + +// ==================== SCHEDULE-BASED EXECUTION ==================== + +/// Schedule representation: rank -> sequence of tasks +#[derive(Debug, Clone)] +pub struct Schedule { + /// Map from rank to list of task names + pub rank_tasks: HashMap>, +} + +impl Schedule { + /// Parse schedule from text file + /// Format: "Rank 0: Task14 -> Task1 -> Task12" + pub fn from_file(path: &str) -> Result { + let content = + fs::read_to_string(path).map_err(|e| format!("Failed to read schedule file: {}", e))?; + + let mut rank_tasks = HashMap::new(); + + for line in content.lines() { + let line = line.trim(); + if line.is_empty() { + continue; + } + + let parts: Vec<&str> = line.split(':').collect(); + if parts.len() != 2 { + return Err(format!("Invalid line format: {}", line)); + } + + // Handle multiple spaces: "Rank 0:" or "Rank 0:" + let rank_part = parts[0].trim(); + if !rank_part.starts_with("Rank") { + return Err(format!("Expected 'Rank X', got: {}", parts[0])); + } + + let rank_str = rank_part + .strip_prefix("Rank") + .ok_or_else(|| format!("Expected 'Rank X', got: {}", parts[0]))? + .trim(); // Trim to handle multiple spaces + + let rank: usize = rank_str + .parse() + .map_err(|e| format!("Invalid rank number '{}': {}", rank_str, e))?; + + let tasks_str = parts[1].trim(); + let tasks: Vec = tasks_str + .split("->") + .map(|t| t.trim().to_string()) + .filter(|t| !t.is_empty()) + .collect(); + + rank_tasks.insert(rank, tasks); + } + + Ok(Schedule { rank_tasks }) + } + + pub fn get_tasks(&self, rank: usize) -> Option<&Vec> { + self.rank_tasks.get(&rank) + } + + pub fn find_peers_at_step(&self, step: usize, task_name: &str) -> Vec { + let mut peers = Vec::new(); + for (rank, tasks) in &self.rank_tasks { + if tasks.len() > step && tasks[step] == task_name { + peers.push(*rank); + } + } + peers.sort(); + peers + } + + pub fn max_steps(&self) -> usize { + self.rank_tasks + .values() + .map(|tasks| tasks.len()) + .max() + .unwrap_or(0) + } + + /// Get all unique task names across all ranks + pub fn get_all_unique_tasks(&self) -> Vec { + use std::collections::HashSet; + let mut unique_tasks = HashSet::new(); + + for tasks in self.rank_tasks.values() { + for task in tasks { + if task != "idle" && task != "..." { + unique_tasks.insert(task.clone()); + } + } + } + + let mut tasks: Vec<_> = unique_tasks.into_iter().collect(); + tasks.sort(); + tasks + } + + /// Find all ranks that will execute a given task (across all steps) + pub fn find_all_peers_for_task(&self, task_name: &str) -> Vec { + let mut peers = Vec::new(); + + for (rank, tasks) in &self.rank_tasks { + if tasks.contains(&task_name.to_string()) { + peers.push(*rank); + } + } + + peers.sort(); + peers + } +} + +/// Parse task mapping file +pub fn parse_task_mapping(path: &str) -> Result, String> { + let content = + fs::read_to_string(path).map_err(|e| format!("Failed to read task mapping file: {}", e))?; + + let mut mapping = HashMap::new(); + + for line in content.lines() { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { + continue; + } + + let parts: Vec<&str> = line.split(':').collect(); + if parts.len() != 2 { + continue; + } + + let task_name = parts[0].trim().to_string(); + let template_idx: usize = parts[1] + .trim() + .parse() + .map_err(|e| format!("Invalid template index: {}", e))?; + + mapping.insert(task_name, template_idx); + } + + Ok(mapping) +} + +/// Parse task dependencies file +/// Format: "Task22: Task9, Task23, Task8, Task17, Task18" +pub fn parse_task_dependencies(path: &str) -> Result>, String> { + let content = + fs::read_to_string(path).map_err(|e| format!("Failed to read dependencies file: {}", e))?; + + let mut dependencies = HashMap::new(); + + for line in content.lines() { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { + continue; + } + + let parts: Vec<&str> = line.split(':').collect(); + if parts.len() != 2 { + continue; + } + + let task_name = parts[0].trim().to_string(); + let deps: Vec = parts[1] + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + + dependencies.insert(task_name, deps); + } + + Ok(dependencies) +} + +/// Mark a task as completed by creating a marker file +fn mark_task_completed(task_name: &str, my_rank: usize, peers: &[usize]) { + // Only the minimum rank in the peer group writes the file + if let Some(&min_rank) = peers.iter().min() { + if my_rank == min_rank { + let marker_path = format!(".task_sync/{}.done", task_name); + if let Err(e) = fs::write( + &marker_path, + format!( + "{}", + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() + ), + ) { + eprintln!( + "[RANK {}] Warning: Failed to write marker for {}: {}", + my_rank, task_name, e + ); + } + } + } +} + +/// Wait for task dependencies to be satisfied +fn wait_for_dependencies(task_name: &str, dependencies: &[String], my_rank: usize) { + if dependencies.is_empty() { + return; + } + + let mut need_wait = false; + let mut waited_deps = vec![]; + + for dep in dependencies { + let marker_path = format!(".task_sync/{}.done", dep); + + if std::path::Path::new(&marker_path).exists() { + continue; + } + + if !need_wait { + need_wait = true; + } + + let start_time = std::time::Instant::now(); + while !std::path::Path::new(&marker_path).exists() { + std::thread::sleep(std::time::Duration::from_millis(100)); + + // Timeout warning + if start_time.elapsed().as_secs() > 600 { + eprintln!( + "[RANK {}] ⚠️ WARNING: Waiting for {} over 10 minutes!", + my_rank, dep + ); + } + } + + if start_time.elapsed().as_secs_f64() > 0.5 { + waited_deps.push((dep.clone(), start_time.elapsed().as_secs_f64())); + } + } + + // Only log if actually waited + if !waited_deps.is_empty() && my_rank % 8 == 0 { + eprintln!( + "[RANK {}] Task {} waited for {} deps (longest: {:.1}s)", + my_rank, + task_name, + waited_deps.len(), + waited_deps + .iter() + .map(|(_, t)| t) + .fold(0.0f64, |a, &b| a.max(b)) + ); + } +} + +/// Create MPI subgroup for a specific task +/// CRITICAL: This is a collective operation - ALL ranks must call this function +fn create_mpi_subgroup_for_task( + global_mpi_config: &MPIConfig<'static>, + peers: &[usize], + task_name: &str, +) -> Option> { + use mpi::topology::Communicator; + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let my_rank = global_mpi_config.world_rank(); + + // Check if I'm in the peer list + let my_position_opt = peers.iter().position(|&r| r == my_rank); + + // Use different colors: 0 for participants, 1 for non-participants + let (color_value, key) = if let Some(pos) = my_position_opt { + // I'm in this task group + let mut hasher = DefaultHasher::new(); + task_name.hash(&mut hasher); + let base_color = (hasher.finish() % 5000) as i32; + (base_color, pos as i32) + } else { + // I'm not in this task group, use a different color + let mut hasher = DefaultHasher::new(); + task_name.hash(&mut hasher); + let base_color = (hasher.finish() % 5000) as i32; + (base_color + 5000, my_rank as i32) // Different color range for non-participants + }; + + let color = mpi::topology::Color::with_value(color_value); + + // CRITICAL: All ranks must call split (collective operation) + let split_comm = unsafe { + global_mpi_config + .world + .unwrap() + .split_by_color_with_key(color, key) + }; + + // Only participants return a valid MPIConfig + if my_position_opt.is_some() { + let split_comm_static: &'static Option<_> = Box::leak(Box::new(split_comm)); + Some(MPIConfig::prover_new( + global_mpi_config.universe, + split_comm_static.as_ref(), + )) + } else { + // Non-participants: still called split but don't use the result + None + } +} + +/// Main prove function with schedule support +pub fn mpi_prove_no_oversubscribe_with_schedule( + global_mpi_config: &MPIConfig<'static>, + schedule_path: &str, + task_mapping_path: Option<&str>, + prover_setup: &ExpanderProverSetup, GetPCS>, + computation_graph: &ComputationGraph, + values: &[impl AsRef<[SIMDField]>], + n_bytes_profiler: &mut NBytesProfiler, +) -> Option>> +where + ::FieldConfig: FieldEngine, +{ + let my_rank = global_mpi_config.world_rank(); + + // Load schedule + let schedule = match Schedule::from_file(schedule_path) { + Ok(s) => s, + Err(e) => { + eprintln!("[RANK {}] Failed to load schedule: {}", my_rank, e); + return None; + } + }; + + if global_mpi_config.is_root() { + eprintln!("========== SCHEDULER MODE =========="); + eprintln!( + " Schedule: {} ranks, max {} steps", + schedule.rank_tasks.len(), + schedule.max_steps() + ); + } + + // Safety checks + let num_templates = computation_graph.proof_templates().len(); + let num_values = values.len(); + + if num_templates == 0 { + eprintln!( + "[RANK {}] ERROR: No templates in computation graph!", + my_rank + ); + return if my_rank == 0 { + Some(CombinedProof { + commitments: vec![], + proofs: vec![], + }) + } else { + None + }; + } + + if num_values == 0 { + eprintln!("[RANK {}] ERROR: Values array is empty!", my_rank); + return None; + } + + // Load task mapping + let task_mapping = if let Some(path) = task_mapping_path { + match parse_task_mapping(path) { + Ok(m) => m, + Err(e) => { + eprintln!("[RANK {}] Failed to load task mapping: {}", my_rank, e); + return None; + } + } + } else { + let mut default_mapping = HashMap::new(); + for i in 0..computation_graph.proof_templates().len() { + default_mapping.insert(format!("Task{}", i), i); + } + default_mapping + }; + + // Load task dependencies (optional) + let task_dependencies = if std::path::Path::new("task_dependencies.txt").exists() { + match parse_task_dependencies("task_dependencies.txt") { + Ok(deps) => { + if global_mpi_config.is_root() { + eprintln!(" Loaded {} task dependencies", deps.len()); + } + deps + } + Err(e) => { + eprintln!( + "[RANK {}] ERROR: Failed to load dependencies: {}", + my_rank, e + ); + HashMap::new() + } + } + } else { + HashMap::new() + }; + + // Create task sync directory (only root) + if global_mpi_config.is_root() { + fs::create_dir_all(".task_sync").ok(); + // Clean old markers + if let Ok(entries) = fs::read_dir(".task_sync") { + for entry in entries { + if let Ok(entry) = entry { + fs::remove_file(entry.path()).ok(); + } + } + } + } + global_mpi_config.barrier(); + + // ========== PRE-CREATE ALL MPI SUBGROUPS ========== + let all_unique_tasks = schedule.get_all_unique_tasks(); + let mut task_mpi_configs: HashMap>> = HashMap::new(); + + if global_mpi_config.is_root() { + eprintln!( + " Pre-creating MPI subgroups for {} tasks...", + all_unique_tasks.len() + ); + } + + for task_name in &all_unique_tasks { + let peers = schedule.find_all_peers_for_task(task_name); + + // All 32 ranks call this together (collective operation) + let mpi_config = if !peers.is_empty() { + create_mpi_subgroup_for_task(global_mpi_config, &peers, task_name) + } else { + None + }; + + task_mpi_configs.insert(task_name.clone(), mpi_config); + } + + if global_mpi_config.is_root() { + eprintln!(" MPI subgroups ready"); + } + global_mpi_config.barrier(); + + // Commit phase (only root) + let commit_timer = Timer::new("Commit to all input", global_mpi_config.is_root()); + let (commitments, states) = if global_mpi_config.is_root() { + eprintln!("[RANK {}] === COMMIT PHASE ===", my_rank); + + let (commitments, states) = values + .iter() + .map(|value| match ZC::BATCH_PCS { + true => max_len_setup_commit_impl::( + prover_setup, + value.as_ref(), + ), + false => local_commit_impl::( + prover_setup.p_keys.get(&value.as_ref().len()).unwrap(), + value.as_ref(), + ), + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + (Some(commitments), Some(states)) + } else { + (None, None) + }; + commit_timer.stop(); + + // Use indexed storage to maintain template order + let num_templates = computation_graph.proof_templates().len(); + + // Store vals and challenges per template to maintain order + let mut vals_per_template: Vec>>>> = + vec![None; num_templates]; + let mut challenges_per_template: Vec< + Option>>>, + > = vec![None; num_templates]; + + let mut vals_ref: Vec<&[SIMDField]> = vec![]; // Keep reference version for non-BATCH_PCS compatibility + let mut challenges: Vec>> = vec![]; // For non-BATCH_PCS compatibility + + // Track which ranks are subgroup roots for result collection + let mut i_am_subgroup_root_for_tasks = vec![]; + + // Get my tasks + let my_tasks = match schedule.get_tasks(my_rank) { + Some(tasks) => tasks, + None => { + eprintln!("[RANK {}] No tasks assigned in schedule", my_rank); + return if my_rank == 0 { + Some(CombinedProof { + commitments: commitments.unwrap(), + proofs: vec![], + }) + } else { + None + }; + } + }; + + // Execute tasks step by step + let mut all_proofs: Vec> = + vec![None; computation_graph.proof_templates().len()]; + let prove_timer = Timer::new("Prove all kernels", global_mpi_config.is_root()); + + for (step, task_name) in my_tasks.iter().enumerate() { + // Skip idle steps + if task_name == "idle" || task_name == "..." { + continue; + } + + // Wait for dependencies before proceeding + if let Some(deps) = task_dependencies.get(task_name) { + wait_for_dependencies(task_name, deps, my_rank); + } + + // Find template index + let template_idx = match task_mapping.get(task_name) { + Some(&idx) => idx, + None => { + eprintln!("[RANK {}] Unknown task: {}", my_rank, task_name); + continue; + } + }; + + if template_idx >= computation_graph.proof_templates().len() { + eprintln!( + "[RANK {}] Invalid template index: {}", + my_rank, template_idx + ); + continue; + } + + // Get pre-created MPI config for this task + let local_mpi_config = task_mpi_configs.get(task_name).and_then(|c| c.clone()); + + // Check if I'm a participant (have valid MPI config or solo task) + let all_peers = schedule.find_all_peers_for_task(task_name); + let i_am_participant = all_peers.contains(&my_rank); + + if !i_am_participant { + continue; + } + + let template = &computation_graph.proof_templates()[template_idx]; + + // Safety check: verify all commitment indices are in bounds + let commit_indices = template.commitment_indices(); + let mut has_error = false; + for &idx in commit_indices { + if idx >= values.len() { + eprintln!( + "[RANK {}] ERROR: Template {} requires value index {} but values.len() = {}", + my_rank, + template_idx, + idx, + values.len() + ); + has_error = true; + } + } + if has_error { + eprintln!( + "[RANK {}] Skipping task {} due to index out of bounds", + my_rank, task_name + ); + continue; + } + + let commitment_values = template + .commitment_indices() + .iter() + .map(|&idx| values[idx].as_ref()) + .collect::>(); + + // Execute GKR + let single_kernel_gkr_timer = Timer::new( + &format!("Task {} GKR", task_name), + local_mpi_config + .as_ref() + .map(|c| c.is_root()) + .unwrap_or(true), + ); + + let gkr_end_state = if let Some(ref local_config) = local_mpi_config { + prove_kernel_gkr_no_oversubscribe::, GetTranscript, ZC::ECCConfig>( + local_config, + &computation_graph.kernels()[template.kernel_id()], + &commitment_values, + next_power_of_two(template.parallel_count()), + template.is_broadcast(), + n_bytes_profiler, + ) + } else { + eprintln!( + "[RANK {}] Executing task {} solo (creating singl + e-rank MPI config)", + my_rank, task_name + ); + + // Create a single-rank MPI config for solo tasks + // We use the pre-created config from task_mpi_configs + let solo_config = task_mpi_configs.get(task_name).and_then(|c| c.as_ref()); + + if let Some(config) = solo_config { + prove_kernel_gkr_no_oversubscribe::< + GetFieldConfig, + GetTranscript, + ZC::ECCConfig, + >( + config, + &computation_graph.kernels()[template.kernel_id()], + &commitment_values, + next_power_of_two(template.parallel_count()), + template.is_broadcast(), + n_bytes_profiler, + ) + } else { + // Fallback: use global MPI config for solo task + prove_kernel_gkr_no_oversubscribe::< + GetFieldConfig, + GetTranscript, + ZC::ECCConfig, + >( + global_mpi_config, + &computation_graph.kernels()[template.kernel_id()], + &commitment_values, + next_power_of_two(template.parallel_count()), + template.is_broadcast(), + n_bytes_profiler, + ) + } + }; + + single_kernel_gkr_timer.stop(); + + // PCS opening + if let Some((mut transcript, challenge)) = gkr_end_state { + let is_subgroup_root = local_mpi_config + .as_ref() + .map(|c| c.is_root()) + .unwrap_or(true); + + if is_subgroup_root { + i_am_subgroup_root_for_tasks.push(template_idx); + + match ZC::BATCH_PCS { + true => { + assert!(challenge.challenge_y().is_none()); + let challenge_x = challenge.challenge_x(); + + let (local_vals_ref, local_challenges) = extract_pcs_claims::( + &commitment_values, + &challenge_x, + template.is_broadcast(), + next_power_of_two(template.parallel_count()), + ); + + // Store in indexed structure to maintain template order + let owned_vals: Vec> = + local_vals_ref.iter().map(|v| v.to_vec()).collect(); + vals_per_template[template_idx] = Some(owned_vals); + challenges_per_template[template_idx] = Some(local_challenges); + + all_proofs[template_idx] = Some(ExpanderProof { + data: vec![transcript.finalize_and_get_proof()], + }); + } + false => { + let pcs_open_timer = Timer::new(&format!("Task {} PCS", task_name), true); + let challenge_list = if let Some(challenge_y) = challenge.challenge_y() { + vec![challenge.challenge_x(), challenge_y] + } else { + vec![challenge.challenge_x()] + }; + + challenge_list.iter().for_each(|c| { + partition_single_gkr_claim_and_open_pcs_mpi::( + prover_setup, + &commitment_values, + &template + .commitment_indices() + .iter() + .map(|&idx| &states.as_ref().unwrap()[idx]) + .collect::>(), + c, + template.is_broadcast(), + &mut transcript, + ); + }); + + pcs_open_timer.stop(); + all_proofs[template_idx] = Some(ExpanderProof { + data: vec![transcript.finalize_and_get_proof()], + }); + } + } + } + } + + // Mark task as completed (file-based synchronization) + mark_task_completed(task_name, my_rank, &all_peers); + } + + // Wait for all ranks to complete all tasks + global_mpi_config.barrier(); + prove_timer.stop(); + + // ========== MPI Result Collection (for BATCH_PCS mode) ========== + // Collect vals_ref and challenges from all subgroup roots to global root + if ZC::BATCH_PCS { + use mpi::traits::*; + use serdes::ExpSerde; + + let i_am_subgroup_root = !i_am_subgroup_root_for_tasks.is_empty(); + + // Step 0: All ranks send a flag to root indicating if they are subgroup roots + let my_flag = if i_am_subgroup_root { 1u8 } else { 0u8 }; + + let all_flags = if global_mpi_config.is_root() { + // Root collects flags from all ranks + let mut flags = vec![0u8; global_mpi_config.world_size()]; + unsafe { + global_mpi_config + .world + .unwrap() + .all_gather_into(&my_flag, &mut flags[..]); + } + Some(flags) + } else { + unsafe { + let mut flags = vec![0u8; global_mpi_config.world_size()]; + global_mpi_config + .world + .unwrap() + .all_gather_into(&my_flag, &mut flags[..]); + } + None + }; + + // Step 1: Non-root subgroup roots send their results to rank 0 + if i_am_subgroup_root && my_rank != 0 { + // Serialize the indexed structures (maintains template order) + let mut vals_bytes = Vec::new(); + vals_per_template.serialize_into(&mut vals_bytes).unwrap(); + + let mut challenges_bytes = Vec::new(); + challenges_per_template + .serialize_into(&mut challenges_bytes) + .unwrap(); + + let mut proofs_bytes = Vec::new(); + all_proofs.serialize_into(&mut proofs_bytes).unwrap(); + + // Send sizes first + let sizes = [vals_bytes.len(), challenges_bytes.len(), proofs_bytes.len()]; + + unsafe { + global_mpi_config + .world + .unwrap() + .process_at_rank(0) + .synchronous_send(&sizes[..]); + + global_mpi_config + .world + .unwrap() + .process_at_rank(0) + .synchronous_send(&vals_bytes[..]); + + global_mpi_config + .world + .unwrap() + .process_at_rank(0) + .synchronous_send(&challenges_bytes[..]); + + global_mpi_config + .world + .unwrap() + .process_at_rank(0) + .synchronous_send(&proofs_bytes[..]); + } + } + + // Step 2: Global root receives all results + if global_mpi_config.is_root() { + let flags = all_flags.unwrap(); + + // Identify which ranks are subgroup roots (have flag=1) + let subgroup_roots: Vec = flags + .iter() + .enumerate() + .filter(|(_, &flag)| flag == 1) + .map(|(rank, _)| rank) + .collect(); + + // Receive from each subgroup root (except self) + for &sender_rank in &subgroup_roots { + if sender_rank == 0 { + continue; // Skip self + } + + // Receive sizes + let (sizes, _status) = unsafe { + global_mpi_config + .world + .unwrap() + .process_at_rank(sender_rank as i32) + .receive_vec::() + }; + + // Receive vals_per_template (indexed structure) + let (vals_bytes, _) = unsafe { + global_mpi_config + .world + .unwrap() + .process_at_rank(sender_rank as i32) + .receive_vec::() + }; + let received_vals_per_template: Vec>>>> = + Vec::deserialize_from(&mut vals_bytes.as_slice()).unwrap(); + + // Receive challenges_per_template + let (challenges_bytes, _) = unsafe { + global_mpi_config + .world + .unwrap() + .process_at_rank(sender_rank as i32) + .receive_vec::() + }; + let received_challenges_per_template: Vec< + Option>>>, + > = Vec::deserialize_from(&mut challenges_bytes.as_slice()).unwrap(); + + // Receive proofs (indexed by template) + let (proofs_bytes, _) = unsafe { + global_mpi_config + .world + .unwrap() + .process_at_rank(sender_rank as i32) + .receive_vec::() + }; + let received_all_proofs: Vec> = + Vec::deserialize_from(&mut proofs_bytes.as_slice()).unwrap(); + + // Merge indexed data (maintains template order) + for template_idx in 0..num_templates { + // Merge vals + if received_vals_per_template[template_idx].is_some() { + vals_per_template[template_idx] = + received_vals_per_template[template_idx].clone(); + } + + // Merge challenges + if received_challenges_per_template[template_idx].is_some() { + challenges_per_template[template_idx] = + received_challenges_per_template[template_idx].clone(); + } + + // Merge proofs + if received_all_proofs[template_idx].is_some() { + all_proofs[template_idx] = received_all_proofs[template_idx].clone(); + } + } + } + + // Build final vals_ref and challenges in template order + let mut vals_ref_owned: Vec>> = vec![]; + let mut challenges_final = vec![]; + + for template_idx in 0..num_templates { + if let Some(vals) = &vals_per_template[template_idx] { + vals_ref_owned.extend(vals.clone()); + } + if let Some(chals) = &challenges_per_template[template_idx] { + challenges_final.extend(chals.clone()); + } + } + + let completed_templates = all_proofs.iter().filter(|p| p.is_some()).count(); + eprintln!( + "Result collection: {}/{} templates, {} vals, {} challenges", + completed_templates, + num_templates, + vals_ref_owned.len(), + challenges_final.len() + ); + + if completed_templates < num_templates { + eprintln!( + "⚠️ WARNING: Only {}/{} templates completed!", + completed_templates, num_templates + ); + for (idx, val) in vals_per_template.iter().enumerate() { + if val.is_none() { + eprintln!(" Missing: Template {}", idx); + } + } + } + } + } + + // Collect results + match ZC::BATCH_PCS { + true => { + if global_mpi_config.is_root() { + let mut proofs = all_proofs.into_iter().filter_map(|p| p).collect::>(); + + let pcs_opening_timer = Timer::new("Batch PCS Opening for all kernels", true); + + // Build final vals_ref and challenges in template order + let mut vals_ref_owned: Vec>> = vec![]; + let mut challenges_final = vec![]; + + for template_idx in 0..num_templates { + if let Some(vals) = &vals_per_template[template_idx] { + vals_ref_owned.extend(vals.clone()); + } + if let Some(chals) = &challenges_per_template[template_idx] { + challenges_final.extend(chals.clone()); + } + } + + // Convert to references for open_defered_pcs + let vals_ref_for_pcs: Vec<&[SIMDField]> = + vals_ref_owned.iter().map(|v| v.as_slice()).collect(); + + let pcs_batch_opening = open_defered_pcs::( + prover_setup, + &vals_ref_for_pcs, + &challenges_final, + ); + pcs_opening_timer.stop(); + + proofs.push(pcs_batch_opening); + Some(CombinedProof { + commitments: commitments.unwrap(), + proofs, + }) + } else { + None + } + } + false => { + if global_mpi_config.is_root() { + let proofs = all_proofs.into_iter().filter_map(|p| p).collect::>(); + Some(CombinedProof { + commitments: commitments.unwrap(), + proofs, + }) + } else { + None + } + } + } +} diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs index 42315b39..c67517f2 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs @@ -87,7 +87,7 @@ where C: GKREngine, ECCConfig: Config, { - let setup_timer = Timer::new("setup", true); + let setup_timer = Timer::new("new setup", true); println!("Starting server with binary: {server_binary}"); let mut bytes = vec![]; @@ -112,7 +112,11 @@ where let mpi_size = if allow_oversubscribe { max_parallel_count } else { - let num_cpus = prev_power_of_two(num_cpus::get_physical()); + let num_cpus = std::env::var("ZKML_NUM_CPUS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or_else(num_cpus::get_physical); + let num_cpus = prev_power_of_two(num_cpus); if max_parallel_count > num_cpus { num_cpus } else { @@ -136,7 +140,11 @@ where setup_timer.stop(); - SharedMemoryEngine::read_pcs_setup_from_shared_memory() + // SharedMemoryEngine::read_pcs_setup_from_shared_memory() + ( + ExpanderProverSetup::default(), + ExpanderVerifierSetup::default(), + ) } pub fn client_send_witness_and_prove( diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs index b8d6c830..ebaef2b3 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs @@ -24,8 +24,16 @@ pub fn start_server( fn parse_config(mpi_size: usize) -> (String, String, String, String) where { - let oversubscription = if mpi_size > num_cpus::get_physical() { - println!("Warning: Not enough cores available for the requested number of processes. Using oversubscription."); + let force_oversubscribe = std::env::var("ZKML_FORCE_OVERSUBSCRIBE").is_ok(); + + let num_cpus = std::env::var("ZKML_NUM_CPUS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or_else(num_cpus::get_physical); + let oversubscription = if force_oversubscribe || mpi_size > num_cpus { + if mpi_size > num_cpus { + println!("Warning: Not enough cores available for the requested number of processes. Using oversubscription."); + } "--oversubscribe" } else { "" diff --git a/expander_compiler/tests/test_bn254_new_data.rs b/expander_compiler/tests/test_bn254_new_data.rs new file mode 100644 index 00000000..2bccb60e --- /dev/null +++ b/expander_compiler/tests/test_bn254_new_data.rs @@ -0,0 +1,134 @@ +use expander_compiler::frontend::*; +use expander_compiler::zkcuda::proving_system::expander::config::ZKCudaBN254KZGBatchPCS; +use expander_compiler::zkcuda::proving_system::{ExpanderNoOverSubscribe, ProvingSystem}; +use expander_compiler::zkcuda::shape::Reshape; +use expander_compiler::zkcuda::{context::*, kernel::*}; + +#[kernel] +fn add_2_macro(api: &mut API, a: &[InputVariable; 2], b: &mut OutputVariable) { + *b = api.add(a[0], a[1]); +} + +#[kernel] +fn add_16_macro(api: &mut API, a: &[InputVariable; 16], b: &mut OutputVariable) { + let mut sum = api.constant(0); + for i in 0..16 { + sum = api.add(sum, a[i]); + } + *b = sum; +} + +fn test_bn254_load_graph_with_new_data_impl>() { + let kernel_add_2: KernelPrimitive = compile_add_2_macro().unwrap(); + let kernel_add_16: KernelPrimitive = compile_add_16_macro().unwrap(); + + println!("\n===== First execution: create and save graph (BN254) ====="); + let mut ctx1: Context = Context::default(); + + // First set of input data (BN254 field elements) + let mut a1: Vec>> = vec![]; + for i in 0..16 { + a1.push(vec![]); + for j in 0..2 { + a1[i].push(CircuitField::::from((i * 2 + j + 1) as u32)); + } + } + let a1 = ctx1.copy_to_device(&a1); + let mut b1: DeviceMemoryHandle = None; + call_kernel!(ctx1, kernel_add_2, 16, a1, mut b1).unwrap(); + let b1 = b1.reshape(&[1, 16]); + let mut c1: DeviceMemoryHandle = None; + call_kernel!(ctx1, kernel_add_16, 1, b1, mut c1).unwrap(); + let c1 = c1.reshape(&[]); + let result1: CircuitField = ctx1.copy_to_host(c1); + println!("First result: {:?}", result1); + assert_eq!(result1, CircuitField::::from(32 * 33 / 2 as u32)); + + let computation_graph = ctx1.compile_computation_graph().unwrap(); + ctx1.solve_witness().unwrap(); + println!("Starting setup (may take some time)..."); + let (prover_setup, verifier_setup) = P::setup(&computation_graph); + println!("Starting prove..."); + let proof1 = P::prove( + &prover_setup, + &computation_graph, + ctx1.export_device_memories(), + ); + println!("Starting verify..."); + assert!(P::verify(&verifier_setup, &computation_graph, &proof1)); + println!("First verification passed!"); + + println!("\n===== Second execution: call_kernel first (new BN254 data), then load_graph ====="); + let mut ctx2: Context = Context::default(); + + // Second set of input data (different BN254 field elements) + let mut a2: Vec>> = vec![]; + for i in 0..16 { + a2.push(vec![]); + for j in 0..2 { + // Use different values: starting from 1000 + a2[i].push(CircuitField::::from((i * 2 + j + 1000) as u32)); + } + } + let a2 = ctx2.copy_to_device(&a2); + + // Call kernels first (same order as the first time) + let mut b2: DeviceMemoryHandle = None; + println!("Calling first kernel (using new data)..."); + call_kernel!(ctx2, kernel_add_2, 16, a2, mut b2).unwrap(); + + let b2 = b2.reshape(&[1, 16]); + let mut c2: DeviceMemoryHandle = None; + println!("Calling second kernel..."); + call_kernel!(ctx2, kernel_add_16, 1, b2, mut c2).unwrap(); + + let c2 = c2.reshape(&[]); + let result2: CircuitField = ctx2.copy_to_host(c2); + println!("Second computation result: {:?}", result2); + + // Verify results are indeed different + assert_ne!(result1, result2, "The two results should be different"); + + // Expected result for the second run: + // Input: [1000,1001], [1002,1003], ..., [1030,1031] (32 numbers total) + // add_2: 2001, 2005, 2009, ..., 2061 (16 numbers) + // add_16: sum(2001, 2005, ..., 2061) = 16 * (2001 + 2061) / 2 = 32496 + let expected2 = CircuitField::::from(32496u32); + assert_eq!(result2, expected2, "Second result should be 32496"); + + // Now load the graph (reuse compiled kernels) + println!("Loading computation_graph..."); + ctx2.load_computation_graph(computation_graph.clone()) + .unwrap(); + println!("Graph loaded successfully!"); + + // solve_witness (will recalculate using new data) + println!("solve_witness (recalculating witness)..."); + ctx2.solve_witness().unwrap(); + println!("solve_witness succeeded!"); + + // prove (using new data) + println!("prove (generating proof with new data)..."); + let proof2 = P::prove( + &prover_setup, + &computation_graph, + ctx2.export_device_memories(), + ); + println!("prove succeeded!"); + + // verify + println!("verify (verifying proof with new data)..."); + assert!(P::verify(&verifier_setup, &computation_graph, &proof2)); + println!("✓ Second verification passed!"); + println!("✓ Successfully generated and verified different proofs using new BN254 data"); + println!(" - First result: {:?}", result1); + println!(" - Second result: {:?}", result2); + + P::post_process(); +} + +#[test] +fn test_bn254_load_graph_with_new_data() { + test_bn254_load_graph_with_new_data_impl::<_, ExpanderNoOverSubscribe>( + ); +}