Skip to content

Commit a5bc5c2

Browse files
Feat: Add batch_prove and batch_verify (#203)
* feat: batch prove and verify working * chore: cargo clippy * chore: cargo clippy
1 parent 15cf666 commit a5bc5c2

File tree

4 files changed

+1048
-48
lines changed

4 files changed

+1048
-48
lines changed

src/whir/domainsep.rs

Lines changed: 123 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,19 @@ pub trait WhirDomainSeparator<F: FftField, MerkleConfig: Config> {
2020
params: &WhirConfig<F, MerkleConfig, PowStrategy>,
2121
) -> Self;
2222

23+
/// Domain separator for regular single-commitment proving
2324
#[must_use]
2425
fn add_whir_proof<PowStrategy>(self, params: &WhirConfig<F, MerkleConfig, PowStrategy>)
2526
-> Self;
27+
28+
/// Domain separator for batch proving multiple commitments
29+
#[must_use]
30+
fn add_whir_batch_proof<PowStrategy>(
31+
self,
32+
params: &WhirConfig<F, MerkleConfig, PowStrategy>,
33+
num_witnesses: usize,
34+
num_constraints_total: usize,
35+
) -> Self;
2636
}
2737

2838
impl<F, MerkleConfig, DomainSeparator> WhirDomainSeparator<F, MerkleConfig> for DomainSeparator
@@ -65,57 +75,124 @@ where
6575
}
6676

6777
fn add_whir_proof<PowStrategy>(
68-
mut self,
78+
self,
6979
params: &WhirConfig<F, MerkleConfig, PowStrategy>,
7080
) -> Self {
71-
// TODO: Add statement
72-
if params.initial_statement {
73-
self = self
74-
.challenge_scalars(1, "initial_combination_randomness")
75-
.add_sumcheck(
76-
params.folding_factor.at_round(0),
77-
params.starting_folding_pow_bits,
78-
);
79-
} else {
80-
self = self
81-
.challenge_scalars(params.folding_factor.at_round(0), "folding_randomness")
82-
.pow(params.starting_folding_pow_bits);
83-
}
81+
add_whir_proof_impl(self, params, None)
82+
}
8483

85-
let mut domain_size = params.starting_domain.size();
86-
for (round, r) in params.round_parameters.iter().enumerate() {
87-
let folded_domain_size = domain_size >> params.folding_factor.at_round(round);
88-
let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize).div_ceil(8);
89-
90-
self = self
91-
.add_digest("merkle_digest")
92-
.add_ood(r.ood_samples, 1)
93-
.pow(r.pow_bits)
94-
.challenge_bytes(r.num_queries * domain_size_bytes, "stir_queries")
95-
.hint("stir_answers")
96-
.hint("merkle_proof");
97-
98-
self = self
99-
.challenge_scalars(1, "combination_randomness")
100-
.add_sumcheck(
101-
params.folding_factor.at_round(round + 1),
102-
r.folding_pow_bits,
103-
);
104-
domain_size >>= 1;
105-
}
84+
fn add_whir_batch_proof<PowStrategy>(
85+
self,
86+
params: &WhirConfig<F, MerkleConfig, PowStrategy>,
87+
num_witnesses: usize,
88+
num_constraints_total: usize,
89+
) -> Self {
90+
// Step 1: Commit the full N×M constraint evaluation matrix to the transcript.
91+
// This binds the prover to all cross-term evaluations before sampling γ,
92+
// preventing adaptive attacks where the prover could choose cross-terms
93+
// after seeing the batching challenge.
94+
let matrix_size = num_witnesses * num_constraints_total;
95+
let this = self.add_scalars(matrix_size, "constraint_evaluation_matrix");
96+
97+
// Step 2: Sample batching randomness γ after committing evaluations
98+
let this = this.challenge_scalars(1, "batching_randomness");
99+
100+
// Step 3: Continue with standard WHIR proof protocol
101+
add_whir_proof_impl(this, params, Some(num_witnesses))
102+
}
103+
}
106104

