Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions expander_compiler/src/zkcuda/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {

let dm_shapes = self.propagate_and_get_shapes();

let (mut cg_kernels, cg_proof_templates, cg_commitments_lens) = if let Some(cg) = cg {
let (cg_kernels, cg_proof_templates, cg_commitments_lens) = if let Some(cg) = cg {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve maintainability and prevent future misuse, it would be beneficial to add a comment explaining the key assumption for reusing a computation graph. Specifically, that call_kernel must be invoked with the same sequence of kernel primitives as when the graph was originally compiled. This ensures that kernel_ids align correctly when fetching pre-compiled kernels from the loaded graph.

for (i, kernel) in cg.kernels.iter().enumerate() {
assert_eq!(self.kernels.add(kernel), i);
}
Expand Down Expand Up @@ -616,8 +616,9 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
.map(get_pad_shape)
.collect::<Vec<_>>();
let kernel_primitive = self.kernel_primitives.get(kernel_call.kernel_id);
let kernel = if let Some(cg_kernels) = cg_kernels.as_mut() {
cg_kernels.drain(..1).next().unwrap()
let kernel = if cg_kernels.is_some() {
// Get kernel from loaded kernels by kernel_id
self.kernels.get(kernel_call.kernel_id).clone()
} else {
Comment on lines 619 to 622
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While this change correctly fixes the bug by not consuming the kernel, the use of .clone() on Kernel<C> could be inefficient if the kernel object is large, as it involves a deep copy of the circuit structure. This might impact performance when loading large, pre-compiled computation graphs.

Consider refactoring to avoid this clone. One approach is to use std::borrow::Cow<Kernel<C>> to hold either a borrowed reference (in the if case) or an owned kernel (in the else case). This would prevent the unnecessary clone when a kernel is already available in self.kernels.

Example using Cow:

use std::borrow::Cow;

// ...

let kernel: Cow<Kernel<C>> = if cg_kernels.is_some() {
    Cow::Borrowed(self.kernels.get(kernel_call.kernel_id))
} else {
    // ... logic to compile primitive ...
    Cow::Owned(compile_primitive(kernel_primitive, &psi, &pso)?)
};

// ... later

let kernel_id = self.kernels.add(&*kernel);

let mut psi = Vec::new();
for (s, &ib) in pad_shapes_input.iter().zip(kernel_call.is_broadcast.iter()) {
Expand Down Expand Up @@ -708,8 +709,9 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
});
}

if let Some(cg_kernels) = cg_kernels {
assert!(cg_kernels.is_empty());
if let Some(_cg_kernels) = cg_kernels {
// No longer checking if cg_kernels is empty since we no longer consume it
// Kernels were already added earlier via self.kernels.add()
assert_eq!(cg_proof_templates.unwrap(), self.proof_templates);
assert_eq!(cg_commitments_lens.unwrap(), commitments_lens);
Ok(None)
Expand Down
134 changes: 134 additions & 0 deletions expander_compiler/tests/test_bn254_new_data.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use expander_compiler::frontend::*;
use expander_compiler::zkcuda::proving_system::expander::config::ZKCudaBN254KZGBatchPCS;
use expander_compiler::zkcuda::proving_system::{ExpanderNoOverSubscribe, ProvingSystem};
use expander_compiler::zkcuda::shape::Reshape;
use expander_compiler::zkcuda::{context::*, kernel::*};

#[kernel]
fn add_2_macro<C: Config>(api: &mut API<C>, a: &[InputVariable; 2], b: &mut OutputVariable) {
*b = api.add(a[0], a[1]);
}

#[kernel]
fn add_16_macro<C: Config>(api: &mut API<C>, a: &[InputVariable; 16], b: &mut OutputVariable) {
let mut sum = api.constant(0);
for i in 0..16 {
sum = api.add(sum, a[i]);
}
*b = sum;
}

fn test_bn254_load_graph_with_new_data_impl<C: Config, P: ProvingSystem<C>>() {
let kernel_add_2: KernelPrimitive<C> = compile_add_2_macro().unwrap();
let kernel_add_16: KernelPrimitive<C> = compile_add_16_macro().unwrap();

println!("\n===== First execution: create and save graph (BN254) =====");
let mut ctx1: Context<C> = Context::default();

// First set of input data (BN254 field elements)
let mut a1: Vec<Vec<CircuitField<C>>> = vec![];
for i in 0..16 {
a1.push(vec![]);
for j in 0..2 {
a1[i].push(CircuitField::<C>::from((i * 2 + j + 1) as u32));
}
}
let a1 = ctx1.copy_to_device(&a1);
let mut b1: DeviceMemoryHandle = None;
call_kernel!(ctx1, kernel_add_2, 16, a1, mut b1).unwrap();
let b1 = b1.reshape(&[1, 16]);
let mut c1: DeviceMemoryHandle = None;
call_kernel!(ctx1, kernel_add_16, 1, b1, mut c1).unwrap();
let c1 = c1.reshape(&[]);
let result1: CircuitField<C> = ctx1.copy_to_host(c1);
println!("First result: {:?}", result1);
assert_eq!(result1, CircuitField::<C>::from(32 * 33 / 2 as u32));

let computation_graph = ctx1.compile_computation_graph().unwrap();
ctx1.solve_witness().unwrap();
println!("Starting setup (may take some time)...");
let (prover_setup, verifier_setup) = P::setup(&computation_graph);
println!("Starting prove...");
let proof1 = P::prove(
&prover_setup,
&computation_graph,
ctx1.export_device_memories(),
);
println!("Starting verify...");
assert!(P::verify(&verifier_setup, &computation_graph, &proof1));
println!("First verification passed!");

println!("\n===== Second execution: call_kernel first (new BN254 data), then load_graph =====");
let mut ctx2: Context<C> = Context::default();

// Second set of input data (different BN254 field elements)
let mut a2: Vec<Vec<CircuitField<C>>> = vec![];
for i in 0..16 {
a2.push(vec![]);
for j in 0..2 {
// Use different values: starting from 1000
a2[i].push(CircuitField::<C>::from((i * 2 + j + 1000) as u32));
}
}
let a2 = ctx2.copy_to_device(&a2);

// Call kernels first (same order as the first time)
let mut b2: DeviceMemoryHandle = None;
println!("Calling first kernel (using new data)...");
call_kernel!(ctx2, kernel_add_2, 16, a2, mut b2).unwrap();

let b2 = b2.reshape(&[1, 16]);
let mut c2: DeviceMemoryHandle = None;
println!("Calling second kernel...");
call_kernel!(ctx2, kernel_add_16, 1, b2, mut c2).unwrap();

let c2 = c2.reshape(&[]);
let result2: CircuitField<C> = ctx2.copy_to_host(c2);
println!("Second computation result: {:?}", result2);

// Verify results are indeed different
assert_ne!(result1, result2, "The two results should be different");

// Expected result for the second run:
// Input: [1000,1001], [1002,1003], ..., [1030,1031] (32 numbers total)
// add_2: 2001, 2005, 2009, ..., 2061 (16 numbers)
// add_16: sum(2001, 2005, ..., 2061) = 16 * (2001 + 2061) / 2 = 32496
let expected2 = CircuitField::<C>::from(32496u32);
assert_eq!(result2, expected2, "Second result should be 32496");

// Now load the graph (reuse compiled kernels)
println!("Loading computation_graph...");
ctx2.load_computation_graph(computation_graph.clone())
.unwrap();
println!("Graph loaded successfully!");

// solve_witness (will recalculate using new data)
println!("solve_witness (recalculating witness)...");
ctx2.solve_witness().unwrap();
println!("solve_witness succeeded!");

// prove (using new data)
println!("prove (generating proof with new data)...");
let proof2 = P::prove(
&prover_setup,
&computation_graph,
ctx2.export_device_memories(),
);
println!("prove succeeded!");

// verify
println!("verify (verifying proof with new data)...");
assert!(P::verify(&verifier_setup, &computation_graph, &proof2));
println!("✓ Second verification passed!");
println!("✓ Successfully generated and verified different proofs using new BN254 data");
println!(" - First result: {:?}", result1);
println!(" - Second result: {:?}", result2);

P::post_process();
}

#[test]
fn test_bn254_load_graph_with_new_data() {
test_bn254_load_graph_with_new_data_impl::<_, ExpanderNoOverSubscribe<ZKCudaBN254KZGBatchPCS>>(
);
}
Loading