From 18818f3284584f35ce201a2c6896fc7a06123005 Mon Sep 17 00:00:00 2001 From: hczphn Date: Sun, 18 Jan 2026 05:32:19 +0000 Subject: [PATCH 1/2] fix bugs to reuse graph --- expander_compiler/src/zkcuda/context.rs | 10 +- .../tests/test_bn254_new_data.rs | 133 ++++++++++++++++++ 2 files changed, 139 insertions(+), 4 deletions(-) create mode 100644 expander_compiler/tests/test_bn254_new_data.rs diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index a95555a4..dbc55b9a 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -616,8 +616,9 @@ impl>> Context { .map(get_pad_shape) .collect::>(); let kernel_primitive = self.kernel_primitives.get(kernel_call.kernel_id); - let kernel = if let Some(cg_kernels) = cg_kernels.as_mut() { - cg_kernels.drain(..1).next().unwrap() + let kernel = if cg_kernels.is_some() { + // 从已加载的 kernels 中通过 kernel_id 获取 + self.kernels.get(kernel_call.kernel_id).clone() } else { let mut psi = Vec::new(); for (s, &ib) in pad_shapes_input.iter().zip(kernel_call.is_broadcast.iter()) { @@ -708,8 +709,9 @@ impl>> Context { }); } - if let Some(cg_kernels) = cg_kernels { - assert!(cg_kernels.is_empty()); + if let Some(_cg_kernels) = cg_kernels { + // 不再检查 cg_kernels 是否为空,因为我们不再消耗它 + // kernels 已经在之前通过 self.kernels.add() 添加了 assert_eq!(cg_proof_templates.unwrap(), self.proof_templates); assert_eq!(cg_commitments_lens.unwrap(), commitments_lens); Ok(None) diff --git a/expander_compiler/tests/test_bn254_new_data.rs b/expander_compiler/tests/test_bn254_new_data.rs new file mode 100644 index 00000000..39e54f96 --- /dev/null +++ b/expander_compiler/tests/test_bn254_new_data.rs @@ -0,0 +1,133 @@ +use expander_compiler::frontend::*; +use expander_compiler::zkcuda::proving_system::expander::config::ZKCudaBN254KZGBatchPCS; +use expander_compiler::zkcuda::proving_system::{ExpanderNoOverSubscribe, ProvingSystem}; +use expander_compiler::zkcuda::shape::Reshape; +use expander_compiler::zkcuda::{context::*, kernel::*}; + +#[kernel] +fn add_2_macro(api: &mut API, a: &[InputVariable; 2], b: &mut OutputVariable) { + *b = api.add(a[0], a[1]); +} + +#[kernel] +fn add_16_macro(api: &mut API, a: &[InputVariable; 16], b: &mut OutputVariable) { + let mut sum = api.constant(0); + for i in 0..16 { + sum = api.add(sum, a[i]); + } + *b = sum; +} + +fn test_bn254_load_graph_with_new_data_impl>() { + + let kernel_add_2: KernelPrimitive = compile_add_2_macro().unwrap(); + let kernel_add_16: KernelPrimitive = compile_add_16_macro().unwrap(); + + println!("\n===== 第一次执行:创建并保存图(BN254) ====="); + let mut ctx1: Context = Context::default(); + + // 第一组输入数据(BN254 field 元素) + let mut a1: Vec>> = vec![]; + for i in 0..16 { + a1.push(vec![]); + for j in 0..2 { + a1[i].push(CircuitField::::from((i * 2 + j + 1) as u32)); + } + } + let a1 = ctx1.copy_to_device(&a1); + let mut b1: DeviceMemoryHandle = None; + call_kernel!(ctx1, kernel_add_2, 16, a1, mut b1).unwrap(); + let b1 = b1.reshape(&[1, 16]); + let mut c1: DeviceMemoryHandle = None; + call_kernel!(ctx1, kernel_add_16, 1, b1, mut c1).unwrap(); + let c1 = c1.reshape(&[]); + let result1: CircuitField = ctx1.copy_to_host(c1); + println!("第一次结果: {:?}", result1); + assert_eq!(result1, CircuitField::::from(32 * 33 / 2 as u32)); + + let computation_graph = ctx1.compile_computation_graph().unwrap(); + ctx1.solve_witness().unwrap(); + println!("开始 setup(可能需要一些时间)..."); + let (prover_setup, verifier_setup) = P::setup(&computation_graph); + println!("开始 prove..."); + let proof1 = P::prove( + &prover_setup, + &computation_graph, + ctx1.export_device_memories(), + ); + println!("开始 verify..."); + assert!(P::verify(&verifier_setup, &computation_graph, &proof1)); + println!("第一次验证通过!"); + + println!("\n===== 第二次执行:先 call_kernel(新的 BN254 数据),再 load_graph ====="); + let mut ctx2: Context = Context::default(); + + // 第二组输入数据(不同的 BN254 field 元素) + let mut a2: Vec>> = vec![]; + for i in 0..16 { + a2.push(vec![]); + for j in 0..2 { + // 使用不同的值:从 1000 开始 + a2[i].push(CircuitField::::from((i * 2 + j + 1000) as u32)); + } + } + let a2 = ctx2.copy_to_device(&a2); + + // 先调用 kernels(和第一次一样的顺序) + let mut b2: DeviceMemoryHandle = None; + println!("调用第一个 kernel(使用新数据)..."); + call_kernel!(ctx2, kernel_add_2, 16, a2, mut b2).unwrap(); + + let b2 = b2.reshape(&[1, 16]); + let mut c2: DeviceMemoryHandle = None; + println!("调用第二个 kernel..."); + call_kernel!(ctx2, kernel_add_16, 1, b2, mut c2).unwrap(); + + let c2 = c2.reshape(&[]); + let result2: CircuitField = ctx2.copy_to_host(c2); + println!("第二次计算结果: {:?}", result2); + + // 验证结果确实不同 + assert_ne!(result1, result2, "两次结果应该不同"); + + // 第二次的预期结果: + // 输入: [1000,1001], [1002,1003], ..., [1030,1031] (共32个数) + // add_2: 2001, 2005, 2009, ..., 2061 (16个数) + // add_16: sum(2001, 2005, ..., 2061) = 16 * (2001 + 2061) / 2 = 32496 + let expected2 = CircuitField::::from(32496u32); + assert_eq!(result2, expected2, "第二次结果应该是 32496"); + + // 现在加载图(复用编译好的 kernels) + println!("加载 computation_graph..."); + ctx2.load_computation_graph(computation_graph.clone()).unwrap(); + println!("图加载成功!"); + + // solve_witness(会使用新数据重新计算) + println!("solve_witness(重新计算 witness)..."); + ctx2.solve_witness().unwrap(); + println!("solve_witness 成功!"); + + // prove(使用新数据) + println!("prove(使用新数据生成证明)..."); + let proof2 = P::prove( + &prover_setup, + &computation_graph, + ctx2.export_device_memories(), + ); + println!("prove 成功!"); + + // verify + println!("verify(验证新数据的证明)..."); + assert!(P::verify(&verifier_setup, &computation_graph, &proof2)); + println!("✓ 第二次验证通过!"); + println!("✓ 成功使用新的 BN254 数据生成并验证了不同的证明"); + println!(" - 第一次结果: {:?}", result1); + println!(" - 第二次结果: {:?}", result2); + + P::post_process(); +} + +#[test] +fn test_bn254_load_graph_with_new_data() { + test_bn254_load_graph_with_new_data_impl::<_, ExpanderNoOverSubscribe>(); +} From ca286e5d3ae193288ff9ba3c6c8969d50d08bd01 Mon Sep 17 00:00:00 2001 From: hczphn Date: Tue, 3 Feb 2026 16:48:44 +0000 Subject: [PATCH 2/2] fix fmt, warning --- expander_compiler/src/zkcuda/context.rs | 8 +- .../tests/test_bn254_new_data.rs | 75 ++++++++++--------- 2 files changed, 42 insertions(+), 41 deletions(-) diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index dbc55b9a..40081f72 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -577,7 +577,7 @@ impl>> Context { let dm_shapes = self.propagate_and_get_shapes(); - let (mut cg_kernels, cg_proof_templates, cg_commitments_lens) = if let Some(cg) = cg { + let (cg_kernels, cg_proof_templates, cg_commitments_lens) = if let Some(cg) = cg { for (i, kernel) in cg.kernels.iter().enumerate() { assert_eq!(self.kernels.add(kernel), i); } @@ -617,7 +617,7 @@ impl>> Context { .collect::>(); let kernel_primitive = self.kernel_primitives.get(kernel_call.kernel_id); let kernel = if cg_kernels.is_some() { - // 从已加载的 kernels 中通过 kernel_id 获取 + // Get kernel from loaded kernels by kernel_id self.kernels.get(kernel_call.kernel_id).clone() } else { let mut psi = Vec::new(); @@ -710,8 +710,8 @@ impl>> Context { } if let Some(_cg_kernels) = cg_kernels { - // 不再检查 cg_kernels 是否为空,因为我们不再消耗它 - // kernels 已经在之前通过 self.kernels.add() 添加了 + // No longer checking if cg_kernels is empty since we no longer consume it + // Kernels were already added earlier via self.kernels.add() assert_eq!(cg_proof_templates.unwrap(), self.proof_templates); assert_eq!(cg_commitments_lens.unwrap(), commitments_lens); Ok(None) diff --git a/expander_compiler/tests/test_bn254_new_data.rs b/expander_compiler/tests/test_bn254_new_data.rs index 39e54f96..2bccb60e 100644 --- a/expander_compiler/tests/test_bn254_new_data.rs +++ b/expander_compiler/tests/test_bn254_new_data.rs @@ -19,14 +19,13 @@ fn add_16_macro(api: &mut API, a: &[InputVariable; 16], b: &mut Ou } fn test_bn254_load_graph_with_new_data_impl>() { - let kernel_add_2: KernelPrimitive = compile_add_2_macro().unwrap(); let kernel_add_16: KernelPrimitive = compile_add_16_macro().unwrap(); - println!("\n===== 第一次执行:创建并保存图(BN254) ====="); + println!("\n===== First execution: create and save graph (BN254) ====="); let mut ctx1: Context = Context::default(); - // 第一组输入数据(BN254 field 元素) + // First set of input data (BN254 field elements) let mut a1: Vec>> = vec![]; for i in 0..16 { a1.push(vec![]); @@ -42,92 +41,94 @@ fn test_bn254_load_graph_with_new_data_impl>() { call_kernel!(ctx1, kernel_add_16, 1, b1, mut c1).unwrap(); let c1 = c1.reshape(&[]); let result1: CircuitField = ctx1.copy_to_host(c1); - println!("第一次结果: {:?}", result1); + println!("First result: {:?}", result1); assert_eq!(result1, CircuitField::::from(32 * 33 / 2 as u32)); let computation_graph = ctx1.compile_computation_graph().unwrap(); ctx1.solve_witness().unwrap(); - println!("开始 setup(可能需要一些时间)..."); + println!("Starting setup (may take some time)..."); let (prover_setup, verifier_setup) = P::setup(&computation_graph); - println!("开始 prove..."); + println!("Starting prove..."); let proof1 = P::prove( &prover_setup, &computation_graph, ctx1.export_device_memories(), ); - println!("开始 verify..."); + println!("Starting verify..."); assert!(P::verify(&verifier_setup, &computation_graph, &proof1)); - println!("第一次验证通过!"); + println!("First verification passed!"); - println!("\n===== 第二次执行:先 call_kernel(新的 BN254 数据),再 load_graph ====="); + println!("\n===== Second execution: call_kernel first (new BN254 data), then load_graph ====="); let mut ctx2: Context = Context::default(); - // 第二组输入数据(不同的 BN254 field 元素) + // Second set of input data (different BN254 field elements) let mut a2: Vec>> = vec![]; for i in 0..16 { a2.push(vec![]); for j in 0..2 { - // 使用不同的值:从 1000 开始 + // Use different values: starting from 1000 a2[i].push(CircuitField::::from((i * 2 + j + 1000) as u32)); } } let a2 = ctx2.copy_to_device(&a2); - // 先调用 kernels(和第一次一样的顺序) + // Call kernels first (same order as the first time) let mut b2: DeviceMemoryHandle = None; - println!("调用第一个 kernel(使用新数据)..."); + println!("Calling first kernel (using new data)..."); call_kernel!(ctx2, kernel_add_2, 16, a2, mut b2).unwrap(); let b2 = b2.reshape(&[1, 16]); let mut c2: DeviceMemoryHandle = None; - println!("调用第二个 kernel..."); + println!("Calling second kernel..."); call_kernel!(ctx2, kernel_add_16, 1, b2, mut c2).unwrap(); let c2 = c2.reshape(&[]); let result2: CircuitField = ctx2.copy_to_host(c2); - println!("第二次计算结果: {:?}", result2); + println!("Second computation result: {:?}", result2); - // 验证结果确实不同 - assert_ne!(result1, result2, "两次结果应该不同"); + // Verify results are indeed different + assert_ne!(result1, result2, "The two results should be different"); - // 第二次的预期结果: - // 输入: [1000,1001], [1002,1003], ..., [1030,1031] (共32个数) - // add_2: 2001, 2005, 2009, ..., 2061 (16个数) + // Expected result for the second run: + // Input: [1000,1001], [1002,1003], ..., [1030,1031] (32 numbers total) + // add_2: 2001, 2005, 2009, ..., 2061 (16 numbers) // add_16: sum(2001, 2005, ..., 2061) = 16 * (2001 + 2061) / 2 = 32496 let expected2 = CircuitField::::from(32496u32); - assert_eq!(result2, expected2, "第二次结果应该是 32496"); + assert_eq!(result2, expected2, "Second result should be 32496"); - // 现在加载图(复用编译好的 kernels) - println!("加载 computation_graph..."); - ctx2.load_computation_graph(computation_graph.clone()).unwrap(); - println!("图加载成功!"); + // Now load the graph (reuse compiled kernels) + println!("Loading computation_graph..."); + ctx2.load_computation_graph(computation_graph.clone()) + .unwrap(); + println!("Graph loaded successfully!"); - // solve_witness(会使用新数据重新计算) - println!("solve_witness(重新计算 witness)..."); + // solve_witness (will recalculate using new data) + println!("solve_witness (recalculating witness)..."); ctx2.solve_witness().unwrap(); - println!("solve_witness 成功!"); + println!("solve_witness succeeded!"); - // prove(使用新数据) - println!("prove(使用新数据生成证明)..."); + // prove (using new data) + println!("prove (generating proof with new data)..."); let proof2 = P::prove( &prover_setup, &computation_graph, ctx2.export_device_memories(), ); - println!("prove 成功!"); + println!("prove succeeded!"); // verify - println!("verify(验证新数据的证明)..."); + println!("verify (verifying proof with new data)..."); assert!(P::verify(&verifier_setup, &computation_graph, &proof2)); - println!("✓ 第二次验证通过!"); - println!("✓ 成功使用新的 BN254 数据生成并验证了不同的证明"); - println!(" - 第一次结果: {:?}", result1); - println!(" - 第二次结果: {:?}", result2); + println!("✓ Second verification passed!"); + println!("✓ Successfully generated and verified different proofs using new BN254 data"); + println!(" - First result: {:?}", result1); + println!(" - Second result: {:?}", result2); P::post_process(); } #[test] fn test_bn254_load_graph_with_new_data() { - test_bn254_load_graph_with_new_data_impl::<_, ExpanderNoOverSubscribe>(); + test_bn254_load_graph_with_new_data_impl::<_, ExpanderNoOverSubscribe>( + ); }