diff --git a/crates/core/executor/src/dependencies.rs b/crates/core/executor/src/dependencies.rs index 7ded3d59e..2007c7934 100644 --- a/crates/core/executor/src/dependencies.rs +++ b/crates/core/executor/src/dependencies.rs @@ -288,6 +288,12 @@ pub fn emit_misc_dependencies(executor: &mut Executor, event: MiscEvent) { } else if matches!(event.opcode, Opcode::EXT) { let lsb = event.c & 0x1f; let msbd = event.c >> 5; + // `execute_ext` rejects encodings with `lsb + msbd >= 32`, so the `31 - lsb - msbd` + // shift amounts below cannot underflow. + debug_assert!( + lsb + msbd < 32, + "EXT with lsb + msbd >= 32 must be rejected during execution" + ); let sll_val = event.b << (31 - lsb - msbd); let sll_event = AluEvent { pc: UNUSED_PC, diff --git a/crates/core/executor/src/events/byte.rs b/crates/core/executor/src/events/byte.rs index aa6d1a3be..6a9b7cc04 100644 --- a/crates/core/executor/src/events/byte.rs +++ b/crates/core/executor/src/events/byte.rs @@ -121,9 +121,13 @@ impl ByteRecord for Vec { fn add_byte_lookup_events_from_maps( &mut self, - _new_events: Vec<&HashMap>, + new_events: Vec<&HashMap>, ) { - todo!() + for new_blu_map in new_events { + for (blu_event, count) in new_blu_map.iter() { + self.extend(std::iter::repeat_n(*blu_event, *count)); + } + } } } @@ -185,3 +189,21 @@ impl ByteOpcode { F::from_canonical_u8(self as u8) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_vec_add_byte_lookup_events_from_maps_expands_counts() { + let event = ByteLookupEvent::new(ByteOpcode::U8Range, 0, 0, 1, 2); + let mut map = HashMap::new(); + map.insert(event, 3); + + let mut events = Vec::new(); + events.add_byte_lookup_events_from_maps(vec![&map]); + + assert_eq!(events.len(), 3); + assert!(events.iter().all(|e| *e == event)); + } +} diff --git a/crates/core/executor/src/executor.rs b/crates/core/executor/src/executor.rs index 4a482d1ac..44c2fdc4d 100644 --- a/crates/core/executor/src/executor.rs +++ b/crates/core/executor/src/executor.rs @@ -250,6 +250,10 @@ pub enum ExecutionError { #[error("buffer length {1} must be greater than or equal to {0}")] BufferLengthTooSmall(usize, usize), + /// The execution failed because a hook received an unsupported elliptic curve identifier. + #[error("unsupported ecrecover curve id: {0}")] + UnsupportedEcrecoverCurveId(u8), + /// The execution failed while converting a slice to an array due to size mismatch. #[error("failed to convert slice {0} to array")] IntoArrayError(String), @@ -1545,11 +1549,11 @@ impl<'a> Executor<'a> { if instruction.opcode == Opcode::WSBH { (a, b, c) = self.execute_wsbh(instruction); } else if instruction.opcode == Opcode::EXT { - (a, b, c) = self.execute_ext(instruction); + (a, b, c) = self.execute_ext(instruction)?; } else if instruction.opcode == Opcode::MADDU { (hi_or_prev_a, a, b, c) = self.execute_maddu(instruction); } else if instruction.opcode == Opcode::INS { - (hi_or_prev_a, a, b, c) = self.execute_ins(instruction); + (hi_or_prev_a, a, b, c) = self.execute_ins(instruction)?; } else if instruction.opcode == Opcode::SEXT { (a, b, c) = self.execute_sext(instruction); } else if instruction.opcode == Opcode::TEQ { @@ -1784,20 +1788,32 @@ impl<'a> Executor<'a> { (a, b, 0) } - fn execute_ext(&mut self, instruction: &Instruction) -> (u32, u32, u32) { + fn execute_ext( + &mut self, + instruction: &Instruction, + ) -> Result<(u32, u32, u32), ExecutionError> { let (rd, rt, c) = (instruction.op_a.into(), (instruction.op_b as u8).into(), instruction.op_c); let b = self.rr_cpu(rt, MemoryAccessPosition::B); let msbd = c >> 5; let lsb = c & 0x1f; + // `lsb + msbd < 32` is architecturally required (and enforced by the EXT AIR + // constraint). Otherwise the `31 - lsb - msbd` shift amount used here and in trace + // generation underflows as a `u32`. Reject the undefined encoding instead of panicking. + if msbd + lsb >= 32 { + return Err(ExecutionError::ExceptionOrTrap()); + } let mask_msb = if msbd + lsb + 1 == 32 { 0xFFFFFFFF } else { (1u32 << (msbd + lsb + 1)) - 1 }; let a = (b & mask_msb) >> lsb; self.rw_cpu(rd, a, MemoryAccessPosition::A); - (a, b, c) + Ok((a, b, c)) } - fn execute_ins(&mut self, instruction: &Instruction) -> (Option, u32, u32, u32) { + fn execute_ins( + &mut self, + instruction: &Instruction, + ) -> Result<(Option, u32, u32, u32), ExecutionError> { let (rd, rt, c) = (instruction.op_a.into(), (instruction.op_b as u8).into(), instruction.op_c); let b = self.rr_cpu(rt, MemoryAccessPosition::B); @@ -1805,11 +1821,14 @@ impl<'a> Executor<'a> { let prev_a = a; let msb = c >> 5; let lsb = c & 0x1f; + if msb < lsb { + return Err(ExecutionError::ExceptionOrTrap()); + } let mask = if msb - lsb + 1 == 32 { 0xFFFFFFFF } else { (1u32 << (msb - lsb + 1)) - 1 }; let mask_field = mask << lsb; let a = (a & !mask_field) | ((b << lsb) & mask_field); self.rw_cpu(rd, a, MemoryAccessPosition::A); - (Some(prev_a), a, b, c) + Ok((Some(prev_a), a, b, c)) } fn execute_teq( @@ -2069,7 +2088,7 @@ impl<'a> Executor<'a> { rt } // Opcode::SDC1 => 0, - _ => todo!(), + _ => unreachable!("unexpected store opcode: {:?}", instruction.opcode), }; if aligned_addr + 3 > MAX_MEMORY as u32 { diff --git a/crates/core/executor/src/hook.rs b/crates/core/executor/src/hook.rs index 1af3d7f08..040d26235 100644 --- a/crates/core/executor/src/hook.rs +++ b/crates/core/executor/src/hook.rs @@ -130,7 +130,7 @@ pub fn hook_ecrecover(_: HookEnv, buf: &[u8]) -> Result>, ExecutionE match curve_id { 1 => Ok(ecrecover::handle_secp256k1(r_bytes, alpha_bytes, r_is_y_odd)), 2 => Ok(ecrecover::handle_secp256r1(r_bytes, alpha_bytes, r_is_y_odd)), - _ => unimplemented!("Unsupported curve id: {}", curve_id), + _ => Err(ExecutionError::UnsupportedEcrecoverCurveId(curve_id)), } } diff --git a/crates/core/executor/src/syscalls/precompiles/boolean_circuit/garble.rs b/crates/core/executor/src/syscalls/precompiles/boolean_circuit/garble.rs index 604cb4d7e..3c7a79da7 100644 --- a/crates/core/executor/src/syscalls/precompiles/boolean_circuit/garble.rs +++ b/crates/core/executor/src/syscalls/precompiles/boolean_circuit/garble.rs @@ -94,3 +94,145 @@ impl Syscall for BooleanCircuitGarbleSyscall { Ok(None) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{events::PrecompileEvent, Executor, Program}; + use zkm_stark::ZKMCoreOpts; + + const INPUT_PTR: u32 = 0x1000; + const OUTPUT_PTR: u32 = 0x2000; + const OR_GATE_ID: u32 = 7; + + fn gate_info_words(gate_type: u32, delta: [u32; 4], valid: bool) -> [u32; GATE_INFO_BYTES] { + let h0 = [11, 12, 13, 14]; + let h1 = [21, 22, 23, 24]; + let label_b = [31, 32, 33, 34]; + let mut expected = [0u32; 4]; + for i in 0..4 { + expected[i] = h0[i] ^ h1[i] ^ label_b[i]; + if gate_type == OR_GATE_ID { + expected[i] ^= delta[i]; + } + } + if !valid { + expected[3] ^= 1; + } + + let mut words = [0u32; GATE_INFO_BYTES]; + words[0] = gate_type; + words[1..5].copy_from_slice(&h0); + words[5..9].copy_from_slice(&h1); + words[9..13].copy_from_slice(&label_b); + words[13..17].copy_from_slice(&expected); + words + } + + fn write_input( + runtime: &mut Executor<'_>, + gate_infos: &[[u32; GATE_INFO_BYTES]], + delta: [u32; 4], + ) { + let mut timestamp = 1; + let shard = 1; + runtime.mw(INPUT_PTR, gate_infos.len() as u32, shard, timestamp, None); + timestamp += 1; + for (i, value) in delta.into_iter().enumerate() { + runtime.mw(INPUT_PTR + 4 + i as u32 * 4, value, shard, timestamp, None); + timestamp += 1; + } + for (gate_idx, gate_info) in gate_infos.iter().enumerate() { + let gate_base = INPUT_PTR + 20 + gate_idx as u32 * (GATE_INFO_BYTES as u32) * 4; + for (word_idx, value) in gate_info.iter().enumerate() { + runtime.mw(gate_base + word_idx as u32 * 4, *value, shard, timestamp, None); + timestamp += 1; + } + } + runtime.mw(OUTPUT_PTR, u32::MAX, shard, timestamp, None); + } + + fn run_syscall(gate_infos: Vec<[u32; GATE_INFO_BYTES]>, delta: [u32; 4]) -> Executor<'static> { + let mut runtime = Executor::new(Program::default(), ZKMCoreOpts::default()); + write_input(&mut runtime, &gate_infos, delta); + runtime.state.current_shard = 2; + runtime.state.clk = 1; + + let syscall = BooleanCircuitGarbleSyscall; + let mut ctx = SyscallContext::new(&mut runtime); + syscall + .execute(&mut ctx, SyscallCode::BOOLEAN_CIRCUIT_GARBLE, INPUT_PTR, OUTPUT_PTR) + .unwrap(); + runtime + } + + #[test] + fn basic_and_gate_verification_succeeds() { + let delta = [101, 102, 103, 104]; + let mut runtime = run_syscall(vec![gate_info_words(0, delta, true)], delta); + assert_eq!(runtime.word(OUTPUT_PTR), 1); + + let events = runtime.record.get_precompile_events(SyscallCode::BOOLEAN_CIRCUIT_GARBLE); + assert_eq!(events.len(), 1); + let (_, event) = &events[0]; + let event = match event { + PrecompileEvent::BooleanCircuitGarble(event) => event, + _ => unreachable!(), + }; + assert_eq!(event.output, 1); + assert_eq!(event.num_gates, 1); + assert_eq!(event.gates_info.len(), GATE_INFO_BYTES); + } + + #[test] + fn basic_or_gate_verification_succeeds() { + let delta = [201, 202, 203, 204]; + let mut runtime = run_syscall(vec![gate_info_words(OR_GATE_ID, delta, true)], delta); + assert_eq!(runtime.word(OUTPUT_PTR), 1); + } + + #[test] + fn mixed_gates_with_bad_ciphertext_return_false() { + let delta = [111, 222, 333, 444]; + let gate_infos = vec![ + gate_info_words(0, delta, true), + gate_info_words(OR_GATE_ID, delta, true), + gate_info_words(0, delta, false), + ]; + let mut runtime = run_syscall(gate_infos, delta); + assert_eq!(runtime.word(OUTPUT_PTR), 0); + + let events = runtime.record.get_precompile_events(SyscallCode::BOOLEAN_CIRCUIT_GARBLE); + let (_, event) = &events[0]; + let event = match event { + PrecompileEvent::BooleanCircuitGarble(event) => event, + _ => unreachable!(), + }; + let accessed_addrs = event + .local_mem_access + .iter() + .map(|access| access.addr) + .collect::>(); + assert!(accessed_addrs.contains(&INPUT_PTR)); + assert!(accessed_addrs.contains(&(INPUT_PTR + 20))); + assert!(accessed_addrs.contains(&(INPUT_PTR + 20 + (GATE_INFO_BYTES as u32) * 4))); + assert!(accessed_addrs.contains(&OUTPUT_PTR)); + } + + #[test] + fn zero_gates_write_true() { + let delta = [1, 2, 3, 4]; + let mut runtime = run_syscall(vec![], delta); + assert_eq!(runtime.word(OUTPUT_PTR), 1); + + let events = runtime.record.get_precompile_events(SyscallCode::BOOLEAN_CIRCUIT_GARBLE); + let (_, event) = &events[0]; + let event = match event { + PrecompileEvent::BooleanCircuitGarble(event) => event, + _ => unreachable!(), + }; + assert_eq!(event.num_gates, 0); + assert!(event.gates_info.is_empty()); + assert_eq!(event.output, 1); + } +} diff --git a/crates/core/executor/src/syscalls/precompiles/sys_linux/sysbrk.rs b/crates/core/executor/src/syscalls/precompiles/sys_linux/sysbrk.rs index dff27983f..f710a4823 100644 --- a/crates/core/executor/src/syscalls/precompiles/sys_linux/sysbrk.rs +++ b/crates/core/executor/src/syscalls/precompiles/sys_linux/sysbrk.rs @@ -1,11 +1,37 @@ use crate::{ events::{LinuxEvent, PrecompileEvent}, + program::MAX_MEMORY, syscalls::{Syscall, SyscallCode, SyscallContext}, ExecutionError, Register, }; pub(crate) struct SysBrkSyscall; +/// Maximum amount of heap growth allowed from the program's initial BRK value. +/// +/// This is a prover-side safety bound to prevent guest-controlled `brk` values from +/// expanding memory usage without limit. +const MAX_HEAP_SIZE: u32 = 0x4000_0000; + +fn max_brk(initial_brk: u32) -> Result { + let limit = + initial_brk.checked_add(MAX_HEAP_SIZE).ok_or(ExecutionError::InvalidSyscallArgs())?; + Ok(limit.min(MAX_MEMORY as u32)) +} + +fn resolve_brk( + initial_brk: u32, + current_brk: u32, + requested_brk: u32, +) -> Result { + let limit = max_brk(initial_brk)?; + let v0 = requested_brk.max(current_brk); + if v0 > limit { + return Err(ExecutionError::InvalidSyscallArgs()); + } + Ok(v0) +} + impl Syscall for SysBrkSyscall { fn num_extra_cycles(&self) -> u32 { 0 @@ -20,7 +46,8 @@ impl Syscall for SysBrkSyscall { ) -> Result, ExecutionError> { let start_clk = rt.clk; let (record, brk) = rt.rr_traced(Register::BRK); - let v0 = if a0 > brk { a0 } else { brk }; + let initial_brk = rt.rt.program.image.get(&(Register::BRK as u32)).copied().unwrap_or(brk); + let v0 = resolve_brk(initial_brk, brk, a0)?; let a3_record = rt.rw_traced(Register::A3, 0); let shard = rt.current_shard(); let event = PrecompileEvent::Linux(LinuxEvent { @@ -40,3 +67,36 @@ impl Syscall for SysBrkSyscall { Ok(Some(v0)) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn resolve_brk_keeps_current_value_when_request_is_lower() { + assert_eq!(resolve_brk(0x1000, 0x2000, 0x1800).unwrap(), 0x2000); + } + + #[test] + fn resolve_brk_allows_growth_within_limit() { + assert_eq!(resolve_brk(0x1000, 0x2000, 0x3000).unwrap(), 0x3000); + } + + #[test] + fn resolve_brk_rejects_growth_past_limit() { + let initial_brk = 0x1000; + let limit = max_brk(initial_brk).unwrap(); + assert!(matches!( + resolve_brk(initial_brk, 0x2000, limit + 4), + Err(ExecutionError::InvalidSyscallArgs()) + )); + } + + #[test] + fn max_brk_rejects_overflowing_initial_brk() { + assert!(matches!( + max_brk(u32::MAX - MAX_HEAP_SIZE + 1), + Err(ExecutionError::InvalidSyscallArgs()) + )); + } +} diff --git a/crates/core/executor/src/syscalls/precompiles/sys_linux/sysmmap.rs b/crates/core/executor/src/syscalls/precompiles/sys_linux/sysmmap.rs index ef59a554d..737b34105 100644 --- a/crates/core/executor/src/syscalls/precompiles/sys_linux/sysmmap.rs +++ b/crates/core/executor/src/syscalls/precompiles/sys_linux/sysmmap.rs @@ -10,6 +10,17 @@ pub const PAGE_ADDR_SIZE: usize = 12; pub const PAGE_ADDR_MASK: usize = (1 << PAGE_ADDR_SIZE) - 1; pub const PAGE_SIZE: usize = 1 << PAGE_ADDR_SIZE; +fn align_size(size: u32) -> Result { + if size & (PAGE_ADDR_MASK as u32) == 0 { + return Ok(size); + } + + let aligned = size + .checked_add(PAGE_SIZE as u32 - (size & (PAGE_ADDR_MASK as u32))) + .ok_or(ExecutionError::InvalidSyscallArgs())?; + Ok(aligned) +} + impl Syscall for SysMmapSyscall { fn num_extra_cycles(&self) -> u32 { 0 @@ -22,12 +33,8 @@ impl Syscall for SysMmapSyscall { a0: u32, a1: u32, ) -> Result, ExecutionError> { - let mut size = a1; let start_clk = rt.clk; - if size & (PAGE_ADDR_MASK as u32) != 0 { - // adjust size to align with page size - size = size.wrapping_add(PAGE_SIZE as u32 - (size & (PAGE_ADDR_MASK as u32))); - } + let size = align_size(a1)?; let a3_record = rt.rw_traced(Register::A3, 0); @@ -57,3 +64,23 @@ impl Syscall for SysMmapSyscall { Ok(Some(v0)) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_align_size_preserves_aligned_values() { + assert_eq!(align_size(0x2000).unwrap(), 0x2000); + } + + #[test] + fn test_align_size_rounds_up_misaligned_values() { + assert_eq!(align_size(0x1001).unwrap(), 0x2000); + } + + #[test] + fn test_align_size_rejects_overflow() { + assert!(matches!(align_size(0xFFFF_F001), Err(ExecutionError::InvalidSyscallArgs()))); + } +} diff --git a/crates/core/executor/tests/ext_executor_edge_cases.rs b/crates/core/executor/tests/ext_executor_edge_cases.rs new file mode 100644 index 000000000..d02672111 --- /dev/null +++ b/crates/core/executor/tests/ext_executor_edge_cases.rs @@ -0,0 +1,36 @@ +use zkm_core_executor::{ExecutionError, Executor}; +use zkm_stark::ZKMCoreOpts; + +/// `EXT` with `lsb + msbd >= 32` is architecturally undefined. The executor must reject it +/// with a trap instead of underflowing the `31 - lsb - msbd` shift amount (which would panic +/// or produce an incorrect trace). Here `op_c = 0x28F` encodes `msbd = 20, lsb = 15`. +#[test] +fn n75_ext_oversized_field_traps() { + let mut runtime = Executor::new( + zkm_core_executor::Program::new( + vec![ + zkm_core_executor::Instruction::new( + zkm_core_executor::Opcode::ADD, + zkm_core_executor::Register::T0 as u8, + 0, + 0x1234_5678, + false, + true, + ), + zkm_core_executor::Instruction::new( + zkm_core_executor::Opcode::EXT, + zkm_core_executor::Register::T1 as u8, + zkm_core_executor::Register::T0 as u32, + 0x28F, + false, + true, + ), + ], + 0, + 0, + ), + ZKMCoreOpts::default(), + ); + let err = runtime.run_very_fast().unwrap_err(); + assert!(matches!(err, ExecutionError::ExceptionOrTrap())); +} diff --git a/crates/core/machine/include/memory_local.hpp b/crates/core/machine/include/memory_local.hpp index 0ac208c60..1d3d948a3 100644 --- a/crates/core/machine/include/memory_local.hpp +++ b/crates/core/machine/include/memory_local.hpp @@ -8,15 +8,40 @@ namespace zkm_core_machine_sys::memory_local { template __ZKM_HOSTDEV__ void event_to_row(const MemoryLocalEvent* event, SingleMemoryLocal* cols) { cols->addr = F::from_canonical_u32(event->addr); - + + // Bit decomposition of `addr` for the KoalaBear field-element range check. + for (uintptr_t i = 0; i < 32; i++) { + cols->addr_bits.bits[i] = F::from_canonical_u32(((event->addr) >> i) & 1); + } + cols->addr_bits.and_most_sig_byte_decomp_0_to_2 = + cols->addr_bits.bits[24] * cols->addr_bits.bits[25]; + cols->addr_bits.and_most_sig_byte_decomp_0_to_3 = + cols->addr_bits.and_most_sig_byte_decomp_0_to_2 * cols->addr_bits.bits[26]; + cols->addr_bits.and_most_sig_byte_decomp_0_to_4 = + cols->addr_bits.and_most_sig_byte_decomp_0_to_3 * cols->addr_bits.bits[27]; + cols->addr_bits.and_most_sig_byte_decomp_0_to_5 = + cols->addr_bits.and_most_sig_byte_decomp_0_to_4 * cols->addr_bits.bits[28]; + cols->addr_bits.and_most_sig_byte_decomp_0_to_6 = + cols->addr_bits.and_most_sig_byte_decomp_0_to_5 * cols->addr_bits.bits[29]; + cols->addr_bits.and_most_sig_byte_decomp_0_to_7 = + cols->addr_bits.and_most_sig_byte_decomp_0_to_6 * cols->addr_bits.bits[30]; + cols->initial_shard = F::from_canonical_u32(event->initial_mem_access.shard); cols->initial_clk = F::from_canonical_u32(event->initial_mem_access.timestamp); write_word_from_u32_v2(cols->initial_value, event->initial_mem_access.value); - + cols->final_shard = F::from_canonical_u32(event->final_mem_access.shard); cols->final_clk = F::from_canonical_u32(event->final_mem_access.timestamp); write_word_from_u32_v2(cols->final_value, event->final_mem_access.value); + // 16-bit limbs for shards and 16-bit + 8-bit limbs for the 24-bit clk range checks. + cols->initial_shard_16bit_limb = F::from_canonical_u32(event->initial_mem_access.shard & 0xffff); + cols->final_shard_16bit_limb = F::from_canonical_u32(event->final_mem_access.shard & 0xffff); + cols->initial_clk_16bit_limb = F::from_canonical_u32(event->initial_mem_access.timestamp & 0xffff); + cols->initial_clk_8bit_limb = F::from_canonical_u32((event->initial_mem_access.timestamp >> 16) & 0xff); + cols->final_clk_16bit_limb = F::from_canonical_u32(event->final_mem_access.timestamp & 0xffff); + cols->final_clk_8bit_limb = F::from_canonical_u32((event->final_mem_access.timestamp >> 16) & 0xff); + cols->is_real = F::one(); } } // namespace zkm::memory_local diff --git a/crates/core/machine/src/air/memory.rs b/crates/core/machine/src/air/memory.rs index 883d5423a..acf54c614 100644 --- a/crates/core/machine/src/air/memory.rs +++ b/crates/core/machine/src/air/memory.rs @@ -8,7 +8,10 @@ use zkm_stark::{ LookupKind, }; -use crate::memory::{MemoryAccessCols, MemoryCols}; +use crate::{ + air::WordAirBuilder, + memory::{MemoryAccessCols, MemoryCols}, +}; pub trait MemoryAirBuilder: BaseAirBuilder { /// Constrain a memory read or write. @@ -35,6 +38,11 @@ pub trait MemoryAirBuilder: BaseAirBuilder { // Verify that the current memory access time is greater than the previous's. self.eval_memory_access_timestamp(mem_access, do_check.clone(), shard.clone(), clk.clone()); + // Defense-in-depth: memory words entering the subsystem must remain byte-shaped even + // if an upstream chip forgot to range check them. + self.slice_range_check_u8(&memory_access.prev_value().0, do_check.clone()); + self.slice_range_check_u8(&memory_access.value().0, do_check.clone()); + // Add to the memory argument. let addr = addr.into(); let prev_shard = mem_access.prev_shard.clone().into(); diff --git a/crates/core/machine/src/alu/divrem/mod.rs b/crates/core/machine/src/alu/divrem/mod.rs index 22aedb94d..8948c0601 100644 --- a/crates/core/machine/src/alu/divrem/mod.rs +++ b/crates/core/machine/src/alu/divrem/mod.rs @@ -55,11 +55,10 @@ //! elif c > 0: //! assert 0 <= remainder < c //! -//! if is_c_0: -//! # if division by 0, then quotient is UNPREDICTABLE per MIPS spec. -//! We restrict the quotient = 0xffffffff and remainder = b. -//! This needs special care since # b = 0 * quotient + b is satisfied by any quotient. -//! assert quotient = 0xffffffff +//! # Division by zero is undefined per the MIPS spec and is rejected by the executor +//! # (it traps), so an honest trace never contains a div-by-zero event. The AIR enforces +//! # the same to stay in agreement with the executor. +//! assert not is_c_0 # i.e. c != 0 on every real row use core::{ borrow::{Borrow, BorrowMut}, @@ -568,7 +567,12 @@ where .assert_zero(local.b_neg); // b is not negative. } - // When division by 0, quotient is UNPREDICTABLE per MIPS spec. We restrict the quotient = 0xffffffff + // Division by zero is architecturally undefined: the executor traps on it + // (`ExecutionError::ExceptionOrTrap`, see `execute_alu`), so an honest trace never + // contains a divrem event with c == 0. Enforce the same here so the AIR rejects + // div-by-zero rows rather than accepting them with a forced quotient. This keeps the + // executor and AIR in agreement and removes a path for injecting controlled + // quotient/remainder values into the trace. { // Calculate whether c is 0. IsZeroWordOperation::::eval( @@ -578,12 +582,8 @@ where is_real.clone(), ); - // If is_c_0 is true, then quotient must be 0xffffffff = u32::MAX. - for i in 0..WORD_SIZE { - builder - .when(local.is_c_0.result) - .assert_eq(local.quotient[i], AB::F::from_canonical_u8(u8::MAX)); - } + // c must be non-zero on every real divrem row. + builder.when(is_real.clone()).assert_zero(local.is_c_0.result); } // Range check remainder. (i.e., |remainder| < |c| when not is_c_0) diff --git a/crates/core/machine/src/alu/mul/mod.rs b/crates/core/machine/src/alu/mul/mod.rs index af9c42dc6..baaeaad2b 100644 --- a/crates/core/machine/src/alu/mul/mod.rs +++ b/crates/core/machine/src/alu/mul/mod.rs @@ -42,7 +42,7 @@ use p3_matrix::{dense::RowMajorMatrix, Matrix}; use p3_maybe_rayon::prelude::{ParallelBridge, ParallelIterator, ParallelSlice}; use zkm_core_executor::{ events::{ByteLookupEvent, ByteRecord, CompAluEvent, MemoryAccessPosition, MemoryRecordEnum}, - ByteOpcode, ExecutionRecord, Opcode, Program, + ByteOpcode, ExecutionRecord, Opcode, Program, UNUSED_PC, }; use zkm_derive::AlignedBorrow; #[cfg(feature = "picus")] @@ -513,6 +513,12 @@ where // if hi_record_is_real = 0, both clk and shard should be zero. builder.when_not(local.is_real).assert_zero(local.hi_record_is_real); builder.when(local.hi_record_is_real).assert_one(local.is_mult + local.is_multu); + // Hardware MULT/MULTU rows must write HI. Dependency-only multiply + // rows use UNUSED_PC and keep hi_record_is_real = 0. + builder.when(local.is_mult + local.is_multu).assert_zero( + (local.pc - AB::Expr::from_canonical_u32(UNUSED_PC)) + * (AB::Expr::one() - local.hi_record_is_real), + ); builder.when(local.hi_record_is_real).assert_word_eq(local.hi, *local.op_hi_access.value()); builder.when_not(local.hi_record_is_real).assert_zero(local.clk); builder.when_not(local.hi_record_is_real).assert_zero(local.shard); diff --git a/crates/core/machine/src/cpu/air/mod.rs b/crates/core/machine/src/cpu/air/mod.rs index 02fa3557c..9df2883ed 100644 --- a/crates/core/machine/src/cpu/air/mod.rs +++ b/crates/core/machine/src/cpu/air/mod.rs @@ -208,6 +208,16 @@ impl CpuChip { // If the last real row is the last row, verify the public value's next pc. builder.when_last_row().when(local.is_real).assert_eq(public_values.next_pc, local.next_pc); + + // A branch or jump row carries its post-delay-slot target in `next_next_pc`. + // Since shard public values only export `next_pc`, such rows must not be + // the last real row of a shard; otherwise the target is dropped at the + // boundary and the next shard can rederive fall-through. + builder + .when_transition() + .when(local.is_real - next.is_real) + .assert_one(local.is_sequential + local.is_halt); + builder.when_last_row().when(local.is_real).assert_one(local.is_sequential + local.is_halt); } /// Constraints related to the is_real column. diff --git a/crates/core/machine/src/cpu/air/register.rs b/crates/core/machine/src/cpu/air/register.rs index b5dc36f7a..82321ef2a 100644 --- a/crates/core/machine/src/cpu/air/register.rs +++ b/crates/core/machine/src/cpu/air/register.rs @@ -49,6 +49,13 @@ impl CpuChip { .when_not(local.instruction.op_a_0) .assert_word_eq(local.op_a_value, *local.op_a_access.value()); + // If `op_a` is an immutable read from register 0, the logical operand + // sent to instruction chips must also be zero. Writes to register 0 + // are intentionally excluded because their computed result is discarded. + builder + .when(local.instruction.op_a_0 * local.op_a_immutable) + .assert_word_zero(local.op_a_value); + // If we are maddu,msubu,madd, msub, ins,mne, meq, syscall and memory instruction then the hi_or_prev_a should equal to op_a_access.prev_value. builder .when(local.is_rw_a) diff --git a/crates/core/machine/src/memory/consistency/trace.rs b/crates/core/machine/src/memory/consistency/trace.rs index f63ff2988..ba0864fb8 100644 --- a/crates/core/machine/src/memory/consistency/trace.rs +++ b/crates/core/machine/src/memory/consistency/trace.rs @@ -74,6 +74,11 @@ impl MemoryAccessCols { ) { self.value = current_record.value.into(); + // Match the byte range checks emitted by `eval_memory_access` for both the previous and + // current memory words. + output.add_u8_range_checks(&prev_record.value.to_le_bytes()); + output.add_u8_range_checks(¤t_record.value.to_le_bytes()); + self.prev_shard = F::from_canonical_u32(prev_record.shard); self.prev_clk = F::from_canonical_u32(prev_record.timestamp); diff --git a/crates/core/machine/src/memory/local.rs b/crates/core/machine/src/memory/local.rs index be0943f32..d81fff807 100644 --- a/crates/core/machine/src/memory/local.rs +++ b/crates/core/machine/src/memory/local.rs @@ -3,14 +3,15 @@ use std::{ mem::size_of, }; -use p3_air::{Air, BaseAir}; +use hashbrown::HashMap; +use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::FieldAlgebra; use p3_field::PrimeField32; use p3_matrix::{dense::RowMajorMatrix, Matrix}; use p3_maybe_rayon::prelude::{ IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator, }; -use zkm_core_executor::events::{GlobalLookupEvent, MemoryLocalEvent}; +use zkm_core_executor::events::{ByteLookupEvent, ByteRecord, GlobalLookupEvent, MemoryLocalEvent}; use zkm_core_executor::{ExecutionRecord, Program}; use zkm_derive::AlignedBorrow; use zkm_stark::{ @@ -19,6 +20,8 @@ use zkm_stark::{ }; use crate::{ + air::{MemoryAirBuilder, WordAirBuilder}, + operations::KoalaBearBitDecomposition, utils::{next_power_of_two, zeroed_f_vec}, CoreChipError, }; @@ -32,6 +35,10 @@ pub struct SingleMemoryLocal { /// The address of the memory access. pub addr: T, + /// The bit decomposition of `addr`, used to range check that `addr` is a valid KoalaBear + /// field element (i.e. strictly less than the modulus `0x7F000001`). + pub addr_bits: KoalaBearBitDecomposition, + /// The initial shard of the memory access. pub initial_shard: T, @@ -44,6 +51,22 @@ pub struct SingleMemoryLocal { /// The final clk of the memory access. pub final_clk: T, + /// The 16-bit limb of `initial_shard`, used for its 16-bit range check. + pub initial_shard_16bit_limb: T, + + /// The 16-bit limb of `final_shard`, used for its 16-bit range check. + pub final_shard_16bit_limb: T, + + /// The 16-bit limb of `initial_clk`, used for its 24-bit range check. + pub initial_clk_16bit_limb: T, + /// The 8-bit limb of `initial_clk`, used for its 24-bit range check. + pub initial_clk_8bit_limb: T, + + /// The 16-bit limb of `final_clk`, used for its 24-bit range check. + pub final_clk_16bit_limb: T, + /// The 8-bit limb of `final_clk`, used for its 24-bit range check. + pub final_clk_8bit_limb: T, + /// The initial value of the memory access. pub initial_value: Word, @@ -100,6 +123,8 @@ impl MachineAir for MemoryLocalChip { output: &mut ExecutionRecord, ) -> Result<(), Self::Error> { let mut events = Vec::new(); + // Byte lookups required by the defense-in-depth range checks emitted in `eval`. + let mut blu: HashMap = HashMap::new(); input.get_local_mem_events().for_each(|mem_event| { events.push(GlobalLookupEvent { @@ -128,9 +153,27 @@ impl MachineAir for MemoryLocalChip { is_receive: false, kind: LookupKind::Memory as u8, }); + + // Byte range check the eight value limbs (initial and final). + blu.add_u8_range_checks(&mem_event.initial_mem_access.value.to_le_bytes()); + blu.add_u8_range_checks(&mem_event.final_mem_access.value.to_le_bytes()); + + // 16-bit range checks for shards. + for value in [mem_event.initial_mem_access.shard, mem_event.final_mem_access.shard] { + blu.add_u16_range_check(value as u16); + } + + // 24-bit range checks (16-bit + 8-bit limbs) for the clk fields. + for value in + [mem_event.initial_mem_access.timestamp, mem_event.final_mem_access.timestamp] + { + blu.add_u16_range_check((value & 0xffff) as u16); + blu.add_u8_range_check(0, ((value >> 16) & 0xff) as u8); + } }); output.global_lookup_events.extend(events); + output.add_byte_lookup_events_from_maps(vec![&blu]); Ok(()) } @@ -170,12 +213,37 @@ impl MachineAir for MemoryLocalChip { let cols = &mut cols.memory_local_entries[k]; if idx + k < events.len() { let event: &&MemoryLocalEvent = &events[idx + k]; + let initial_shard = event.initial_mem_access.shard; + let final_shard = event.final_mem_access.shard; + let initial_clk = event.initial_mem_access.timestamp; + let final_clk = event.final_mem_access.timestamp; + cols.addr = F::from_canonical_u32(event.addr); - cols.initial_shard = F::from_canonical_u32(event.initial_mem_access.shard); - cols.final_shard = F::from_canonical_u32(event.final_mem_access.shard); - cols.initial_clk = - F::from_canonical_u32(event.initial_mem_access.timestamp); - cols.final_clk = F::from_canonical_u32(event.final_mem_access.timestamp); + cols.addr_bits.populate(event.addr); + cols.initial_shard = F::from_canonical_u32(initial_shard); + cols.final_shard = F::from_canonical_u32(final_shard); + cols.initial_clk = F::from_canonical_u32(initial_clk); + cols.final_clk = F::from_canonical_u32(final_clk); + + // Populate the limbs backing the defense-in-depth range checks. + for (value, limb_16, limb_8) in [ + ( + initial_clk, + &mut cols.initial_clk_16bit_limb, + &mut cols.initial_clk_8bit_limb, + ), + ( + final_clk, + &mut cols.final_clk_16bit_limb, + &mut cols.final_clk_8bit_limb, + ), + ] { + *limb_16 = F::from_canonical_u32(value & 0xffff); + *limb_8 = F::from_canonical_u32((value >> 16) & 0xff); + } + cols.initial_shard_16bit_limb = F::from_canonical_u32(initial_shard); + cols.final_shard_16bit_limb = F::from_canonical_u32(final_shard); + cols.initial_value = event.initial_mem_access.value.into(); cols.final_value = event.final_mem_access.value.into(); cols.is_real = F::ONE; @@ -213,6 +281,41 @@ where for local in local.memory_local_entries.iter() { builder.assert_bool(local.is_real); + // Defense-in-depth: byte range check all eight value limbs via the byte lookup table. + builder.slice_range_check_u8(&local.initial_value.0, local.is_real); + builder.slice_range_check_u8(&local.final_value.0, local.is_real); + + // Defense-in-depth: range check `addr` to be a valid KoalaBear field element + // (strictly less than the modulus `0x7F000001`). + KoalaBearBitDecomposition::::range_check( + builder, + local.addr, + local.addr_bits, + local.is_real.into(), + ); + + // Defense-in-depth: range check shards to 16 bits and clocks to 24 bits. + builder + .when(local.is_real) + .assert_eq(local.initial_shard, local.initial_shard_16bit_limb); + builder.when(local.is_real).assert_eq(local.final_shard, local.final_shard_16bit_limb); + builder.slice_range_check_u16( + &[local.initial_shard_16bit_limb, local.final_shard_16bit_limb], + local.is_real, + ); + builder.eval_range_check_24bits( + local.initial_clk, + local.initial_clk_16bit_limb, + local.initial_clk_8bit_limb, + local.is_real, + ); + builder.eval_range_check_24bits( + local.final_clk, + local.final_clk_16bit_limb, + local.final_clk_8bit_limb, + local.is_real, + ); + let mut values = vec![local.initial_shard.into(), local.initial_clk.into(), local.addr.into()]; values.extend(local.initial_value.map(Into::into)); @@ -310,6 +413,49 @@ mod tests { } } + #[test] + fn test_memory_local_defense_in_depth_lookups() { + // Uses the inline `simple_program` (no guest ELF) so it can run without the zkVM + // toolchain. Verifies that the byte-lookup events recorded in `generate_dependencies` + // for the defense-in-depth range checks exactly balance the AIR `send_byte` calls, and + // that the memory lookups still balance. + setup_logger(); + let program = simple_program(); + let program_clone = program.clone(); + let mut runtime = Executor::new(program, ZKMCoreOpts::default()); + runtime.run().unwrap(); + + // Sanity check: the program must exercise the memory-local chip for the byte-lookup + // balance assertion below to be meaningful. + let n_local_events: usize = + runtime.records.iter().map(|r| r.get_local_mem_events().count()).sum(); + assert!(n_local_events > 0, "expected the test program to produce local memory events"); + + let machine: StarkMachine> = + MipsAir::machine(KoalaBearPoseidon2::new()); + let (pkey, _) = machine.setup(&program_clone); + let opts = ZKMCoreOpts::default(); + machine.generate_dependencies(&mut runtime.records, &opts, None).unwrap(); + + let shards = runtime.records; + for shard in shards.clone() { + debug_lookups_with_all_chips::>( + &machine, + &pkey, + &[shard], + vec![LookupKind::Memory], + LookupScope::Local, + ); + } + debug_lookups_with_all_chips::>( + &machine, + &pkey, + &shards, + vec![LookupKind::Byte], + LookupScope::Global, + ); + } + #[test] fn test_memory_lookup_lookups() { setup_logger(); diff --git a/crates/core/machine/src/operations/global_accumulation.rs b/crates/core/machine/src/operations/global_accumulation.rs index 40b45b4b9..ffff5db52 100644 --- a/crates/core/machine/src/operations/global_accumulation.rs +++ b/crates/core/machine/src/operations/global_accumulation.rs @@ -143,6 +143,13 @@ impl GlobalAccumulationOperation { }), }; + let assert_on_curve = |builder: &mut AB, point: SepticCurve| { + builder.assert_septic_ext_eq( + point.y.square(), + SepticCurve::::curve_formula(point.x), + ); + }; + let ith_cumulative_sum = |idx: usize| SepticCurve:: { x: SepticExtension::::from_base_fn(|i| { local_accumulation.cumulative_sum[idx][0].0[i].into() @@ -166,12 +173,17 @@ impl GlobalAccumulationOperation { builder.when_first_row().assert_septic_ext_eq(initial_digest.x.clone(), zero_digest.x); builder.when_first_row().assert_septic_ext_eq(initial_digest.y.clone(), zero_digest.y); + // Defense-in-depth: every witnessed running digest must stay on-curve even if the + // incomplete Weierstrass addition edge case is triggered. + assert_on_curve(builder, initial_digest.clone()); + // Constrain that when `is_real = 1`, addition is being carried out, and when `is_real = 0`, the sum remains the same. for i in 0..N { let current_sum = if i == 0 { initial_digest.clone() } else { ith_cumulative_sum(i - 1) }; let point_to_add = ith_point_to_add(i); let next_sum = ith_cumulative_sum(i); + assert_on_curve(builder, next_sum.clone()); // If `local_is_real[i] == 1`, current_sum + point_to_add == next_sum must hold. // To do this, constrain that `sum_checker_x` and `sum_checker_y` are both zero when `is_real == 1`. let sum_checker_x = SepticCurve::::sum_checker_x( diff --git a/crates/core/machine/src/operations/poseidon2/mod.rs b/crates/core/machine/src/operations/poseidon2/mod.rs index addc7d454..bdb5bedca 100644 --- a/crates/core/machine/src/operations/poseidon2/mod.rs +++ b/crates/core/machine/src/operations/poseidon2/mod.rs @@ -15,7 +15,10 @@ pub const RATE: usize = WIDTH / 2; pub const NUM_EXTERNAL_ROUNDS: usize = 8; /// The number of internal rounds. -pub const NUM_INTERNAL_ROUNDS: usize = 13; +/// +/// Must match `poseidon2_init()` in zkm-primitives (KoalaBear/α=3 reference R_P = 20), since the +/// Poseidon2Permute precompile witness is generated by that permutation. +pub const NUM_INTERNAL_ROUNDS: usize = 20; /// The total number of rounds. pub const NUM_ROUNDS: usize = NUM_EXTERNAL_ROUNDS + NUM_INTERNAL_ROUNDS; diff --git a/crates/core/machine/src/syscall/precompiles/boolean_circuit_garble/air.rs b/crates/core/machine/src/syscall/precompiles/boolean_circuit_garble/air.rs index 1f9b0128a..dfc545acd 100644 --- a/crates/core/machine/src/syscall/precompiles/boolean_circuit_garble/air.rs +++ b/crates/core/machine/src/syscall/precompiles/boolean_circuit_garble/air.rs @@ -55,17 +55,21 @@ impl BooleanCircuitGarbleChip { // In a true single-row trace, this chip only has the prelude row // (num_gates + delta read), never a gate row. let single_row_phase = builder.is_first_row() * builder.is_last_row(); - builder.when(single_row_phase).assert_one(local.is_first_row); + builder.when(single_row_phase.clone()).assert_one(local.is_first_row); + builder.when(single_row_phase).assert_one(local.is_empty); builder.assert_bool(local.is_real); builder.assert_bool(local.is_first_row); builder.assert_bool(local.is_first_gate); builder.assert_bool(local.not_last_gate); builder.assert_bool(local.is_gate); + builder.assert_bool(local.is_empty); builder.assert_bool(local.checks_acc); builder.assert_eq(local.is_first_gate * local.is_gate, local.is_first_gate); builder.assert_eq(local.is_last_gate * local.is_gate, local.is_last_gate); builder.assert_eq(local.not_last_gate * local.is_gate, local.not_last_gate); + builder.assert_eq(local.is_empty * local.is_first_row, local.is_empty); + builder.assert_zero(local.is_empty * local.is_gate); builder.assert_zero(local.is_last_gate * local.is_first_gate); builder.when(local.is_gate).assert_one(local.is_last_gate + local.not_last_gate); builder.assert_bool(local.gate_type[0]); @@ -109,26 +113,28 @@ impl BooleanCircuitGarbleChip { local.is_gate, ); } - // eval result write + let writes_result = local.is_last_gate + local.is_empty; builder.eval_memory_access( local.shard, local.clk, local.output_address, &local.result_mem, - local.is_last_gate, + writes_result.clone(), ); - // The syscall writes a boolean result (as u32) at the final gate. - builder - .when(local.is_last_gate) - .assert_eq(local.result_mem.access.value[0], local.checks_acc * local.checks[2]); - builder.when(local.is_last_gate).assert_zero(local.result_mem.access.value[1]); - builder.when(local.is_last_gate).assert_zero(local.result_mem.access.value[2]); - builder.when(local.is_last_gate).assert_zero(local.result_mem.access.value[3]); - builder.when(local.is_last_gate).assert_zero(local.result_mem.prev_value[0]); - builder.when(local.is_last_gate).assert_zero(local.result_mem.prev_value[1]); - builder.when(local.is_last_gate).assert_zero(local.result_mem.prev_value[2]); - builder.when(local.is_last_gate).assert_zero(local.result_mem.prev_value[3]); + // The syscall writes a boolean result (as u32) either after the final gate or directly on + // the prelude row for the empty circuit. + builder.when(writes_result.clone()).assert_eq( + local.result_mem.access.value[0], + local.is_empty + local.is_last_gate * local.checks_acc * local.checks[2], + ); + builder.when(writes_result.clone()).assert_zero(local.result_mem.access.value[1]); + builder.when(writes_result.clone()).assert_zero(local.result_mem.access.value[2]); + builder.when(writes_result.clone()).assert_zero(local.result_mem.access.value[3]); + builder.when(writes_result.clone()).assert_zero(local.result_mem.prev_value[0]); + builder.when(writes_result.clone()).assert_zero(local.result_mem.prev_value[1]); + builder.when(writes_result.clone()).assert_zero(local.result_mem.prev_value[2]); + builder.when(writes_result).assert_zero(local.result_mem.prev_value[3]); } fn eval_logic_check( @@ -217,12 +223,14 @@ impl BooleanCircuitGarbleChip { next: &BooleanCircuitGarbleCols, ) { let transition_continuation = local.not_last_gate * local.is_gate; + let first_row_with_gates = local.is_first_row - local.is_empty; let bytes_shift = AB::F::from_canonical_u32(256); let num_gates = local.gates_input_mem[0].access.value.0[0] + local.gates_input_mem[0].access.value.0[1] * bytes_shift + local.gates_input_mem[0].access.value.0[2] * bytes_shift * bytes_shift + local.gates_input_mem[0].access.value.0[3] * bytes_shift * bytes_shift * bytes_shift; builder.when(local.is_first_row).assert_eq(local.gates_num, num_gates.clone()); + builder.when(local.is_first_row).assert_zero(local.is_empty * local.gates_num); for i in 0..4 { let delta_i = local.gates_input_mem[i + 1].access.value; @@ -231,8 +239,7 @@ impl BooleanCircuitGarbleChip { } } - let gate_type_value = - local.gate_type[0] + local.gate_type[1] * AB::Expr::from_canonical_u32(OR_GATE_ID); + let gate_type_value = local.gate_type[1] * AB::Expr::from_canonical_u32(OR_GATE_ID); builder.when(local.is_gate).assert_eq(gate_type_value, num_gates); builder.when(local.is_first_gate).assert_zero(local.gate_id); @@ -243,19 +250,24 @@ impl BooleanCircuitGarbleChip { // Bridge the prelude row (num_gates + delta read) to the first gate row. builder - .when(local.is_first_row) + .when(first_row_with_gates.clone()) .assert_eq(next.input_address, local.input_address + AB::F::from_canonical_u32(20)); - builder.when(local.is_first_row).assert_eq(next.output_address, local.output_address); - builder.when(local.is_first_row).assert_eq(next.shard, local.shard); - builder.when(local.is_first_row).assert_eq(next.clk, local.clk); - builder.when(local.is_first_row).assert_eq(next.gates_num, local.gates_num); - builder.when(local.is_first_row).assert_zero(next.is_first_row); - builder.when(local.is_first_row).assert_eq(next.is_gate, next.is_first_gate); - builder.when(local.is_first_row).assert_eq(next.checks_acc, next.is_gate); + builder + .when(first_row_with_gates.clone()) + .assert_eq(next.output_address, local.output_address); + builder.when(first_row_with_gates.clone()).assert_eq(next.shard, local.shard); + builder.when(first_row_with_gates.clone()).assert_eq(next.clk, local.clk); + builder.when(first_row_with_gates.clone()).assert_eq(next.gates_num, local.gates_num); + builder.when(first_row_with_gates.clone()).assert_zero(next.is_first_row); + builder.when(first_row_with_gates.clone()).assert_zero(next.is_empty); + builder.when(first_row_with_gates.clone()).assert_one(next.is_first_gate); + builder.when(first_row_with_gates.clone()).assert_eq(next.is_gate, next.is_first_gate); + builder.when(first_row_with_gates.clone()).assert_eq(next.checks_acc, next.is_gate); // Continue with next gate row only when explicitly in same-event continuation. builder.when(transition_continuation.clone()).assert_one(next.is_gate); builder.when(transition_continuation.clone()).assert_zero(next.is_first_row); builder.when(transition_continuation.clone()).assert_zero(next.is_first_gate); + builder.when(transition_continuation.clone()).assert_zero(next.is_empty); builder .when(transition_continuation.clone()) .assert_eq(next.output_address, local.output_address); @@ -263,7 +275,9 @@ impl BooleanCircuitGarbleChip { builder.when(transition_continuation.clone()).assert_eq(next.clk, local.clk); for i in 0..4 { for j in 0..4 { - builder.when(local.is_first_row).assert_eq(local.delta[i][j], next.delta[i][j]); + builder + .when(first_row_with_gates.clone()) + .assert_eq(local.delta[i][j], next.delta[i][j]); } } diff --git a/crates/core/machine/src/syscall/precompiles/boolean_circuit_garble/columns.rs b/crates/core/machine/src/syscall/precompiles/boolean_circuit_garble/columns.rs index 7c5711426..84edea81b 100644 --- a/crates/core/machine/src/syscall/precompiles/boolean_circuit_garble/columns.rs +++ b/crates/core/machine/src/syscall/precompiles/boolean_circuit_garble/columns.rs @@ -31,6 +31,7 @@ pub struct BooleanCircuitGarbleCols { pub is_last_gate: T, #[cfg_attr(feature = "picus", picus(transition_input))] pub not_last_gate: T, // from first gate -> (last - 1)-th gate + pub is_empty: T, pub gate_type: [T; 2], #[cfg_attr(feature = "picus", picus(transition_input))] pub gate_id: T, diff --git a/crates/core/machine/src/syscall/precompiles/boolean_circuit_garble/trace.rs b/crates/core/machine/src/syscall/precompiles/boolean_circuit_garble/trace.rs index 7e558f626..6020a6ce1 100644 --- a/crates/core/machine/src/syscall/precompiles/boolean_circuit_garble/trace.rs +++ b/crates/core/machine/src/syscall/precompiles/boolean_circuit_garble/trace.rs @@ -4,7 +4,7 @@ use crate::syscall::precompiles::boolean_circuit_garble::columns::{ use crate::syscall::precompiles::boolean_circuit_garble::{ BooleanCircuitGarbleChip, GATE_INFO_BYTES, OR_GATE_ID, }; -use crate::CoreChipError; +use crate::{utils::next_power_of_two, CoreChipError}; use hashbrown::HashMap; use itertools::Itertools; use p3_field::PrimeField32; @@ -84,7 +84,11 @@ impl MachineAir for BooleanCircuitGarbleChip { }) .collect(); - let padded = if rows.is_empty() { 0 } else { rows.len().next_power_of_two() }; + let padded = next_power_of_two( + rows.len(), + input.fixed_log2_rows::(self), + >::name(self).as_str(), + ); rows.resize_with(padded, || [F::ZERO; NUM_BOOLEAN_CIRCUIT_GARBLE_COLS]); Ok(RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), @@ -124,6 +128,7 @@ impl BooleanCircuitGarbleChip { cols.is_real = F::ONE; cols.is_gate = F::ZERO; cols.is_first_row = F::ONE; + cols.is_empty = F::from_bool(gates_num == 0); cols.input_address = F::from_canonical_u32(input_address); cols.output_address = F::from_canonical_u32(event.output_addr); cols.gates_num = F::from_canonical_u32(gates_num as u32); @@ -142,6 +147,9 @@ impl BooleanCircuitGarbleChip { for i in 0..4 { cols.gates_input_mem[1 + i].populate(event.delta_read_records[i], blu); } + if gates_num == 0 { + cols.result_mem.populate(event.output_write_record, blu); + } rows.push(row); } @@ -155,6 +163,7 @@ impl BooleanCircuitGarbleChip { cols.is_gate = F::ONE; cols.input_address = F::from_canonical_u32(input_address); cols.output_address = F::from_canonical_u32(event.output_addr); + cols.is_empty = F::ZERO; cols.is_first_gate = F::from_bool(gate_id == 0); cols.is_last_gate = F::from_bool(gate_id == gates_num - 1); cols.not_last_gate = F::from_bool(gate_id != gates_num - 1); @@ -230,3 +239,258 @@ impl BooleanCircuitGarbleChip { rows } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::syscall::precompiles::boolean_circuit_garble::columns::BooleanCircuitGarbleCols; + use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues}; + use p3_field::FieldAlgebra; + use p3_koala_bear::KoalaBear; + use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView}; + use p3_matrix::stack::VerticalPair; + use p3_matrix::Matrix; + use std::{ + borrow::BorrowMut, + panic::{catch_unwind, set_hook, take_hook, AssertUnwindSafe}, + }; + use zkm_core_executor::{ + events::{ + BooleanCircuitGarbleEvent, MemoryReadRecord, MemoryWriteRecord, PrecompileEvent, + SyscallEvent, + }, + syscalls::SyscallCode, + ExecutionRecord, + }; + use zkm_stark::air::{EmptyMessageBuilder, MachineAir}; + + fn gate_info_words(gate_type: u32, delta: [u32; 4], valid: bool) -> [u32; GATE_INFO_BYTES] { + let h0 = [11, 12, 13, 14]; + let h1 = [21, 22, 23, 24]; + let label_b = [31, 32, 33, 34]; + let mut expected = [0u32; 4]; + for i in 0..4 { + expected[i] = h0[i] ^ h1[i] ^ label_b[i]; + if gate_type == OR_GATE_ID { + expected[i] ^= delta[i]; + } + } + if !valid { + expected[3] ^= 1; + } + + let mut words = [0u32; GATE_INFO_BYTES]; + words[0] = gate_type; + words[1..5].copy_from_slice(&h0); + words[5..9].copy_from_slice(&h1); + words[9..13].copy_from_slice(&label_b); + words[13..17].copy_from_slice(&expected); + words + } + + fn make_event(gate_types: &[(u32, bool)], output: u32) -> BooleanCircuitGarbleEvent { + let shard = 1; + let clk = 5; + let input_addr = 0x1000; + let output_addr = 0x2000; + let delta = [101, 102, 103, 104]; + let mut gates_info = Vec::new(); + for &(gate_type, valid) in gate_types { + gates_info.extend_from_slice(&gate_info_words(gate_type, delta, valid)); + } + + let mut timestamp = 1u32; + let num_gates_read_record = + MemoryReadRecord::new(gate_types.len() as u32, shard, timestamp, 0, 0); + timestamp += 1; + + let delta_read_records = core::array::from_fn(|i| { + let record = MemoryReadRecord::new(delta[i], shard, timestamp, 0, 0); + timestamp += 1; + record + }); + + let gates_read_records = gates_info + .iter() + .map(|&value| { + let record = MemoryReadRecord::new(value, shard, timestamp, 0, 0); + timestamp += 1; + record + }) + .collect(); + + let output_write_record = MemoryWriteRecord::new(output, shard, timestamp, 0, 0, 0); + + BooleanCircuitGarbleEvent { + shard, + clk, + input_addr, + output_addr, + num_gates: gate_types.len() as u32, + delta, + gates_info, + output, + num_gates_read_record, + delta_read_records, + gates_read_records, + output_write_record, + local_mem_access: vec![], + } + } + + fn trace_for_event(event: BooleanCircuitGarbleEvent) -> RowMajorMatrix { + let mut record = ExecutionRecord::default(); + let syscall_code = SyscallCode::BOOLEAN_CIRCUIT_GARBLE; + let syscall_event = SyscallEvent { + pc: 32, + next_pc: 36, + shard: event.shard, + clk: event.clk, + a_record: MemoryWriteRecord::default(), + a_record_is_real: false, + syscall_id: syscall_code.syscall_id(), + arg1: event.input_addr, + arg2: event.output_addr, + }; + record.precompile_events.add_event( + syscall_code, + syscall_event, + PrecompileEvent::BooleanCircuitGarble(event), + ); + + BooleanCircuitGarbleChip.generate_trace(&record, &mut ExecutionRecord::default()).unwrap() + } + + struct EvalBuilder<'a> { + local: &'a [KoalaBear], + next: &'a [KoalaBear], + is_first_row: bool, + is_last_row: bool, + } + + impl<'a> AirBuilder for EvalBuilder<'a> { + type F = KoalaBear; + type Expr = KoalaBear; + type Var = KoalaBear; + type M = VerticalPair, RowMajorMatrixView<'a, KoalaBear>>; + + fn main(&self) -> Self::M { + VerticalPair::new( + RowMajorMatrixView::new_row(self.local), + RowMajorMatrixView::new_row(self.next), + ) + } + + fn is_first_row(&self) -> Self::Expr { + KoalaBear::from_bool(self.is_first_row) + } + + fn is_last_row(&self) -> Self::Expr { + KoalaBear::from_bool(self.is_last_row) + } + + fn is_transition_window(&self, size: usize) -> Self::Expr { + assert_eq!(size, 2); + KoalaBear::from_bool(!self.is_last_row) + } + + fn assert_zero>(&mut self, x: I) { + assert_eq!(x.into(), KoalaBear::ZERO, "constraints had nonzero value"); + } + } + + impl<'a> AirBuilderWithPublicValues for EvalBuilder<'a> { + type PublicVar = KoalaBear; + + fn public_values(&self) -> &[Self::PublicVar] { + &[] + } + } + + impl<'a> EmptyMessageBuilder for EvalBuilder<'a> {} + + fn check_trace(trace: &RowMajorMatrix) { + let air = BooleanCircuitGarbleChip; + let height = trace.height(); + for row_index in 0..height { + let row_index_next = (row_index + 1) % height; + let local = trace.row_slice(row_index); + let next = trace.row_slice(row_index_next); + let mut builder = EvalBuilder { + local: &*local, + next: &*next, + is_first_row: row_index == 0, + is_last_row: row_index == height - 1, + }; + air.eval(&mut builder); + } + } + fn assert_gate_row_encoding(row: &mut [KoalaBear], expected_gate_type: u32) { + let cols: &mut BooleanCircuitGarbleCols = row.borrow_mut(); + assert_eq!(cols.gate_type[0], KoalaBear::from_bool(expected_gate_type == 0)); + assert_eq!(cols.gate_type[1], KoalaBear::from_bool(expected_gate_type == OR_GATE_ID)); + + let encoded_gate_type = cols.gate_type[1] * KoalaBear::from_canonical_u32(OR_GATE_ID); + let gate_type_word = cols.gates_input_mem[0].access.value[0]; + assert_eq!(encoded_gate_type, gate_type_word); + } + + #[test] + fn test_zero_gate_trace_writes_true_result() { + let event = make_event(&[], 1); + let chip = BooleanCircuitGarbleChip::default(); + let mut rows = chip.event_to_rows::(&event, &mut Vec::new()); + assert_eq!(rows.len(), 1); + + let cols: &mut BooleanCircuitGarbleCols = rows[0].as_mut_slice().borrow_mut(); + assert_eq!(cols.is_first_row, KoalaBear::ONE); + assert_eq!(cols.is_empty, KoalaBear::ONE); + assert_eq!(cols.result_mem.access.value[0], KoalaBear::ONE); + } + + #[test] + fn test_and_gate_type_encoding_matches_gate_word() { + let event = make_event(&[(0, true)], 1); + let chip = BooleanCircuitGarbleChip::default(); + let mut rows = chip.event_to_rows::(&event, &mut Vec::new()); + assert_eq!(rows.len(), 2); + assert_gate_row_encoding(&mut rows[1], 0); + } + + #[test] + fn test_mixed_gate_type_encoding_matches_gate_words() { + let event = make_event(&[(0, true), (OR_GATE_ID, true)], 1); + let chip = BooleanCircuitGarbleChip::default(); + let mut rows = chip.event_to_rows::(&event, &mut Vec::new()); + assert_eq!(rows.len(), 3); + assert_gate_row_encoding(&mut rows[1], 0); + assert_gate_row_encoding(&mut rows[2], OR_GATE_ID); + + let second_gate: &mut BooleanCircuitGarbleCols = + rows[2].as_mut_slice().borrow_mut(); + assert_eq!(second_gate.input_address, KoalaBear::from_canonical_u32(0x1000 + 20 + 68)); + assert_eq!(second_gate.result_mem.access.value[0], KoalaBear::ONE); + } + + #[test] + fn test_boolean_circuit_garble_air_accepts_valid_trace() { + let trace = trace_for_event(make_event(&[(0, true), (OR_GATE_ID, true), (0, true)], 1)); + check_trace(&trace); + } + + #[test] + fn test_boolean_circuit_garble_air_accepts_false_result() { + let trace = trace_for_event(make_event(&[(0, true), (OR_GATE_ID, false)], 0)); + check_trace(&trace); + } + + #[test] + fn test_boolean_circuit_garble_air_rejects_inconsistent_output() { + let trace = trace_for_event(make_event(&[(0, true), (OR_GATE_ID, false)], 1)); + let prev_hook = take_hook(); + set_hook(Box::new(|_| {})); + let result = catch_unwind(AssertUnwindSafe(|| check_trace(&trace))); + set_hook(prev_hook); + assert!(result.is_err()); + } +} diff --git a/crates/core/machine/src/syscall/precompiles/keccak_sponge/air.rs b/crates/core/machine/src/syscall/precompiles/keccak_sponge/air.rs index b10cd32a5..f4292cd41 100644 --- a/crates/core/machine/src/syscall/precompiles/keccak_sponge/air.rs +++ b/crates/core/machine/src/syscall/precompiles/keccak_sponge/air.rs @@ -58,6 +58,26 @@ where LookupScope::Local, ); + // The first block flag is fixed on the first row and then held steady + // within the block. Once a non-final block ends, the next block must + // reset this flag to zero. + builder.when(local.receive_syscall).assert_one(first_block); + builder + .when_transition() + .when(not_final_step.clone()) + .assert_eq(next.is_first_input_block, first_block); + builder.when(local.is_absorbed).assert_zero(next.is_first_input_block); + + // The final block flag stays low on all non-final blocks and is fixed + // to one on the final block's write-output row. Within a block it stays + // constant across Keccak rounds. + builder.when(local.write_output).assert_one(final_block); + builder.when(local.is_absorbed).assert_zero(final_block); + builder + .when_transition() + .when(not_final_step.clone()) + .assert_eq(next.is_final_input_block, final_block); + // Constrain that the inputs stay the same throughout the rows of each cycle let mut transition_builder = builder.when_transition(); let mut transition_not_final_builder = transition_builder.when(not_final_sponge.clone()); @@ -142,6 +162,12 @@ where .assert_eq(local.input_address, next.input_address); // If this is the first block, absorbed bytes should be 0 builder.when(first_block).assert_eq(local.already_absorbed_u32s, AB::Expr::zero()); + // If this is the first block, the sponge state must start from the + // fixed all-zero Keccak IV. + let mut first_block_builder = builder.when(first_block); + for i in 0..KECCAK_STATE_U32S { + first_block_builder.assert_word_zero(local.original_state[i]); + } // If this is the final block, absorbed bytes should be equal to the input length - KECCAK_GENERAL_RATE_U32S builder.when(final_block).assert_eq( local.already_absorbed_u32s, @@ -269,6 +295,9 @@ impl KeccakSpongeChip { // enforce booleanity and mutual exclusion of first/final step flags. // This prevents degenerate witnesses where a single row is both // first-round and final-round when summary internals are hidden. + builder.assert_bool(first_block); + builder.assert_bool(final_block); + builder.assert_bool(local.read_block); builder.when(local.is_real).assert_bool(first_step); builder.when(local.is_real).assert_bool(final_step); builder.when(local.is_real).assert_zero(first_step * final_step); diff --git a/crates/core/machine/src/syscall/precompiles/keccak_sponge/trace.rs b/crates/core/machine/src/syscall/precompiles/keccak_sponge/trace.rs index 68fb182da..608082e49 100644 --- a/crates/core/machine/src/syscall/precompiles/keccak_sponge/trace.rs +++ b/crates/core/machine/src/syscall/precompiles/keccak_sponge/trace.rs @@ -228,3 +228,82 @@ impl KeccakSpongeChip { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::syscall::precompiles::keccak_sponge::columns::KeccakSpongeCols; + use p3_field::FieldAlgebra; + use p3_koala_bear::KoalaBear; + use std::borrow::Borrow; + use zkm_core_executor::events::{ + ByteLookupEvent, KeccakSpongeEvent, MemoryReadRecord, MemoryWriteRecord, + }; + + fn make_event(blocks: usize) -> KeccakSpongeEvent { + let shard = 1; + let clk = 7; + let input_addr = 0x1000; + let output_addr = 0x2000; + let input_len_u32s = (blocks * KECCAK_GENERAL_RATE_U32S) as u32; + let input = vec![0u32; input_len_u32s as usize]; + let output = [0u32; KECCAK_GENERAL_OUTPUT_U32S]; + let input_length_record = MemoryReadRecord::new(input_len_u32s, shard, 1, 0, 0); + + let mut timestamp = 2u32; + let input_read_records = input + .iter() + .map(|&value| { + let record = MemoryReadRecord::new(value, shard, timestamp, 0, 0); + timestamp += 1; + record + }) + .collect::>(); + + let output_write_records = output + .iter() + .map(|&value| { + let record = MemoryWriteRecord::new(value, shard, timestamp, 0, 0, 0); + timestamp += 1; + record + }) + .collect::>(); + + KeccakSpongeEvent { + shard, + clk, + input, + output, + input_len_u32s, + input_read_records, + input_length_record, + output_write_records, + xored_state_list: vec![[0u64; 25]; blocks], + input_addr, + output_addr, + local_mem_access: vec![], + } + } + + #[test] + fn test_keccak_sponge_block_flags_follow_trace_positions() { + let chip = KeccakSpongeChip::new(); + let event = make_event(2); + let mut rows = Some(Vec::new()); + let mut blu = Vec::::new(); + chip.event_to_rows::(&event, &mut rows, &mut blu); + + let rows = rows.unwrap(); + assert_eq!(rows.len(), 2 * NUM_ROUNDS); + + for (index, row) in rows.iter().enumerate() { + let cols: &KeccakSpongeCols = row.as_slice().borrow(); + let block_index = index / NUM_ROUNDS; + let round_index = index % NUM_ROUNDS; + + assert_eq!(cols.is_first_input_block, KoalaBear::from_bool(block_index == 0)); + assert_eq!(cols.is_final_input_block, KoalaBear::from_bool(block_index == 1)); + assert_eq!(cols.read_block, KoalaBear::from_bool(round_index == 0)); + } + } +} diff --git a/crates/core/machine/src/syscall/precompiles/u256x2048_mul/air.rs b/crates/core/machine/src/syscall/precompiles/u256x2048_mul/air.rs index 57033999d..731fd7bd7 100644 --- a/crates/core/machine/src/syscall/precompiles/u256x2048_mul/air.rs +++ b/crates/core/machine/src/syscall/precompiles/u256x2048_mul/air.rs @@ -1,7 +1,7 @@ use crate::{ air::MemoryAirBuilder, memory::{value_as_limbs, MemoryCols, MemoryReadCols, MemoryWriteCols}, - operations::field::field_op::FieldOpCols, + operations::{field::field_op::FieldOpCols, KoalaBearWordRangeChecker}, utils::{limbs_from_access, pad_rows_fixed, words_to_bytes_le}, CoreChipError, }; @@ -74,6 +74,9 @@ pub struct U256x2048MulCols { pub lo_ptr_memory: MemoryReadCols, pub hi_ptr_memory: MemoryReadCols, + pub lo_ptr_range_checker: KoalaBearWordRangeChecker, + pub hi_ptr_range_checker: KoalaBearWordRangeChecker, + // Memory columns. pub a_memory: [MemoryReadCols; WORDS_FIELD_ELEMENT], pub b_memory: [MemoryReadCols; WORDS_FIELD_ELEMENT * 8], @@ -144,6 +147,8 @@ impl MachineAir for U256x2048MulChip { .populate(event.lo_ptr_memory, &mut new_byte_lookup_events); cols.hi_ptr_memory .populate(event.hi_ptr_memory, &mut new_byte_lookup_events); + cols.lo_ptr_range_checker.populate(event.lo_ptr_memory.value); + cols.hi_ptr_range_checker.populate(event.hi_ptr_memory.value); // Populate memory columns. for i in 0..WORDS_FIELD_ELEMENT { @@ -400,6 +405,21 @@ where ); } + // Range-check the raw pointer words before reducing them to field + // addresses, so the reduction is injective for these memory values. + KoalaBearWordRangeChecker::::range_check( + builder, + *local.lo_ptr_memory.value(), + local.lo_ptr_range_checker, + local.is_real.into(), + ); + KoalaBearWordRangeChecker::::range_check( + builder, + *local.hi_ptr_memory.value(), + local.hi_ptr_range_checker, + local.is_real.into(), + ); + // Constrain that the lo_ptr is the value of lo_ptr_memory. builder .when(local.is_real) diff --git a/crates/core/machine/src/utils/mod.rs b/crates/core/machine/src/utils/mod.rs index c20c9a819..f4a44c938 100644 --- a/crates/core/machine/src/utils/mod.rs +++ b/crates/core/machine/src/utils/mod.rs @@ -179,8 +179,22 @@ pub fn zkm_debug_mode() -> bool { /// /// This function is safe to use only for fields that can be transmuted from 0u32. pub fn zeroed_f_vec(len: usize) -> Vec { - debug_assert!(std::mem::size_of::() == 4); + assert!(std::mem::size_of::() == 4, "zeroed_f_vec only supports 4-byte field elements"); let vec = vec![0u32; len]; unsafe { std::mem::transmute::, Vec>(vec) } } + +#[cfg(test)] +mod tests { + use super::*; + use p3_field::FieldAlgebra; + use p3_koala_bear::KoalaBear; + + #[test] + fn zeroed_f_vec_returns_zero_values_for_koalabear() { + let values = zeroed_f_vec::(4); + assert_eq!(values.len(), 4); + assert!(values.iter().all(|value| *value == KoalaBear::ZERO)); + } +} diff --git a/crates/go-runtime/zkvm_runtime/deserialize.go b/crates/go-runtime/zkvm_runtime/deserialize.go index c98b3fe54..61a6d03dd 100644 --- a/crates/go-runtime/zkvm_runtime/deserialize.go +++ b/crates/go-runtime/zkvm_runtime/deserialize.go @@ -34,70 +34,113 @@ func DeserializeData(data []byte, e any) { } } +func readBytes(data []byte, index, size int) ([]byte, error) { + if index < 0 || size < 0 || index > len(data) || size > len(data)-index { + return nil, fmt.Errorf("deserialize failed: need %d bytes at offset %d, have %d", size, index, len(data)-index) + } + return data[index : index+size], nil +} + +func maxInt() int { + return int(^uint(0) >> 1) +} + func deserializeData(data []byte, v reflect.Value, index int) (int, error) { switch v.Kind() { case reflect.Bool: - v.SetBool(data[index] == 1) + b, err := readBytes(data, index, 1) + if err != nil { + return index, err + } + v.SetBool(b[0] == 1) return index + 1, nil case reflect.Int8: - v.SetInt(int64(int8(data[index]))) + b, err := readBytes(data, index, 1) + if err != nil { + return index, err + } + v.SetInt(int64(int8(b[0]))) return index + 1, nil case reflect.Uint8: - v.SetUint(uint64(data[index])) + b, err := readBytes(data, index, 1) + if err != nil { + return index, err + } + v.SetUint(uint64(b[0])) return index + 1, nil case reflect.Int16: - b := []byte{data[index], data[index+1]} + b, err := readBytes(data, index, 2) + if err != nil { + return index, err + } a := binary.LittleEndian.Uint16(b) v.SetInt(int64(int16(a))) return index + 2, nil case reflect.Uint16: - b := []byte{data[index], data[index+1]} + b, err := readBytes(data, index, 2) + if err != nil { + return index, err + } a := binary.LittleEndian.Uint16(b) v.SetUint(uint64(a)) return index + 2, nil case reflect.Int32: - b := []byte{data[index], data[index+1], data[index+2], data[index+3]} + b, err := readBytes(data, index, 4) + if err != nil { + return index, err + } a := binary.LittleEndian.Uint32(b) v.SetInt(int64(int32(a))) return index + 4, nil case reflect.Uint32: - b := []byte{data[index], data[index+1], data[index+2], data[index+3]} + b, err := readBytes(data, index, 4) + if err != nil { + return index, err + } a := binary.LittleEndian.Uint32(b) v.SetUint(uint64(a)) return index + 4, nil case reflect.Int64: - b := []byte{data[index], data[index+1], data[index+2], data[index+3], - data[index+4], data[index+5], data[index+6], data[index+7]} + b, err := readBytes(data, index, 8) + if err != nil { + return index, err + } a := binary.LittleEndian.Uint64(b) v.SetInt(int64(a)) return index + 8, nil case reflect.Uint64: - b := []byte{data[index], data[index+1], data[index+2], data[index+3], - data[index+4], data[index+5], data[index+6], data[index+7]} + b, err := readBytes(data, index, 8) + if err != nil { + return index, err + } a := binary.LittleEndian.Uint64(b) v.SetUint(a) return index + 8, nil case reflect.Slice: const maxSliceLen = 1_000_000 - b := []byte{data[index], data[index+1], data[index+2], data[index+3], - data[index+4], data[index+5], data[index+6], data[index+7]} + b, err := readBytes(data, index, 8) + if err != nil { + return index, err + } length := binary.LittleEndian.Uint64(b) index += 8 elemKind := v.Type().Elem().Kind() + if length > uint64(maxInt()) { + return index, fmt.Errorf("deserialize failed: slice length %d exceeds int range", length) + } + l := int(length) if elemKind == reflect.Uint8 { - if length > uint64(len(data)-index) { - return index, fmt.Errorf("deserialize failed: []byte length %d exceeds remaining %d", length, len(data)-index) + if l > len(data)-index { + return index, fmt.Errorf("deserialize failed: []byte length %d exceeds remaining %d", l, len(data)-index) } - bytes := data[index : index+int(length)] - v.SetBytes(bytes) - return index + int(length), nil + v.SetBytes(data[index : index+l]) + return index + l, nil } - if length > maxSliceLen { - return index, fmt.Errorf("deserialize failed: slice length %d exceeds max %d", length, maxSliceLen) + if l > maxSliceLen { + return index, fmt.Errorf("deserialize failed: slice length %d exceeds max %d", l, maxSliceLen) } - l := int(length) slice := reflect.MakeSlice(v.Type(), l, l) for i := 0; i < l; i++ { var err error @@ -118,20 +161,35 @@ func deserializeData(data []byte, v reflect.Value, index int) (int, error) { } return index, nil case reflect.String: - b := []byte{data[index], data[index+1], data[index+2], data[index+3], - data[index+4], data[index+5], data[index+6], data[index+7]} + b, err := readBytes(data, index, 8) + if err != nil { + return index, err + } l := binary.LittleEndian.Uint64(b) index += 8 + if l > uint64(maxInt()) { + return index, fmt.Errorf("deserialize failed: string length %d exceeds int range", l) + } length := int(l) + if length > len(data)-index { + return index, fmt.Errorf("deserialize failed: string length %d exceeds remaining %d", length, len(data)-index) + } str := make([]byte, length) copy(str[:], data[index:index+length]) v.SetString(string(str)) return index + length, nil case reflect.Ptr: - if data[index] == 0 { + b, err := readBytes(data, index, 1) + if err != nil { + return index, err + } + if b[0] == 0 { v.SetZero() return index + 1, nil } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } return deserializeData(data, v.Elem(), index+1) case reflect.Struct: for i := 0; i < v.NumField(); i++ { diff --git a/crates/go-runtime/zkvm_runtime/deserialize_test.go b/crates/go-runtime/zkvm_runtime/deserialize_test.go new file mode 100644 index 000000000..6b0b05b53 --- /dev/null +++ b/crates/go-runtime/zkvm_runtime/deserialize_test.go @@ -0,0 +1,51 @@ +package zkvm_runtime + +import ( + "reflect" + "testing" +) + +type nestedPayload struct { + Value *uint32 +} + +func TestDeserializeDataRejectsTruncatedInput(t *testing.T) { + var out uint16 + _, err := deserializeData([]byte{0x01}, reflect.ValueOf(&out).Elem(), 0) + if err == nil { + t.Fatal("expected an error for truncated input") + } +} + +func TestDeserializeDataRejectsOversizedSliceLength(t *testing.T) { + var out []byte + data := []byte{ + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + } + _, err := deserializeData(data, reflect.ValueOf(&out).Elem(), 0) + if err == nil { + t.Fatal("expected an error for oversized slice length") + } +} + +func TestDeserializeDataAllocatesNilPointerField(t *testing.T) { + var out nestedPayload + data := []byte{ + 0x01, + 0x78, 0x56, 0x34, 0x12, + } + index, err := deserializeData(data, reflect.ValueOf(&out).Elem(), 0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if index != len(data) { + t.Fatalf("unexpected index: got %d want %d", index, len(data)) + } + if out.Value == nil { + t.Fatal("expected nil pointer field to be allocated") + } + if *out.Value != 0x12345678 { + t.Fatalf("unexpected pointer value: got %x want %x", *out.Value, 0x12345678) + } +} diff --git a/crates/go-runtime/zkvm_runtime/runtime.go b/crates/go-runtime/zkvm_runtime/runtime.go index 2ddb90778..4b2e52a94 100644 --- a/crates/go-runtime/zkvm_runtime/runtime.go +++ b/crates/go-runtime/zkvm_runtime/runtime.go @@ -103,12 +103,23 @@ const MAX_MEMORY int = 0x7f000000 var RESERVED_INPUT_PTR int = MAX_MEMORY - EMBEDDED_RESERVED_INPUT_REGION_SIZE +func readReservedInput(capacity int) int { + addr := RESERVED_INPUT_PTR + if capacity < 0 { + panic("input region overflowed") + } + if addr < 0 || addr > MAX_MEMORY-capacity { + panic("input region overflowed") + } + RESERVED_INPUT_PTR = addr + capacity + return addr +} + func Read[T any]() T { len := SyscallHintLen() var value []byte capacity := (len + 3) / 4 * 4 - addr := RESERVED_INPUT_PTR - RESERVED_INPUT_PTR += capacity + addr := readReservedInput(capacity) ptr := unsafe.Pointer(uintptr(addr)) value = unsafe.Slice((*byte)(ptr), capacity) var result T diff --git a/crates/go-runtime/zkvm_runtime/runtime_test.go b/crates/go-runtime/zkvm_runtime/runtime_test.go new file mode 100644 index 000000000..da798efdb --- /dev/null +++ b/crates/go-runtime/zkvm_runtime/runtime_test.go @@ -0,0 +1,33 @@ +//go:build mipsle +// +build mipsle + +package zkvm_runtime + +import "testing" + +func TestReadReservedInputAdvancesPointer(t *testing.T) { + oldPtr := RESERVED_INPUT_PTR + defer func() { RESERVED_INPUT_PTR = oldPtr }() + + RESERVED_INPUT_PTR = MAX_MEMORY - EMBEDDED_RESERVED_INPUT_REGION_SIZE + addr := readReservedInput(16) + if addr != MAX_MEMORY-EMBEDDED_RESERVED_INPUT_REGION_SIZE { + t.Fatalf("unexpected addr: got %x want %x", addr, MAX_MEMORY-EMBEDDED_RESERVED_INPUT_REGION_SIZE) + } + if RESERVED_INPUT_PTR != addr+16 { + t.Fatalf("unexpected ptr: got %x want %x", RESERVED_INPUT_PTR, addr+16) + } +} + +func TestReadReservedInputPanicsOnOverflow(t *testing.T) { + oldPtr := RESERVED_INPUT_PTR + defer func() { RESERVED_INPUT_PTR = oldPtr }() + + RESERVED_INPUT_PTR = MAX_MEMORY - 8 + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic on overflow") + } + }() + _ = readReservedInput(16) +} diff --git a/crates/primitives/src/lib.rs b/crates/primitives/src/lib.rs index 991fb10e9..398de56d3 100644 --- a/crates/primitives/src/lib.rs +++ b/crates/primitives/src/lib.rs @@ -1106,7 +1106,9 @@ lazy_static! { pub fn poseidon2_init() -> Poseidon2KoalaBear<16> { const ROUNDS_F: usize = 8; - const ROUNDS_P: usize = 13; + // Plonky3's KoalaBear/α=3 reference (poseidon2_round_numbers_128) recommends R_P = 20 + // for a 31-bit field at width 16. (R_P=13 was a leftover from the BabyBear/α=7 fork.) + const ROUNDS_P: usize = 20; let mut round_constants = RC_16_30.to_vec(); let internal_start = ROUNDS_F / 2; let internal_end = (ROUNDS_F / 2) + ROUNDS_P; @@ -1148,3 +1150,25 @@ pub fn hash_deferred_proof( inputs.extend_from_slice(pv_digest); poseidon2_hash(inputs.to_vec()) } + +#[cfg(test)] +mod tests { + use p3_koala_bear::KoalaBear; + use p3_poseidon2::poseidon2_round_numbers_128; + + /// Ziren instantiates Poseidon2 over KoalaBear (width 16) with S-box degree α=3 (see + /// `poseidon2_init`). The partial- (internal-) round count must match Plonky3's own 128-bit + /// recommendation for that configuration; a lower-degree S-box needs *more* partial rounds for + /// algebraic-attack resistance, so under-counting silently weakens the hash. This pins R_F/R_P + /// to the library reference so switching the field or S-box can never leave the round count + /// stale (the original α=7 fork shipped R_P=13, which is the α=7 value, not the α=3 value). + #[test] + fn koalabear_poseidon2_round_counts_match_plonky3_reference() { + let (rounds_f, rounds_p) = poseidon2_round_numbers_128::(16, 3); + assert_eq!(rounds_f, 8, "external (full) round count drifted from the Plonky3 reference"); + assert_eq!( + rounds_p, 20, + "internal (partial) round count drifted from the Plonky3 reference" + ); + } +} diff --git a/crates/recursion/circuit/src/machine/compress.rs b/crates/recursion/circuit/src/machine/compress.rs index 5128eb2fc..b69e470a0 100644 --- a/crates/recursion/circuit/src/machine/compress.rs +++ b/crates/recursion/circuit/src/machine/compress.rs @@ -2,7 +2,6 @@ use std::{ array, borrow::{Borrow, BorrowMut}, marker::PhantomData, - mem::MaybeUninit, }; use itertools::{izip, Itertools}; @@ -103,9 +102,8 @@ where // Initialize the values for the aggregated public output. - let mut reduce_public_values_stream: Vec> = (0..RECURSIVE_PROOF_NUM_PV_ELTS) - .map(|_| unsafe { MaybeUninit::zeroed().assume_init() }) - .collect(); + let mut reduce_public_values_stream: Vec> = + (0..RECURSIVE_PROOF_NUM_PV_ELTS).map(|_| builder.uninit()).collect(); let compress_public_values: &mut RecursionPublicValues<_> = reduce_public_values_stream.as_mut_slice().borrow_mut(); @@ -115,27 +113,22 @@ where assert!(!vks_and_proofs.is_empty()); // Initialize the consistency check variables. - let mut zkm_vk_digest: [Felt<_>; DIGEST_SIZE] = - array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() }); - let mut pc: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() }; - let mut shard: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() }; + let mut zkm_vk_digest: [Felt<_>; DIGEST_SIZE] = array::from_fn(|_| builder.uninit()); + let mut pc: Felt<_> = builder.uninit(); + let mut shard: Felt<_> = builder.uninit(); let mut exit_code: Felt<_> = builder.uninit(); - let mut execution_shard: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() }; + let mut execution_shard: Felt<_> = builder.uninit(); let mut committed_value_digest: [Word>; PV_DIGEST_NUM_WORDS] = - array::from_fn(|_| { - Word(array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() })) - }); + array::from_fn(|_| Word(array::from_fn(|_| builder.uninit()))); let mut deferred_proofs_digest: [Felt<_>; POSEIDON_NUM_WORDS] = - array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() }); + array::from_fn(|_| builder.uninit()); let mut reconstruct_deferred_digest: [Felt<_>; POSEIDON_NUM_WORDS] = - core::array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() }); + core::array::from_fn(|_| builder.uninit()); let mut global_cumulative_sums = Vec::new(); - let mut init_addr_bits: [Felt<_>; 32] = - core::array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() }); - let mut finalize_addr_bits: [Felt<_>; 32] = - core::array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() }); + let mut init_addr_bits: [Felt<_>; 32] = core::array::from_fn(|_| builder.uninit()); + let mut finalize_addr_bits: [Felt<_>; 32] = core::array::from_fn(|_| builder.uninit()); // Initialize a flag to denote if any of the recursive proofs represents a shard range // where at least once of the shards is an execution shard (i.e. contains cpu). diff --git a/crates/recursion/circuit/src/machine/core.rs b/crates/recursion/circuit/src/machine/core.rs index b4d7617e7..c621e50bb 100644 --- a/crates/recursion/circuit/src/machine/core.rs +++ b/crates/recursion/circuit/src/machine/core.rs @@ -2,7 +2,6 @@ use std::{ array, borrow::{Borrow, BorrowMut}, marker::PhantomData, - mem::MaybeUninit, }; use itertools::Itertools; @@ -126,29 +125,27 @@ where input; // Initialize shard variables. - let mut initial_shard: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() }; - let mut current_shard: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() }; + let mut initial_shard: Felt<_> = builder.uninit(); + let mut current_shard: Felt<_> = builder.uninit(); // Initialize execution shard variables. - let mut initial_execution_shard: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() }; - let mut current_execution_shard: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() }; + let mut initial_execution_shard: Felt<_> = builder.uninit(); + let mut current_execution_shard: Felt<_> = builder.uninit(); // Initialize program counter variables. - let mut start_pc: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() }; - let mut current_pc: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() }; + let mut start_pc: Felt<_> = builder.uninit(); + let mut current_pc: Felt<_> = builder.uninit(); // Initialize memory initialization and finalization variables. let mut initial_previous_init_addr_bits: [Felt<_>; 32] = - unsafe { MaybeUninit::zeroed().assume_init() }; + array::from_fn(|_| builder.uninit()); let mut initial_previous_finalize_addr_bits: [Felt<_>; 32] = - unsafe { MaybeUninit::zeroed().assume_init() }; - let mut current_init_addr_bits: [Felt<_>; 32] = - unsafe { MaybeUninit::zeroed().assume_init() }; - let mut current_finalize_addr_bits: [Felt<_>; 32] = - unsafe { MaybeUninit::zeroed().assume_init() }; + array::from_fn(|_| builder.uninit()); + let mut current_init_addr_bits: [Felt<_>; 32] = array::from_fn(|_| builder.uninit()); + let mut current_finalize_addr_bits: [Felt<_>; 32] = array::from_fn(|_| builder.uninit()); // Initialize the exit code variable. - let mut exit_code: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() }; + let mut exit_code: Felt<_> = builder.uninit(); // Initialize the public values digest. let mut committed_value_digest: [Word>; PV_DIGEST_NUM_WORDS] = diff --git a/crates/recursion/circuit/src/machine/wrap.rs b/crates/recursion/circuit/src/machine/wrap.rs index 604669a13..b2586b85c 100644 --- a/crates/recursion/circuit/src/machine/wrap.rs +++ b/crates/recursion/circuit/src/machine/wrap.rs @@ -12,7 +12,7 @@ use zkm_stark::{air::MachineAir, StarkMachine}; use crate::{ challenger::CanObserveVariable, constraints::RecursiveVerifierConstraintFolder, - machine::{assert_root_public_values_valid, RootPublicValues}, + machine::{assert_complete, assert_root_public_values_valid, RootPublicValues}, stark::StarkVerifier, CircuitConfig, KoalaBearFriConfigVariable, }; @@ -78,6 +78,8 @@ where // Get the public values, and assert that they are valid. let public_values: &RootPublicValues> = proof.public_values.as_slice().borrow(); assert_root_public_values_valid::(builder, public_values); + builder.assert_felt_eq(public_values.inner.is_complete, C::F::ONE); + assert_complete(builder, &public_values.inner, public_values.inner.is_complete); // Reflect the public values to the next level. if zkm_imm_wrap_vk_mode() { diff --git a/crates/recursion/compiler/src/ir/collections.rs b/crates/recursion/compiler/src/ir/collections.rs index dca00fa37..ccad4cbd7 100644 --- a/crates/recursion/compiler/src/ir/collections.rs +++ b/crates/recursion/compiler/src/ir/collections.rs @@ -30,9 +30,7 @@ impl> Array { /// Shifts the array by `shift` elements. pub fn shift(&self, builder: &mut Builder, shift: Var) -> Array { match self { - Self::Fixed(_) => { - todo!() - } + Self::Fixed(_) => unreachable!("Array::Fixed does not support shift()"), Self::Dyn(ptr, len) => { assert!(V::size_of() == 1, "only support variables of size 1"); let new_address = builder.eval(ptr.address + shift); @@ -47,9 +45,7 @@ impl> Array { /// Truncates the array to `len` elements. pub fn truncate(&self, builder: &mut Builder, len: Usize) { match self { - Self::Fixed(_) => { - todo!() - } + Self::Fixed(_) => unreachable!("Array::Fixed does not support truncate()"), Self::Dyn(_, old_len) => { builder.assign(*old_len, len); } @@ -157,9 +153,7 @@ impl Builder { let index = index.into(); match slice { - Array::Fixed(_) => { - todo!() - } + Array::Fixed(_) => unreachable!("Array::Fixed does not support get_ptr()"), Array::Dyn(ptr, len) => { if self.debug { let index_v = index.materialize(self); @@ -184,8 +178,13 @@ impl Builder { let index = index.into(); match slice { - Array::Fixed(_) => { - todo!() + Array::Fixed(slice) => { + if let Usize::Const(idx) = index { + let value: V = self.eval(value); + slice[idx] = value; + } else { + unreachable!("Array::Fixed does not support symbolic indices in set()") + } } Array::Dyn(ptr, len) => { if self.debug { @@ -210,8 +209,12 @@ impl Builder { let index = index.into(); match slice { - Array::Fixed(_) => { - todo!() + Array::Fixed(slice) => { + if let Usize::Const(idx) = index { + slice[idx] = value; + } else { + unreachable!("Array::Fixed does not support symbolic indices in set_value()") + } } Array::Dyn(ptr, _) => { let index = MemIndex { index, offset: 0, size: V::size_of() }; @@ -364,3 +367,61 @@ impl + MemVariable, const N: usize> FromConstan value.map(|x| V::constant(x, builder)) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::OuterConfig; + use crate::ir::Felt; + use p3_field::FieldAlgebra; + + type C = OuterConfig; + type N = ::N; + type Ff = ::F; + + #[test] + fn fixed_array_set_and_set_value_update_elements() { + let mut builder = Builder::::default(); + let one: Felt = builder.eval(Ff::from_canonical_u32(1)); + let two: Felt = builder.eval(Ff::from_canonical_u32(2)); + let three: Felt = builder.eval(Ff::from_canonical_u32(3)); + let seven: Felt = builder.eval(Ff::from_canonical_u32(7)); + let nine: Felt = builder.eval(Ff::from_canonical_u32(9)); + let mut array = builder.vec::>(vec![one, two, three]); + + builder.set_value(&mut array, 1, seven); + builder.set(&mut array, 2, nine); + + let values = array.vec(); + assert_eq!(values.len(), 3); + assert_eq!(array.len(), Usize::Const(3)); + } + + #[test] + #[should_panic(expected = "Array::Fixed does not support shift()")] + fn fixed_array_shift_panics_with_clear_message() { + let mut builder = Builder::::default(); + let one: Felt = builder.eval(Ff::from_canonical_u32(1)); + let shift = builder.eval(N::from_canonical_u32(1)); + let array = builder.vec::>(vec![one]); + let _ = array.shift(&mut builder, shift); + } + + #[test] + #[should_panic(expected = "Array::Fixed does not support truncate()")] + fn fixed_array_truncate_panics_with_clear_message() { + let mut builder = Builder::::default(); + let one: Felt = builder.eval(Ff::from_canonical_u32(1)); + let array = builder.vec::>(vec![one]); + array.truncate(&mut builder, Usize::Const(0)); + } + + #[test] + #[should_panic(expected = "Array::Fixed does not support get_ptr()")] + fn fixed_array_get_ptr_panics_with_clear_message() { + let mut builder = Builder::::default(); + let one: Felt = builder.eval(Ff::from_canonical_u32(1)); + let array = builder.vec::>(vec![one]); + let _ = builder.get_ptr(&array, 0usize); + } +} diff --git a/crates/recursion/core/include/poseidon2_skinny.hpp b/crates/recursion/core/include/poseidon2_skinny.hpp index a26a32093..ef0f7ae86 100644 --- a/crates/recursion/core/include/poseidon2_skinny.hpp +++ b/crates/recursion/core/include/poseidon2_skinny.hpp @@ -87,13 +87,16 @@ __ZKM_HOSTDEV__ void instr_to_row(const Poseidon2Instr& instr, size_t i, cols.round_counters_preprocessed.is_internal_round = F::from_bool(i == INTERNAL_ROUND_IDX); - for (size_t j = 0; j < WIDTH; j++) { - if (is_external_round) { + // The shared `round_constants` column holds WIDTH per-lane constants on external rows and + // NUM_INTERNAL_ROUNDS constants on the single internal row, so iterate over the full column + // width (NUM_ROUND_CONSTANTS) and fill each kind in its own range. + for (size_t j = 0; j < NUM_ROUND_CONSTANTS; j++) { + if (is_external_round && j < WIDTH) { size_t r = i - 1; size_t round = (i < INTERNAL_ROUND_IDX) ? r : r + NUM_INTERNAL_ROUNDS - 1; cols.round_counters_preprocessed.round_constants[j] = F(F::to_monty(RC_16_30_U32[round][j])); - } else if (i == INTERNAL_ROUND_IDX) { + } else if (i == INTERNAL_ROUND_IDX && j < NUM_INTERNAL_ROUNDS) { cols.round_counters_preprocessed.round_constants[j] = F(F::to_monty(RC_16_30_U32[NUM_EXTERNAL_ROUNDS / 2 + j][0])); } else { diff --git a/crates/recursion/core/src/chips/poseidon2_skinny/air.rs b/crates/recursion/core/src/chips/poseidon2_skinny/air.rs index f1ac4b2c4..b6f837432 100644 --- a/crates/recursion/core/src/chips/poseidon2_skinny/air.rs +++ b/crates/recursion/core/src/chips/poseidon2_skinny/air.rs @@ -10,7 +10,8 @@ use crate::{builder::ZKMRecursionAirBuilder, chips::poseidon2_skinny::columns::P use super::{ columns::{preprocessed::Poseidon2PreprocessedCols, NUM_POSEIDON2_COLS}, - external_linear_layer, internal_linear_layer, Poseidon2SkinnyChip, NUM_INTERNAL_ROUNDS, WIDTH, + external_linear_layer, internal_linear_layer, Poseidon2SkinnyChip, NUM_INTERNAL_ROUNDS, + NUM_ROUND_CONSTANTS, WIDTH, }; impl BaseAir for Poseidon2SkinnyChip { @@ -128,7 +129,7 @@ impl Poseidon2SkinnyChip { builder: &mut AB, local_row: &Poseidon2, next_row: &Poseidon2, - round_constants: [AB::Var; WIDTH], + round_constants: [AB::Var; NUM_ROUND_CONSTANTS], is_internal_row: AB::Var, ) { let local_state = local_row.state_var; diff --git a/crates/recursion/core/src/chips/poseidon2_skinny/columns/preprocessed.rs b/crates/recursion/core/src/chips/poseidon2_skinny/columns/preprocessed.rs index 8bcc0360e..29cb0372c 100644 --- a/crates/recursion/core/src/chips/poseidon2_skinny/columns/preprocessed.rs +++ b/crates/recursion/core/src/chips/poseidon2_skinny/columns/preprocessed.rs @@ -1,6 +1,9 @@ use zkm_derive::AlignedBorrow; -use crate::chips::{mem::MemoryAccessColsChips, poseidon2_skinny::WIDTH}; +use crate::chips::{ + mem::MemoryAccessColsChips, + poseidon2_skinny::{NUM_ROUND_CONSTANTS, WIDTH}, +}; #[derive(AlignedBorrow, Clone, Copy, Debug)] #[repr(C)] @@ -8,7 +11,7 @@ pub struct RoundCountersPreprocessedCols { pub is_input_round: T, pub is_external_round: T, pub is_internal_round: T, - pub round_constants: [T; WIDTH], + pub round_constants: [T; NUM_ROUND_CONSTANTS], } #[derive(AlignedBorrow, Clone, Copy, Debug)] diff --git a/crates/recursion/core/src/chips/poseidon2_skinny/mod.rs b/crates/recursion/core/src/chips/poseidon2_skinny/mod.rs index 19944ef00..2b409a640 100644 --- a/crates/recursion/core/src/chips/poseidon2_skinny/mod.rs +++ b/crates/recursion/core/src/chips/poseidon2_skinny/mod.rs @@ -13,9 +13,21 @@ pub const WIDTH: usize = 16; pub const RATE: usize = WIDTH / 2; pub const NUM_EXTERNAL_ROUNDS: usize = 8; -pub const NUM_INTERNAL_ROUNDS: usize = 13; +// Must match `poseidon2_init()` in zkm-primitives (KoalaBear/α=3 reference R_P = 20). Native and +// circuit (poseidon2_wide) chips must stay equal or the equivalence test breaks. +pub const NUM_INTERNAL_ROUNDS: usize = 20; pub const NUM_ROUNDS: usize = NUM_EXTERNAL_ROUNDS + NUM_INTERNAL_ROUNDS; +/// Number of slots in each preprocessed row's shared `round_constants` column. +/// +/// The column is reused by two kinds of rows: every external-round row uses its first `WIDTH` +/// entries as the per-lane round constants, while the single internal-rounds row packs all +/// `NUM_INTERNAL_ROUNDS` internal-round constants. The column must hold the larger of the two. For +/// KoalaBear/α=3, `NUM_INTERNAL_ROUNDS` (20) exceeds `WIDTH` (16); the assertion below guards the +/// invariant if either constant changes (with R_P=13 the old code relied on `WIDTH` being larger). +pub const NUM_ROUND_CONSTANTS: usize = NUM_INTERNAL_ROUNDS; +const _: () = assert!(NUM_ROUND_CONSTANTS >= WIDTH); + /// A chip that implements the Poseidon2 permutation in the skinny variant (one external round per /// row and one row for all internal rounds). pub struct Poseidon2SkinnyChip(PhantomData<()>); diff --git a/crates/recursion/core/src/chips/poseidon2_skinny/trace.rs b/crates/recursion/core/src/chips/poseidon2_skinny/trace.rs index 3087fae7c..a57170455 100644 --- a/crates/recursion/core/src/chips/poseidon2_skinny/trace.rs +++ b/crates/recursion/core/src/chips/poseidon2_skinny/trace.rs @@ -24,6 +24,8 @@ use crate::chips::poseidon2_skinny::internal_linear_layer; #[cfg(not(feature = "sys"))] use crate::chips::poseidon2_skinny::NUM_INTERNAL_ROUNDS; #[cfg(not(feature = "sys"))] +use crate::chips::poseidon2_skinny::NUM_ROUND_CONSTANTS; +#[cfg(not(feature = "sys"))] use crate::chips::poseidon2_skinny::WIDTH; use crate::{ chips::poseidon2_skinny::{ @@ -239,21 +241,25 @@ impl MachineAir for Poseidon2SkinnyChip cols.round_counters_preprocessed.is_internal_round = F::from_bool(i == INTERNAL_ROUND_IDX); - (0..WIDTH).for_each(|j| { - cols.round_counters_preprocessed.round_constants[j] = if is_external_round { - let r = i - 1; - let round = if i < INTERNAL_ROUND_IDX { - r + // The shared `round_constants` column holds `WIDTH` per-lane constants on + // external rows and `NUM_INTERNAL_ROUNDS` constants on the single internal row, + // so iterate over the full column width and fill each kind in its own range. + (0..NUM_ROUND_CONSTANTS).for_each(|j| { + cols.round_counters_preprocessed.round_constants[j] = + if is_external_round && j < WIDTH { + let r = i - 1; + let round = if i < INTERNAL_ROUND_IDX { + r + } else { + r + NUM_INTERNAL_ROUNDS - 1 + }; + + F::from_wrapped_u32(RC_16_30_U32[round][j]) + } else if i == INTERNAL_ROUND_IDX && j < NUM_INTERNAL_ROUNDS { + F::from_wrapped_u32(RC_16_30_U32[NUM_EXTERNAL_ROUNDS / 2 + j][0]) } else { - r + NUM_INTERNAL_ROUNDS - 1 + F::ZERO }; - - F::from_wrapped_u32(RC_16_30_U32[round][j]) - } else if i == INTERNAL_ROUND_IDX { - F::from_wrapped_u32(RC_16_30_U32[NUM_EXTERNAL_ROUNDS / 2 + j][0]) - } else { - F::ZERO - }; }); // Set the memory columns. We read once, at the first iteration, diff --git a/crates/recursion/core/src/chips/poseidon2_wide/mod.rs b/crates/recursion/core/src/chips/poseidon2_wide/mod.rs index 5316f57aa..0ef2c38ae 100644 --- a/crates/recursion/core/src/chips/poseidon2_wide/mod.rs +++ b/crates/recursion/core/src/chips/poseidon2_wide/mod.rs @@ -19,7 +19,9 @@ pub const WIDTH: usize = 16; pub const RATE: usize = WIDTH / 2; pub const NUM_EXTERNAL_ROUNDS: usize = 8; -pub const NUM_INTERNAL_ROUNDS: usize = 13; +// Must match `poseidon2_init()` in zkm-primitives (KoalaBear/α=3 reference R_P = 20). Native +// (poseidon2_skinny) and circuit chips must stay equal or the equivalence test breaks. +pub const NUM_INTERNAL_ROUNDS: usize = 20; pub const NUM_ROUNDS: usize = NUM_EXTERNAL_ROUNDS + NUM_INTERNAL_ROUNDS; /// A chip that implements addition for the opcode Poseidon2Wide. diff --git a/crates/stark/src/kb31_poseidon2.rs b/crates/stark/src/kb31_poseidon2.rs index 234241147..9109789de 100644 --- a/crates/stark/src/kb31_poseidon2.rs +++ b/crates/stark/src/kb31_poseidon2.rs @@ -183,7 +183,8 @@ pub mod koala_bear_poseidon2 { #[must_use] pub fn my_perm() -> Perm { const ROUNDS_F: usize = 8; - const ROUNDS_P: usize = 13; + // Must match `poseidon2_init()` in zkm-primitives (KoalaBear/α=3 reference R_P = 20). + const ROUNDS_P: usize = 20; let mut round_constants = RC_16_30.to_vec(); let internal_start = ROUNDS_F / 2; let internal_end = (ROUNDS_F / 2) + ROUNDS_P; diff --git a/crates/stark/src/machine.rs b/crates/stark/src/machine.rs index 32a96eb6f..c5ca72afd 100644 --- a/crates/stark/src/machine.rs +++ b/crates/stark/src/machine.rs @@ -10,7 +10,7 @@ use p3_uni_stark::{get_symbolic_constraints, SymbolicAirBuilder}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; -use std::{cmp::Reverse, env, fmt::Debug, iter::once, time::Instant}; +use std::{cmp::Reverse, collections::HashSet, env, fmt::Debug, iter::once, time::Instant}; use tracing::instrument; use super::{debug_constraints, Dom}; @@ -177,6 +177,56 @@ impl>> StarkMachine { .sorted_by_key(|chip| chip_ordering.get(&chip.name())) } + fn validate_chip_ordering( + &self, + vk: &StarkVerifyingKey, + proof: &ShardProof, + ) -> Result<(), VerificationError> { + let n = proof.opened_values.chips.len(); + if proof.chip_ordering.len() != n { + return Err(VerificationError::InvalidChipOrdering( + "chip ordering length does not match opened values".to_string(), + )); + } + + let known_chips = self.chips.iter().map(|chip| chip.name()).collect::>(); + let mut seen_indices = vec![false; n]; + for (name, &index) in &proof.chip_ordering { + if !known_chips.contains(name) { + return Err(VerificationError::InvalidChipOrdering(format!( + "unknown chip in ordering: {name}" + ))); + } + if index >= n { + return Err(VerificationError::InvalidChipOrdering(format!( + "chip ordering index out of bounds for {name}: {index}" + ))); + } + if seen_indices[index] { + return Err(VerificationError::InvalidChipOrdering(format!( + "duplicate chip ordering index: {index}" + ))); + } + seen_indices[index] = true; + } + + if seen_indices.iter().any(|seen| !seen) { + return Err(VerificationError::InvalidChipOrdering( + "chip ordering indices are not contiguous".to_string(), + )); + } + + for (name, _, _) in &vk.chip_information { + if !proof.chip_ordering.contains_key(name) { + return Err(VerificationError::InvalidChipOrdering(format!( + "missing preprocessed chip in ordering: {name}" + ))); + } + } + + Ok(()) + } + /// Returns the config of the machine. pub const fn config(&self) -> &SC { &self.config @@ -637,6 +687,8 @@ impl> + Air>(); let mut shard_challenger = challenger.clone(); diff --git a/crates/stark/src/verifier.rs b/crates/stark/src/verifier.rs index 52dd7b81c..dcf174f22 100644 --- a/crates/stark/src/verifier.rs +++ b/crates/stark/src/verifier.rs @@ -1,9 +1,11 @@ use core::fmt::Display; use std::{ + collections::HashSet, fmt::{Debug, Formatter}, marker::PhantomData, }; +use hashbrown::HashMap; use itertools::Itertools; use num_traits::cast::ToPrimitive; use p3_air::{Air, BaseAir}; @@ -25,6 +27,58 @@ use crate::{ pub struct Verifier(PhantomData, PhantomData); impl>> Verifier { + fn validate_chip_ordering( + vk: &StarkVerifyingKey, + chips: &[&MachineChip], + opened_values_len: usize, + chip_ordering: &HashMap, + ) -> Result<(), VerificationError> { + if chip_ordering.len() != opened_values_len { + return Err(VerificationError::InvalidChipOrdering( + "chip ordering length does not match opened values".to_string(), + )); + } + if chips.len() != opened_values_len { + return Err(VerificationError::ChipOpeningLengthMismatch); + } + + let chip_names = chips.iter().map(|chip| chip.name()).collect::>(); + let mut seen_indices = vec![false; opened_values_len]; + for (name, &index) in chip_ordering { + if !chip_names.contains(name) { + return Err(VerificationError::InvalidChipOrdering(format!( + "unexpected chip in ordering: {name}" + ))); + } + if index >= opened_values_len { + return Err(VerificationError::InvalidChipOrdering(format!( + "chip ordering index out of bounds for {name}: {index}" + ))); + } + if seen_indices[index] { + return Err(VerificationError::InvalidChipOrdering(format!( + "duplicate chip ordering index: {index}" + ))); + } + seen_indices[index] = true; + } + if seen_indices.iter().any(|seen| !seen) { + return Err(VerificationError::InvalidChipOrdering( + "chip ordering indices are not contiguous".to_string(), + )); + } + + for (name, _, _) in &vk.chip_information { + if !chip_ordering.contains_key(name) { + return Err(VerificationError::InvalidChipOrdering(format!( + "missing preprocessed chip in ordering: {name}" + ))); + } + } + + Ok(()) + } + /// Verify a proof for a collection of air chips. #[allow(clippy::too_many_lines)] pub fn verify_shard( @@ -54,6 +108,8 @@ impl>> Verifier { return Err(VerificationError::ChipOpeningLengthMismatch); } + Self::validate_chip_ordering(vk, chips, opened_values.chips.len(), chip_ordering)?; + // Assert that the byte multiplicities don't overflow. let mut max_byte_lookup_mult = 0u64; chips.iter().zip(opened_values.chips.iter()).for_each(|(chip, val)| { @@ -466,6 +522,8 @@ pub enum VerificationError { MissingCpuChip, /// The length of the chip opening does not match the expected length. ChipOpeningLengthMismatch, + /// The prover-supplied chip ordering is malformed. + InvalidChipOrdering(String), /// Cumulative sums error CumulativeSumsError(&'static str), } @@ -518,6 +576,9 @@ impl Debug for VerificationError { VerificationError::ChipOpeningLengthMismatch => { write!(f, "Chip opening length mismatch") } + VerificationError::InvalidChipOrdering(s) => { + write!(f, "Invalid chip ordering: {}", s) + } VerificationError::CumulativeSumsError(s) => write!(f, "cumulative sums error: {}", s), } } @@ -542,6 +603,9 @@ impl Display for VerificationError { VerificationError::ChipOpeningLengthMismatch => { write!(f, "Chip opening length mismatch") } + VerificationError::InvalidChipOrdering(s) => { + write!(f, "Invalid chip ordering: {}", s) + } VerificationError::CumulativeSumsError(s) => write!(f, "cumulative sums error: {}", s), } }