feat(gadget): add in-circuit blake3 hash#58
Conversation
|
Part of #52 |
Hakkush-07
left a comment
There was a problem hiding this comment.
Using the already existing and optimized functionality gives a much better gate count, for example Blake3 for 64 byte input drops from 21600 to 10752 (nonfree gates) according to my experiments (also total gate count drops slightly as well).
| fn xor_wire<C: CircuitContext>(ctx: &mut C, a: WireId, b: WireId) -> WireId { | ||
| let result = ctx.issue_wire(); | ||
| ctx.add_gate(crate::Gate::xor(a, b, result)); | ||
| result | ||
| } |
There was a problem hiding this comment.
I think this doesn't need a separate function, just use Gate::xor with the circuit context when needed.
| fn and_wire<C: CircuitContext>(ctx: &mut C, a: WireId, b: WireId) -> WireId { | ||
| let result = ctx.issue_wire(); | ||
| ctx.add_gate(crate::Gate::and(a, b, result)); | ||
| result | ||
| } |
There was a problem hiding this comment.
Similarly this function can be removed as well.
| fn and_u32<C: CircuitContext>(circuit: &mut C, a: U32, b: U32) -> U32 { | ||
| let c: Vec<WireId> = (0..32).map(|i| and_wire(circuit, a[i], b[i])).collect(); | ||
| c.try_into().unwrap() | ||
| } |
There was a problem hiding this comment.
This is only used in or_u32 as far as I can see. For or_u32, we can just use Gate::or as everything else, we dont need xor+and combo.
|
Gate counts after 5a76546 and variants: 10848 |
|
Thank you, looks good now. |
| let xpy = Self::xor(circuit, x, y); | ||
| let xmy = Self::and(circuit, x, y); | ||
| Self::xor(circuit, xpy, xmy) | ||
| } |
There was a problem hiding this comment.
Final comment, is there a reason for not just using Gate::or instead of 2XOR+1AND like this?
|
actually, i have one final remark |
|
@manishbista28 is this ready for merging? It is still in draft. |
| fn first_8_words(compression_output: [U32; 16]) -> [U32; 8] { | ||
| compression_output[0..8].try_into().unwrap() | ||
| } | ||
|
|
||
| fn words_from_little_endian_bytes(bytes: &[U8], words: &mut [U32]) { | ||
| debug_assert_eq!(bytes.len(), 4 * words.len()); | ||
| for (four_bytes, word) in bytes.chunks_exact(4).zip(words) { | ||
| let wire_vec: Vec<WireId> = four_bytes.iter().flat_map(|x| x.0).collect(); | ||
| let app_four_bytes: U32 = U32(wire_vec.try_into().unwrap()); | ||
| *word = app_four_bytes; | ||
| } | ||
| } | ||
|
|
||
| struct Output { | ||
| input_chaining_value: [U32; 8], | ||
| block_words: [U32; 16], | ||
| block_len: U32, | ||
| flags: U32, | ||
| } | ||
|
|
||
| impl Output { | ||
| fn root_output_bytes<C: CircuitContext>(&self, circuit: &mut C, out_slice: &mut [U8]) { | ||
| let root = U32::from_constant(ROOT); | ||
| for (output_block_counter, out_block) in out_slice.chunks_mut(2 * OUT_LEN).enumerate() { | ||
| let flags = U32::or(circuit, self.flags, root); | ||
| let words = compress( | ||
| circuit, | ||
| &self.input_chaining_value, | ||
| &self.block_words, | ||
| output_block_counter as u64, | ||
| self.block_len, | ||
| flags, | ||
| ); | ||
| for (word_bits, out_word_bits) in words.iter().zip(out_block.chunks_mut(4)) { | ||
| for (i, byte_bits) in out_word_bits.iter_mut().enumerate() { | ||
| let arr: U8 = U8(word_bits.0[8 * i..(i + 1) * 8].try_into().unwrap()); | ||
| *byte_bits = arr; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| struct ChunkState { | ||
| chaining_value: [U32; 8], | ||
| chunk_counter: u64, | ||
| block: [U8; BLOCK_LEN], | ||
| block_len: u8, | ||
| blocks_compressed: u8, | ||
| flags: U32, | ||
| } | ||
|
|
||
| impl ChunkState { | ||
| fn new(key_words: [U32; 8], chunk_counter: u64, flags: U32) -> Self { | ||
| Self { | ||
| chaining_value: key_words, | ||
| chunk_counter, | ||
| block: [U8([FALSE_WIRE; 8]); BLOCK_LEN], | ||
| block_len: 0, | ||
| blocks_compressed: 0, | ||
| flags, | ||
| } | ||
| } | ||
|
|
||
| fn len(&self) -> usize { | ||
| BLOCK_LEN * self.blocks_compressed as usize + self.block_len as usize | ||
| } | ||
|
|
||
| fn start_flag(&self) -> U32 { | ||
| let r = if self.blocks_compressed == 0 { | ||
| CHUNK_START | ||
| } else { | ||
| 0 | ||
| }; | ||
| U32::from_constant(r) | ||
| } | ||
|
|
||
| fn update<C: CircuitContext>(&mut self, circuit: &mut C, mut input: &[U8]) { | ||
| let zero_gate = FALSE_WIRE; | ||
| let block_len = U32::from_constant(BLOCK_LEN as u32); | ||
| while !input.is_empty() { | ||
| // If the block buffer is full, compress it and clear it. More | ||
| // input is coming, so this compression is not CHUNK_END. | ||
| if self.block_len as usize == BLOCK_LEN { | ||
| let mut block_words = [U32([zero_gate; 32]); 16]; | ||
| words_from_little_endian_bytes(&self.block, &mut block_words); | ||
| let start_flag = self.start_flag(); | ||
| let flags = U32::or(circuit, self.flags, start_flag); | ||
| let cmp = compress( | ||
| circuit, | ||
| &self.chaining_value, | ||
| &block_words, | ||
| self.chunk_counter, | ||
| block_len, | ||
| flags, | ||
| ); | ||
| self.chaining_value = first_8_words(cmp); | ||
| self.blocks_compressed += 1; | ||
| self.block = [U8([zero_gate; 8]); BLOCK_LEN]; | ||
| self.block_len = 0; | ||
| } | ||
|
|
||
| // Copy input bytes into the block buffer. | ||
| let want = BLOCK_LEN - self.block_len as usize; | ||
| let take = min(want, input.len()); | ||
| self.block[self.block_len as usize..][..take].copy_from_slice(&input[..take]); | ||
| self.block_len += take as u8; | ||
| input = &input[take..]; | ||
| } | ||
| } | ||
|
|
||
| fn output<C: CircuitContext>(&self, circuit: &mut C) -> Output { | ||
| let zero_gate = FALSE_WIRE; | ||
| let mut block_words = [U32([zero_gate; 32]); 16]; | ||
| words_from_little_endian_bytes(&self.block, &mut block_words); | ||
| let start_flag = self.start_flag(); | ||
| let flags = U32::or(circuit, self.flags, start_flag); | ||
| let chunk_end = U32::from_constant(CHUNK_END); | ||
| let flags = U32::or(circuit, flags, chunk_end); | ||
|
|
||
| Output { | ||
| input_chaining_value: self.chaining_value, | ||
| block_words, | ||
| block_len: U32::from_constant(self.block_len as u32), | ||
| flags, | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /// An incremental hasher that can accept any number of writes. | ||
| pub(crate) struct Hasher { | ||
| chunk_state: ChunkState, | ||
| } | ||
|
|
||
| impl Hasher { | ||
| fn new_internal(key_words: [U32; 8], flags: U32) -> Self { | ||
| Self { | ||
| chunk_state: ChunkState::new(key_words, 0, flags), | ||
| } | ||
| } | ||
|
|
||
| /// Construct a new `Hasher` for the regular hash function. | ||
| pub(crate) fn new() -> Self { | ||
| let zero_gate = FALSE_WIRE; | ||
| let iv = get_iv(); | ||
| let zero = U32([zero_gate; 32]); | ||
| Self::new_internal(iv, zero) | ||
| } | ||
|
|
||
| /// Add input to the hash state. This can be called any number of times. | ||
| pub(crate) fn update<C: CircuitContext>(&mut self, circuit: &mut C, mut input: &[U8]) { | ||
| while !input.is_empty() { | ||
| // Compress input bytes into the current chunk state. | ||
| let want = CHUNK_LEN - self.chunk_state.len(); | ||
| let take = min(want, input.len()); | ||
| self.chunk_state.update(circuit, &input[..take]); | ||
| input = &input[take..]; | ||
| } | ||
| } | ||
|
|
||
| /// Finalize the hash and write any number of output bytes. | ||
| pub(crate) fn finalize<C: CircuitContext>(&self, circuit: &mut C, out_slice: &mut [U8]) { | ||
| let output = self.chunk_state.output(circuit); | ||
| output.root_output_bytes(circuit, out_slice); | ||
| } | ||
| } | ||
|
|
There was a problem hiding this comment.
If we're anyway happy with the <512 bits limitation, we could just remove all of this code - just a call to compress would suffice. Its interface could then be changed to
pub fn compress<C: CircuitContext>(circuit: &mut C, input_bits: BigIntWires) -> BigIntWires {since the other values can all be hardcoded (and length can be set to input_bits length).
Edit: nvm, you have 1024 bytes limitation - I was assuming 512 bits limitation, i.e. a single call of compress
This PR adds circuit implementation of blake3 hash which is useful especially for validating groth16 proofs generated by ZKVMs.
Core implementation is referenced from official reference implementation.
Unit tests as well as integration with groth16 verifier has been added.
Note to Reviewers:
Current implementation truncates top 3 bits to fit 256-bit blake3 hash output into 253-bit scalar field element as is done by SP1 ZKVM. Other ZKVMs employ different approach for this hash-to-fr conversion.