From 18818f3284584f35ce201a2c6896fc7a06123005 Mon Sep 17 00:00:00 2001 From: hczphn Date: Sun, 18 Jan 2026 05:32:19 +0000 Subject: [PATCH 01/10] fix bugs to reuse graph --- expander_compiler/src/zkcuda/context.rs | 10 +- .../tests/test_bn254_new_data.rs | 133 ++++++++++++++++++ 2 files changed, 139 insertions(+), 4 deletions(-) create mode 100644 expander_compiler/tests/test_bn254_new_data.rs diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index a95555a4..dbc55b9a 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -616,8 +616,9 @@ impl>> Context { .map(get_pad_shape) .collect::>(); let kernel_primitive = self.kernel_primitives.get(kernel_call.kernel_id); - let kernel = if let Some(cg_kernels) = cg_kernels.as_mut() { - cg_kernels.drain(..1).next().unwrap() + let kernel = if cg_kernels.is_some() { + // 从已加载的 kernels 中通过 kernel_id 获取 + self.kernels.get(kernel_call.kernel_id).clone() } else { let mut psi = Vec::new(); for (s, &ib) in pad_shapes_input.iter().zip(kernel_call.is_broadcast.iter()) { @@ -708,8 +709,9 @@ impl>> Context { }); } - if let Some(cg_kernels) = cg_kernels { - assert!(cg_kernels.is_empty()); + if let Some(_cg_kernels) = cg_kernels { + // 不再检查 cg_kernels 是否为空,因为我们不再消耗它 + // kernels 已经在之前通过 self.kernels.add() 添加了 assert_eq!(cg_proof_templates.unwrap(), self.proof_templates); assert_eq!(cg_commitments_lens.unwrap(), commitments_lens); Ok(None) diff --git a/expander_compiler/tests/test_bn254_new_data.rs b/expander_compiler/tests/test_bn254_new_data.rs new file mode 100644 index 00000000..39e54f96 --- /dev/null +++ b/expander_compiler/tests/test_bn254_new_data.rs @@ -0,0 +1,133 @@ +use expander_compiler::frontend::*; +use expander_compiler::zkcuda::proving_system::expander::config::ZKCudaBN254KZGBatchPCS; +use expander_compiler::zkcuda::proving_system::{ExpanderNoOverSubscribe, ProvingSystem}; +use expander_compiler::zkcuda::shape::Reshape; +use expander_compiler::zkcuda::{context::*, kernel::*}; + +#[kernel] +fn add_2_macro(api: &mut API, a: &[InputVariable; 2], b: &mut OutputVariable) { + *b = api.add(a[0], a[1]); +} + +#[kernel] +fn add_16_macro(api: &mut API, a: &[InputVariable; 16], b: &mut OutputVariable) { + let mut sum = api.constant(0); + for i in 0..16 { + sum = api.add(sum, a[i]); + } + *b = sum; +} + +fn test_bn254_load_graph_with_new_data_impl>() { + + let kernel_add_2: KernelPrimitive = compile_add_2_macro().unwrap(); + let kernel_add_16: KernelPrimitive = compile_add_16_macro().unwrap(); + + println!("\n===== 第一次执行:创建并保存图(BN254) ====="); + let mut ctx1: Context = Context::default(); + + // 第一组输入数据(BN254 field 元素) + let mut a1: Vec>> = vec![]; + for i in 0..16 { + a1.push(vec![]); + for j in 0..2 { + a1[i].push(CircuitField::::from((i * 2 + j + 1) as u32)); + } + } + let a1 = ctx1.copy_to_device(&a1); + let mut b1: DeviceMemoryHandle = None; + call_kernel!(ctx1, kernel_add_2, 16, a1, mut b1).unwrap(); + let b1 = b1.reshape(&[1, 16]); + let mut c1: DeviceMemoryHandle = None; + call_kernel!(ctx1, kernel_add_16, 1, b1, mut c1).unwrap(); + let c1 = c1.reshape(&[]); + let result1: CircuitField = ctx1.copy_to_host(c1); + println!("第一次结果: {:?}", result1); + assert_eq!(result1, CircuitField::::from(32 * 33 / 2 as u32)); + + let computation_graph = ctx1.compile_computation_graph().unwrap(); + ctx1.solve_witness().unwrap(); + println!("开始 setup(可能需要一些时间)..."); + let (prover_setup, verifier_setup) = P::setup(&computation_graph); + println!("开始 prove..."); + let proof1 = P::prove( + &prover_setup, + &computation_graph, + ctx1.export_device_memories(), + ); + println!("开始 verify..."); + assert!(P::verify(&verifier_setup, &computation_graph, &proof1)); + println!("第一次验证通过!"); + + println!("\n===== 第二次执行:先 call_kernel(新的 BN254 数据),再 load_graph ====="); + let mut ctx2: Context = Context::default(); + + // 第二组输入数据(不同的 BN254 field 元素) + let mut a2: Vec>> = vec![]; + for i in 0..16 { + a2.push(vec![]); + for j in 0..2 { + // 使用不同的值:从 1000 开始 + a2[i].push(CircuitField::::from((i * 2 + j + 1000) as u32)); + } + } + let a2 = ctx2.copy_to_device(&a2); + + // 先调用 kernels(和第一次一样的顺序) + let mut b2: DeviceMemoryHandle = None; + println!("调用第一个 kernel(使用新数据)..."); + call_kernel!(ctx2, kernel_add_2, 16, a2, mut b2).unwrap(); + + let b2 = b2.reshape(&[1, 16]); + let mut c2: DeviceMemoryHandle = None; + println!("调用第二个 kernel..."); + call_kernel!(ctx2, kernel_add_16, 1, b2, mut c2).unwrap(); + + let c2 = c2.reshape(&[]); + let result2: CircuitField = ctx2.copy_to_host(c2); + println!("第二次计算结果: {:?}", result2); + + // 验证结果确实不同 + assert_ne!(result1, result2, "两次结果应该不同"); + + // 第二次的预期结果: + // 输入: [1000,1001], [1002,1003], ..., [1030,1031] (共32个数) + // add_2: 2001, 2005, 2009, ..., 2061 (16个数) + // add_16: sum(2001, 2005, ..., 2061) = 16 * (2001 + 2061) / 2 = 32496 + let expected2 = CircuitField::::from(32496u32); + assert_eq!(result2, expected2, "第二次结果应该是 32496"); + + // 现在加载图(复用编译好的 kernels) + println!("加载 computation_graph..."); + ctx2.load_computation_graph(computation_graph.clone()).unwrap(); + println!("图加载成功!"); + + // solve_witness(会使用新数据重新计算) + println!("solve_witness(重新计算 witness)..."); + ctx2.solve_witness().unwrap(); + println!("solve_witness 成功!"); + + // prove(使用新数据) + println!("prove(使用新数据生成证明)..."); + let proof2 = P::prove( + &prover_setup, + &computation_graph, + ctx2.export_device_memories(), + ); + println!("prove 成功!"); + + // verify + println!("verify(验证新数据的证明)..."); + assert!(P::verify(&verifier_setup, &computation_graph, &proof2)); + println!("✓ 第二次验证通过!"); + println!("✓ 成功使用新的 BN254 数据生成并验证了不同的证明"); + println!(" - 第一次结果: {:?}", result1); + println!(" - 第二次结果: {:?}", result2); + + P::post_process(); +} + +#[test] +fn test_bn254_load_graph_with_new_data() { + test_bn254_load_graph_with_new_data_impl::<_, ExpanderNoOverSubscribe>(); +} From 3d69ff89789115c64ae0a99f6320f0cd651e02f5 Mon Sep 17 00:00:00 2001 From: hczphn Date: Mon, 19 Jan 2026 05:04:12 +0000 Subject: [PATCH 02/10] add schedule prove version, pass test --- expander_compiler/src/zkcuda/context.rs | 10 +- expander_compiler/src/zkcuda/cpu_monitor.rs | 196 +++ expander_compiler/src/zkcuda/mod.rs | 1 + .../proving_system/expander/prove_impl.rs | 167 +++ .../expander_no_oversubscribe/prove_impl.rs | 1140 ++++++++++++++++- .../expander_no_oversubscribe/server_bin.rs | 1 + .../expander_parallelized/client_utils.rs | 7 +- .../expander_parallelized/cmd_utils.rs | 14 +- .../expander_parallelized/prove_impl.rs | 380 ++++++ .../expander_parallelized/server_ctrl.rs | 174 ++- .../expander_parallelized/server_fns.rs | 75 +- .../shared_memory_utils.rs | 205 +++ 12 files changed, 2360 insertions(+), 10 deletions(-) create mode 100644 expander_compiler/src/zkcuda/cpu_monitor.rs diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index a95555a4..dbc55b9a 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -616,8 +616,9 @@ impl>> Context { .map(get_pad_shape) .collect::>(); let kernel_primitive = self.kernel_primitives.get(kernel_call.kernel_id); - let kernel = if let Some(cg_kernels) = cg_kernels.as_mut() { - cg_kernels.drain(..1).next().unwrap() + let kernel = if cg_kernels.is_some() { + // 从已加载的 kernels 中通过 kernel_id 获取 + self.kernels.get(kernel_call.kernel_id).clone() } else { let mut psi = Vec::new(); for (s, &ib) in pad_shapes_input.iter().zip(kernel_call.is_broadcast.iter()) { @@ -708,8 +709,9 @@ impl>> Context { }); } - if let Some(cg_kernels) = cg_kernels { - assert!(cg_kernels.is_empty()); + if let Some(_cg_kernels) = cg_kernels { + // 不再检查 cg_kernels 是否为空,因为我们不再消耗它 + // kernels 已经在之前通过 self.kernels.add() 添加了 assert_eq!(cg_proof_templates.unwrap(), self.proof_templates); assert_eq!(cg_commitments_lens.unwrap(), commitments_lens); Ok(None) diff --git a/expander_compiler/src/zkcuda/cpu_monitor.rs b/expander_compiler/src/zkcuda/cpu_monitor.rs new file mode 100644 index 00000000..b02c26e3 --- /dev/null +++ b/expander_compiler/src/zkcuda/cpu_monitor.rs @@ -0,0 +1,196 @@ +/// CPU使用率监控模块 +/// 用于验证commit和PCS opening是否真正用满所有CPU核心 +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; +use std::thread; +use std::time::Duration; + +pub struct CpuMonitor { + stop_flag: Arc, + handle: Option>, +} + +impl CpuMonitor { + /// 开始监控CPU使用率 + /// interval_ms: 采样间隔(毫秒) + pub fn start(tag: &str, interval_ms: u64) -> Self { + let stop_flag = Arc::new(AtomicBool::new(false)); + let stop_flag_clone = stop_flag.clone(); + let tag = tag.to_string(); + + let handle = thread::spawn(move || { + let mut prev_stats = get_cpu_stats(); + + while !stop_flag_clone.load(Ordering::Relaxed) { + thread::sleep(Duration::from_millis(interval_ms)); + + let curr_stats = get_cpu_stats(); + if let (Some(ref prev), Some(ref curr)) = (prev_stats.as_ref(), curr_stats.as_ref()) + { + let usage = calculate_cpu_usage(prev, curr); + let num_cpus = num_cpus::get(); + eprintln!( + "[CPU_MONITOR] {} | Total CPUs: {} | Usage: {:.2}% | Active cores estimate: {:.1}", + tag, num_cpus, usage, usage / 100.0 * num_cpus as f64 + ); + } + + prev_stats = curr_stats; + } + }); + + Self { + stop_flag, + handle: Some(handle), + } + } + + /// 停止监控并返回 + pub fn stop(mut self) { + self.stop_flag.store(true, Ordering::Relaxed); + if let Some(handle) = self.handle.take() { + let _ = handle.join(); + } + } +} + +impl Drop for CpuMonitor { + fn drop(&mut self) { + self.stop_flag.store(true, Ordering::Relaxed); + if let Some(handle) = self.handle.take() { + let _ = handle.join(); + } + } +} + +#[derive(Debug, Clone)] +struct CpuStats { + user: u64, + nice: u64, + system: u64, + idle: u64, + iowait: u64, + irq: u64, + softirq: u64, + steal: u64, +} + +#[cfg(target_os = "linux")] +fn get_cpu_stats() -> Option { + use std::fs; + + let content = fs::read_to_string("/proc/stat").ok()?; + let line = content.lines().next()?; + + // 第一行格式: cpu user nice system idle iowait irq softirq steal + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() < 9 || parts[0] != "cpu" { + return None; + } + + Some(CpuStats { + user: parts[1].parse().ok()?, + nice: parts[2].parse().ok()?, + system: parts[3].parse().ok()?, + idle: parts[4].parse().ok()?, + iowait: parts[5].parse().ok()?, + irq: parts[6].parse().ok()?, + softirq: parts[7].parse().ok()?, + steal: parts[8].parse().ok()?, + }) +} + +#[cfg(not(target_os = "linux"))] +fn get_cpu_stats() -> Option { + None +} + +fn calculate_cpu_usage(prev: &CpuStats, curr: &CpuStats) -> f64 { + let prev_idle = prev.idle + prev.iowait; + let curr_idle = curr.idle + curr.iowait; + + let prev_total = prev.user + + prev.nice + + prev.system + + prev.idle + + prev.iowait + + prev.irq + + prev.softirq + + prev.steal; + let curr_total = curr.user + + curr.nice + + curr.system + + curr.idle + + curr.iowait + + curr.irq + + curr.softirq + + curr.steal; + + let total_diff = curr_total.saturating_sub(prev_total); + let idle_diff = curr_idle.saturating_sub(prev_idle); + + if total_diff == 0 { + return 0.0; + } + + (total_diff.saturating_sub(idle_diff) as f64 / total_diff as f64) * 100.0 +} + +/// 简化版本:单次快照CPU使用情况 +pub fn snapshot_cpu_usage(tag: &str) { + let num_cpus = num_cpus::get(); + + #[cfg(target_os = "linux")] + { + if let Some(stats) = get_cpu_stats() { + // 等待一小段时间再采样 + thread::sleep(Duration::from_millis(100)); + if let Some(stats2) = get_cpu_stats() { + let usage = calculate_cpu_usage(&stats, &stats2); + eprintln!( + "[CPU_SNAPSHOT] {} | Total CPUs: {} | Usage: {:.2}% | Estimated active cores: {:.1}", + tag, num_cpus, usage, usage / 100.0 * num_cpus as f64 + ); + return; + } + } + } + + eprintln!( + "[CPU_SNAPSHOT] {} | Total CPUs: {} | (monitoring not available on this platform)", + tag, num_cpus + ); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cpu_monitor() { + let monitor = CpuMonitor::start("test", 100); + + // 模拟一些CPU密集型工作 + let handles: Vec<_> = (0..4) + .map(|_| { + thread::spawn(|| { + let mut sum = 0u64; + for i in 0..100_000_000 { + sum = sum.wrapping_add(i); + } + sum + }) + }) + .collect(); + + thread::sleep(Duration::from_secs(2)); + + for h in handles { + let _ = h.join(); + } + + monitor.stop(); + } +} diff --git a/expander_compiler/src/zkcuda/mod.rs b/expander_compiler/src/zkcuda/mod.rs index 41acc8d5..072aefe9 100644 --- a/expander_compiler/src/zkcuda/mod.rs +++ b/expander_compiler/src/zkcuda/mod.rs @@ -1,4 +1,5 @@ pub mod context; +pub mod cpu_monitor; pub mod kernel; pub mod mpi_mem_share; pub mod proving_system; diff --git a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs index cc2afad4..7924c343 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs @@ -7,6 +7,7 @@ use gkr_engine::{ }; use polynomials::RefMultiLinearPoly; use serdes::ExpSerde; +use std::time::Instant; use sumcheck::ProverScratchPad; use crate::{ @@ -17,6 +18,28 @@ use crate::{ }, }; +/// 获取当前进程的内存使用情况 (RSS, 单位: KB) +fn get_memory_usage_kb() -> Option { + #[cfg(target_os = "linux")] + { + if let Ok(content) = std::fs::read_to_string("/proc/self/statm") { + let parts: Vec<&str> = content.split_whitespace().collect(); + if parts.len() >= 2 { + // statm 第二个字段是 RSS (以页为单位) + // Linux 页大小通常是 4KB + if let Ok(rss_pages) = parts[1].parse::() { + return Some(rss_pages * 4); // 转换为 KB + } + } + } + None + } + #[cfg(not(target_os = "linux"))] + { + None + } +} + /// ECCCircuit -> ExpanderCircuit /// Returns an additional prover scratch pad for later use in GKR. pub fn prepare_expander_circuit( @@ -28,12 +51,156 @@ where ECCConfig: Config, ECCConfig::FieldConfig: FieldEngine, { + // 记录开始时间 + let start_time = Instant::now(); + eprintln!("[prepare_expander_circuit] ============== Start =============="); + + // 记录开始时的内存 + let mem_before = get_memory_usage_kb(); + eprintln!( + "[prepare_expander_circuit] Memory before: {:?} KB", + mem_before + ); + + // Step 1: export_to_expander().flatten() + let step1_start = Instant::now(); let mut expander_circuit = kernel.layered_circuit().export_to_expander().flatten(); + let step1_duration = step1_start.elapsed(); + eprintln!( + "[prepare_expander_circuit] Step 1 (export_to_expander + flatten) took: {:.3}s", + step1_duration.as_secs_f64() + ); + + // Step 2: 打印电路大小信息 + let step2_start = Instant::now(); + let num_layers = expander_circuit.layers.len(); + let mut total_gates = 0usize; + let mut total_add_gates = 0usize; + let mut total_mul_gates = 0usize; + let mut total_const_gates = 0usize; + + for layer in expander_circuit.layers.iter() { + total_add_gates += layer.add.len(); + total_mul_gates += layer.mul.len(); + total_const_gates += layer.const_.len(); + total_gates += layer.add.len() + layer.mul.len() + layer.const_.len(); + } + + eprintln!("[prepare_expander_circuit] Circuit stats:"); + eprintln!(" - num_layers: {}", num_layers); + eprintln!(" - total_gates: {}", total_gates); + eprintln!(" - total_add_gates: {}", total_add_gates); + eprintln!(" - total_mul_gates: {}", total_mul_gates); + eprintln!(" - total_const_gates: {}", total_const_gates); + eprintln!(" - log_input_size: {}", expander_circuit.log_input_size()); + let step2_duration = step2_start.elapsed(); + eprintln!( + "[prepare_expander_circuit] Step 2 (circuit stats calculation) took: {:.3}s", + step2_duration.as_secs_f64() + ); + + // 记录 export_to_expander().flatten() 后的内存 + let mem_after_flatten = get_memory_usage_kb(); + eprintln!( + "[prepare_expander_circuit] Memory after flatten: {:?} KB", + mem_after_flatten + ); + if let (Some(before), Some(after)) = (mem_before, mem_after_flatten) { + eprintln!( + "[prepare_expander_circuit] Memory delta (flatten): {} KB ({:.2} MB)", + after as i64 - before as i64, + (after as i64 - before as i64) as f64 / 1024.0 + ); + } + + // Step 3: pre_process_gkr + let step3_start = Instant::now(); expander_circuit.pre_process_gkr(); + let step3_duration = step3_start.elapsed(); + eprintln!( + "[prepare_expander_circuit] Step 3 (pre_process_gkr) took: {:.3}s", + step3_duration.as_secs_f64() + ); let (max_num_input_var, max_num_output_var) = super::utils::max_n_vars(&expander_circuit); + eprintln!(" - max_num_input_var: {}", max_num_input_var); + eprintln!(" - max_num_output_var: {}", max_num_output_var); + + // 记录 pre_process_gkr 后的内存 + let mem_after_preprocess = get_memory_usage_kb(); + eprintln!( + "[prepare_expander_circuit] Memory after pre_process_gkr: {:?} KB", + mem_after_preprocess + ); + if let (Some(before), Some(after)) = (mem_after_flatten, mem_after_preprocess) { + eprintln!( + "[prepare_expander_circuit] Memory delta (pre_process_gkr): {} KB ({:.2} MB)", + after as i64 - before as i64, + (after as i64 - before as i64) as f64 / 1024.0 + ); + } + + // Step 4: create ProverScratchPad + let step4_start = Instant::now(); let prover_scratch = ProverScratchPad::::new(max_num_input_var, max_num_output_var, mpi_world_size); + let step4_duration = step4_start.elapsed(); + eprintln!( + "[prepare_expander_circuit] Step 4 (create ProverScratchPad) took: {:.3}s", + step4_duration.as_secs_f64() + ); + + // 记录分配 ProverScratchPad 后的内存 + let mem_after_scratch = get_memory_usage_kb(); + eprintln!( + "[prepare_expander_circuit] Memory after ProverScratchPad: {:?} KB", + mem_after_scratch + ); + if let (Some(before), Some(after)) = (mem_after_preprocess, mem_after_scratch) { + eprintln!( + "[prepare_expander_circuit] Memory delta (ProverScratchPad): {} KB ({:.2} MB)", + after as i64 - before as i64, + (after as i64 - before as i64) as f64 / 1024.0 + ); + } + + // 总内存增量 + if let (Some(before), Some(after)) = (mem_before, mem_after_scratch) { + eprintln!( + "[prepare_expander_circuit] Total memory delta: {} KB ({:.2} MB)", + after as i64 - before as i64, + (after as i64 - before as i64) as f64 / 1024.0 + ); + } + + // 总时间统计 + let total_duration = start_time.elapsed(); + eprintln!("[prepare_expander_circuit] ============== Summary =============="); + eprintln!( + "[prepare_expander_circuit] Step 1 (export + flatten): {:.3}s ({:.1}%)", + step1_duration.as_secs_f64(), + step1_duration.as_secs_f64() / total_duration.as_secs_f64() * 100.0 + ); + eprintln!( + "[prepare_expander_circuit] Step 2 (circuit stats): {:.3}s ({:.1}%)", + step2_duration.as_secs_f64(), + step2_duration.as_secs_f64() / total_duration.as_secs_f64() * 100.0 + ); + eprintln!( + "[prepare_expander_circuit] Step 3 (pre_process_gkr): {:.3}s ({:.1}%)", + step3_duration.as_secs_f64(), + step3_duration.as_secs_f64() / total_duration.as_secs_f64() * 100.0 + ); + eprintln!( + "[prepare_expander_circuit] Step 4 (scratch pad): {:.3}s ({:.1}%)", + step4_duration.as_secs_f64(), + step4_duration.as_secs_f64() / total_duration.as_secs_f64() * 100.0 + ); + eprintln!( + "[prepare_expander_circuit] Total time: {:.3}s", + total_duration.as_secs_f64() + ); + eprintln!("[prepare_expander_circuit] ===================================="); (expander_circuit, prover_scratch) } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index bc980372..34a7e415 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -1,9 +1,39 @@ +use crate::zkcuda::cpu_monitor::CpuMonitor; use arith::{Field, Fr, SimdField}; use expander_utils::timer::Timer; use gkr_engine::{ BN254ConfigXN, ExpanderDualVarChallenge, FieldEngine, GKREngine, MPIConfig, MPIEngine, Transcript, }; +use std::collections::HashMap; +use std::fs; + +/// 获取当前进程的内存使用情况 (RSS, 单位: KB) +fn get_memory_kb() -> u64 { + #[cfg(target_os = "linux")] + { + if let Ok(content) = std::fs::read_to_string("/proc/self/statm") { + let parts: Vec<&str> = content.split_whitespace().collect(); + if parts.len() >= 2 { + if let Ok(rss_pages) = parts[1].parse::() { + return rss_pages * 4; // 页大小 4KB + } + } + } + } + 0 +} + +fn log_memory(rank: usize, tag: &str) { + let mem_kb = get_memory_kb(); + eprintln!( + "[MEM] rank={} {} : {} KB ({:.2} MB)", + rank, + tag, + mem_kb, + mem_kb as f64 / 1024.0 + ); +} use crate::{ frontend::{Config, SIMDField}, @@ -43,8 +73,36 @@ pub fn mpi_prove_no_oversubscribe_impl( where ::FieldConfig: FieldEngine, { + // Check for schedule file and use scheduler if available + if std::path::Path::new("schedule.txt").exists() { + let my_rank = global_mpi_config.world_rank(); + eprintln!( + "[RANK {}] ⚡ Schedule file detected, using scheduled execution", + my_rank + ); + return mpi_prove_no_oversubscribe_with_schedule::( + global_mpi_config, + "schedule.txt", + Some("task_mapping.txt"), + prover_setup, + computation_graph, + values, + n_bytes_profiler, + ); + } + let commit_timer = Timer::new("Commit to all input", global_mpi_config.is_root()); let (commitments, states) = if global_mpi_config.is_root() { + eprintln!("\n========== COMMIT PHASE START =========="); + eprintln!( + "[RANK {}] Starting commit on {} values", + global_mpi_config.world_rank(), + values.len() + ); + + // 启动CPU监控(每200ms采样一次) + let _cpu_monitor = CpuMonitor::start("COMMIT", 200); + let (commitments, states) = values .iter() .map(|value| match ZC::BATCH_PCS { @@ -58,8 +116,16 @@ where ), }) .unzip::<_, _, Vec<_>, Vec<_>>(); + + // _cpu_monitor在这里自动drop,停止监控 + eprintln!("========== COMMIT PHASE END ==========\n"); + (Some(commitments), Some(states)) } else { + eprintln!( + "[RANK {}] Skipping commit (not root)", + global_mpi_config.world_rank() + ); (None, None) }; commit_timer.stop(); @@ -162,14 +228,23 @@ where true => { if global_mpi_config.is_root() { let mut proofs = proofs.into_iter().map(|p| p.unwrap()).collect::>(); - + eprintln!("\n========== PCS OPENING PHASE START =========="); + eprintln!( + "[RANK {}] Starting batch PCS opening for {} values, {} challenges", + global_mpi_config.world_rank(), + vals_ref.len(), + challenges.len() + ); let pcs_opening_timer = Timer::new("Batch PCS Opening for all kernels", true); + // 启动CPU监控 + let _cpu_monitor = CpuMonitor::start("PCS_OPENING", 200); let pcs_batch_opening = open_defered_pcs::( prover_setup, &vals_ref, &challenges, ); pcs_opening_timer.stop(); + eprintln!("========== PCS OPENING PHASE END ==========\n"); proofs.push(pcs_batch_opening); Some(CombinedProof { @@ -338,6 +413,8 @@ where let world_size = mpi_config.world_size(); let n_copies = parallel_count / world_size; + log_memory(world_rank, "prove_kernel_gkr_internal::start"); + let local_commitment_values = get_local_vals_multi_copies( commitments_values, is_broadcast, @@ -345,9 +422,17 @@ where n_copies, parallel_count, ); + log_memory( + world_rank, + "prove_kernel_gkr_internal::after_get_local_vals", + ); let (mut expander_circuit, mut prover_scratch) = prepare_expander_circuit::(kernel, world_size); + log_memory( + world_rank, + "prove_kernel_gkr_internal::after_prepare_expander_circuit", + ); let mut transcript = T::new(); let challenge = prove_gkr_with_local_vals_multi_copies::( @@ -359,6 +444,12 @@ where mpi_config, n_bytes_profiler, ); + log_memory(world_rank, "prove_kernel_gkr_internal::after_prove_gkr"); + + // expander_circuit 和 prover_scratch 在这里被 drop + drop(expander_circuit); + drop(prover_scratch); + log_memory(world_rank, "prove_kernel_gkr_internal::after_drop_circuit"); Some((transcript, challenge)) } @@ -397,6 +488,9 @@ where FieldEngine, T: Transcript, { + let world_rank = mpi_config.world_rank(); + log_memory(world_rank, "prove_gkr::start"); + let input_vals_multi_copies = local_commitment_values_multi_copies .iter() .map(|local_commitment_values| { @@ -407,6 +501,7 @@ where ) }) .collect::>(); + log_memory(world_rank, "prove_gkr::after_prepare_inputs"); let mut input_vals = vec![FMulti::SimdCircuitField::ZERO; 1 << expander_circuit.log_input_size()]; @@ -419,9 +514,13 @@ where *vals = FMulti::SimdCircuitField::pack(&vals_unpacked); } expander_circuit.layers[0].input_vals = input_vals; + log_memory(world_rank, "prove_gkr::after_set_input_vals"); expander_circuit.fill_rnd_coefs(transcript); + log_memory(world_rank, "prove_gkr::after_fill_rnd_coefs"); + expander_circuit.evaluate(); + log_memory(world_rank, "prove_gkr::after_evaluate"); #[cfg(feature = "zkcuda_profile")] { @@ -436,6 +535,8 @@ where let (claimed_v, challenge) = gkr::gkr_prove(expander_circuit, prover_scratch, transcript, mpi_config); + log_memory(world_rank, "prove_gkr::after_gkr_prove"); + assert_eq!(claimed_v, FBasic::ChallengeField::from(0u32)); let n_simd_vars_basic = FBasic::SimdCircuitField::PACK_SIZE.ilog2() as usize; @@ -451,3 +552,1040 @@ where }, } } + +// ==================== SCHEDULE-BASED EXECUTION ==================== + +/// Schedule representation: rank -> sequence of tasks +#[derive(Debug, Clone)] +pub struct Schedule { + /// Map from rank to list of task names + pub rank_tasks: HashMap>, +} + +impl Schedule { + /// Parse schedule from text file + /// Format: "Rank 0: Task14 -> Task1 -> Task12" + pub fn from_file(path: &str) -> Result { + let content = + fs::read_to_string(path).map_err(|e| format!("Failed to read schedule file: {}", e))?; + + let mut rank_tasks = HashMap::new(); + + for line in content.lines() { + let line = line.trim(); + if line.is_empty() { + continue; + } + + let parts: Vec<&str> = line.split(':').collect(); + if parts.len() != 2 { + return Err(format!("Invalid line format: {}", line)); + } + + // Handle multiple spaces: "Rank 0:" or "Rank 0:" + let rank_part = parts[0].trim(); + if !rank_part.starts_with("Rank") { + return Err(format!("Expected 'Rank X', got: {}", parts[0])); + } + + let rank_str = rank_part + .strip_prefix("Rank") + .ok_or_else(|| format!("Expected 'Rank X', got: {}", parts[0]))? + .trim(); // Trim to handle multiple spaces + + let rank: usize = rank_str + .parse() + .map_err(|e| format!("Invalid rank number '{}': {}", rank_str, e))?; + + let tasks_str = parts[1].trim(); + let tasks: Vec = tasks_str + .split("->") + .map(|t| t.trim().to_string()) + .filter(|t| !t.is_empty()) + .collect(); + + rank_tasks.insert(rank, tasks); + } + + Ok(Schedule { rank_tasks }) + } + + pub fn get_tasks(&self, rank: usize) -> Option<&Vec> { + self.rank_tasks.get(&rank) + } + + pub fn find_peers_at_step(&self, step: usize, task_name: &str) -> Vec { + let mut peers = Vec::new(); + for (rank, tasks) in &self.rank_tasks { + if tasks.len() > step && tasks[step] == task_name { + peers.push(*rank); + } + } + peers.sort(); + peers + } + + pub fn max_steps(&self) -> usize { + self.rank_tasks + .values() + .map(|tasks| tasks.len()) + .max() + .unwrap_or(0) + } + + /// Get all unique task names across all ranks + pub fn get_all_unique_tasks(&self) -> Vec { + use std::collections::HashSet; + let mut unique_tasks = HashSet::new(); + + for tasks in self.rank_tasks.values() { + for task in tasks { + if task != "idle" && task != "..." { + unique_tasks.insert(task.clone()); + } + } + } + + let mut tasks: Vec<_> = unique_tasks.into_iter().collect(); + tasks.sort(); + tasks + } + + /// Find all ranks that will execute a given task (across all steps) + pub fn find_all_peers_for_task(&self, task_name: &str) -> Vec { + let mut peers = Vec::new(); + + for (rank, tasks) in &self.rank_tasks { + if tasks.contains(&task_name.to_string()) { + peers.push(*rank); + } + } + + peers.sort(); + peers + } +} + +/// Parse task mapping file +pub fn parse_task_mapping(path: &str) -> Result, String> { + let content = + fs::read_to_string(path).map_err(|e| format!("Failed to read task mapping file: {}", e))?; + + let mut mapping = HashMap::new(); + + for line in content.lines() { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { + continue; + } + + let parts: Vec<&str> = line.split(':').collect(); + if parts.len() != 2 { + continue; + } + + let task_name = parts[0].trim().to_string(); + let template_idx: usize = parts[1] + .trim() + .parse() + .map_err(|e| format!("Invalid template index: {}", e))?; + + mapping.insert(task_name, template_idx); + } + + Ok(mapping) +} + +/// Parse task dependencies file +/// Format: "Task22: Task9, Task23, Task8, Task17, Task18" +pub fn parse_task_dependencies(path: &str) -> Result>, String> { + let content = + fs::read_to_string(path).map_err(|e| format!("Failed to read dependencies file: {}", e))?; + + let mut dependencies = HashMap::new(); + + for line in content.lines() { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { + continue; + } + + let parts: Vec<&str> = line.split(':').collect(); + if parts.len() != 2 { + continue; + } + + let task_name = parts[0].trim().to_string(); + let deps: Vec = parts[1] + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + + dependencies.insert(task_name, deps); + } + + Ok(dependencies) +} + +/// Mark a task as completed by creating a marker file +fn mark_task_completed(task_name: &str, my_rank: usize, peers: &[usize]) { + // Only the minimum rank in the peer group writes the file + if let Some(&min_rank) = peers.iter().min() { + if my_rank == min_rank { + let marker_path = format!(".task_sync/{}.done", task_name); + if let Err(e) = fs::write( + &marker_path, + format!( + "{}", + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() + ), + ) { + eprintln!( + "[RANK {}] Warning: Failed to write marker for {}: {}", + my_rank, task_name, e + ); + } else { + eprintln!( + "[RANK {}] ✓ Marked task {} as completed", + my_rank, task_name + ); + } + } + } +} + +/// Wait for task dependencies to be satisfied +fn wait_for_dependencies(task_name: &str, dependencies: &[String], my_rank: usize) { + if dependencies.is_empty() { + return; + } + + eprintln!( + "[RANK {}] Task {} waiting for {} dependencies: {:?}", + my_rank, + task_name, + dependencies.len(), + dependencies + ); + + for dep in dependencies { + let marker_path = format!(".task_sync/{}.done", dep); + + if std::path::Path::new(&marker_path).exists() { + eprintln!( + "[RANK {}] ✓ Dependency {} already satisfied", + my_rank, dep + ); + continue; + } + + eprintln!("[RANK {}] ⏳ Waiting for dependency: {}", my_rank, dep); + + let start_time = std::time::Instant::now(); + while !std::path::Path::new(&marker_path).exists() { + std::thread::sleep(std::time::Duration::from_millis(100)); + + // Timeout check (optional, for debugging) + if start_time.elapsed().as_secs() > 600 { + eprintln!( + "[RANK {}] ⚠️ WARNING: Waiting for {} over 10 minutes!", + my_rank, dep + ); + } + } + + eprintln!( + "[RANK {}] ✓ Dependency {} satisfied (waited {:.1}s)", + my_rank, + dep, + start_time.elapsed().as_secs_f64() + ); + } + + eprintln!( + "[RANK {}] All dependencies for {} satisfied", + my_rank, task_name + ); +} + +/// Create MPI subgroup for a specific task +/// CRITICAL: This is a collective operation - ALL ranks must call this function +fn create_mpi_subgroup_for_task( + global_mpi_config: &MPIConfig<'static>, + peers: &[usize], + task_name: &str, +) -> Option> { + use mpi::topology::Communicator; + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let my_rank = global_mpi_config.world_rank(); + + // Check if I'm in the peer list + let my_position_opt = peers.iter().position(|&r| r == my_rank); + + // Use different colors: 0 for participants, 1 for non-participants + let (color_value, key) = if let Some(pos) = my_position_opt { + // I'm in this task group + let mut hasher = DefaultHasher::new(); + task_name.hash(&mut hasher); + let base_color = (hasher.finish() % 5000) as i32; + (base_color, pos as i32) + } else { + // I'm not in this task group, use a different color + let mut hasher = DefaultHasher::new(); + task_name.hash(&mut hasher); + let base_color = (hasher.finish() % 5000) as i32; + (base_color + 5000, my_rank as i32) // Different color range for non-participants + }; + + let color = mpi::topology::Color::with_value(color_value); + + // CRITICAL: All ranks must call split (collective operation) + let split_comm = unsafe { + global_mpi_config + .world + .unwrap() + .split_by_color_with_key(color, key) + }; + + // Only participants return a valid MPIConfig + if my_position_opt.is_some() { + let split_comm_static: &'static Option<_> = Box::leak(Box::new(split_comm)); + Some(MPIConfig::prover_new( + global_mpi_config.universe, + split_comm_static.as_ref(), + )) + } else { + // Non-participants: still called split but don't use the result + None + } +} + +/// Main prove function with schedule support +pub fn mpi_prove_no_oversubscribe_with_schedule( + global_mpi_config: &MPIConfig<'static>, + schedule_path: &str, + task_mapping_path: Option<&str>, + prover_setup: &ExpanderProverSetup, GetPCS>, + computation_graph: &ComputationGraph, + values: &[impl AsRef<[SIMDField]>], + n_bytes_profiler: &mut NBytesProfiler, +) -> Option>> +where + ::FieldConfig: FieldEngine, +{ + let my_rank = global_mpi_config.world_rank(); + + // Load schedule + let schedule = match Schedule::from_file(schedule_path) { + Ok(s) => s, + Err(e) => { + eprintln!("[RANK {}] Failed to load schedule: {}", my_rank, e); + return None; + } + }; + + eprintln!("[RANK {}] ========== SCHEDULER MODE ==========", my_rank); + eprintln!( + "[RANK {}] Loaded schedule with {} ranks, max {} steps", + my_rank, + schedule.rank_tasks.len(), + schedule.max_steps() + ); + + // Safety checks + let num_templates = computation_graph.proof_templates().len(); + let num_values = values.len(); + eprintln!( + "[RANK {}] Computation graph has {} templates", + my_rank, num_templates + ); + eprintln!( + "[RANK {}] Values array has {} elements", + my_rank, num_values + ); + + if num_templates == 0 { + eprintln!( + "[RANK {}] ERROR: No templates in computation graph!", + my_rank + ); + return if my_rank == 0 { + Some(CombinedProof { + commitments: vec![], + proofs: vec![], + }) + } else { + None + }; + } + + if num_values == 0 { + eprintln!("[RANK {}] ERROR: Values array is empty!", my_rank); + return None; + } + + // Load task mapping + let task_mapping = if let Some(path) = task_mapping_path { + match parse_task_mapping(path) { + Ok(m) => m, + Err(e) => { + eprintln!("[RANK {}] Failed to load task mapping: {}", my_rank, e); + return None; + } + } + } else { + let mut default_mapping = HashMap::new(); + for i in 0..computation_graph.proof_templates().len() { + default_mapping.insert(format!("Task{}", i), i); + } + default_mapping + }; + + // Load task dependencies (optional) + let task_dependencies = if std::path::Path::new("task_dependencies.txt").exists() { + match parse_task_dependencies("task_dependencies.txt") { + Ok(deps) => { + eprintln!("[RANK {}] Loaded {} task dependencies", my_rank, deps.len()); + deps + } + Err(e) => { + eprintln!( + "[RANK {}] Warning: Failed to load dependencies: {}", + my_rank, e + ); + HashMap::new() + } + } + } else { + eprintln!( + "[RANK {}] No task_dependencies.txt found, using file-based sync", + my_rank + ); + HashMap::new() + }; + + // Create task sync directory (only root) + if global_mpi_config.is_root() { + fs::create_dir_all(".task_sync").ok(); + // Clean old markers + if let Ok(entries) = fs::read_dir(".task_sync") { + for entry in entries { + if let Ok(entry) = entry { + fs::remove_file(entry.path()).ok(); + } + } + } + eprintln!("[RANK 0] Initialized .task_sync directory"); + } + global_mpi_config.barrier(); // Wait for directory creation + + // ========== PRE-CREATE ALL MPI SUBGROUPS ========== + // CRITICAL: Create all task subgroups BEFORE any task execution + // This allows ranks to proceed asynchronously without collective deadlock + eprintln!( + "[RANK {}] Pre-creating MPI subgroups for all tasks...", + my_rank + ); + + let all_unique_tasks = schedule.get_all_unique_tasks(); + let mut task_mpi_configs: HashMap>> = HashMap::new(); + + for task_name in &all_unique_tasks { + let peers = schedule.find_all_peers_for_task(task_name); + + eprintln!( + "[RANK {}] Creating MPI subgroup for task {} (peers: {:?})", + my_rank, task_name, peers + ); + + // All 32 ranks call this together (collective operation) + let mpi_config = if peers.len() >= 1 { + create_mpi_subgroup_for_task(global_mpi_config, &peers, task_name) + } else { + None + }; + + task_mpi_configs.insert(task_name.clone(), mpi_config); + } + + eprintln!( + "[RANK {}] Pre-created {} MPI subgroups", + my_rank, + task_mpi_configs.len() + ); + global_mpi_config.barrier(); // Ensure all subgroups created before proceeding + + // Commit phase (only root) + let commit_timer = Timer::new("Commit to all input", global_mpi_config.is_root()); + let (commitments, states) = if global_mpi_config.is_root() { + eprintln!("[RANK {}] === COMMIT PHASE ===", my_rank); + let _cpu_monitor = CpuMonitor::start("COMMIT", 200); + + let (commitments, states) = values + .iter() + .map(|value| match ZC::BATCH_PCS { + true => max_len_setup_commit_impl::( + prover_setup, + value.as_ref(), + ), + false => local_commit_impl::( + prover_setup.p_keys.get(&value.as_ref().len()).unwrap(), + value.as_ref(), + ), + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + (Some(commitments), Some(states)) + } else { + (None, None) + }; + commit_timer.stop(); + + // Use indexed storage to maintain template order + let num_templates = computation_graph.proof_templates().len(); + + // Store vals and challenges per template to maintain order + let mut vals_per_template: Vec>>>> = + vec![None; num_templates]; + let mut challenges_per_template: Vec< + Option>>>, + > = vec![None; num_templates]; + + let mut vals_ref: Vec<&[SIMDField]> = vec![]; // Keep reference version for non-BATCH_PCS compatibility + let mut challenges: Vec>> = vec![]; // For non-BATCH_PCS compatibility + + // Track which ranks are subgroup roots for result collection + let mut i_am_subgroup_root_for_tasks = vec![]; + + // Get my tasks + let my_tasks = match schedule.get_tasks(my_rank) { + Some(tasks) => tasks, + None => { + eprintln!("[RANK {}] No tasks assigned in schedule", my_rank); + return if my_rank == 0 { + Some(CombinedProof { + commitments: commitments.unwrap(), + proofs: vec![], + }) + } else { + None + }; + } + }; + + eprintln!("[RANK {}] My tasks: {:?}", my_rank, my_tasks); + + // Execute tasks step by step + let mut all_proofs: Vec> = + vec![None; computation_graph.proof_templates().len()]; + + for (step, task_name) in my_tasks.iter().enumerate() { + eprintln!( + "[RANK {}] === STEP {} === Task: {}", + my_rank, step, task_name + ); + + // Skip idle steps + if task_name == "idle" || task_name == "..." { + // Idle ranks still participate in MPI collective operations + continue; + } + + // Wait for dependencies before proceeding + if let Some(deps) = task_dependencies.get(task_name) { + wait_for_dependencies(task_name, deps, my_rank); + } + + // Find template index + let template_idx = match task_mapping.get(task_name) { + Some(&idx) => idx, + None => { + eprintln!("[RANK {}] Unknown task: {}", my_rank, task_name); + continue; + } + }; + + if template_idx >= computation_graph.proof_templates().len() { + eprintln!( + "[RANK {}] Invalid template index: {}", + my_rank, template_idx + ); + continue; + } + + // Get pre-created MPI config for this task + let local_mpi_config = task_mpi_configs.get(task_name).and_then(|c| c.clone()); + + // Check if I'm a participant (have valid MPI config or solo task) + let all_peers = schedule.find_all_peers_for_task(task_name); + let i_am_participant = all_peers.contains(&my_rank); + + if !i_am_participant { + eprintln!("[RANK {}] Not participating in task {}", my_rank, task_name); + continue; + } + + eprintln!( + "[RANK {}] Task {} peers: {:?} (using pre-created MPI subgroup)", + my_rank, task_name, all_peers + ); + + let template = &computation_graph.proof_templates()[template_idx]; + + // Safety check: verify all commitment indices are in bounds + let commit_indices = template.commitment_indices(); + let mut has_error = false; + for &idx in commit_indices { + if idx >= values.len() { + eprintln!( + "[RANK {}] ERROR: Template {} requires value index {} but values.len() = {}", + my_rank, + template_idx, + idx, + values.len() + ); + has_error = true; + } + } + if has_error { + eprintln!( + "[RANK {}] Skipping task {} due to index out of bounds", + my_rank, task_name + ); + continue; + } + + let commitment_values = template + .commitment_indices() + .iter() + .map(|&idx| values[idx].as_ref()) + .collect::>(); + + // Execute GKR + let single_kernel_gkr_timer = Timer::new( + &format!("Task {} GKR", task_name), + local_mpi_config + .as_ref() + .map(|c| c.is_root()) + .unwrap_or(true), + ); + + let gkr_end_state = if let Some(ref local_config) = local_mpi_config { + eprintln!( + "[RANK {}] Executing task {} with {} peers (local_rank={}, group_size={})", + my_rank, + task_name, + all_peers.len(), + local_config.world_rank(), + local_config.world_size() + ); + + prove_kernel_gkr_no_oversubscribe::, GetTranscript, ZC::ECCConfig>( + local_config, + &computation_graph.kernels()[template.kernel_id()], + &commitment_values, + next_power_of_two(template.parallel_count()), + template.is_broadcast(), + n_bytes_profiler, + ) + } else { + eprintln!( + "[RANK {}] Executing task {} solo (creating singl + e-rank MPI config)", + my_rank, task_name + ); + + // Create a single-rank MPI config for solo tasks + // We use the pre-created config from task_mpi_configs + let solo_config = task_mpi_configs.get(task_name).and_then(|c| c.as_ref()); + + if let Some(config) = solo_config { + prove_kernel_gkr_no_oversubscribe::< + GetFieldConfig, + GetTranscript, + ZC::ECCConfig, + >( + config, + &computation_graph.kernels()[template.kernel_id()], + &commitment_values, + next_power_of_two(template.parallel_count()), + template.is_broadcast(), + n_bytes_profiler, + ) + } else { + // Fallback: use global MPI config for solo task + prove_kernel_gkr_no_oversubscribe::< + GetFieldConfig, + GetTranscript, + ZC::ECCConfig, + >( + global_mpi_config, + &computation_graph.kernels()[template.kernel_id()], + &commitment_values, + next_power_of_two(template.parallel_count()), + template.is_broadcast(), + n_bytes_profiler, + ) + } + }; + + single_kernel_gkr_timer.stop(); + + // PCS opening + if let Some((mut transcript, challenge)) = gkr_end_state { + let is_subgroup_root = local_mpi_config + .as_ref() + .map(|c| c.is_root()) + .unwrap_or(true); + + if is_subgroup_root { + eprintln!( + "[RANK {}] I am subgroup root for task {}", + my_rank, task_name + ); + i_am_subgroup_root_for_tasks.push(template_idx); + + match ZC::BATCH_PCS { + true => { + assert!(challenge.challenge_y().is_none()); + let challenge_x = challenge.challenge_x(); + + let (local_vals_ref, local_challenges) = extract_pcs_claims::( + &commitment_values, + &challenge_x, + template.is_broadcast(), + next_power_of_two(template.parallel_count()), + ); + + // Store in indexed structure to maintain template order + let owned_vals: Vec> = + local_vals_ref.iter().map(|v| v.to_vec()).collect(); + vals_per_template[template_idx] = Some(owned_vals); + challenges_per_template[template_idx] = Some(local_challenges); + + all_proofs[template_idx] = Some(ExpanderProof { + data: vec![transcript.finalize_and_get_proof()], + }); + } + false => { + let pcs_open_timer = Timer::new(&format!("Task {} PCS", task_name), true); + let challenge_list = if let Some(challenge_y) = challenge.challenge_y() { + vec![challenge.challenge_x(), challenge_y] + } else { + vec![challenge.challenge_x()] + }; + + challenge_list.iter().for_each(|c| { + partition_single_gkr_claim_and_open_pcs_mpi::( + prover_setup, + &commitment_values, + &template + .commitment_indices() + .iter() + .map(|&idx| &states.as_ref().unwrap()[idx]) + .collect::>(), + c, + template.is_broadcast(), + &mut transcript, + ); + }); + + pcs_open_timer.stop(); + all_proofs[template_idx] = Some(ExpanderProof { + data: vec![transcript.finalize_and_get_proof()], + }); + } + } + } + } + + // Mark task as completed (file-based synchronization) + mark_task_completed(task_name, my_rank, &all_peers); + } + + // ========== CRITICAL: Global barrier ========== + // Wait for all ranks to complete all their tasks before proceeding to PCS opening + eprintln!( + "[RANK {}] All my tasks completed, waiting for other ranks...", + my_rank + ); + global_mpi_config.barrier(); + eprintln!( + "[RANK {}] All ranks ready, proceeding to result collection", + my_rank + ); + + // ========== MPI Result Collection (for BATCH_PCS mode) ========== + // Collect vals_ref and challenges from all subgroup roots to global root + if ZC::BATCH_PCS { + use mpi::traits::*; + use serdes::ExpSerde; + + let i_am_subgroup_root = !i_am_subgroup_root_for_tasks.is_empty(); + + eprintln!( + "[RANK {}] Am I subgroup root? {} (for {} tasks)", + my_rank, + i_am_subgroup_root, + i_am_subgroup_root_for_tasks.len() + ); + + // Step 0: All ranks send a flag to root indicating if they are subgroup roots + let my_flag = if i_am_subgroup_root { 1u8 } else { 0u8 }; + + let all_flags = if global_mpi_config.is_root() { + // Root collects flags from all ranks + let mut flags = vec![0u8; global_mpi_config.world_size()]; + unsafe { + global_mpi_config + .world + .unwrap() + .all_gather_into(&my_flag, &mut flags[..]); + } + Some(flags) + } else { + unsafe { + let mut flags = vec![0u8; global_mpi_config.world_size()]; + global_mpi_config + .world + .unwrap() + .all_gather_into(&my_flag, &mut flags[..]); + } + None + }; + + // Step 1: Non-root subgroup roots send their results to rank 0 + if i_am_subgroup_root && my_rank != 0 { + eprintln!( + "[RANK {}] Sending my results to global root (indexed by template)", + my_rank + ); + + // Serialize the indexed structures (maintains template order) + let mut vals_bytes = Vec::new(); + vals_per_template.serialize_into(&mut vals_bytes).unwrap(); + + let mut challenges_bytes = Vec::new(); + challenges_per_template + .serialize_into(&mut challenges_bytes) + .unwrap(); + + let mut proofs_bytes = Vec::new(); + all_proofs.serialize_into(&mut proofs_bytes).unwrap(); + + // Send sizes first + let sizes = [vals_bytes.len(), challenges_bytes.len(), proofs_bytes.len()]; + + unsafe { + global_mpi_config + .world + .unwrap() + .process_at_rank(0) + .synchronous_send(&sizes[..]); + + global_mpi_config + .world + .unwrap() + .process_at_rank(0) + .synchronous_send(&vals_bytes[..]); + + global_mpi_config + .world + .unwrap() + .process_at_rank(0) + .synchronous_send(&challenges_bytes[..]); + + global_mpi_config + .world + .unwrap() + .process_at_rank(0) + .synchronous_send(&proofs_bytes[..]); + } + + eprintln!("[RANK {}] Results sent to global root", my_rank); + } + + // Step 2: Global root receives all results + if global_mpi_config.is_root() { + eprintln!("[RANK 0] Collecting results from all subgroup roots..."); + + let flags = all_flags.unwrap(); + + // Identify which ranks are subgroup roots (have flag=1) + let subgroup_roots: Vec = flags + .iter() + .enumerate() + .filter(|(_, &flag)| flag == 1) + .map(|(rank, _)| rank) + .collect(); + + eprintln!("[RANK 0] Subgroup roots detected: {:?}", subgroup_roots); + + // Receive from each subgroup root (except self) + for &sender_rank in &subgroup_roots { + if sender_rank == 0 { + continue; // Skip self + } + + eprintln!("[RANK 0] Receiving results from rank {}", sender_rank); + + // Receive sizes + let (sizes, _status) = unsafe { + global_mpi_config + .world + .unwrap() + .process_at_rank(sender_rank as i32) + .receive_vec::() + }; + + // Receive vals_per_template (indexed structure) + let (vals_bytes, _) = unsafe { + global_mpi_config + .world + .unwrap() + .process_at_rank(sender_rank as i32) + .receive_vec::() + }; + let received_vals_per_template: Vec>>>> = + Vec::deserialize_from(&mut vals_bytes.as_slice()).unwrap(); + + // Receive challenges_per_template + let (challenges_bytes, _) = unsafe { + global_mpi_config + .world + .unwrap() + .process_at_rank(sender_rank as i32) + .receive_vec::() + }; + let received_challenges_per_template: Vec< + Option>>>, + > = Vec::deserialize_from(&mut challenges_bytes.as_slice()).unwrap(); + + // Receive proofs (indexed by template) + let (proofs_bytes, _) = unsafe { + global_mpi_config + .world + .unwrap() + .process_at_rank(sender_rank as i32) + .receive_vec::() + }; + let received_all_proofs: Vec> = + Vec::deserialize_from(&mut proofs_bytes.as_slice()).unwrap(); + + // Merge indexed data (maintains template order) + for template_idx in 0..num_templates { + // Merge vals + if received_vals_per_template[template_idx].is_some() { + vals_per_template[template_idx] = + received_vals_per_template[template_idx].clone(); + } + + // Merge challenges + if received_challenges_per_template[template_idx].is_some() { + challenges_per_template[template_idx] = + received_challenges_per_template[template_idx].clone(); + } + + // Merge proofs + if received_all_proofs[template_idx].is_some() { + all_proofs[template_idx] = received_all_proofs[template_idx].clone(); + } + } + + let received_count = received_vals_per_template + .iter() + .filter(|v| v.is_some()) + .count(); + eprintln!( + "[RANK 0] Received results from rank {} ({} templates)", + sender_rank, received_count + ); + } + + // Build final vals_ref and challenges in template order + let mut vals_ref_owned: Vec>> = vec![]; + let mut challenges_final = vec![]; + + for template_idx in 0..num_templates { + if let Some(vals) = &vals_per_template[template_idx] { + vals_ref_owned.extend(vals.clone()); + } + if let Some(chals) = &challenges_per_template[template_idx] { + challenges_final.extend(chals.clone()); + } + } + + eprintln!( + "[RANK 0] Result collection complete. Total: {} vals, {} challenges, {} proofs", + vals_ref_owned.len(), + challenges_final.len(), + all_proofs.iter().filter(|p| p.is_some()).count() + ); + + eprintln!("[RANK 0] Templates coverage:"); + for (idx, val) in vals_per_template.iter().enumerate() { + let status = if val.is_some() { "✓" } else { "✗" }; + eprintln!(" Template {}: {}", idx, status); + } + } + } + + // Collect results + match ZC::BATCH_PCS { + true => { + if global_mpi_config.is_root() { + let mut proofs = all_proofs.into_iter().filter_map(|p| p).collect::>(); + + let pcs_opening_timer = Timer::new("Batch PCS Opening for all kernels", true); + + // Build final vals_ref and challenges in template order + let mut vals_ref_owned: Vec>> = vec![]; + let mut challenges_final = vec![]; + + for template_idx in 0..num_templates { + if let Some(vals) = &vals_per_template[template_idx] { + vals_ref_owned.extend(vals.clone()); + } + if let Some(chals) = &challenges_per_template[template_idx] { + challenges_final.extend(chals.clone()); + } + } + + // Convert to references for open_defered_pcs + let vals_ref_for_pcs: Vec<&[SIMDField]> = + vals_ref_owned.iter().map(|v| v.as_slice()).collect(); + + let pcs_batch_opening = open_defered_pcs::( + prover_setup, + &vals_ref_for_pcs, + &challenges_final, + ); + pcs_opening_timer.stop(); + + proofs.push(pcs_batch_opening); + Some(CombinedProof { + commitments: commitments.unwrap(), + proofs, + }) + } else { + None + } + } + false => { + if global_mpi_config.is_root() { + let proofs = all_proofs.into_iter().filter_map(|p| p).collect::>(); + Some(CombinedProof { + commitments: commitments.unwrap(), + proofs, + }) + } else { + None + } + } + } +} diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs index e6eb38d6..a49c5eff 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs @@ -73,6 +73,7 @@ async fn async_main() { } pub fn main() { + println!("Enter expander_server no oversubscribe!"); let stack_size_mb = std::env::var("THREAD_STACK_SIZE_MB") .ok() .and_then(|v| v.parse::().ok()) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs index 42315b39..ea63c5d1 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs @@ -112,7 +112,12 @@ where let mpi_size = if allow_oversubscribe { max_parallel_count } else { - let num_cpus = prev_power_of_two(num_cpus::get_physical()); + // 支持通过环境变量 ZKML_NUM_CPUS 覆盖 CPU 数量(用于 Docker 等环境) + let num_cpus = std::env::var("ZKML_NUM_CPUS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or_else(num_cpus::get_physical); + let num_cpus = prev_power_of_two(num_cpus); if max_parallel_count > num_cpus { num_cpus } else { diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs index b8d6c830..0f4a4d8a 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs @@ -24,8 +24,18 @@ pub fn start_server( fn parse_config(mpi_size: usize) -> (String, String, String, String) where { - let oversubscription = if mpi_size > num_cpus::get_physical() { - println!("Warning: Not enough cores available for the requested number of processes. Using oversubscription."); + // 支持通过环境变量强制启用 oversubscribe(用于 Docker 等 CPU ID 不连续的环境) + let force_oversubscribe = std::env::var("ZKML_FORCE_OVERSUBSCRIBE").is_ok(); + + // 支持通过环境变量 ZKML_NUM_CPUS 覆盖 CPU 数量 + let num_cpus = std::env::var("ZKML_NUM_CPUS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or_else(num_cpus::get_physical); + let oversubscription = if force_oversubscribe || mpi_size > num_cpus { + if mpi_size > num_cpus { + println!("Warning: Not enough cores available for the requested number of processes. Using oversubscription."); + } "--oversubscribe" } else { "" diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs index 5605daf0..600036fd 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs @@ -4,6 +4,8 @@ use gkr_engine::{ ExpanderDualVarChallenge, ExpanderSingleVarChallenge, FieldEngine, GKREngine, MPIConfig, MPIEngine, Transcript, }; +use std::collections::HashMap; +use std::fs; use crate::{ frontend::{Config, SIMDField}, @@ -221,3 +223,381 @@ pub fn partition_single_gkr_claim_and_open_pcs_mpi( ); } } + +// ==================== SCHEDULE-BASED EXECUTION ==================== + +/// Schedule representation: rank -> sequence of tasks +#[derive(Debug, Clone)] +pub struct Schedule { + /// Map from rank to list of task names + /// e.g., rank 0 -> ["Task14", "Task1", "Task12"] + pub rank_tasks: HashMap>, +} + +impl Schedule { + /// Parse schedule from text file + /// Format: "Rank 0: Task14 -> Task1 -> Task12" + pub fn from_file(path: &str) -> Result { + let content = + fs::read_to_string(path).map_err(|e| format!("Failed to read schedule file: {}", e))?; + + let mut rank_tasks = HashMap::new(); + + for line in content.lines() { + let line = line.trim(); + if line.is_empty() { + continue; + } + + // Parse "Rank X: TaskA -> TaskB -> TaskC" + let parts: Vec<&str> = line.split(':').collect(); + if parts.len() != 2 { + return Err(format!("Invalid line format: {}", line)); + } + + // Handle multiple spaces: "Rank 0:" or "Rank 0:" + let rank_part = parts[0].trim(); + if !rank_part.starts_with("Rank") { + return Err(format!("Expected 'Rank X', got: {}", parts[0])); + } + + let rank_str = rank_part + .strip_prefix("Rank") + .ok_or_else(|| format!("Expected 'Rank X', got: {}", parts[0]))? + .trim(); // Trim to handle multiple spaces + + let rank: usize = rank_str + .parse() + .map_err(|e| format!("Invalid rank number '{}': {}", rank_str, e))?; + + // Extract tasks + let tasks_str = parts[1].trim(); + let tasks: Vec = tasks_str + .split("->") + .map(|t| t.trim().to_string()) + .filter(|t| !t.is_empty()) + .collect(); + + rank_tasks.insert(rank, tasks); + } + + Ok(Schedule { rank_tasks }) + } + + /// Get tasks for a specific rank + pub fn get_tasks(&self, rank: usize) -> Option<&Vec> { + self.rank_tasks.get(&rank) + } + + /// Find which ranks are executing the same task at the same step + pub fn find_peers_at_step(&self, step: usize, task_name: &str) -> Vec { + let mut peers = Vec::new(); + + for (rank, tasks) in &self.rank_tasks { + if tasks.len() > step && tasks[step] == task_name { + peers.push(*rank); + } + } + + peers.sort(); + peers + } + + /// Get maximum number of steps across all ranks + pub fn max_steps(&self) -> usize { + self.rank_tasks + .values() + .map(|tasks| tasks.len()) + .max() + .unwrap_or(0) + } +} + +/// Parse task mapping file +/// Format: "Task1: 0" (Task1 maps to template index 0) +pub fn parse_task_mapping(path: &str) -> Result, String> { + let content = + fs::read_to_string(path).map_err(|e| format!("Failed to read task mapping file: {}", e))?; + + let mut mapping = HashMap::new(); + + for line in content.lines() { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { + continue; + } + + let parts: Vec<&str> = line.split(':').collect(); + if parts.len() != 2 { + continue; + } + + let task_name = parts[0].trim().to_string(); + let template_idx: usize = parts[1] + .trim() + .parse() + .map_err(|e| format!("Invalid template index: {}", e))?; + + mapping.insert(task_name, template_idx); + } + + Ok(mapping) +} + +/// Create MPI subgroup for a specific task based on peers +fn create_mpi_subgroup_for_task( + global_mpi_config: &MPIConfig<'static>, + peers: &[usize], + task_name: &str, +) -> Option> { + use mpi::topology::Communicator; + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let my_rank = global_mpi_config.world_rank(); + + // Find my position in peers + let my_position = peers.iter().position(|&r| r == my_rank)?; + + // Use task name hash as color + let mut hasher = DefaultHasher::new(); + task_name.hash(&mut hasher); + let color_value = (hasher.finish() % 10000) as i32; + let color = mpi::topology::Color::with_value(color_value); + + // Split communicator and leak it to get 'static lifetime + let split_comm = unsafe { + global_mpi_config + .world + .unwrap() + .split_by_color_with_key(color, my_position as i32) + }; + + // Leak the communicator to get 'static lifetime + // split_comm is Option, we need to leak the inner value + let split_comm_static: &'static Option<_> = Box::leak(Box::new(split_comm)); + + Some(MPIConfig::prover_new( + global_mpi_config.universe, + split_comm_static.as_ref(), + )) +} + +/// Main prove function with schedule +pub fn mpi_prove_with_schedule( + global_mpi_config: &MPIConfig<'static>, + schedule_path: &str, + task_mapping_path: Option<&str>, // Optional: if None, use template index as task name + prover_setup: &ExpanderProverSetup, + computation_graph: &ComputationGraph, + values: &[impl AsRef<[SIMDField]>], +) -> Option>> +where + C: GKREngine, + ECCConfig: Config, +{ + let my_rank = global_mpi_config.world_rank(); + + // 1. Load schedule + let schedule = match Schedule::from_file(schedule_path) { + Ok(s) => s, + Err(e) => { + eprintln!("[RANK {}] Failed to load schedule: {}", my_rank, e); + return None; + } + }; + + eprintln!( + "[RANK {}] Loaded schedule with {} ranks, max {} steps", + my_rank, + schedule.rank_tasks.len(), + schedule.max_steps() + ); + + // 2. Load task mapping (or use default: TaskX -> template X) + let task_mapping = if let Some(path) = task_mapping_path { + match parse_task_mapping(path) { + Ok(m) => m, + Err(e) => { + eprintln!("[RANK {}] Failed to load task mapping: {}", my_rank, e); + return None; + } + } + } else { + // Default: Task0->0, Task1->1, etc. + let mut default_mapping = HashMap::new(); + for i in 0..computation_graph.proof_templates().len() { + default_mapping.insert(format!("Task{}", i), i); + } + default_mapping + }; + + // 3. Commit phase (only root) + let (commitments, states) = if global_mpi_config.is_root() { + eprintln!("[RANK {}] === COMMIT PHASE ===", my_rank); + let commit_timer = Timer::new("Commit to all input", true); + let (commitments, states) = values + .iter() + .map(|value| { + local_commit_impl::( + prover_setup.p_keys.get(&value.as_ref().len()).unwrap(), + value.as_ref(), + ) + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + commit_timer.stop(); + (Some(commitments), Some(states)) + } else { + (None, None) + }; + + // 4. Get my tasks + let my_tasks = match schedule.get_tasks(my_rank) { + Some(tasks) => tasks, + None => { + eprintln!("[RANK {}] No tasks assigned in schedule", my_rank); + return if my_rank == 0 { + Some(CombinedProof { + commitments: commitments.unwrap(), + proofs: vec![], + }) + } else { + None + }; + } + }; + + eprintln!("[RANK {}] My tasks: {:?}", my_rank, my_tasks); + + // 5. Execute tasks step by step + let mut all_proofs: Vec> = + vec![None; computation_graph.proof_templates().len()]; + + for (step, task_name) in my_tasks.iter().enumerate() { + eprintln!( + "[RANK {}] === STEP {} === Task: {}", + my_rank, step, task_name + ); + + // Find template index + let template_idx = match task_mapping.get(task_name) { + Some(&idx) => idx, + None => { + eprintln!("[RANK {}] Unknown task: {}", my_rank, task_name); + continue; + } + }; + + if template_idx >= computation_graph.proof_templates().len() { + eprintln!( + "[RANK {}] Invalid template index: {}", + my_rank, template_idx + ); + continue; + } + + // Find peers for this task at this step + let peers = schedule.find_peers_at_step(step, task_name); + eprintln!("[RANK {}] Task {} peers: {:?}", my_rank, task_name, peers); + + if peers.is_empty() || !peers.contains(&my_rank) { + eprintln!("[RANK {}] Not participating in task {}", my_rank, task_name); + continue; + } + + let template = &computation_graph.proof_templates()[template_idx]; + let kernel = &computation_graph.kernels()[template.kernel_id()]; + + let commitment_values = template + .commitment_indices() + .iter() + .map(|&idx| values[idx].as_ref()) + .collect::>(); + + // Create MPI subgroup if multiple peers + let local_mpi_config = if peers.len() > 1 { + create_mpi_subgroup_for_task(global_mpi_config, &peers, task_name) + } else { + None + }; + + // Execute GKR + let gkr_result = if let Some(ref local_config) = local_mpi_config { + eprintln!( + "[RANK {}] Executing task {} with {} peers (local_rank={})", + my_rank, + task_name, + peers.len(), + local_config.world_rank() + ); + + prove_kernel_gkr::( + local_config, + kernel, + &commitment_values, + next_power_of_two(template.parallel_count()), + template.is_broadcast(), + ) + } else { + // Single rank task + eprintln!("[RANK {}] Executing task {} solo", my_rank, task_name); + None // Skip for now + }; + + // PCS opening (only subgroup root) + if let Some((mut transcript, challenge)) = gkr_result { + let is_subgroup_root = local_mpi_config + .as_ref() + .map(|c| c.is_root()) + .unwrap_or(true); + + if is_subgroup_root { + eprintln!( + "[RANK {}] Performing PCS opening for task {}", + my_rank, task_name + ); + + let pcs_timer = Timer::new(&format!("PCS for {}", task_name), true); + let challenges = if let Some(challenge_y) = challenge.challenge_y() { + vec![challenge.challenge_x(), challenge_y] + } else { + vec![challenge.challenge_x()] + }; + + challenges.iter().for_each(|c| { + partition_single_gkr_claim_and_open_pcs_mpi::( + prover_setup, + &commitment_values, + &template + .commitment_indices() + .iter() + .map(|&idx| &states.as_ref().unwrap()[idx]) + .collect::>(), + c, + template.is_broadcast(), + &mut transcript, + ); + }); + + pcs_timer.stop(); + + all_proofs[template_idx] = Some(ExpanderProof { + data: vec![transcript.finalize_and_get_proof()], + }); + } + } + } + + // 6. Collect results (only root) + if global_mpi_config.is_root() { + let proofs = all_proofs.into_iter().filter_map(|p| p).collect::>(); + eprintln!("[RANK {}] Collected {} proofs", my_rank, proofs.len()); + + Some(CombinedProof { + commitments: commitments.unwrap(), + proofs, + }) + } else { + None + } +} diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs index 27919a50..0bde9127 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs @@ -27,6 +27,57 @@ use std::sync::Arc; use std::sync::Mutex as SyncMutex; use tokio::sync::{oneshot, Mutex}; +/// 获取所有 expander_server 进程的内存占用(单位:MB) +/// 返回 (VmRSS物理内存, VmSize虚拟内存) +fn get_total_expander_memory_mb() -> (usize, usize) { + use std::fs; + use std::io::{BufRead, BufReader}; + + let mut total_rss_kb = 0usize; + let mut total_vmsize_kb = 0usize; + + // 遍历 /proc 目录 + if let Ok(entries) = fs::read_dir("/proc") { + for entry in entries.flatten() { + if let Ok(file_name) = entry.file_name().into_string() { + // 只处理数字目录(进程PID) + if file_name.chars().all(|c| c.is_ascii_digit()) { + // 读取 /proc/[pid]/comm 检查进程名 + let comm_path = format!("/proc/{}/comm", file_name); + if let Ok(comm) = fs::read_to_string(&comm_path) { + if comm.trim() == "expander_server" { + // 读取 /proc/[pid]/status 获取内存信息 + let status_path = format!("/proc/{}/status", file_name); + if let Ok(file) = fs::File::open(&status_path) { + let reader = BufReader::new(file); + for line in reader.lines().flatten() { + if line.starts_with("VmRSS:") { + // VmRSS: 12345 kB (物理内存) + if let Some(rss_str) = line.split_whitespace().nth(1) { + if let Ok(rss_kb) = rss_str.parse::() { + total_rss_kb += rss_kb; + } + } + } else if line.starts_with("VmSize:") { + // VmSize: 12345 kB (虚拟内存) + if let Some(size_str) = line.split_whitespace().nth(1) { + if let Ok(size_kb) = size_str.parse::() { + total_vmsize_kb += size_kb; + } + } + } + } + } + } + } + } + } + } + } + + (total_rss_kb / 1024, total_vmsize_kb / 1024) // 转换为MB +} + pub static SERVER_IP: &str = "127.0.0.1"; pub static SERVER_PORT: Lazy> = Lazy::new(|| SyncMutex::new(3000)); @@ -140,11 +191,21 @@ where setup_timer.stop(); } RequestType::Prove => { - println!("Received prove request"); + let (rss_start, vmsize_start) = get_total_expander_memory_mb(); + println!( + "[MPI Rank {}] Received prove request, MEMORY = {} MB (RSS), {} MB (VmSize)", + state.global_mpi_config.world_rank(), + rss_start, + vmsize_start + ); // Handle proving logic here let prove_timer = Timer::new("server prove", true); let _ = broadcast_request_type(&state.global_mpi_config, 2); + println!( + "[MPI Rank {}] Acquiring witness lock...", + state.global_mpi_config.world_rank() + ); let mut witness = state.witness.lock().await; let mut witness_win = state.wt_shared_memory_win.lock().await; S::setup_shared_witness(&state.global_mpi_config, &mut witness, &mut witness_win); @@ -161,6 +222,10 @@ where SharedMemoryEngine::write_proof_to_shared_memory(proof.as_ref().unwrap()); prove_timer.stop(); + + let (rss_end, vmsize_end) = get_total_expander_memory_mb(); + println!("[MPI Rank {}] Prove request done - witness lock will be released, but witness remains in state.witness, MEMORY = {} MB (RSS), {} MB (VmSize)", + state.global_mpi_config.world_rank(), rss_end, vmsize_end); } RequestType::Exit => { println!("Received exit request, shutting down server"); @@ -209,6 +274,14 @@ pub async fn worker_main( } 2 => { // Prove + let (rss_start, vmsize_start) = get_total_expander_memory_mb(); + println!("[MPI Rank {}] Worker received prove broadcast, MEMORY = {} MB (RSS), {} MB (VmSize)", + state.global_mpi_config.world_rank(), rss_start, vmsize_start); + + println!( + "[MPI Rank {}] Worker acquiring witness lock...", + state.global_mpi_config.world_rank() + ); let mut witness = state.witness.lock().await; let mut witness_win = state.wt_shared_memory_win.lock().await; S::setup_shared_witness(&state.global_mpi_config, &mut witness, &mut witness_win); @@ -222,6 +295,10 @@ pub async fn worker_main( &witness, ); assert!(proof.is_none()); + + let (rss_end, vmsize_end) = get_total_expander_memory_mb(); + println!("[MPI Rank {}] Worker prove done - witness lock will be released, but witness remains in state.witness, MEMORY = {} MB (RSS), {} MB (VmSize)", + state.global_mpi_config.world_rank(), rss_end, vmsize_end); } 255 => { break; @@ -276,12 +353,25 @@ where S: ServerFns + 'static, { + use std::time::Instant; + + let serve_start = Instant::now(); + println!("[TIMING] serve() START"); + + let step_start = Instant::now(); let global_mpi_config = unsafe { UNIVERSE = MPIConfig::init(); GLOBAL_COMMUNICATOR = UNIVERSE.as_ref().map(|u| u.world()); MPIConfig::prover_new(UNIVERSE.as_ref(), GLOBAL_COMMUNICATOR.as_ref()) }; + let rank = global_mpi_config.world_rank(); + println!( + "[TIMING Rank {}] MPI initialization took {:.3}s", + rank, + step_start.elapsed().as_secs_f64() + ); + let step_start = Instant::now(); let state = ServerState { lock: Arc::new(Mutex::new(())), global_mpi_config: global_mpi_config.clone(), @@ -294,24 +384,67 @@ where wt_shared_memory_win: Arc::new(Mutex::new(None)), shutdown_tx: Arc::new(Mutex::new(None)), }; + println!( + "[TIMING Rank {}] ServerState creation took {:.3}s", + rank, + step_start.elapsed().as_secs_f64() + ); if global_mpi_config.is_root() { + println!( + "[TIMING Rank {}] Root process: setting up HTTP server", + rank + ); + + let step_start = Instant::now(); let (tx, rx) = oneshot::channel::<()>(); state.shutdown_tx.lock().await.replace(tx); + println!( + "[TIMING Rank {}] Shutdown channel setup took {:.3}s", + rank, + step_start.elapsed().as_secs_f64() + ); + let step_start = Instant::now(); let app = Router::new() .route("/", post(root_main::)) .route("/", get(|| async { "Expander Server is running" })) .with_state(state.clone()); + println!( + "[TIMING Rank {}] Router creation took {:.3}s", + rank, + step_start.elapsed().as_secs_f64() + ); + let step_start = Instant::now(); let ip: IpAddr = SERVER_IP.parse().expect("Invalid SERVER_IP"); let port_val = port_number.parse::().unwrap_or_else(|e| { eprintln!("Error: Invalid port number '{port_number}'. {e}."); std::process::exit(1); }); let addr = SocketAddr::new(ip, port_val); + println!( + "[TIMING Rank {}] Address parsing took {:.3}s", + rank, + step_start.elapsed().as_secs_f64() + ); + + let step_start = Instant::now(); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); + println!( + "[TIMING Rank {}] TCP listener bind took {:.3}s", + rank, + step_start.elapsed().as_secs_f64() + ); + println!("Server running at http://{addr}"); + println!( + "[TIMING Rank {}] Total startup time: {:.3}s", + rank, + serve_start.elapsed().as_secs_f64() + ); + + let step_start = Instant::now(); axum::serve(listener, app.into_make_service()) .with_graceful_shutdown(async { rx.await.ok(); @@ -319,8 +452,14 @@ where }) .await .unwrap(); + println!( + "[TIMING Rank {}] Server shutdown after {:.3}s of running", + rank, + step_start.elapsed().as_secs_f64() + ); // it might need some time for the server to properly shutdown + let step_start = Instant::now(); loop { match Arc::strong_count(&state.computation_graph) { 1 => { @@ -332,10 +471,26 @@ where } } } + println!( + "[TIMING Rank {}] Waiting for clean shutdown took {:.3}s", + rank, + step_start.elapsed().as_secs_f64() + ); } else { + println!( + "[TIMING Rank {}] Worker process: entering worker_main", + rank + ); + let step_start = Instant::now(); worker_main::(global_mpi_config, state.clone()).await; + println!( + "[TIMING Rank {}] Worker finished after {:.3}s", + rank, + step_start.elapsed().as_secs_f64() + ); } + let step_start = Instant::now(); match ( Arc::try_unwrap(state.computation_graph), Arc::try_unwrap(state.witness), @@ -355,12 +510,29 @@ where panic!("Failed to unwrap Arc, multiple references exist"); } } + println!( + "[TIMING Rank {}] Shared memory cleanup took {:.3}s", + rank, + step_start.elapsed().as_secs_f64() + ); if state.global_mpi_config.is_root() { println!("Server has been shut down."); } + let step_start = Instant::now(); unsafe { mpi::ffi::MPI_Finalize() }; + println!( + "[TIMING Rank {}] MPI_Finalize took {:.3}s", + rank, + step_start.elapsed().as_secs_f64() + ); + + println!( + "[TIMING Rank {}] serve() TOTAL TIME: {:.3}s", + rank, + serve_start.elapsed().as_secs_f64() + ); } #[derive(Parser, Debug)] diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs index d3e326d9..47c095d0 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs @@ -19,6 +19,57 @@ use crate::{ }, }; +/// 获取所有 expander_server 进程的内存占用(单位:MB) +/// 返回 (VmRSS物理内存, VmSize虚拟内存) +fn get_total_expander_memory_mb() -> (usize, usize) { + use std::fs; + use std::io::{BufRead, BufReader}; + + let mut total_rss_kb = 0usize; + let mut total_vmsize_kb = 0usize; + + // 遍历 /proc 目录 + if let Ok(entries) = fs::read_dir("/proc") { + for entry in entries.flatten() { + if let Ok(file_name) = entry.file_name().into_string() { + // 只处理数字目录(进程PID) + if file_name.chars().all(|c| c.is_ascii_digit()) { + // 读取 /proc/[pid]/comm 检查进程名 + let comm_path = format!("/proc/{}/comm", file_name); + if let Ok(comm) = fs::read_to_string(&comm_path) { + if comm.trim() == "expander_server" { + // 读取 /proc/[pid]/status 获取内存信息 + let status_path = format!("/proc/{}/status", file_name); + if let Ok(file) = fs::File::open(&status_path) { + let reader = BufReader::new(file); + for line in reader.lines().flatten() { + if line.starts_with("VmRSS:") { + // VmRSS: 12345 kB (物理内存) + if let Some(rss_str) = line.split_whitespace().nth(1) { + if let Ok(rss_kb) = rss_str.parse::() { + total_rss_kb += rss_kb; + } + } + } else if line.starts_with("VmSize:") { + // VmSize: 12345 kB (虚拟内存) + if let Some(size_str) = line.split_whitespace().nth(1) { + if let Ok(size_kb) = size_str.parse::() { + total_vmsize_kb += size_kb; + } + } + } + } + } + } + } + } + } + } + } + + (total_rss_kb / 1024, total_vmsize_kb / 1024) // 转换为MB +} + pub trait ServerFns where C: gkr_engine::GKREngine, @@ -45,6 +96,10 @@ where witness_target: &mut Vec>>, mpi_shared_memory_win: &mut Option, ) { + let (rss_start, vmsize_start) = get_total_expander_memory_mb(); + println!("[MPI Rank {}] setup_shared_witness: START - disposing old witness, MEMORY = {} MB (RSS), {} MB (VmSize)", + global_mpi_config.world_rank(), rss_start, vmsize_start); + // dispose of the previous shared memory if it exists while let Some(w) = witness_target.pop() { w.discard_control_of_shared_mem(); @@ -55,6 +110,10 @@ where global_mpi_config.free_shared_mem(&mut win_wrapper.win); } + let (rss_after_dispose, vmsize_after_dispose) = get_total_expander_memory_mb(); + println!("[MPI Rank {}] setup_shared_witness: Old witness disposed, MEMORY = {} MB (RSS), {} MB (VmSize), calling read_shared_witness_from_shared_memory", + global_mpi_config.world_rank(), rss_after_dispose, vmsize_after_dispose); + // Allocate new shared memory for the witness let (witness_v, wt_shared_memory_win) = SharedMemoryEngine::read_shared_witness_from_shared_memory::( @@ -62,6 +121,10 @@ where ); *witness_target = witness_v; *mpi_shared_memory_win = Some(wt_shared_memory_win); + + let (rss_end, vmsize_end) = get_total_expander_memory_mb(); + println!("[MPI Rank {}] setup_shared_witness: DONE - witness loaded into local memory, MEMORY = {} MB (RSS), {} MB (VmSize)", + global_mpi_config.world_rank(), rss_end, vmsize_end); } fn shared_memory_clean_up( @@ -123,7 +186,17 @@ where C: GKREngine, ECCConfig: Config, { - mpi_prove_impl(global_mpi_config, prover_setup, computation_graph, values) + let (rss_start, vmsize_start) = get_total_expander_memory_mb(); + println!("[MPI Rank {}] prove_request_handler: START - witness is being used for proving, MEMORY = {} MB (RSS), {} MB (VmSize)", + global_mpi_config.world_rank(), rss_start, vmsize_start); + + let proof = mpi_prove_impl(global_mpi_config, prover_setup, computation_graph, values); + + let (rss_end, vmsize_end) = get_total_expander_memory_mb(); + println!("[MPI Rank {}] prove_request_handler: DONE - witness is still in memory but no longer actively used, MEMORY = {} MB (RSS), {} MB (VmSize)", + global_mpi_config.world_rank(), rss_end, vmsize_end); + + proof } } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs index 648f33a8..d01bb253 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs @@ -9,6 +9,57 @@ use shared_memory::{Shmem, ShmemConf}; use crate::circuit::config::Config; +/// 获取所有 expander_server 进程的内存占用(单位:MB) +/// 返回 (VmRSS物理内存, VmSize虚拟内存) +fn get_total_expander_memory_mb() -> (usize, usize) { + use std::fs; + use std::io::{BufRead, BufReader}; + + let mut total_rss_kb = 0usize; + let mut total_vmsize_kb = 0usize; + + // 遍历 /proc 目录 + if let Ok(entries) = fs::read_dir("/proc") { + for entry in entries.flatten() { + if let Ok(file_name) = entry.file_name().into_string() { + // 只处理数字目录(进程PID) + if file_name.chars().all(|c| c.is_ascii_digit()) { + // 读取 /proc/[pid]/comm 检查进程名 + let comm_path = format!("/proc/{}/comm", file_name); + if let Ok(comm) = fs::read_to_string(&comm_path) { + if comm.trim() == "expander_server" { + // 读取 /proc/[pid]/status 获取内存信息 + let status_path = format!("/proc/{}/status", file_name); + if let Ok(file) = fs::File::open(&status_path) { + let reader = BufReader::new(file); + for line in reader.lines().flatten() { + if line.starts_with("VmRSS:") { + // VmRSS: 12345 kB (物理内存) + if let Some(rss_str) = line.split_whitespace().nth(1) { + if let Ok(rss_kb) = rss_str.parse::() { + total_rss_kb += rss_kb; + } + } + } else if line.starts_with("VmSize:") { + // VmSize: 12345 kB (虚拟内存) + if let Some(size_str) = line.split_whitespace().nth(1) { + if let Ok(size_kb) = size_str.parse::() { + total_vmsize_kb += size_kb; + } + } + } + } + } + } + } + } + } + } + } + + (total_rss_kb / 1024, total_vmsize_kb / 1024) // 转换为MB +} + use crate::zkcuda::proving_system::expander::structs::{ ExpanderProverSetup, ExpanderVerifierSetup, }; @@ -177,6 +228,12 @@ impl SharedMemoryEngine { pub fn read_shared_witness_from_shared_memory( global_mpi_config: &MPIConfig<'static>, ) -> (Vec>, SharedMemoryWINWrapper) { + use std::time::Instant; + + let (rss_before, vmsize_before) = get_total_expander_memory_mb(); + // 打印关键信息:进程rank和witness长度 + println!("[MPI Rank {}] read_shared_witness_from_shared_memory: MEMORY_BEFORE = {} MB (RSS), {} MB (VmSize)", + global_mpi_config.world_rank(), rss_before, vmsize_before); let (mut mpi_shared_mem_ptr, mem_win) = if global_mpi_config.is_root() { let witness = Self::read_witness_from_shared_memory::(); let bytes_size = std::mem::size_of::() @@ -196,10 +253,158 @@ impl SharedMemoryEngine { global_mpi_config.barrier(); + // ⏸️ 等待检查点:等待 /tmp/continue_witness_test 文件出现才继续 + let checkpoint_file = "/tmp/continue_witness_test1"; + println!( + "[MPI Rank {}] ⏸️ CHECKPOINT: Waiting for file '{}' to continue...", + global_mpi_config.world_rank(), + checkpoint_file + ); + println!("[MPI Rank {}] ⏸️ You can now check memory usage. Create the file to continue: touch {}", + global_mpi_config.world_rank(), checkpoint_file); + + let mut check_count = 0; + loop { + if std::path::Path::new(checkpoint_file).exists() { + println!( + "[MPI Rank {}] ✅ Checkpoint file detected, continuing execution", + global_mpi_config.world_rank() + ); + break; + } + + check_count += 1; + std::thread::sleep(std::time::Duration::from_millis(500)); + } + // ⏱️ 开始计时:测量从共享内存读取witness的耗时 + let read_start = Instant::now(); + let n_witness = usize::new_from_memory(&mut mpi_shared_mem_ptr); + let read_n_witness_duration = read_start.elapsed(); + + println!( + "[MPI Rank {}] ⏱️ Read n_witness={} took {:.3} µs", + global_mpi_config.world_rank(), + n_witness, + read_n_witness_duration.as_micros() + ); + + let witness_read_start = Instant::now(); let witness = (0..n_witness) .map(|_| Vec::::new_from_memory(&mut mpi_shared_mem_ptr)) .collect::>(); + let witness_read_duration = witness_read_start.elapsed(); + + println!("[MPI Rank {}] ⏱️ Read {} witness components from shared memory took {:.3} ms ({:.3} µs)", + global_mpi_config.world_rank(), + n_witness, + witness_read_duration.as_secs_f64() * 1000.0, + witness_read_duration.as_micros()); + + let (rss_after, vmsize_after) = get_total_expander_memory_mb(); + + // 打印每个witness component的大小 + let total_elements: usize = witness.iter().map(|v| v.len()).sum(); + let total_bytes: usize = witness + .iter() + .map(|v| v.len() * std::mem::size_of_val(&v[0])) + .sum(); + let rss_increase = rss_after.saturating_sub(rss_before); + let vmsize_increase = vmsize_after.saturating_sub(vmsize_before); + println!("[MPI Rank {}] Copied witness to local memory: {} components, {} total elements, ~{} MB witness data", + global_mpi_config.world_rank(), + witness.len(), + total_elements, + total_bytes / 1024 / 1024); + println!( + "[MPI Rank {}] MEMORY_AFTER_COPY: RSS = {} MB (+{} MB), VmSize = {} MB (+{} MB)", + global_mpi_config.world_rank(), + rss_after, + rss_increase, + vmsize_after, + vmsize_increase + ); + + // ⏸️ 等待检查点:等待 /tmp/continue_witness_test 文件出现才继续 + let checkpoint_file = "/tmp/continue_witness_test"; + println!( + "[MPI Rank {}] ⏸️ CHECKPOINT: Waiting for file '{}' to continue...", + global_mpi_config.world_rank(), + checkpoint_file + ); + println!("[MPI Rank {}] ⏸️ You can now check memory usage. Create the file to continue: touch {}", + global_mpi_config.world_rank(), checkpoint_file); + + let mut check_count = 0; + loop { + if std::path::Path::new(checkpoint_file).exists() { + println!( + "[MPI Rank {}] ✅ Checkpoint file detected, continuing execution", + global_mpi_config.world_rank() + ); + break; + } + + // 每10次检查打印一次内存状态(避免日志过多) + if check_count % 10 == 0 { + let (rss, vmsize) = get_total_expander_memory_mb(); + println!( + "[MPI Rank {}] ⏳ Still waiting... (check #{}, RSS = {} MB, VmSize = {} MB)", + global_mpi_config.world_rank(), + check_count, + rss, + vmsize + ); + } + + check_count += 1; + std::thread::sleep(std::time::Duration::from_millis(500)); + } + + // 🔥 主动访问witness数据,强制触发物理页分配 + println!("[MPI Rank {}] 🔥 Now actively accessing witness data to trigger physical page allocation...", + global_mpi_config.world_rank()); + + let access_start = Instant::now(); + + // 遍历所有witness数据,真正读取每个元素的字节 + let mut dummy_sum = 0u64; + for component in witness.iter() { + // 将Vec转为字节切片,确保访问实际内存 + let bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + component.as_ptr() as *const u8, + component.len() * std::mem::size_of::(), + ) + }; + + // 每隔4KB(页面大小)读取一个字节,确保触碰所有页面 + for i in (0..bytes.len()).step_by(4096) { + unsafe { + // 使用read_volatile防止编译器优化 + dummy_sum = dummy_sum.wrapping_add(std::ptr::read_volatile(&bytes[i]) as u64); + } + } + } + + let access_duration = access_start.elapsed(); + println!( + "[MPI Rank {}] 🔥 Finished accessing witness data (dummy_sum = {}, took {:.3}s)", + global_mpi_config.world_rank(), + dummy_sum, + access_duration.as_secs_f64() + ); + + // 再次测量内存,看是否因为访问而增长 + let (rss_after_access, vmsize_after_access) = get_total_expander_memory_mb(); + let rss_increase_by_access = rss_after_access.saturating_sub(rss_after); + println!( + "[MPI Rank {}] 📊 MEMORY_AFTER_ACCESS: RSS = {} MB (+{} MB from copy), VmSize = {} MB", + global_mpi_config.world_rank(), + rss_after_access, + rss_increase_by_access, + vmsize_after_access + ); (witness, SharedMemoryWINWrapper { win: mem_win }) } From c1fe2a7dadbb2fafb53361a60a267cf428ff2181 Mon Sep 17 00:00:00 2001 From: hczphn Date: Tue, 20 Jan 2026 02:14:42 +0000 Subject: [PATCH 03/10] remove debug-level print info --- expander_compiler/src/zkcuda/cpu_monitor.rs | 196 ------------------ expander_compiler/src/zkcuda/mod.rs | 1 - .../proving_system/expander/prove_impl.rs | 169 +-------------- .../expander_no_oversubscribe/prove_impl.rs | 81 +------- .../expander_no_oversubscribe/server_bin.rs | 1 - .../expander_parallelized/server_ctrl.rs | 178 +--------------- .../expander_parallelized/server_fns.rs | 75 +------ .../shared_memory_utils.rs | 153 -------------- 8 files changed, 7 insertions(+), 847 deletions(-) delete mode 100644 expander_compiler/src/zkcuda/cpu_monitor.rs diff --git a/expander_compiler/src/zkcuda/cpu_monitor.rs b/expander_compiler/src/zkcuda/cpu_monitor.rs deleted file mode 100644 index b02c26e3..00000000 --- a/expander_compiler/src/zkcuda/cpu_monitor.rs +++ /dev/null @@ -1,196 +0,0 @@ -/// CPU使用率监控模块 -/// 用于验证commit和PCS opening是否真正用满所有CPU核心 -use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, -}; -use std::thread; -use std::time::Duration; - -pub struct CpuMonitor { - stop_flag: Arc, - handle: Option>, -} - -impl CpuMonitor { - /// 开始监控CPU使用率 - /// interval_ms: 采样间隔(毫秒) - pub fn start(tag: &str, interval_ms: u64) -> Self { - let stop_flag = Arc::new(AtomicBool::new(false)); - let stop_flag_clone = stop_flag.clone(); - let tag = tag.to_string(); - - let handle = thread::spawn(move || { - let mut prev_stats = get_cpu_stats(); - - while !stop_flag_clone.load(Ordering::Relaxed) { - thread::sleep(Duration::from_millis(interval_ms)); - - let curr_stats = get_cpu_stats(); - if let (Some(ref prev), Some(ref curr)) = (prev_stats.as_ref(), curr_stats.as_ref()) - { - let usage = calculate_cpu_usage(prev, curr); - let num_cpus = num_cpus::get(); - eprintln!( - "[CPU_MONITOR] {} | Total CPUs: {} | Usage: {:.2}% | Active cores estimate: {:.1}", - tag, num_cpus, usage, usage / 100.0 * num_cpus as f64 - ); - } - - prev_stats = curr_stats; - } - }); - - Self { - stop_flag, - handle: Some(handle), - } - } - - /// 停止监控并返回 - pub fn stop(mut self) { - self.stop_flag.store(true, Ordering::Relaxed); - if let Some(handle) = self.handle.take() { - let _ = handle.join(); - } - } -} - -impl Drop for CpuMonitor { - fn drop(&mut self) { - self.stop_flag.store(true, Ordering::Relaxed); - if let Some(handle) = self.handle.take() { - let _ = handle.join(); - } - } -} - -#[derive(Debug, Clone)] -struct CpuStats { - user: u64, - nice: u64, - system: u64, - idle: u64, - iowait: u64, - irq: u64, - softirq: u64, - steal: u64, -} - -#[cfg(target_os = "linux")] -fn get_cpu_stats() -> Option { - use std::fs; - - let content = fs::read_to_string("/proc/stat").ok()?; - let line = content.lines().next()?; - - // 第一行格式: cpu user nice system idle iowait irq softirq steal - let parts: Vec<&str> = line.split_whitespace().collect(); - if parts.len() < 9 || parts[0] != "cpu" { - return None; - } - - Some(CpuStats { - user: parts[1].parse().ok()?, - nice: parts[2].parse().ok()?, - system: parts[3].parse().ok()?, - idle: parts[4].parse().ok()?, - iowait: parts[5].parse().ok()?, - irq: parts[6].parse().ok()?, - softirq: parts[7].parse().ok()?, - steal: parts[8].parse().ok()?, - }) -} - -#[cfg(not(target_os = "linux"))] -fn get_cpu_stats() -> Option { - None -} - -fn calculate_cpu_usage(prev: &CpuStats, curr: &CpuStats) -> f64 { - let prev_idle = prev.idle + prev.iowait; - let curr_idle = curr.idle + curr.iowait; - - let prev_total = prev.user - + prev.nice - + prev.system - + prev.idle - + prev.iowait - + prev.irq - + prev.softirq - + prev.steal; - let curr_total = curr.user - + curr.nice - + curr.system - + curr.idle - + curr.iowait - + curr.irq - + curr.softirq - + curr.steal; - - let total_diff = curr_total.saturating_sub(prev_total); - let idle_diff = curr_idle.saturating_sub(prev_idle); - - if total_diff == 0 { - return 0.0; - } - - (total_diff.saturating_sub(idle_diff) as f64 / total_diff as f64) * 100.0 -} - -/// 简化版本:单次快照CPU使用情况 -pub fn snapshot_cpu_usage(tag: &str) { - let num_cpus = num_cpus::get(); - - #[cfg(target_os = "linux")] - { - if let Some(stats) = get_cpu_stats() { - // 等待一小段时间再采样 - thread::sleep(Duration::from_millis(100)); - if let Some(stats2) = get_cpu_stats() { - let usage = calculate_cpu_usage(&stats, &stats2); - eprintln!( - "[CPU_SNAPSHOT] {} | Total CPUs: {} | Usage: {:.2}% | Estimated active cores: {:.1}", - tag, num_cpus, usage, usage / 100.0 * num_cpus as f64 - ); - return; - } - } - } - - eprintln!( - "[CPU_SNAPSHOT] {} | Total CPUs: {} | (monitoring not available on this platform)", - tag, num_cpus - ); -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_cpu_monitor() { - let monitor = CpuMonitor::start("test", 100); - - // 模拟一些CPU密集型工作 - let handles: Vec<_> = (0..4) - .map(|_| { - thread::spawn(|| { - let mut sum = 0u64; - for i in 0..100_000_000 { - sum = sum.wrapping_add(i); - } - sum - }) - }) - .collect(); - - thread::sleep(Duration::from_secs(2)); - - for h in handles { - let _ = h.join(); - } - - monitor.stop(); - } -} diff --git a/expander_compiler/src/zkcuda/mod.rs b/expander_compiler/src/zkcuda/mod.rs index 072aefe9..41acc8d5 100644 --- a/expander_compiler/src/zkcuda/mod.rs +++ b/expander_compiler/src/zkcuda/mod.rs @@ -1,5 +1,4 @@ pub mod context; -pub mod cpu_monitor; pub mod kernel; pub mod mpi_mem_share; pub mod proving_system; diff --git a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs index 7924c343..c83ced65 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs @@ -7,7 +7,6 @@ use gkr_engine::{ }; use polynomials::RefMultiLinearPoly; use serdes::ExpSerde; -use std::time::Instant; use sumcheck::ProverScratchPad; use crate::{ @@ -18,28 +17,6 @@ use crate::{ }, }; -/// 获取当前进程的内存使用情况 (RSS, 单位: KB) -fn get_memory_usage_kb() -> Option { - #[cfg(target_os = "linux")] - { - if let Ok(content) = std::fs::read_to_string("/proc/self/statm") { - let parts: Vec<&str> = content.split_whitespace().collect(); - if parts.len() >= 2 { - // statm 第二个字段是 RSS (以页为单位) - // Linux 页大小通常是 4KB - if let Ok(rss_pages) = parts[1].parse::() { - return Some(rss_pages * 4); // 转换为 KB - } - } - } - None - } - #[cfg(not(target_os = "linux"))] - { - None - } -} - /// ECCCircuit -> ExpanderCircuit /// Returns an additional prover scratch pad for later use in GKR. pub fn prepare_expander_circuit( @@ -51,156 +28,12 @@ where ECCConfig: Config, ECCConfig::FieldConfig: FieldEngine, { - // 记录开始时间 - let start_time = Instant::now(); - eprintln!("[prepare_expander_circuit] ============== Start =============="); - - // 记录开始时的内存 - let mem_before = get_memory_usage_kb(); - eprintln!( - "[prepare_expander_circuit] Memory before: {:?} KB", - mem_before - ); - - // Step 1: export_to_expander().flatten() - let step1_start = Instant::now(); let mut expander_circuit = kernel.layered_circuit().export_to_expander().flatten(); - let step1_duration = step1_start.elapsed(); - eprintln!( - "[prepare_expander_circuit] Step 1 (export_to_expander + flatten) took: {:.3}s", - step1_duration.as_secs_f64() - ); - - // Step 2: 打印电路大小信息 - let step2_start = Instant::now(); - let num_layers = expander_circuit.layers.len(); - let mut total_gates = 0usize; - let mut total_add_gates = 0usize; - let mut total_mul_gates = 0usize; - let mut total_const_gates = 0usize; - - for layer in expander_circuit.layers.iter() { - total_add_gates += layer.add.len(); - total_mul_gates += layer.mul.len(); - total_const_gates += layer.const_.len(); - total_gates += layer.add.len() + layer.mul.len() + layer.const_.len(); - } - - eprintln!("[prepare_expander_circuit] Circuit stats:"); - eprintln!(" - num_layers: {}", num_layers); - eprintln!(" - total_gates: {}", total_gates); - eprintln!(" - total_add_gates: {}", total_add_gates); - eprintln!(" - total_mul_gates: {}", total_mul_gates); - eprintln!(" - total_const_gates: {}", total_const_gates); - eprintln!(" - log_input_size: {}", expander_circuit.log_input_size()); - let step2_duration = step2_start.elapsed(); - eprintln!( - "[prepare_expander_circuit] Step 2 (circuit stats calculation) took: {:.3}s", - step2_duration.as_secs_f64() - ); - - // 记录 export_to_expander().flatten() 后的内存 - let mem_after_flatten = get_memory_usage_kb(); - eprintln!( - "[prepare_expander_circuit] Memory after flatten: {:?} KB", - mem_after_flatten - ); - if let (Some(before), Some(after)) = (mem_before, mem_after_flatten) { - eprintln!( - "[prepare_expander_circuit] Memory delta (flatten): {} KB ({:.2} MB)", - after as i64 - before as i64, - (after as i64 - before as i64) as f64 / 1024.0 - ); - } - - // Step 3: pre_process_gkr - let step3_start = Instant::now(); expander_circuit.pre_process_gkr(); - let step3_duration = step3_start.elapsed(); - eprintln!( - "[prepare_expander_circuit] Step 3 (pre_process_gkr) took: {:.3}s", - step3_duration.as_secs_f64() - ); - + let (max_num_input_var, max_num_output_var) = super::utils::max_n_vars(&expander_circuit); - eprintln!(" - max_num_input_var: {}", max_num_input_var); - eprintln!(" - max_num_output_var: {}", max_num_output_var); - - // 记录 pre_process_gkr 后的内存 - let mem_after_preprocess = get_memory_usage_kb(); - eprintln!( - "[prepare_expander_circuit] Memory after pre_process_gkr: {:?} KB", - mem_after_preprocess - ); - if let (Some(before), Some(after)) = (mem_after_flatten, mem_after_preprocess) { - eprintln!( - "[prepare_expander_circuit] Memory delta (pre_process_gkr): {} KB ({:.2} MB)", - after as i64 - before as i64, - (after as i64 - before as i64) as f64 / 1024.0 - ); - } - - // Step 4: create ProverScratchPad - let step4_start = Instant::now(); let prover_scratch = ProverScratchPad::::new(max_num_input_var, max_num_output_var, mpi_world_size); - let step4_duration = step4_start.elapsed(); - eprintln!( - "[prepare_expander_circuit] Step 4 (create ProverScratchPad) took: {:.3}s", - step4_duration.as_secs_f64() - ); - - // 记录分配 ProverScratchPad 后的内存 - let mem_after_scratch = get_memory_usage_kb(); - eprintln!( - "[prepare_expander_circuit] Memory after ProverScratchPad: {:?} KB", - mem_after_scratch - ); - if let (Some(before), Some(after)) = (mem_after_preprocess, mem_after_scratch) { - eprintln!( - "[prepare_expander_circuit] Memory delta (ProverScratchPad): {} KB ({:.2} MB)", - after as i64 - before as i64, - (after as i64 - before as i64) as f64 / 1024.0 - ); - } - - // 总内存增量 - if let (Some(before), Some(after)) = (mem_before, mem_after_scratch) { - eprintln!( - "[prepare_expander_circuit] Total memory delta: {} KB ({:.2} MB)", - after as i64 - before as i64, - (after as i64 - before as i64) as f64 / 1024.0 - ); - } - - // 总时间统计 - let total_duration = start_time.elapsed(); - eprintln!("[prepare_expander_circuit] ============== Summary =============="); - eprintln!( - "[prepare_expander_circuit] Step 1 (export + flatten): {:.3}s ({:.1}%)", - step1_duration.as_secs_f64(), - step1_duration.as_secs_f64() / total_duration.as_secs_f64() * 100.0 - ); - eprintln!( - "[prepare_expander_circuit] Step 2 (circuit stats): {:.3}s ({:.1}%)", - step2_duration.as_secs_f64(), - step2_duration.as_secs_f64() / total_duration.as_secs_f64() * 100.0 - ); - eprintln!( - "[prepare_expander_circuit] Step 3 (pre_process_gkr): {:.3}s ({:.1}%)", - step3_duration.as_secs_f64(), - step3_duration.as_secs_f64() / total_duration.as_secs_f64() * 100.0 - ); - eprintln!( - "[prepare_expander_circuit] Step 4 (scratch pad): {:.3}s ({:.1}%)", - step4_duration.as_secs_f64(), - step4_duration.as_secs_f64() / total_duration.as_secs_f64() * 100.0 - ); - eprintln!( - "[prepare_expander_circuit] Total time: {:.3}s", - total_duration.as_secs_f64() - ); - eprintln!("[prepare_expander_circuit] ===================================="); (expander_circuit, prover_scratch) } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index 34a7e415..9b4ef6c4 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -1,4 +1,3 @@ -use crate::zkcuda::cpu_monitor::CpuMonitor; use arith::{Field, Fr, SimdField}; use expander_utils::timer::Timer; use gkr_engine::{ @@ -8,33 +7,6 @@ use gkr_engine::{ use std::collections::HashMap; use std::fs; -/// 获取当前进程的内存使用情况 (RSS, 单位: KB) -fn get_memory_kb() -> u64 { - #[cfg(target_os = "linux")] - { - if let Ok(content) = std::fs::read_to_string("/proc/self/statm") { - let parts: Vec<&str> = content.split_whitespace().collect(); - if parts.len() >= 2 { - if let Ok(rss_pages) = parts[1].parse::() { - return rss_pages * 4; // 页大小 4KB - } - } - } - } - 0 -} - -fn log_memory(rank: usize, tag: &str) { - let mem_kb = get_memory_kb(); - eprintln!( - "[MEM] rank={} {} : {} KB ({:.2} MB)", - rank, - tag, - mem_kb, - mem_kb as f64 / 1024.0 - ); -} - use crate::{ frontend::{Config, SIMDField}, utils::misc::next_power_of_two, @@ -93,16 +65,6 @@ where let commit_timer = Timer::new("Commit to all input", global_mpi_config.is_root()); let (commitments, states) = if global_mpi_config.is_root() { - eprintln!("\n========== COMMIT PHASE START =========="); - eprintln!( - "[RANK {}] Starting commit on {} values", - global_mpi_config.world_rank(), - values.len() - ); - - // 启动CPU监控(每200ms采样一次) - let _cpu_monitor = CpuMonitor::start("COMMIT", 200); - let (commitments, states) = values .iter() .map(|value| match ZC::BATCH_PCS { @@ -117,15 +79,8 @@ where }) .unzip::<_, _, Vec<_>, Vec<_>>(); - // _cpu_monitor在这里自动drop,停止监控 - eprintln!("========== COMMIT PHASE END ==========\n"); - (Some(commitments), Some(states)) } else { - eprintln!( - "[RANK {}] Skipping commit (not root)", - global_mpi_config.world_rank() - ); (None, None) }; commit_timer.stop(); @@ -228,23 +183,13 @@ where true => { if global_mpi_config.is_root() { let mut proofs = proofs.into_iter().map(|p| p.unwrap()).collect::>(); - eprintln!("\n========== PCS OPENING PHASE START =========="); - eprintln!( - "[RANK {}] Starting batch PCS opening for {} values, {} challenges", - global_mpi_config.world_rank(), - vals_ref.len(), - challenges.len() - ); let pcs_opening_timer = Timer::new("Batch PCS Opening for all kernels", true); - // 启动CPU监控 - let _cpu_monitor = CpuMonitor::start("PCS_OPENING", 200); let pcs_batch_opening = open_defered_pcs::( prover_setup, &vals_ref, &challenges, ); pcs_opening_timer.stop(); - eprintln!("========== PCS OPENING PHASE END ==========\n"); proofs.push(pcs_batch_opening); Some(CombinedProof { @@ -413,8 +358,6 @@ where let world_size = mpi_config.world_size(); let n_copies = parallel_count / world_size; - log_memory(world_rank, "prove_kernel_gkr_internal::start"); - let local_commitment_values = get_local_vals_multi_copies( commitments_values, is_broadcast, @@ -422,17 +365,9 @@ where n_copies, parallel_count, ); - log_memory( - world_rank, - "prove_kernel_gkr_internal::after_get_local_vals", - ); let (mut expander_circuit, mut prover_scratch) = prepare_expander_circuit::(kernel, world_size); - log_memory( - world_rank, - "prove_kernel_gkr_internal::after_prepare_expander_circuit", - ); let mut transcript = T::new(); let challenge = prove_gkr_with_local_vals_multi_copies::( @@ -444,12 +379,6 @@ where mpi_config, n_bytes_profiler, ); - log_memory(world_rank, "prove_kernel_gkr_internal::after_prove_gkr"); - - // expander_circuit 和 prover_scratch 在这里被 drop - drop(expander_circuit); - drop(prover_scratch); - log_memory(world_rank, "prove_kernel_gkr_internal::after_drop_circuit"); Some((transcript, challenge)) } @@ -488,8 +417,6 @@ where FieldEngine, T: Transcript, { - let world_rank = mpi_config.world_rank(); - log_memory(world_rank, "prove_gkr::start"); let input_vals_multi_copies = local_commitment_values_multi_copies .iter() @@ -501,7 +428,6 @@ where ) }) .collect::>(); - log_memory(world_rank, "prove_gkr::after_prepare_inputs"); let mut input_vals = vec![FMulti::SimdCircuitField::ZERO; 1 << expander_circuit.log_input_size()]; @@ -514,13 +440,10 @@ where *vals = FMulti::SimdCircuitField::pack(&vals_unpacked); } expander_circuit.layers[0].input_vals = input_vals; - log_memory(world_rank, "prove_gkr::after_set_input_vals"); expander_circuit.fill_rnd_coefs(transcript); - log_memory(world_rank, "prove_gkr::after_fill_rnd_coefs"); expander_circuit.evaluate(); - log_memory(world_rank, "prove_gkr::after_evaluate"); #[cfg(feature = "zkcuda_profile")] { @@ -535,7 +458,6 @@ where let (claimed_v, challenge) = gkr::gkr_prove(expander_circuit, prover_scratch, transcript, mpi_config); - log_memory(world_rank, "prove_gkr::after_gkr_prove"); assert_eq!(claimed_v, FBasic::ChallengeField::from(0u32)); @@ -1025,7 +947,6 @@ where let commit_timer = Timer::new("Commit to all input", global_mpi_config.is_root()); let (commitments, states) = if global_mpi_config.is_root() { eprintln!("[RANK {}] === COMMIT PHASE ===", my_rank); - let _cpu_monitor = CpuMonitor::start("COMMIT", 200); let (commitments, states) = values .iter() @@ -1083,6 +1004,7 @@ where // Execute tasks step by step let mut all_proofs: Vec> = vec![None; computation_graph.proof_templates().len()]; + let prove_timer = Timer::new("Prove all kernels", global_mpi_config.is_root()); for (step, task_name) in my_tasks.iter().enumerate() { eprintln!( @@ -1319,6 +1241,7 @@ where "[RANK {}] All ranks ready, proceeding to result collection", my_rank ); + prove_timer.stop(); // ========== MPI Result Collection (for BATCH_PCS mode) ========== // Collect vals_ref and challenges from all subgroup roots to global root diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs index a49c5eff..e6eb38d6 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs @@ -73,7 +73,6 @@ async fn async_main() { } pub fn main() { - println!("Enter expander_server no oversubscribe!"); let stack_size_mb = std::env::var("THREAD_STACK_SIZE_MB") .ok() .and_then(|v| v.parse::().ok()) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs index 0bde9127..70be640d 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs @@ -27,57 +27,6 @@ use std::sync::Arc; use std::sync::Mutex as SyncMutex; use tokio::sync::{oneshot, Mutex}; -/// 获取所有 expander_server 进程的内存占用(单位:MB) -/// 返回 (VmRSS物理内存, VmSize虚拟内存) -fn get_total_expander_memory_mb() -> (usize, usize) { - use std::fs; - use std::io::{BufRead, BufReader}; - - let mut total_rss_kb = 0usize; - let mut total_vmsize_kb = 0usize; - - // 遍历 /proc 目录 - if let Ok(entries) = fs::read_dir("/proc") { - for entry in entries.flatten() { - if let Ok(file_name) = entry.file_name().into_string() { - // 只处理数字目录(进程PID) - if file_name.chars().all(|c| c.is_ascii_digit()) { - // 读取 /proc/[pid]/comm 检查进程名 - let comm_path = format!("/proc/{}/comm", file_name); - if let Ok(comm) = fs::read_to_string(&comm_path) { - if comm.trim() == "expander_server" { - // 读取 /proc/[pid]/status 获取内存信息 - let status_path = format!("/proc/{}/status", file_name); - if let Ok(file) = fs::File::open(&status_path) { - let reader = BufReader::new(file); - for line in reader.lines().flatten() { - if line.starts_with("VmRSS:") { - // VmRSS: 12345 kB (物理内存) - if let Some(rss_str) = line.split_whitespace().nth(1) { - if let Ok(rss_kb) = rss_str.parse::() { - total_rss_kb += rss_kb; - } - } - } else if line.starts_with("VmSize:") { - // VmSize: 12345 kB (虚拟内存) - if let Some(size_str) = line.split_whitespace().nth(1) { - if let Ok(size_kb) = size_str.parse::() { - total_vmsize_kb += size_kb; - } - } - } - } - } - } - } - } - } - } - } - - (total_rss_kb / 1024, total_vmsize_kb / 1024) // 转换为MB -} - pub static SERVER_IP: &str = "127.0.0.1"; pub static SERVER_PORT: Lazy> = Lazy::new(|| SyncMutex::new(3000)); @@ -191,21 +140,11 @@ where setup_timer.stop(); } RequestType::Prove => { - let (rss_start, vmsize_start) = get_total_expander_memory_mb(); - println!( - "[MPI Rank {}] Received prove request, MEMORY = {} MB (RSS), {} MB (VmSize)", - state.global_mpi_config.world_rank(), - rss_start, - vmsize_start - ); + println!("Received prove request"); // Handle proving logic here let prove_timer = Timer::new("server prove", true); let _ = broadcast_request_type(&state.global_mpi_config, 2); - println!( - "[MPI Rank {}] Acquiring witness lock...", - state.global_mpi_config.world_rank() - ); let mut witness = state.witness.lock().await; let mut witness_win = state.wt_shared_memory_win.lock().await; S::setup_shared_witness(&state.global_mpi_config, &mut witness, &mut witness_win); @@ -222,10 +161,6 @@ where SharedMemoryEngine::write_proof_to_shared_memory(proof.as_ref().unwrap()); prove_timer.stop(); - - let (rss_end, vmsize_end) = get_total_expander_memory_mb(); - println!("[MPI Rank {}] Prove request done - witness lock will be released, but witness remains in state.witness, MEMORY = {} MB (RSS), {} MB (VmSize)", - state.global_mpi_config.world_rank(), rss_end, vmsize_end); } RequestType::Exit => { println!("Received exit request, shutting down server"); @@ -274,14 +209,6 @@ pub async fn worker_main( } 2 => { // Prove - let (rss_start, vmsize_start) = get_total_expander_memory_mb(); - println!("[MPI Rank {}] Worker received prove broadcast, MEMORY = {} MB (RSS), {} MB (VmSize)", - state.global_mpi_config.world_rank(), rss_start, vmsize_start); - - println!( - "[MPI Rank {}] Worker acquiring witness lock...", - state.global_mpi_config.world_rank() - ); let mut witness = state.witness.lock().await; let mut witness_win = state.wt_shared_memory_win.lock().await; S::setup_shared_witness(&state.global_mpi_config, &mut witness, &mut witness_win); @@ -295,10 +222,6 @@ pub async fn worker_main( &witness, ); assert!(proof.is_none()); - - let (rss_end, vmsize_end) = get_total_expander_memory_mb(); - println!("[MPI Rank {}] Worker prove done - witness lock will be released, but witness remains in state.witness, MEMORY = {} MB (RSS), {} MB (VmSize)", - state.global_mpi_config.world_rank(), rss_end, vmsize_end); } 255 => { break; @@ -353,25 +276,12 @@ where S: ServerFns + 'static, { - use std::time::Instant; - - let serve_start = Instant::now(); - println!("[TIMING] serve() START"); - - let step_start = Instant::now(); let global_mpi_config = unsafe { UNIVERSE = MPIConfig::init(); GLOBAL_COMMUNICATOR = UNIVERSE.as_ref().map(|u| u.world()); MPIConfig::prover_new(UNIVERSE.as_ref(), GLOBAL_COMMUNICATOR.as_ref()) }; - let rank = global_mpi_config.world_rank(); - println!( - "[TIMING Rank {}] MPI initialization took {:.3}s", - rank, - step_start.elapsed().as_secs_f64() - ); - let step_start = Instant::now(); let state = ServerState { lock: Arc::new(Mutex::new(())), global_mpi_config: global_mpi_config.clone(), @@ -384,67 +294,24 @@ where wt_shared_memory_win: Arc::new(Mutex::new(None)), shutdown_tx: Arc::new(Mutex::new(None)), }; - println!( - "[TIMING Rank {}] ServerState creation took {:.3}s", - rank, - step_start.elapsed().as_secs_f64() - ); if global_mpi_config.is_root() { - println!( - "[TIMING Rank {}] Root process: setting up HTTP server", - rank - ); - - let step_start = Instant::now(); let (tx, rx) = oneshot::channel::<()>(); state.shutdown_tx.lock().await.replace(tx); - println!( - "[TIMING Rank {}] Shutdown channel setup took {:.3}s", - rank, - step_start.elapsed().as_secs_f64() - ); - - let step_start = Instant::now(); + let app = Router::new() .route("/", post(root_main::)) .route("/", get(|| async { "Expander Server is running" })) .with_state(state.clone()); - println!( - "[TIMING Rank {}] Router creation took {:.3}s", - rank, - step_start.elapsed().as_secs_f64() - ); - - let step_start = Instant::now(); + let ip: IpAddr = SERVER_IP.parse().expect("Invalid SERVER_IP"); let port_val = port_number.parse::().unwrap_or_else(|e| { eprintln!("Error: Invalid port number '{port_number}'. {e}."); std::process::exit(1); }); let addr = SocketAddr::new(ip, port_val); - println!( - "[TIMING Rank {}] Address parsing took {:.3}s", - rank, - step_start.elapsed().as_secs_f64() - ); - - let step_start = Instant::now(); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); - println!( - "[TIMING Rank {}] TCP listener bind took {:.3}s", - rank, - step_start.elapsed().as_secs_f64() - ); - println!("Server running at http://{addr}"); - println!( - "[TIMING Rank {}] Total startup time: {:.3}s", - rank, - serve_start.elapsed().as_secs_f64() - ); - - let step_start = Instant::now(); axum::serve(listener, app.into_make_service()) .with_graceful_shutdown(async { rx.await.ok(); @@ -452,14 +319,8 @@ where }) .await .unwrap(); - println!( - "[TIMING Rank {}] Server shutdown after {:.3}s of running", - rank, - step_start.elapsed().as_secs_f64() - ); // it might need some time for the server to properly shutdown - let step_start = Instant::now(); loop { match Arc::strong_count(&state.computation_graph) { 1 => { @@ -471,26 +332,10 @@ where } } } - println!( - "[TIMING Rank {}] Waiting for clean shutdown took {:.3}s", - rank, - step_start.elapsed().as_secs_f64() - ); } else { - println!( - "[TIMING Rank {}] Worker process: entering worker_main", - rank - ); - let step_start = Instant::now(); worker_main::(global_mpi_config, state.clone()).await; - println!( - "[TIMING Rank {}] Worker finished after {:.3}s", - rank, - step_start.elapsed().as_secs_f64() - ); } - let step_start = Instant::now(); match ( Arc::try_unwrap(state.computation_graph), Arc::try_unwrap(state.witness), @@ -510,29 +355,12 @@ where panic!("Failed to unwrap Arc, multiple references exist"); } } - println!( - "[TIMING Rank {}] Shared memory cleanup took {:.3}s", - rank, - step_start.elapsed().as_secs_f64() - ); if state.global_mpi_config.is_root() { println!("Server has been shut down."); } - let step_start = Instant::now(); unsafe { mpi::ffi::MPI_Finalize() }; - println!( - "[TIMING Rank {}] MPI_Finalize took {:.3}s", - rank, - step_start.elapsed().as_secs_f64() - ); - - println!( - "[TIMING Rank {}] serve() TOTAL TIME: {:.3}s", - rank, - serve_start.elapsed().as_secs_f64() - ); } #[derive(Parser, Debug)] diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs index 47c095d0..d3e326d9 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs @@ -19,57 +19,6 @@ use crate::{ }, }; -/// 获取所有 expander_server 进程的内存占用(单位:MB) -/// 返回 (VmRSS物理内存, VmSize虚拟内存) -fn get_total_expander_memory_mb() -> (usize, usize) { - use std::fs; - use std::io::{BufRead, BufReader}; - - let mut total_rss_kb = 0usize; - let mut total_vmsize_kb = 0usize; - - // 遍历 /proc 目录 - if let Ok(entries) = fs::read_dir("/proc") { - for entry in entries.flatten() { - if let Ok(file_name) = entry.file_name().into_string() { - // 只处理数字目录(进程PID) - if file_name.chars().all(|c| c.is_ascii_digit()) { - // 读取 /proc/[pid]/comm 检查进程名 - let comm_path = format!("/proc/{}/comm", file_name); - if let Ok(comm) = fs::read_to_string(&comm_path) { - if comm.trim() == "expander_server" { - // 读取 /proc/[pid]/status 获取内存信息 - let status_path = format!("/proc/{}/status", file_name); - if let Ok(file) = fs::File::open(&status_path) { - let reader = BufReader::new(file); - for line in reader.lines().flatten() { - if line.starts_with("VmRSS:") { - // VmRSS: 12345 kB (物理内存) - if let Some(rss_str) = line.split_whitespace().nth(1) { - if let Ok(rss_kb) = rss_str.parse::() { - total_rss_kb += rss_kb; - } - } - } else if line.starts_with("VmSize:") { - // VmSize: 12345 kB (虚拟内存) - if let Some(size_str) = line.split_whitespace().nth(1) { - if let Ok(size_kb) = size_str.parse::() { - total_vmsize_kb += size_kb; - } - } - } - } - } - } - } - } - } - } - } - - (total_rss_kb / 1024, total_vmsize_kb / 1024) // 转换为MB -} - pub trait ServerFns where C: gkr_engine::GKREngine, @@ -96,10 +45,6 @@ where witness_target: &mut Vec>>, mpi_shared_memory_win: &mut Option, ) { - let (rss_start, vmsize_start) = get_total_expander_memory_mb(); - println!("[MPI Rank {}] setup_shared_witness: START - disposing old witness, MEMORY = {} MB (RSS), {} MB (VmSize)", - global_mpi_config.world_rank(), rss_start, vmsize_start); - // dispose of the previous shared memory if it exists while let Some(w) = witness_target.pop() { w.discard_control_of_shared_mem(); @@ -110,10 +55,6 @@ where global_mpi_config.free_shared_mem(&mut win_wrapper.win); } - let (rss_after_dispose, vmsize_after_dispose) = get_total_expander_memory_mb(); - println!("[MPI Rank {}] setup_shared_witness: Old witness disposed, MEMORY = {} MB (RSS), {} MB (VmSize), calling read_shared_witness_from_shared_memory", - global_mpi_config.world_rank(), rss_after_dispose, vmsize_after_dispose); - // Allocate new shared memory for the witness let (witness_v, wt_shared_memory_win) = SharedMemoryEngine::read_shared_witness_from_shared_memory::( @@ -121,10 +62,6 @@ where ); *witness_target = witness_v; *mpi_shared_memory_win = Some(wt_shared_memory_win); - - let (rss_end, vmsize_end) = get_total_expander_memory_mb(); - println!("[MPI Rank {}] setup_shared_witness: DONE - witness loaded into local memory, MEMORY = {} MB (RSS), {} MB (VmSize)", - global_mpi_config.world_rank(), rss_end, vmsize_end); } fn shared_memory_clean_up( @@ -186,17 +123,7 @@ where C: GKREngine, ECCConfig: Config, { - let (rss_start, vmsize_start) = get_total_expander_memory_mb(); - println!("[MPI Rank {}] prove_request_handler: START - witness is being used for proving, MEMORY = {} MB (RSS), {} MB (VmSize)", - global_mpi_config.world_rank(), rss_start, vmsize_start); - - let proof = mpi_prove_impl(global_mpi_config, prover_setup, computation_graph, values); - - let (rss_end, vmsize_end) = get_total_expander_memory_mb(); - println!("[MPI Rank {}] prove_request_handler: DONE - witness is still in memory but no longer actively used, MEMORY = {} MB (RSS), {} MB (VmSize)", - global_mpi_config.world_rank(), rss_end, vmsize_end); - - proof + mpi_prove_impl(global_mpi_config, prover_setup, computation_graph, values) } } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs index d01bb253..75791bee 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs @@ -228,12 +228,6 @@ impl SharedMemoryEngine { pub fn read_shared_witness_from_shared_memory( global_mpi_config: &MPIConfig<'static>, ) -> (Vec>, SharedMemoryWINWrapper) { - use std::time::Instant; - - let (rss_before, vmsize_before) = get_total_expander_memory_mb(); - // 打印关键信息:进程rank和witness长度 - println!("[MPI Rank {}] read_shared_witness_from_shared_memory: MEMORY_BEFORE = {} MB (RSS), {} MB (VmSize)", - global_mpi_config.world_rank(), rss_before, vmsize_before); let (mut mpi_shared_mem_ptr, mem_win) = if global_mpi_config.is_root() { let witness = Self::read_witness_from_shared_memory::(); let bytes_size = std::mem::size_of::() @@ -253,158 +247,11 @@ impl SharedMemoryEngine { global_mpi_config.barrier(); - // ⏸️ 等待检查点:等待 /tmp/continue_witness_test 文件出现才继续 - let checkpoint_file = "/tmp/continue_witness_test1"; - println!( - "[MPI Rank {}] ⏸️ CHECKPOINT: Waiting for file '{}' to continue...", - global_mpi_config.world_rank(), - checkpoint_file - ); - println!("[MPI Rank {}] ⏸️ You can now check memory usage. Create the file to continue: touch {}", - global_mpi_config.world_rank(), checkpoint_file); - - let mut check_count = 0; - loop { - if std::path::Path::new(checkpoint_file).exists() { - println!( - "[MPI Rank {}] ✅ Checkpoint file detected, continuing execution", - global_mpi_config.world_rank() - ); - break; - } - - check_count += 1; - std::thread::sleep(std::time::Duration::from_millis(500)); - } - // ⏱️ 开始计时:测量从共享内存读取witness的耗时 - let read_start = Instant::now(); let n_witness = usize::new_from_memory(&mut mpi_shared_mem_ptr); - let read_n_witness_duration = read_start.elapsed(); - - println!( - "[MPI Rank {}] ⏱️ Read n_witness={} took {:.3} µs", - global_mpi_config.world_rank(), - n_witness, - read_n_witness_duration.as_micros() - ); - - let witness_read_start = Instant::now(); let witness = (0..n_witness) .map(|_| Vec::::new_from_memory(&mut mpi_shared_mem_ptr)) .collect::>(); - let witness_read_duration = witness_read_start.elapsed(); - - println!("[MPI Rank {}] ⏱️ Read {} witness components from shared memory took {:.3} ms ({:.3} µs)", - global_mpi_config.world_rank(), - n_witness, - witness_read_duration.as_secs_f64() * 1000.0, - witness_read_duration.as_micros()); - - let (rss_after, vmsize_after) = get_total_expander_memory_mb(); - - // 打印每个witness component的大小 - let total_elements: usize = witness.iter().map(|v| v.len()).sum(); - let total_bytes: usize = witness - .iter() - .map(|v| v.len() * std::mem::size_of_val(&v[0])) - .sum(); - let rss_increase = rss_after.saturating_sub(rss_before); - let vmsize_increase = vmsize_after.saturating_sub(vmsize_before); - println!("[MPI Rank {}] Copied witness to local memory: {} components, {} total elements, ~{} MB witness data", - global_mpi_config.world_rank(), - witness.len(), - total_elements, - total_bytes / 1024 / 1024); - println!( - "[MPI Rank {}] MEMORY_AFTER_COPY: RSS = {} MB (+{} MB), VmSize = {} MB (+{} MB)", - global_mpi_config.world_rank(), - rss_after, - rss_increase, - vmsize_after, - vmsize_increase - ); - - // ⏸️ 等待检查点:等待 /tmp/continue_witness_test 文件出现才继续 - let checkpoint_file = "/tmp/continue_witness_test"; - println!( - "[MPI Rank {}] ⏸️ CHECKPOINT: Waiting for file '{}' to continue...", - global_mpi_config.world_rank(), - checkpoint_file - ); - println!("[MPI Rank {}] ⏸️ You can now check memory usage. Create the file to continue: touch {}", - global_mpi_config.world_rank(), checkpoint_file); - - let mut check_count = 0; - loop { - if std::path::Path::new(checkpoint_file).exists() { - println!( - "[MPI Rank {}] ✅ Checkpoint file detected, continuing execution", - global_mpi_config.world_rank() - ); - break; - } - - // 每10次检查打印一次内存状态(避免日志过多) - if check_count % 10 == 0 { - let (rss, vmsize) = get_total_expander_memory_mb(); - println!( - "[MPI Rank {}] ⏳ Still waiting... (check #{}, RSS = {} MB, VmSize = {} MB)", - global_mpi_config.world_rank(), - check_count, - rss, - vmsize - ); - } - - check_count += 1; - std::thread::sleep(std::time::Duration::from_millis(500)); - } - - // 🔥 主动访问witness数据,强制触发物理页分配 - println!("[MPI Rank {}] 🔥 Now actively accessing witness data to trigger physical page allocation...", - global_mpi_config.world_rank()); - - let access_start = Instant::now(); - - // 遍历所有witness数据,真正读取每个元素的字节 - let mut dummy_sum = 0u64; - for component in witness.iter() { - // 将Vec转为字节切片,确保访问实际内存 - let bytes: &[u8] = unsafe { - std::slice::from_raw_parts( - component.as_ptr() as *const u8, - component.len() * std::mem::size_of::(), - ) - }; - - // 每隔4KB(页面大小)读取一个字节,确保触碰所有页面 - for i in (0..bytes.len()).step_by(4096) { - unsafe { - // 使用read_volatile防止编译器优化 - dummy_sum = dummy_sum.wrapping_add(std::ptr::read_volatile(&bytes[i]) as u64); - } - } - } - - let access_duration = access_start.elapsed(); - println!( - "[MPI Rank {}] 🔥 Finished accessing witness data (dummy_sum = {}, took {:.3}s)", - global_mpi_config.world_rank(), - dummy_sum, - access_duration.as_secs_f64() - ); - - // 再次测量内存,看是否因为访问而增长 - let (rss_after_access, vmsize_after_access) = get_total_expander_memory_mb(); - let rss_increase_by_access = rss_after_access.saturating_sub(rss_after); - println!( - "[MPI Rank {}] 📊 MEMORY_AFTER_ACCESS: RSS = {} MB (+{} MB from copy), VmSize = {} MB", - global_mpi_config.world_rank(), - rss_after_access, - rss_increase_by_access, - vmsize_after_access - ); (witness, SharedMemoryWINWrapper { win: mem_win }) } From 9a12c27f01763ce219df1e40be9c5b9005b818ec Mon Sep 17 00:00:00 2001 From: hczphn Date: Tue, 20 Jan 2026 03:21:18 +0000 Subject: [PATCH 04/10] remove duplicate schedule --- .../expander_parallelized/prove_impl.rs | 380 +----------------- 1 file changed, 1 insertion(+), 379 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs index 600036fd..aaf3f5be 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs @@ -222,382 +222,4 @@ pub fn partition_single_gkr_claim_and_open_pcs_mpi( transcript, ); } -} - -// ==================== SCHEDULE-BASED EXECUTION ==================== - -/// Schedule representation: rank -> sequence of tasks -#[derive(Debug, Clone)] -pub struct Schedule { - /// Map from rank to list of task names - /// e.g., rank 0 -> ["Task14", "Task1", "Task12"] - pub rank_tasks: HashMap>, -} - -impl Schedule { - /// Parse schedule from text file - /// Format: "Rank 0: Task14 -> Task1 -> Task12" - pub fn from_file(path: &str) -> Result { - let content = - fs::read_to_string(path).map_err(|e| format!("Failed to read schedule file: {}", e))?; - - let mut rank_tasks = HashMap::new(); - - for line in content.lines() { - let line = line.trim(); - if line.is_empty() { - continue; - } - - // Parse "Rank X: TaskA -> TaskB -> TaskC" - let parts: Vec<&str> = line.split(':').collect(); - if parts.len() != 2 { - return Err(format!("Invalid line format: {}", line)); - } - - // Handle multiple spaces: "Rank 0:" or "Rank 0:" - let rank_part = parts[0].trim(); - if !rank_part.starts_with("Rank") { - return Err(format!("Expected 'Rank X', got: {}", parts[0])); - } - - let rank_str = rank_part - .strip_prefix("Rank") - .ok_or_else(|| format!("Expected 'Rank X', got: {}", parts[0]))? - .trim(); // Trim to handle multiple spaces - - let rank: usize = rank_str - .parse() - .map_err(|e| format!("Invalid rank number '{}': {}", rank_str, e))?; - - // Extract tasks - let tasks_str = parts[1].trim(); - let tasks: Vec = tasks_str - .split("->") - .map(|t| t.trim().to_string()) - .filter(|t| !t.is_empty()) - .collect(); - - rank_tasks.insert(rank, tasks); - } - - Ok(Schedule { rank_tasks }) - } - - /// Get tasks for a specific rank - pub fn get_tasks(&self, rank: usize) -> Option<&Vec> { - self.rank_tasks.get(&rank) - } - - /// Find which ranks are executing the same task at the same step - pub fn find_peers_at_step(&self, step: usize, task_name: &str) -> Vec { - let mut peers = Vec::new(); - - for (rank, tasks) in &self.rank_tasks { - if tasks.len() > step && tasks[step] == task_name { - peers.push(*rank); - } - } - - peers.sort(); - peers - } - - /// Get maximum number of steps across all ranks - pub fn max_steps(&self) -> usize { - self.rank_tasks - .values() - .map(|tasks| tasks.len()) - .max() - .unwrap_or(0) - } -} - -/// Parse task mapping file -/// Format: "Task1: 0" (Task1 maps to template index 0) -pub fn parse_task_mapping(path: &str) -> Result, String> { - let content = - fs::read_to_string(path).map_err(|e| format!("Failed to read task mapping file: {}", e))?; - - let mut mapping = HashMap::new(); - - for line in content.lines() { - let line = line.trim(); - if line.is_empty() || line.starts_with('#') { - continue; - } - - let parts: Vec<&str> = line.split(':').collect(); - if parts.len() != 2 { - continue; - } - - let task_name = parts[0].trim().to_string(); - let template_idx: usize = parts[1] - .trim() - .parse() - .map_err(|e| format!("Invalid template index: {}", e))?; - - mapping.insert(task_name, template_idx); - } - - Ok(mapping) -} - -/// Create MPI subgroup for a specific task based on peers -fn create_mpi_subgroup_for_task( - global_mpi_config: &MPIConfig<'static>, - peers: &[usize], - task_name: &str, -) -> Option> { - use mpi::topology::Communicator; - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - - let my_rank = global_mpi_config.world_rank(); - - // Find my position in peers - let my_position = peers.iter().position(|&r| r == my_rank)?; - - // Use task name hash as color - let mut hasher = DefaultHasher::new(); - task_name.hash(&mut hasher); - let color_value = (hasher.finish() % 10000) as i32; - let color = mpi::topology::Color::with_value(color_value); - - // Split communicator and leak it to get 'static lifetime - let split_comm = unsafe { - global_mpi_config - .world - .unwrap() - .split_by_color_with_key(color, my_position as i32) - }; - - // Leak the communicator to get 'static lifetime - // split_comm is Option, we need to leak the inner value - let split_comm_static: &'static Option<_> = Box::leak(Box::new(split_comm)); - - Some(MPIConfig::prover_new( - global_mpi_config.universe, - split_comm_static.as_ref(), - )) -} - -/// Main prove function with schedule -pub fn mpi_prove_with_schedule( - global_mpi_config: &MPIConfig<'static>, - schedule_path: &str, - task_mapping_path: Option<&str>, // Optional: if None, use template index as task name - prover_setup: &ExpanderProverSetup, - computation_graph: &ComputationGraph, - values: &[impl AsRef<[SIMDField]>], -) -> Option>> -where - C: GKREngine, - ECCConfig: Config, -{ - let my_rank = global_mpi_config.world_rank(); - - // 1. Load schedule - let schedule = match Schedule::from_file(schedule_path) { - Ok(s) => s, - Err(e) => { - eprintln!("[RANK {}] Failed to load schedule: {}", my_rank, e); - return None; - } - }; - - eprintln!( - "[RANK {}] Loaded schedule with {} ranks, max {} steps", - my_rank, - schedule.rank_tasks.len(), - schedule.max_steps() - ); - - // 2. Load task mapping (or use default: TaskX -> template X) - let task_mapping = if let Some(path) = task_mapping_path { - match parse_task_mapping(path) { - Ok(m) => m, - Err(e) => { - eprintln!("[RANK {}] Failed to load task mapping: {}", my_rank, e); - return None; - } - } - } else { - // Default: Task0->0, Task1->1, etc. - let mut default_mapping = HashMap::new(); - for i in 0..computation_graph.proof_templates().len() { - default_mapping.insert(format!("Task{}", i), i); - } - default_mapping - }; - - // 3. Commit phase (only root) - let (commitments, states) = if global_mpi_config.is_root() { - eprintln!("[RANK {}] === COMMIT PHASE ===", my_rank); - let commit_timer = Timer::new("Commit to all input", true); - let (commitments, states) = values - .iter() - .map(|value| { - local_commit_impl::( - prover_setup.p_keys.get(&value.as_ref().len()).unwrap(), - value.as_ref(), - ) - }) - .unzip::<_, _, Vec<_>, Vec<_>>(); - commit_timer.stop(); - (Some(commitments), Some(states)) - } else { - (None, None) - }; - - // 4. Get my tasks - let my_tasks = match schedule.get_tasks(my_rank) { - Some(tasks) => tasks, - None => { - eprintln!("[RANK {}] No tasks assigned in schedule", my_rank); - return if my_rank == 0 { - Some(CombinedProof { - commitments: commitments.unwrap(), - proofs: vec![], - }) - } else { - None - }; - } - }; - - eprintln!("[RANK {}] My tasks: {:?}", my_rank, my_tasks); - - // 5. Execute tasks step by step - let mut all_proofs: Vec> = - vec![None; computation_graph.proof_templates().len()]; - - for (step, task_name) in my_tasks.iter().enumerate() { - eprintln!( - "[RANK {}] === STEP {} === Task: {}", - my_rank, step, task_name - ); - - // Find template index - let template_idx = match task_mapping.get(task_name) { - Some(&idx) => idx, - None => { - eprintln!("[RANK {}] Unknown task: {}", my_rank, task_name); - continue; - } - }; - - if template_idx >= computation_graph.proof_templates().len() { - eprintln!( - "[RANK {}] Invalid template index: {}", - my_rank, template_idx - ); - continue; - } - - // Find peers for this task at this step - let peers = schedule.find_peers_at_step(step, task_name); - eprintln!("[RANK {}] Task {} peers: {:?}", my_rank, task_name, peers); - - if peers.is_empty() || !peers.contains(&my_rank) { - eprintln!("[RANK {}] Not participating in task {}", my_rank, task_name); - continue; - } - - let template = &computation_graph.proof_templates()[template_idx]; - let kernel = &computation_graph.kernels()[template.kernel_id()]; - - let commitment_values = template - .commitment_indices() - .iter() - .map(|&idx| values[idx].as_ref()) - .collect::>(); - - // Create MPI subgroup if multiple peers - let local_mpi_config = if peers.len() > 1 { - create_mpi_subgroup_for_task(global_mpi_config, &peers, task_name) - } else { - None - }; - - // Execute GKR - let gkr_result = if let Some(ref local_config) = local_mpi_config { - eprintln!( - "[RANK {}] Executing task {} with {} peers (local_rank={})", - my_rank, - task_name, - peers.len(), - local_config.world_rank() - ); - - prove_kernel_gkr::( - local_config, - kernel, - &commitment_values, - next_power_of_two(template.parallel_count()), - template.is_broadcast(), - ) - } else { - // Single rank task - eprintln!("[RANK {}] Executing task {} solo", my_rank, task_name); - None // Skip for now - }; - - // PCS opening (only subgroup root) - if let Some((mut transcript, challenge)) = gkr_result { - let is_subgroup_root = local_mpi_config - .as_ref() - .map(|c| c.is_root()) - .unwrap_or(true); - - if is_subgroup_root { - eprintln!( - "[RANK {}] Performing PCS opening for task {}", - my_rank, task_name - ); - - let pcs_timer = Timer::new(&format!("PCS for {}", task_name), true); - let challenges = if let Some(challenge_y) = challenge.challenge_y() { - vec![challenge.challenge_x(), challenge_y] - } else { - vec![challenge.challenge_x()] - }; - - challenges.iter().for_each(|c| { - partition_single_gkr_claim_and_open_pcs_mpi::( - prover_setup, - &commitment_values, - &template - .commitment_indices() - .iter() - .map(|&idx| &states.as_ref().unwrap()[idx]) - .collect::>(), - c, - template.is_broadcast(), - &mut transcript, - ); - }); - - pcs_timer.stop(); - - all_proofs[template_idx] = Some(ExpanderProof { - data: vec![transcript.finalize_and_get_proof()], - }); - } - } - } - - // 6. Collect results (only root) - if global_mpi_config.is_root() { - let proofs = all_proofs.into_iter().filter_map(|p| p).collect::>(); - eprintln!("[RANK {}] Collected {} proofs", my_rank, proofs.len()); - - Some(CombinedProof { - commitments: commitments.unwrap(), - proofs, - }) - } else { - None - } -} +} \ No newline at end of file From c89adbb8f3f737da222236670aa22003f1444b3f Mon Sep 17 00:00:00 2001 From: hczphn Date: Tue, 20 Jan 2026 04:17:18 +0000 Subject: [PATCH 05/10] reduce eprintln --- .../expander_no_oversubscribe/prove_impl.rs | 201 +++++------------- 1 file changed, 50 insertions(+), 151 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index 9b4ef6c4..10f80de8 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -48,10 +48,9 @@ where // Check for schedule file and use scheduler if available if std::path::Path::new("schedule.txt").exists() { let my_rank = global_mpi_config.world_rank(); - eprintln!( - "[RANK {}] ⚡ Schedule file detected, using scheduled execution", - my_rank - ); + if my_rank == 0 { + eprintln!("⚡ Schedule file detected, using scheduled execution"); + } return mpi_prove_no_oversubscribe_with_schedule::( global_mpi_config, "schedule.txt", @@ -670,11 +669,6 @@ fn mark_task_completed(task_name: &str, my_rank: usize, peers: &[usize]) { "[RANK {}] Warning: Failed to write marker for {}: {}", my_rank, task_name, e ); - } else { - eprintln!( - "[RANK {}] ✓ Marked task {} as completed", - my_rank, task_name - ); } } } @@ -686,52 +680,44 @@ fn wait_for_dependencies(task_name: &str, dependencies: &[String], my_rank: usiz return; } - eprintln!( - "[RANK {}] Task {} waiting for {} dependencies: {:?}", - my_rank, - task_name, - dependencies.len(), - dependencies - ); + let mut need_wait = false; + let mut waited_deps = vec![]; for dep in dependencies { let marker_path = format!(".task_sync/{}.done", dep); if std::path::Path::new(&marker_path).exists() { - eprintln!( - "[RANK {}] ✓ Dependency {} already satisfied", - my_rank, dep - ); continue; } - eprintln!("[RANK {}] ⏳ Waiting for dependency: {}", my_rank, dep); + if !need_wait { + need_wait = true; + } let start_time = std::time::Instant::now(); while !std::path::Path::new(&marker_path).exists() { std::thread::sleep(std::time::Duration::from_millis(100)); - // Timeout check (optional, for debugging) + // Timeout warning if start_time.elapsed().as_secs() > 600 { eprintln!( - "[RANK {}] ⚠️ WARNING: Waiting for {} over 10 minutes!", + "[RANK {}] ⚠️ WARNING: Waiting for {} over 10 minutes!", my_rank, dep ); } } - eprintln!( - "[RANK {}] ✓ Dependency {} satisfied (waited {:.1}s)", - my_rank, - dep, - start_time.elapsed().as_secs_f64() - ); + if start_time.elapsed().as_secs_f64() > 0.5 { + waited_deps.push((dep.clone(), start_time.elapsed().as_secs_f64())); + } } - eprintln!( - "[RANK {}] All dependencies for {} satisfied", - my_rank, task_name - ); + // Only log if actually waited + if !waited_deps.is_empty() && my_rank % 8 == 0 { + eprintln!("[RANK {}] Task {} waited for {} deps (longest: {:.1}s)", + my_rank, task_name, waited_deps.len(), + waited_deps.iter().map(|(_, t)| t).fold(0.0f64, |a, &b| a.max(b))); + } } /// Create MPI subgroup for a specific task @@ -812,25 +798,15 @@ where } }; - eprintln!("[RANK {}] ========== SCHEDULER MODE ==========", my_rank); - eprintln!( - "[RANK {}] Loaded schedule with {} ranks, max {} steps", - my_rank, - schedule.rank_tasks.len(), - schedule.max_steps() - ); + if global_mpi_config.is_root() { + eprintln!("========== SCHEDULER MODE =========="); + eprintln!(" Schedule: {} ranks, max {} steps", + schedule.rank_tasks.len(), schedule.max_steps()); + } // Safety checks let num_templates = computation_graph.proof_templates().len(); let num_values = values.len(); - eprintln!( - "[RANK {}] Computation graph has {} templates", - my_rank, num_templates - ); - eprintln!( - "[RANK {}] Values array has {} elements", - my_rank, num_values - ); if num_templates == 0 { eprintln!( @@ -873,22 +849,17 @@ where let task_dependencies = if std::path::Path::new("task_dependencies.txt").exists() { match parse_task_dependencies("task_dependencies.txt") { Ok(deps) => { - eprintln!("[RANK {}] Loaded {} task dependencies", my_rank, deps.len()); + if global_mpi_config.is_root() { + eprintln!(" Loaded {} task dependencies", deps.len()); + } deps } Err(e) => { - eprintln!( - "[RANK {}] Warning: Failed to load dependencies: {}", - my_rank, e - ); + eprintln!("[RANK {}] ERROR: Failed to load dependencies: {}", my_rank, e); HashMap::new() } } } else { - eprintln!( - "[RANK {}] No task_dependencies.txt found, using file-based sync", - my_rank - ); HashMap::new() }; @@ -903,31 +874,22 @@ where } } } - eprintln!("[RANK 0] Initialized .task_sync directory"); } - global_mpi_config.barrier(); // Wait for directory creation + global_mpi_config.barrier(); // ========== PRE-CREATE ALL MPI SUBGROUPS ========== - // CRITICAL: Create all task subgroups BEFORE any task execution - // This allows ranks to proceed asynchronously without collective deadlock - eprintln!( - "[RANK {}] Pre-creating MPI subgroups for all tasks...", - my_rank - ); - let all_unique_tasks = schedule.get_all_unique_tasks(); let mut task_mpi_configs: HashMap>> = HashMap::new(); + if global_mpi_config.is_root() { + eprintln!(" Pre-creating MPI subgroups for {} tasks...", all_unique_tasks.len()); + } + for task_name in &all_unique_tasks { let peers = schedule.find_all_peers_for_task(task_name); - eprintln!( - "[RANK {}] Creating MPI subgroup for task {} (peers: {:?})", - my_rank, task_name, peers - ); - // All 32 ranks call this together (collective operation) - let mpi_config = if peers.len() >= 1 { + let mpi_config = if !peers.is_empty() { create_mpi_subgroup_for_task(global_mpi_config, &peers, task_name) } else { None @@ -936,12 +898,10 @@ where task_mpi_configs.insert(task_name.clone(), mpi_config); } - eprintln!( - "[RANK {}] Pre-created {} MPI subgroups", - my_rank, - task_mpi_configs.len() - ); - global_mpi_config.barrier(); // Ensure all subgroups created before proceeding + if global_mpi_config.is_root() { + eprintln!(" MPI subgroups ready"); + } + global_mpi_config.barrier(); // Commit phase (only root) let commit_timer = Timer::new("Commit to all input", global_mpi_config.is_root()); @@ -999,22 +959,14 @@ where } }; - eprintln!("[RANK {}] My tasks: {:?}", my_rank, my_tasks); - // Execute tasks step by step let mut all_proofs: Vec> = vec![None; computation_graph.proof_templates().len()]; let prove_timer = Timer::new("Prove all kernels", global_mpi_config.is_root()); for (step, task_name) in my_tasks.iter().enumerate() { - eprintln!( - "[RANK {}] === STEP {} === Task: {}", - my_rank, step, task_name - ); - // Skip idle steps if task_name == "idle" || task_name == "..." { - // Idle ranks still participate in MPI collective operations continue; } @@ -1048,15 +1000,9 @@ where let i_am_participant = all_peers.contains(&my_rank); if !i_am_participant { - eprintln!("[RANK {}] Not participating in task {}", my_rank, task_name); continue; } - eprintln!( - "[RANK {}] Task {} peers: {:?} (using pre-created MPI subgroup)", - my_rank, task_name, all_peers - ); - let template = &computation_graph.proof_templates()[template_idx]; // Safety check: verify all commitment indices are in bounds @@ -1098,14 +1044,6 @@ where ); let gkr_end_state = if let Some(ref local_config) = local_mpi_config { - eprintln!( - "[RANK {}] Executing task {} with {} peers (local_rank={}, group_size={})", - my_rank, - task_name, - all_peers.len(), - local_config.world_rank(), - local_config.world_size() - ); prove_kernel_gkr_no_oversubscribe::, GetTranscript, ZC::ECCConfig>( local_config, @@ -1166,10 +1104,6 @@ where .unwrap_or(true); if is_subgroup_root { - eprintln!( - "[RANK {}] I am subgroup root for task {}", - my_rank, task_name - ); i_am_subgroup_root_for_tasks.push(template_idx); match ZC::BATCH_PCS { @@ -1230,17 +1164,8 @@ where mark_task_completed(task_name, my_rank, &all_peers); } - // ========== CRITICAL: Global barrier ========== - // Wait for all ranks to complete all their tasks before proceeding to PCS opening - eprintln!( - "[RANK {}] All my tasks completed, waiting for other ranks...", - my_rank - ); + // Wait for all ranks to complete all tasks global_mpi_config.barrier(); - eprintln!( - "[RANK {}] All ranks ready, proceeding to result collection", - my_rank - ); prove_timer.stop(); // ========== MPI Result Collection (for BATCH_PCS mode) ========== @@ -1251,13 +1176,6 @@ where let i_am_subgroup_root = !i_am_subgroup_root_for_tasks.is_empty(); - eprintln!( - "[RANK {}] Am I subgroup root? {} (for {} tasks)", - my_rank, - i_am_subgroup_root, - i_am_subgroup_root_for_tasks.len() - ); - // Step 0: All ranks send a flag to root indicating if they are subgroup roots let my_flag = if i_am_subgroup_root { 1u8 } else { 0u8 }; @@ -1284,10 +1202,6 @@ where // Step 1: Non-root subgroup roots send their results to rank 0 if i_am_subgroup_root && my_rank != 0 { - eprintln!( - "[RANK {}] Sending my results to global root (indexed by template)", - my_rank - ); // Serialize the indexed structures (maintains template order) let mut vals_bytes = Vec::new(); @@ -1329,14 +1243,10 @@ where .process_at_rank(0) .synchronous_send(&proofs_bytes[..]); } - - eprintln!("[RANK {}] Results sent to global root", my_rank); } // Step 2: Global root receives all results if global_mpi_config.is_root() { - eprintln!("[RANK 0] Collecting results from all subgroup roots..."); - let flags = all_flags.unwrap(); // Identify which ranks are subgroup roots (have flag=1) @@ -1347,16 +1257,12 @@ where .map(|(rank, _)| rank) .collect(); - eprintln!("[RANK 0] Subgroup roots detected: {:?}", subgroup_roots); - // Receive from each subgroup root (except self) for &sender_rank in &subgroup_roots { if sender_rank == 0 { continue; // Skip self } - eprintln!("[RANK 0] Receiving results from rank {}", sender_rank); - // Receive sizes let (sizes, _status) = unsafe { global_mpi_config @@ -1420,14 +1326,6 @@ where } } - let received_count = received_vals_per_template - .iter() - .filter(|v| v.is_some()) - .count(); - eprintln!( - "[RANK 0] Received results from rank {} ({} templates)", - sender_rank, received_count - ); } // Build final vals_ref and challenges in template order @@ -1443,17 +1341,18 @@ where } } - eprintln!( - "[RANK 0] Result collection complete. Total: {} vals, {} challenges, {} proofs", - vals_ref_owned.len(), - challenges_final.len(), - all_proofs.iter().filter(|p| p.is_some()).count() - ); + let completed_templates = all_proofs.iter().filter(|p| p.is_some()).count(); + eprintln!("Result collection: {}/{} templates, {} vals, {} challenges", + completed_templates, num_templates, + vals_ref_owned.len(), challenges_final.len()); - eprintln!("[RANK 0] Templates coverage:"); - for (idx, val) in vals_per_template.iter().enumerate() { - let status = if val.is_some() { "✓" } else { "✗" }; - eprintln!(" Template {}: {}", idx, status); + if completed_templates < num_templates { + eprintln!("⚠️ WARNING: Only {}/{} templates completed!", completed_templates, num_templates); + for (idx, val) in vals_per_template.iter().enumerate() { + if val.is_none() { + eprintln!(" Missing: Template {}", idx); + } + } } } } From d91d499fa22823edadd9a07b508c03aff3a56363 Mon Sep 17 00:00:00 2001 From: hczphn Date: Tue, 20 Jan 2026 16:52:13 +0000 Subject: [PATCH 06/10] support no hint logup --- circuit-std-rs/src/logup.rs | 47 +++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/circuit-std-rs/src/logup.rs b/circuit-std-rs/src/logup.rs index 01123c23..b14cbd9d 100644 --- a/circuit-std-rs/src/logup.rs +++ b/circuit-std-rs/src/logup.rs @@ -326,6 +326,37 @@ impl LogUpSingleKeyTable { &alpha, ); + assert_eq_rational(builder, &v_table, &v_query); + } + pub fn final_check_with_query_count>(&mut self, builder: &mut B, query_count: &[Variable]) { + if self.table.is_empty() || self.query_keys.is_empty() { + panic!("empty table or empty query"); + } + + let value_len = self.table[0].len(); + + let alpha = builder.get_random_value(); + let randomness = get_column_randomness(builder, value_len); + + let table_combined = combine_columns(builder, &self.table, &randomness); + let mut inputs = vec![builder.constant(self.table.len() as u32)]; + //append table keys + for i in 0..self.table.len() { + inputs.push(self.table[i][0]); + } + //append query keys + inputs.extend(self.query_keys.clone()); + let v_table = logup_poly_val(builder, &table_combined, &query_count, &alpha); + + let query_combined = combine_columns(builder, &self.query_results, &randomness); + let one = builder.constant(1); + let v_query = logup_poly_val( + builder, + &query_combined, + &vec![one; query_combined.len()], + &alpha, + ); + assert_eq_rational(builder, &v_table, &v_query); } } @@ -455,6 +486,22 @@ impl LogUpRangeProofTable { ); assert_eq_rational(builder, &v_table, &v_query); } + + pub fn final_check_with_query_count>(&mut self, builder: &mut B, query_count: &[Variable]) { + let alpha = builder.get_random_value(); + let inputs = self.query_keys.clone(); + + let v_table = logup_poly_val(builder, &self.table_keys, &query_count, &alpha); + + let one = builder.constant(1); + let v_query = logup_poly_val( + builder, + &self.query_keys, + &vec![one; self.query_keys.len()], + &alpha, + ); + assert_eq_rational(builder, &v_table, &v_query); + } } pub fn query_count_hint(inputs: &[F], outputs: &mut [F]) -> Result<(), Error> { From 4fee13658d6a5f787200b75f574f97b2a151ac0b Mon Sep 17 00:00:00 2001 From: hczphn Date: Mon, 2 Feb 2026 20:13:34 +0000 Subject: [PATCH 07/10] lightweight prove, reduce memory cost after sending prove request --- .../api_no_oversubscribe.rs | 17 ++++++++++++++- .../expander_parallelized/client_utils.rs | 21 +++++++++++++++++-- .../expander_parallelized/server_ctrl.rs | 4 ++-- 3 files changed, 37 insertions(+), 5 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs index 7d7fed98..6280fb4b 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs @@ -6,7 +6,7 @@ use crate::zkcuda::proving_system::expander::structs::{ }; use crate::zkcuda::proving_system::expander_parallelized::client_utils::{ client_launch_server_and_setup, client_parse_args, client_send_witness_and_prove, wait_async, - ClientHttpHelper, + ClientHttpHelper, client_send_witness_and_prove_nowait }; use crate::zkcuda::proving_system::{ CombinedProof, ExpanderPCSDefered, ParallelizedExpander, ProvingSystem, @@ -73,3 +73,18 @@ where wait_async(ClientHttpHelper::request_exit()) } } + +impl ExpanderNoOverSubscribe +where + as ExpanderPCS>>::Commitment: + AsRef< as ExpanderPCS>>::Commitment>, +{ + /// Lightweight prove that doesn't require computation_graph or prover_setup. + /// Use this after setup() to allow releasing those large data structures before proving. + pub fn prove_lightweight( + device_memories: Vec>>, + ) { + client_send_witness_and_prove::(device_memories); + // client_send_witness_and_prove_nowait::(device_memories); + } +} diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs index ea63c5d1..432386fd 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs @@ -87,7 +87,7 @@ where C: GKREngine, ECCConfig: Config, { - let setup_timer = Timer::new("setup", true); + let setup_timer = Timer::new("new setup", true); println!("Starting server with binary: {server_binary}"); let mut bytes = vec![]; @@ -141,7 +141,11 @@ where setup_timer.stop(); - SharedMemoryEngine::read_pcs_setup_from_shared_memory() + // SharedMemoryEngine::read_pcs_setup_from_shared_memory() + ( + ExpanderProverSetup::default(), + ExpanderVerifierSetup::default(), + ) } pub fn client_send_witness_and_prove( @@ -163,6 +167,19 @@ where proof } + +pub fn client_send_witness_and_prove_nowait( + device_memories: Vec>>, +) +where + C: GKREngine, + ECCConfig: Config, +{ + let timer = Timer::new("prove", true); + + SharedMemoryEngine::write_witness_to_shared_memory::(device_memories); + ClientHttpHelper::request_prove(); +} /// Run an async function in a blocking context. #[inline(always)] pub fn wait_async(f: F) -> T diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs index 70be640d..14a04035 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs @@ -28,7 +28,7 @@ use std::sync::Mutex as SyncMutex; use tokio::sync::{oneshot, Mutex}; pub static SERVER_IP: &str = "127.0.0.1"; -pub static SERVER_PORT: Lazy> = Lazy::new(|| SyncMutex::new(3000)); +pub static SERVER_PORT: Lazy> = Lazy::new(|| SyncMutex::new(5555)); pub fn parse_port_number() -> u16 { let mut port = SERVER_PORT.lock().unwrap(); @@ -379,7 +379,7 @@ pub struct ExpanderExecArgs { pub poly_commit: String, /// The port number for the server to listen on. - #[arg(short, long, default_value = "3000")] + #[arg(short, long, default_value = "5555")] pub port_number: String, /// Whether to batch PCS opening in proving. From ca286e5d3ae193288ff9ba3c6c8969d50d08bd01 Mon Sep 17 00:00:00 2001 From: hczphn Date: Tue, 3 Feb 2026 16:48:44 +0000 Subject: [PATCH 08/10] fix fmt, warning --- expander_compiler/src/zkcuda/context.rs | 8 +- .../tests/test_bn254_new_data.rs | 75 ++++++++++--------- 2 files changed, 42 insertions(+), 41 deletions(-) diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index dbc55b9a..40081f72 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -577,7 +577,7 @@ impl>> Context { let dm_shapes = self.propagate_and_get_shapes(); - let (mut cg_kernels, cg_proof_templates, cg_commitments_lens) = if let Some(cg) = cg { + let (cg_kernels, cg_proof_templates, cg_commitments_lens) = if let Some(cg) = cg { for (i, kernel) in cg.kernels.iter().enumerate() { assert_eq!(self.kernels.add(kernel), i); } @@ -617,7 +617,7 @@ impl>> Context { .collect::>(); let kernel_primitive = self.kernel_primitives.get(kernel_call.kernel_id); let kernel = if cg_kernels.is_some() { - // 从已加载的 kernels 中通过 kernel_id 获取 + // Get kernel from loaded kernels by kernel_id self.kernels.get(kernel_call.kernel_id).clone() } else { let mut psi = Vec::new(); @@ -710,8 +710,8 @@ impl>> Context { } if let Some(_cg_kernels) = cg_kernels { - // 不再检查 cg_kernels 是否为空,因为我们不再消耗它 - // kernels 已经在之前通过 self.kernels.add() 添加了 + // No longer checking if cg_kernels is empty since we no longer consume it + // Kernels were already added earlier via self.kernels.add() assert_eq!(cg_proof_templates.unwrap(), self.proof_templates); assert_eq!(cg_commitments_lens.unwrap(), commitments_lens); Ok(None) diff --git a/expander_compiler/tests/test_bn254_new_data.rs b/expander_compiler/tests/test_bn254_new_data.rs index 39e54f96..2bccb60e 100644 --- a/expander_compiler/tests/test_bn254_new_data.rs +++ b/expander_compiler/tests/test_bn254_new_data.rs @@ -19,14 +19,13 @@ fn add_16_macro(api: &mut API, a: &[InputVariable; 16], b: &mut Ou } fn test_bn254_load_graph_with_new_data_impl>() { - let kernel_add_2: KernelPrimitive = compile_add_2_macro().unwrap(); let kernel_add_16: KernelPrimitive = compile_add_16_macro().unwrap(); - println!("\n===== 第一次执行:创建并保存图(BN254) ====="); + println!("\n===== First execution: create and save graph (BN254) ====="); let mut ctx1: Context = Context::default(); - // 第一组输入数据(BN254 field 元素) + // First set of input data (BN254 field elements) let mut a1: Vec>> = vec![]; for i in 0..16 { a1.push(vec![]); @@ -42,92 +41,94 @@ fn test_bn254_load_graph_with_new_data_impl>() { call_kernel!(ctx1, kernel_add_16, 1, b1, mut c1).unwrap(); let c1 = c1.reshape(&[]); let result1: CircuitField = ctx1.copy_to_host(c1); - println!("第一次结果: {:?}", result1); + println!("First result: {:?}", result1); assert_eq!(result1, CircuitField::::from(32 * 33 / 2 as u32)); let computation_graph = ctx1.compile_computation_graph().unwrap(); ctx1.solve_witness().unwrap(); - println!("开始 setup(可能需要一些时间)..."); + println!("Starting setup (may take some time)..."); let (prover_setup, verifier_setup) = P::setup(&computation_graph); - println!("开始 prove..."); + println!("Starting prove..."); let proof1 = P::prove( &prover_setup, &computation_graph, ctx1.export_device_memories(), ); - println!("开始 verify..."); + println!("Starting verify..."); assert!(P::verify(&verifier_setup, &computation_graph, &proof1)); - println!("第一次验证通过!"); + println!("First verification passed!"); - println!("\n===== 第二次执行:先 call_kernel(新的 BN254 数据),再 load_graph ====="); + println!("\n===== Second execution: call_kernel first (new BN254 data), then load_graph ====="); let mut ctx2: Context = Context::default(); - // 第二组输入数据(不同的 BN254 field 元素) + // Second set of input data (different BN254 field elements) let mut a2: Vec>> = vec![]; for i in 0..16 { a2.push(vec![]); for j in 0..2 { - // 使用不同的值:从 1000 开始 + // Use different values: starting from 1000 a2[i].push(CircuitField::::from((i * 2 + j + 1000) as u32)); } } let a2 = ctx2.copy_to_device(&a2); - // 先调用 kernels(和第一次一样的顺序) + // Call kernels first (same order as the first time) let mut b2: DeviceMemoryHandle = None; - println!("调用第一个 kernel(使用新数据)..."); + println!("Calling first kernel (using new data)..."); call_kernel!(ctx2, kernel_add_2, 16, a2, mut b2).unwrap(); let b2 = b2.reshape(&[1, 16]); let mut c2: DeviceMemoryHandle = None; - println!("调用第二个 kernel..."); + println!("Calling second kernel..."); call_kernel!(ctx2, kernel_add_16, 1, b2, mut c2).unwrap(); let c2 = c2.reshape(&[]); let result2: CircuitField = ctx2.copy_to_host(c2); - println!("第二次计算结果: {:?}", result2); + println!("Second computation result: {:?}", result2); - // 验证结果确实不同 - assert_ne!(result1, result2, "两次结果应该不同"); + // Verify results are indeed different + assert_ne!(result1, result2, "The two results should be different"); - // 第二次的预期结果: - // 输入: [1000,1001], [1002,1003], ..., [1030,1031] (共32个数) - // add_2: 2001, 2005, 2009, ..., 2061 (16个数) + // Expected result for the second run: + // Input: [1000,1001], [1002,1003], ..., [1030,1031] (32 numbers total) + // add_2: 2001, 2005, 2009, ..., 2061 (16 numbers) // add_16: sum(2001, 2005, ..., 2061) = 16 * (2001 + 2061) / 2 = 32496 let expected2 = CircuitField::::from(32496u32); - assert_eq!(result2, expected2, "第二次结果应该是 32496"); + assert_eq!(result2, expected2, "Second result should be 32496"); - // 现在加载图(复用编译好的 kernels) - println!("加载 computation_graph..."); - ctx2.load_computation_graph(computation_graph.clone()).unwrap(); - println!("图加载成功!"); + // Now load the graph (reuse compiled kernels) + println!("Loading computation_graph..."); + ctx2.load_computation_graph(computation_graph.clone()) + .unwrap(); + println!("Graph loaded successfully!"); - // solve_witness(会使用新数据重新计算) - println!("solve_witness(重新计算 witness)..."); + // solve_witness (will recalculate using new data) + println!("solve_witness (recalculating witness)..."); ctx2.solve_witness().unwrap(); - println!("solve_witness 成功!"); + println!("solve_witness succeeded!"); - // prove(使用新数据) - println!("prove(使用新数据生成证明)..."); + // prove (using new data) + println!("prove (generating proof with new data)..."); let proof2 = P::prove( &prover_setup, &computation_graph, ctx2.export_device_memories(), ); - println!("prove 成功!"); + println!("prove succeeded!"); // verify - println!("verify(验证新数据的证明)..."); + println!("verify (verifying proof with new data)..."); assert!(P::verify(&verifier_setup, &computation_graph, &proof2)); - println!("✓ 第二次验证通过!"); - println!("✓ 成功使用新的 BN254 数据生成并验证了不同的证明"); - println!(" - 第一次结果: {:?}", result1); - println!(" - 第二次结果: {:?}", result2); + println!("✓ Second verification passed!"); + println!("✓ Successfully generated and verified different proofs using new BN254 data"); + println!(" - First result: {:?}", result1); + println!(" - Second result: {:?}", result2); P::post_process(); } #[test] fn test_bn254_load_graph_with_new_data() { - test_bn254_load_graph_with_new_data_impl::<_, ExpanderNoOverSubscribe>(); + test_bn254_load_graph_with_new_data_impl::<_, ExpanderNoOverSubscribe>( + ); } From 736a6e2548409099bd92d8b70d063006e06d2918 Mon Sep 17 00:00:00 2001 From: hczphn Date: Tue, 3 Feb 2026 17:21:47 +0000 Subject: [PATCH 09/10] fix format --- circuit-std-rs/src/logup.rs | 12 ++++- .../proving_system/expander/prove_impl.rs | 2 +- .../api_no_oversubscribe.rs | 7 +-- .../expander_no_oversubscribe/prove_impl.rs | 49 +++++++++++------ .../expander_parallelized/client_utils.rs | 14 ----- .../expander_parallelized/cmd_utils.rs | 2 - .../expander_parallelized/prove_impl.rs | 4 +- .../expander_parallelized/server_ctrl.rs | 8 +-- .../shared_memory_utils.rs | 52 ------------------- 9 files changed, 52 insertions(+), 98 deletions(-) diff --git a/circuit-std-rs/src/logup.rs b/circuit-std-rs/src/logup.rs index b14cbd9d..24f04ee2 100644 --- a/circuit-std-rs/src/logup.rs +++ b/circuit-std-rs/src/logup.rs @@ -328,7 +328,11 @@ impl LogUpSingleKeyTable { assert_eq_rational(builder, &v_table, &v_query); } - pub fn final_check_with_query_count>(&mut self, builder: &mut B, query_count: &[Variable]) { + pub fn final_check_with_query_count>( + &mut self, + builder: &mut B, + query_count: &[Variable], + ) { if self.table.is_empty() || self.query_keys.is_empty() { panic!("empty table or empty query"); } @@ -487,7 +491,11 @@ impl LogUpRangeProofTable { assert_eq_rational(builder, &v_table, &v_query); } - pub fn final_check_with_query_count>(&mut self, builder: &mut B, query_count: &[Variable]) { + pub fn final_check_with_query_count>( + &mut self, + builder: &mut B, + query_count: &[Variable], + ) { let alpha = builder.get_random_value(); let inputs = self.query_keys.clone(); diff --git a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs index c83ced65..cc2afad4 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs @@ -30,7 +30,7 @@ where { let mut expander_circuit = kernel.layered_circuit().export_to_expander().flatten(); expander_circuit.pre_process_gkr(); - + let (max_num_input_var, max_num_output_var) = super::utils::max_n_vars(&expander_circuit); let prover_scratch = ProverScratchPad::::new(max_num_input_var, max_num_output_var, mpi_world_size); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs index 6280fb4b..6a559fa1 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs @@ -6,7 +6,7 @@ use crate::zkcuda::proving_system::expander::structs::{ }; use crate::zkcuda::proving_system::expander_parallelized::client_utils::{ client_launch_server_and_setup, client_parse_args, client_send_witness_and_prove, wait_async, - ClientHttpHelper, client_send_witness_and_prove_nowait + ClientHttpHelper, }; use crate::zkcuda::proving_system::{ CombinedProof, ExpanderPCSDefered, ParallelizedExpander, ProvingSystem, @@ -81,10 +81,7 @@ where { /// Lightweight prove that doesn't require computation_graph or prover_setup. /// Use this after setup() to allow releasing those large data structures before proving. - pub fn prove_lightweight( - device_memories: Vec>>, - ) { + pub fn prove_lightweight(device_memories: Vec>>) { client_send_witness_and_prove::(device_memories); - // client_send_witness_and_prove_nowait::(device_memories); } } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index 10f80de8..cd5c894b 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -416,7 +416,6 @@ where FieldEngine, T: Transcript, { - let input_vals_multi_copies = local_commitment_values_multi_copies .iter() .map(|local_commitment_values| { @@ -714,9 +713,16 @@ fn wait_for_dependencies(task_name: &str, dependencies: &[String], my_rank: usiz // Only log if actually waited if !waited_deps.is_empty() && my_rank % 8 == 0 { - eprintln!("[RANK {}] Task {} waited for {} deps (longest: {:.1}s)", - my_rank, task_name, waited_deps.len(), - waited_deps.iter().map(|(_, t)| t).fold(0.0f64, |a, &b| a.max(b))); + eprintln!( + "[RANK {}] Task {} waited for {} deps (longest: {:.1}s)", + my_rank, + task_name, + waited_deps.len(), + waited_deps + .iter() + .map(|(_, t)| t) + .fold(0.0f64, |a, &b| a.max(b)) + ); } } @@ -800,8 +806,11 @@ where if global_mpi_config.is_root() { eprintln!("========== SCHEDULER MODE =========="); - eprintln!(" Schedule: {} ranks, max {} steps", - schedule.rank_tasks.len(), schedule.max_steps()); + eprintln!( + " Schedule: {} ranks, max {} steps", + schedule.rank_tasks.len(), + schedule.max_steps() + ); } // Safety checks @@ -855,7 +864,10 @@ where deps } Err(e) => { - eprintln!("[RANK {}] ERROR: Failed to load dependencies: {}", my_rank, e); + eprintln!( + "[RANK {}] ERROR: Failed to load dependencies: {}", + my_rank, e + ); HashMap::new() } } @@ -882,7 +894,10 @@ where let mut task_mpi_configs: HashMap>> = HashMap::new(); if global_mpi_config.is_root() { - eprintln!(" Pre-creating MPI subgroups for {} tasks...", all_unique_tasks.len()); + eprintln!( + " Pre-creating MPI subgroups for {} tasks...", + all_unique_tasks.len() + ); } for task_name in &all_unique_tasks { @@ -1044,7 +1059,6 @@ where ); let gkr_end_state = if let Some(ref local_config) = local_mpi_config { - prove_kernel_gkr_no_oversubscribe::, GetTranscript, ZC::ECCConfig>( local_config, &computation_graph.kernels()[template.kernel_id()], @@ -1202,7 +1216,6 @@ where // Step 1: Non-root subgroup roots send their results to rank 0 if i_am_subgroup_root && my_rank != 0 { - // Serialize the indexed structures (maintains template order) let mut vals_bytes = Vec::new(); vals_per_template.serialize_into(&mut vals_bytes).unwrap(); @@ -1325,7 +1338,6 @@ where all_proofs[template_idx] = received_all_proofs[template_idx].clone(); } } - } // Build final vals_ref and challenges in template order @@ -1342,12 +1354,19 @@ where } let completed_templates = all_proofs.iter().filter(|p| p.is_some()).count(); - eprintln!("Result collection: {}/{} templates, {} vals, {} challenges", - completed_templates, num_templates, - vals_ref_owned.len(), challenges_final.len()); + eprintln!( + "Result collection: {}/{} templates, {} vals, {} challenges", + completed_templates, + num_templates, + vals_ref_owned.len(), + challenges_final.len() + ); if completed_templates < num_templates { - eprintln!("⚠️ WARNING: Only {}/{} templates completed!", completed_templates, num_templates); + eprintln!( + "⚠️ WARNING: Only {}/{} templates completed!", + completed_templates, num_templates + ); for (idx, val) in vals_per_template.iter().enumerate() { if val.is_none() { eprintln!(" Missing: Template {}", idx); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs index 432386fd..c67517f2 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs @@ -112,7 +112,6 @@ where let mpi_size = if allow_oversubscribe { max_parallel_count } else { - // 支持通过环境变量 ZKML_NUM_CPUS 覆盖 CPU 数量(用于 Docker 等环境) let num_cpus = std::env::var("ZKML_NUM_CPUS") .ok() .and_then(|s| s.parse().ok()) @@ -167,19 +166,6 @@ where proof } - -pub fn client_send_witness_and_prove_nowait( - device_memories: Vec>>, -) -where - C: GKREngine, - ECCConfig: Config, -{ - let timer = Timer::new("prove", true); - - SharedMemoryEngine::write_witness_to_shared_memory::(device_memories); - ClientHttpHelper::request_prove(); -} /// Run an async function in a blocking context. #[inline(always)] pub fn wait_async(f: F) -> T diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs index 0f4a4d8a..ebaef2b3 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs @@ -24,10 +24,8 @@ pub fn start_server( fn parse_config(mpi_size: usize) -> (String, String, String, String) where { - // 支持通过环境变量强制启用 oversubscribe(用于 Docker 等 CPU ID 不连续的环境) let force_oversubscribe = std::env::var("ZKML_FORCE_OVERSUBSCRIBE").is_ok(); - // 支持通过环境变量 ZKML_NUM_CPUS 覆盖 CPU 数量 let num_cpus = std::env::var("ZKML_NUM_CPUS") .ok() .and_then(|s| s.parse().ok()) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs index aaf3f5be..5605daf0 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs @@ -4,8 +4,6 @@ use gkr_engine::{ ExpanderDualVarChallenge, ExpanderSingleVarChallenge, FieldEngine, GKREngine, MPIConfig, MPIEngine, Transcript, }; -use std::collections::HashMap; -use std::fs; use crate::{ frontend::{Config, SIMDField}, @@ -222,4 +220,4 @@ pub fn partition_single_gkr_claim_and_open_pcs_mpi( transcript, ); } -} \ No newline at end of file +} diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs index 14a04035..27919a50 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs @@ -28,7 +28,7 @@ use std::sync::Mutex as SyncMutex; use tokio::sync::{oneshot, Mutex}; pub static SERVER_IP: &str = "127.0.0.1"; -pub static SERVER_PORT: Lazy> = Lazy::new(|| SyncMutex::new(5555)); +pub static SERVER_PORT: Lazy> = Lazy::new(|| SyncMutex::new(3000)); pub fn parse_port_number() -> u16 { let mut port = SERVER_PORT.lock().unwrap(); @@ -298,12 +298,12 @@ where if global_mpi_config.is_root() { let (tx, rx) = oneshot::channel::<()>(); state.shutdown_tx.lock().await.replace(tx); - + let app = Router::new() .route("/", post(root_main::)) .route("/", get(|| async { "Expander Server is running" })) .with_state(state.clone()); - + let ip: IpAddr = SERVER_IP.parse().expect("Invalid SERVER_IP"); let port_val = port_number.parse::().unwrap_or_else(|e| { eprintln!("Error: Invalid port number '{port_number}'. {e}."); @@ -379,7 +379,7 @@ pub struct ExpanderExecArgs { pub poly_commit: String, /// The port number for the server to listen on. - #[arg(short, long, default_value = "5555")] + #[arg(short, long, default_value = "3000")] pub port_number: String, /// Whether to batch PCS opening in proving. diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs index 75791bee..648f33a8 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs @@ -9,57 +9,6 @@ use shared_memory::{Shmem, ShmemConf}; use crate::circuit::config::Config; -/// 获取所有 expander_server 进程的内存占用(单位:MB) -/// 返回 (VmRSS物理内存, VmSize虚拟内存) -fn get_total_expander_memory_mb() -> (usize, usize) { - use std::fs; - use std::io::{BufRead, BufReader}; - - let mut total_rss_kb = 0usize; - let mut total_vmsize_kb = 0usize; - - // 遍历 /proc 目录 - if let Ok(entries) = fs::read_dir("/proc") { - for entry in entries.flatten() { - if let Ok(file_name) = entry.file_name().into_string() { - // 只处理数字目录(进程PID) - if file_name.chars().all(|c| c.is_ascii_digit()) { - // 读取 /proc/[pid]/comm 检查进程名 - let comm_path = format!("/proc/{}/comm", file_name); - if let Ok(comm) = fs::read_to_string(&comm_path) { - if comm.trim() == "expander_server" { - // 读取 /proc/[pid]/status 获取内存信息 - let status_path = format!("/proc/{}/status", file_name); - if let Ok(file) = fs::File::open(&status_path) { - let reader = BufReader::new(file); - for line in reader.lines().flatten() { - if line.starts_with("VmRSS:") { - // VmRSS: 12345 kB (物理内存) - if let Some(rss_str) = line.split_whitespace().nth(1) { - if let Ok(rss_kb) = rss_str.parse::() { - total_rss_kb += rss_kb; - } - } - } else if line.starts_with("VmSize:") { - // VmSize: 12345 kB (虚拟内存) - if let Some(size_str) = line.split_whitespace().nth(1) { - if let Ok(size_kb) = size_str.parse::() { - total_vmsize_kb += size_kb; - } - } - } - } - } - } - } - } - } - } - } - - (total_rss_kb / 1024, total_vmsize_kb / 1024) // 转换为MB -} - use crate::zkcuda::proving_system::expander::structs::{ ExpanderProverSetup, ExpanderVerifierSetup, }; @@ -247,7 +196,6 @@ impl SharedMemoryEngine { global_mpi_config.barrier(); - let n_witness = usize::new_from_memory(&mut mpi_shared_mem_ptr); let witness = (0..n_witness) .map(|_| Vec::::new_from_memory(&mut mpi_shared_mem_ptr)) From 2a68eb101c667e15d52d62304add658ce8b328a6 Mon Sep 17 00:00:00 2001 From: hczphn Date: Tue, 3 Feb 2026 17:29:28 +0000 Subject: [PATCH 10/10] remove unuse variable --- expander_compiler/src/zkcuda/context.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index 40081f72..1bec7e9f 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -709,7 +709,7 @@ impl>> Context { }); } - if let Some(_cg_kernels) = cg_kernels { + if cg_kernels.is_some() { // No longer checking if cg_kernels is empty since we no longer consume it // Kernels were already added earlier via self.kernels.add() assert_eq!(cg_proof_templates.unwrap(), self.proof_templates);