107-
let folded_domain_size = domain_size
108-
>> params
109-
.folding_factor
110-
.at_round(params.round_parameters.len());
105+
/// Private helper: shared implementation for both regular and batch proving.
106+
///
107+
/// # Arguments
108+
/// * `ds` - Domain separator state
109+
/// * `params` - WHIR protocol configuration
110+
/// * `num_witnesses` - `None` for regular proving, `Some(n)` for batch proving with n witnesses
111+
fn add_whir_proof_impl<F, MerkleConfig, DomainSeparator, PowStrategy>(
112+
mut ds: DomainSeparator,
113+
params: &WhirConfig<F, MerkleConfig, PowStrategy>,
114+
num_witnesses: Option<usize>,
115+
) -> DomainSeparator
116+
where
117+
F: FftField,
118+
MerkleConfig: Config,
119+
DomainSeparator:
120+
ByteDomainSeparator + FieldDomainSeparator<F> + DigestDomainSeparator<MerkleConfig>,
121+
{
122+
// Initial sumcheck (same for both regular and batch)
123+
if params.initial_statement {
124+
ds = ds
125+
.challenge_scalars(1, "initial_combination_randomness")
126+
.add_sumcheck(
127+
params.folding_factor.at_round(0),
128+
params.starting_folding_pow_bits,
129+
);
130+
} else {
131+
ds = ds
132+
.challenge_scalars(params.folding_factor.at_round(0), "folding_randomness")
133+
.pow(params.starting_folding_pow_bits);
134+
}
135+
136+
let mut domain_size = params.starting_domain.size();
137+
138+
// Round handling
139+
for (round, r) in params.round_parameters.iter().enumerate() {
140+
let folded_domain_size = domain_size >> params.folding_factor.at_round(round);
111141
let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize).div_ceil(8);
112142

113-
self.add_scalars(1 << params.final_sumcheck_rounds, "final_coeffs")
114-
.pow(params.final_pow_bits)
115-
.challenge_bytes(domain_size_bytes * params.final_queries, "final_queries")
116-
.hint("stir_answers")
117-
.hint("merkle_proof")
118-
.add_sumcheck(params.final_sumcheck_rounds, params.final_folding_pow_bits)
119-
.hint("deferred_weight_evaluations")
143+
// Digest label differs for batch round 0
144+
let digest_label = if round == 0 && num_witnesses.is_some() {
145+
"batched_merkle_digest" // Round 0 commits to the batched polynomial
146+
} else {
147+
"merkle_digest"
148+
};
149+
150+
ds = ds
151+
.add_digest(digest_label)
152+
.add_ood(r.ood_samples, 1)
153+
.pow(r.pow_bits)
154+
.challenge_bytes(r.num_queries * domain_size_bytes, "stir_queries");
155+
156+
// Round 0 Merkle proofs: batch proving requires N proofs (one per original tree),
157+
// while regular proving requires just 1 proof.
158+
if round == 0 {
159+
if let Some(n) = num_witnesses {
160+
// Batch proving: verify openings in all N original commitment trees
161+
for i in 0..n {
162+
ds = ds
163+
.hint(&format!("stir_answers_witness_{i}"))
164+
.hint(&format!("merkle_proof_witness_{i}"));
165+
}
166+
} else {
167+
// Regular proving: single commitment tree
168+
ds = ds.hint("stir_answers").hint("merkle_proof");
169+
}
170+
} else {
171+
// Rounds 1+: all proving modes use the single batched tree
172+
ds = ds.hint("stir_answers").hint("merkle_proof");
173+
}
174+
175+
ds = ds
176+
.challenge_scalars(1, "combination_randomness")
177+
.add_sumcheck(
178+
params.folding_factor.at_round(round + 1),
179+
r.folding_pow_bits,
180+
);
181+
domain_size >>= 1;
120182
}
183+
184+
// Final round (same for both regular and batch)
185+
let folded_domain_size = domain_size
186+
>> params
187+
.folding_factor
188+
.at_round(params.round_parameters.len());
189+
let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize).div_ceil(8);
190+
191+
ds.add_scalars(1 << params.final_sumcheck_rounds, "final_coeffs")
192+
.pow(params.final_pow_bits)
193+
.challenge_bytes(domain_size_bytes * params.final_queries, "final_queries")
194+
.hint("stir_answers")
195+
.hint("merkle_proof")
196+
.add_sumcheck(params.final_sumcheck_rounds, params.final_folding_pow_bits)
197+
.hint("deferred_weight_evaluations")
121198
}

0 commit comments

Comments
 (0)