-
Notifications
You must be signed in to change notification settings - Fork 25
Fix computation graph reuse issue by preventing kernel eviction #198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -577,7 +577,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |
|
|
||
| let dm_shapes = self.propagate_and_get_shapes(); | ||
|
|
||
| let (mut cg_kernels, cg_proof_templates, cg_commitments_lens) = if let Some(cg) = cg { | ||
| let (cg_kernels, cg_proof_templates, cg_commitments_lens) = if let Some(cg) = cg { | ||
| for (i, kernel) in cg.kernels.iter().enumerate() { | ||
| assert_eq!(self.kernels.add(kernel), i); | ||
| } | ||
|
|
@@ -616,8 +616,9 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |
| .map(get_pad_shape) | ||
| .collect::<Vec<_>>(); | ||
| let kernel_primitive = self.kernel_primitives.get(kernel_call.kernel_id); | ||
| let kernel = if let Some(cg_kernels) = cg_kernels.as_mut() { | ||
| cg_kernels.drain(..1).next().unwrap() | ||
| let kernel = if cg_kernels.is_some() { | ||
| // Get kernel from loaded kernels by kernel_id | ||
| self.kernels.get(kernel_call.kernel_id).clone() | ||
| } else { | ||
|
Comment on lines
619
to
622
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While this change correctly fixes the bug by not consuming the kernel, the use of Consider refactoring to avoid this clone. One approach is to use Example using use std::borrow::Cow;
// ...
let kernel: Cow<Kernel<C>> = if cg_kernels.is_some() {
Cow::Borrowed(self.kernels.get(kernel_call.kernel_id))
} else {
// ... logic to compile primitive ...
Cow::Owned(compile_primitive(kernel_primitive, &psi, &pso)?)
};
// ... later
let kernel_id = self.kernels.add(&*kernel); |
||
| let mut psi = Vec::new(); | ||
| for (s, &ib) in pad_shapes_input.iter().zip(kernel_call.is_broadcast.iter()) { | ||
|
|
@@ -708,8 +709,9 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> { | |
| }); | ||
| } | ||
|
|
||
| if let Some(cg_kernels) = cg_kernels { | ||
| assert!(cg_kernels.is_empty()); | ||
| if let Some(_cg_kernels) = cg_kernels { | ||
| // No longer checking if cg_kernels is empty since we no longer consume it | ||
| // Kernels were already added earlier via self.kernels.add() | ||
| assert_eq!(cg_proof_templates.unwrap(), self.proof_templates); | ||
| assert_eq!(cg_commitments_lens.unwrap(), commitments_lens); | ||
| Ok(None) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| use expander_compiler::frontend::*; | ||
| use expander_compiler::zkcuda::proving_system::expander::config::ZKCudaBN254KZGBatchPCS; | ||
| use expander_compiler::zkcuda::proving_system::{ExpanderNoOverSubscribe, ProvingSystem}; | ||
| use expander_compiler::zkcuda::shape::Reshape; | ||
| use expander_compiler::zkcuda::{context::*, kernel::*}; | ||
|
|
||
| #[kernel] | ||
| fn add_2_macro<C: Config>(api: &mut API<C>, a: &[InputVariable; 2], b: &mut OutputVariable) { | ||
| *b = api.add(a[0], a[1]); | ||
| } | ||
|
|
||
| #[kernel] | ||
| fn add_16_macro<C: Config>(api: &mut API<C>, a: &[InputVariable; 16], b: &mut OutputVariable) { | ||
| let mut sum = api.constant(0); | ||
| for i in 0..16 { | ||
| sum = api.add(sum, a[i]); | ||
| } | ||
| *b = sum; | ||
| } | ||
|
|
||
| fn test_bn254_load_graph_with_new_data_impl<C: Config, P: ProvingSystem<C>>() { | ||
| let kernel_add_2: KernelPrimitive<C> = compile_add_2_macro().unwrap(); | ||
| let kernel_add_16: KernelPrimitive<C> = compile_add_16_macro().unwrap(); | ||
|
|
||
| println!("\n===== First execution: create and save graph (BN254) ====="); | ||
| let mut ctx1: Context<C> = Context::default(); | ||
|
|
||
| // First set of input data (BN254 field elements) | ||
| let mut a1: Vec<Vec<CircuitField<C>>> = vec![]; | ||
| for i in 0..16 { | ||
| a1.push(vec![]); | ||
| for j in 0..2 { | ||
| a1[i].push(CircuitField::<C>::from((i * 2 + j + 1) as u32)); | ||
| } | ||
| } | ||
| let a1 = ctx1.copy_to_device(&a1); | ||
| let mut b1: DeviceMemoryHandle = None; | ||
| call_kernel!(ctx1, kernel_add_2, 16, a1, mut b1).unwrap(); | ||
| let b1 = b1.reshape(&[1, 16]); | ||
| let mut c1: DeviceMemoryHandle = None; | ||
| call_kernel!(ctx1, kernel_add_16, 1, b1, mut c1).unwrap(); | ||
| let c1 = c1.reshape(&[]); | ||
| let result1: CircuitField<C> = ctx1.copy_to_host(c1); | ||
| println!("First result: {:?}", result1); | ||
| assert_eq!(result1, CircuitField::<C>::from(32 * 33 / 2 as u32)); | ||
|
|
||
| let computation_graph = ctx1.compile_computation_graph().unwrap(); | ||
| ctx1.solve_witness().unwrap(); | ||
| println!("Starting setup (may take some time)..."); | ||
| let (prover_setup, verifier_setup) = P::setup(&computation_graph); | ||
| println!("Starting prove..."); | ||
| let proof1 = P::prove( | ||
| &prover_setup, | ||
| &computation_graph, | ||
| ctx1.export_device_memories(), | ||
| ); | ||
| println!("Starting verify..."); | ||
| assert!(P::verify(&verifier_setup, &computation_graph, &proof1)); | ||
| println!("First verification passed!"); | ||
|
|
||
| println!("\n===== Second execution: call_kernel first (new BN254 data), then load_graph ====="); | ||
| let mut ctx2: Context<C> = Context::default(); | ||
|
|
||
| // Second set of input data (different BN254 field elements) | ||
| let mut a2: Vec<Vec<CircuitField<C>>> = vec![]; | ||
| for i in 0..16 { | ||
| a2.push(vec![]); | ||
| for j in 0..2 { | ||
| // Use different values: starting from 1000 | ||
| a2[i].push(CircuitField::<C>::from((i * 2 + j + 1000) as u32)); | ||
| } | ||
| } | ||
| let a2 = ctx2.copy_to_device(&a2); | ||
|
|
||
| // Call kernels first (same order as the first time) | ||
| let mut b2: DeviceMemoryHandle = None; | ||
| println!("Calling first kernel (using new data)..."); | ||
| call_kernel!(ctx2, kernel_add_2, 16, a2, mut b2).unwrap(); | ||
|
|
||
| let b2 = b2.reshape(&[1, 16]); | ||
| let mut c2: DeviceMemoryHandle = None; | ||
| println!("Calling second kernel..."); | ||
| call_kernel!(ctx2, kernel_add_16, 1, b2, mut c2).unwrap(); | ||
|
|
||
| let c2 = c2.reshape(&[]); | ||
| let result2: CircuitField<C> = ctx2.copy_to_host(c2); | ||
| println!("Second computation result: {:?}", result2); | ||
|
|
||
| // Verify results are indeed different | ||
| assert_ne!(result1, result2, "The two results should be different"); | ||
|
|
||
| // Expected result for the second run: | ||
| // Input: [1000,1001], [1002,1003], ..., [1030,1031] (32 numbers total) | ||
| // add_2: 2001, 2005, 2009, ..., 2061 (16 numbers) | ||
| // add_16: sum(2001, 2005, ..., 2061) = 16 * (2001 + 2061) / 2 = 32496 | ||
| let expected2 = CircuitField::<C>::from(32496u32); | ||
| assert_eq!(result2, expected2, "Second result should be 32496"); | ||
|
|
||
| // Now load the graph (reuse compiled kernels) | ||
| println!("Loading computation_graph..."); | ||
| ctx2.load_computation_graph(computation_graph.clone()) | ||
| .unwrap(); | ||
| println!("Graph loaded successfully!"); | ||
|
|
||
| // solve_witness (will recalculate using new data) | ||
| println!("solve_witness (recalculating witness)..."); | ||
| ctx2.solve_witness().unwrap(); | ||
| println!("solve_witness succeeded!"); | ||
|
|
||
| // prove (using new data) | ||
| println!("prove (generating proof with new data)..."); | ||
| let proof2 = P::prove( | ||
| &prover_setup, | ||
| &computation_graph, | ||
| ctx2.export_device_memories(), | ||
| ); | ||
| println!("prove succeeded!"); | ||
|
|
||
| // verify | ||
| println!("verify (verifying proof with new data)..."); | ||
| assert!(P::verify(&verifier_setup, &computation_graph, &proof2)); | ||
| println!("✓ Second verification passed!"); | ||
| println!("✓ Successfully generated and verified different proofs using new BN254 data"); | ||
| println!(" - First result: {:?}", result1); | ||
| println!(" - Second result: {:?}", result2); | ||
|
|
||
| P::post_process(); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_bn254_load_graph_with_new_data() { | ||
| test_bn254_load_graph_with_new_data_impl::<_, ExpanderNoOverSubscribe<ZKCudaBN254KZGBatchPCS>>( | ||
| ); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve maintainability and prevent future misuse, it would be beneficial to add a comment explaining the key assumption for reusing a computation graph. Specifically, that
call_kernelmust be invoked with the same sequence of kernel primitives as when the graph was originally compiled. This ensures thatkernel_ids align correctly when fetching pre-compiled kernels from the loaded graph.