From bf9ef512c8fa8680553daa60cd1b68c84678bf76 Mon Sep 17 00:00:00 2001 From: semar Date: Mon, 6 Oct 2025 10:00:40 -0300 Subject: [PATCH 01/18] adds temp (lessComparators.lean) file --- Clean/Circomlib/lessComparators.lean | 214 +++++++++++++++++++++++++++ 1 file changed, 214 insertions(+) create mode 100644 Clean/Circomlib/lessComparators.lean diff --git a/Clean/Circomlib/lessComparators.lean b/Clean/Circomlib/lessComparators.lean new file mode 100644 index 000000000..816099625 --- /dev/null +++ b/Clean/Circomlib/lessComparators.lean @@ -0,0 +1,214 @@ +/- import Std.Data.Vector.Basic -/ +import Clean.Circuit +import Clean.Utils.Bits +import Clean.Circomlib.Bitify +import Mathlib.Data.Nat.Bitwise +import Mathlib.Data.ZMod.Basic + +/- +Original source code: +https://github.com/iden3/circomlib/blob/35e54ea21da3e8762557234298dbb553c175ea8d/circuits/comparators.circom +-/ + +namespace Circomlib +open Utils.Bits +variable {p : ℕ} [Fact p.Prime] [Fact (p > 2)] + + +namespace LessThan +/- +template LessThan(n) { + signal input in[2]; + signal output out; + + component n2b = Num2Bits(n+1); + + n2b.in <== in[0]+ (1 << n) - in[1]; + + out <== 1-n2b.out[n]; +} +-/ + +lemma a_lt_b_eq_sum_lt_2n {p : ℕ} [Fact p.Prime] + (n : ℕ) (a b : F p) (hn : ZMod.val a < ZMod.val b) : + (ZMod.val a : ℤ) + 2^n - (ZMod.val b : ℤ) < 2^n := by + rw [Int.add_sub_assoc, Int.add_comm] + + linarith [Int.ofNat_lt.mpr hn] + +lemma val_pow_of_lt {p n : ℕ} [NeZero p] (h : 2^n < p) : + ((2^n : ℕ): ZMod p).val = 2^n := by + rw [ZMod.val_natCast_of_lt (a := 2^n) (n := p)] + exact h + +lemma test_lemma {p : ℕ} [Fact p.Prime] + (n : ℕ) (a b : F p) (hab : ZMod.val a < ZMod.val b) + (ha : ZMod.val a < 2^n) + (hb : ZMod.val b < 2^n) + (hp : 2^n < p) + (hp' : 2^(n+1) < p) + (h_val: (ZMod.val a + 2^n - ZMod.val b) < 2^n) : + (a + 2 ^ n - b).val < 2 ^ n := by + + have h2n : ((2^n : ℕ) : ZMod p).val = 2^n := by + change (((2^n : ℕ) : ZMod p).val) = 2^n + simp only [ZMod.val_natCast] + simp only [Nat.mod_eq_of_lt hp] + + have h_int : (a.val : ℤ) + 2^n - (b.val : ℤ) < 2^n := by + linarith + + have h_aa : ZMod.val a = a := by + have hap : a.val < p := by + linarith + + sorry + /- rw [ZMod.val_natCast p (a.val: ℕ)] -/ + /- rw [FieldUtils.natToField a hap] -/ + /- simp only [instFieldFOfFactPrime.congr_simp] -/ + /- apply FieldUtils.val_lt_p -/ + /- linarith -/ + + + have h_a2n : (a.val : ℕ) + 2^n < p + 0 := by + have hap : a.val < p := by + linarith + + sorry + /- rw [ZMod.val_cast_of_lt hap] -/ + /- simp_all only [Nat.add_lt_add_of_le_of_lt (a := a.val) (b := p) (c := 2^n) (d := 0) hap ] -/ + /- linarith -/ + + /- ring_nf at h_int -/ + /- ring_nf -/ + /- simp_all only [Nat.cast_pow, Nat.cast_ofNat, ZMod.natCast_val] -/ + simp_all only [Nat.cast_pow, Nat.cast_ofNat, ZMod.natCast_val] + /- simp_all only [ZMod.val_sub, ZMod.cast_sub, Int.cast_sub] -/ + sorry + +lemma zmod_def {p : ℕ} [Fact p.Prime] (x : ZMod p) : + x.val = ZMod.val x := rfl + +lemma bit_is_clear {p : ℕ} [Fact p.Prime] + (n : ℕ) (a : ZMod p) + (hlt : ZMod.val a < 2^n) : + (ZMod.val a).testBit n = false := by + rw [Nat.testBit_eq_decide_div_mod_eq] + have hbpos : 0 < 2^n := pow_pos (by decide) n + have hdiv : ZMod.val a / 2^n = 0 := + Nat.div_eq_of_lt hlt + rw [hdiv, Nat.zero_mod] + simp only [zero_ne_one, decide_false] + +lemma bit_is_set {p : ℕ} [Fact p.Prime] + /- (n : ℕ) (a b : ℕ) -/ + (n : ℕ) (x : F p) + (hlt: ZMod.val x < 2^(n+1)) + (hge: ZMod.val x > 2^n) : + (ZMod.val x).testBit n = true := by + rw [Nat.testBit_eq_decide_div_mod_eq ] + + /- ⊢ ZMod.val a / 2 ^ n % 2 = 1 -/ + simp only [decide_eq_true_eq] + set x := ZMod.val x + -- lower bound: 1 ≤ x / 2^n + have hbpos : 0 < 2^n := pow_pos (by decide) n + have h1 : 1 ≤ x / 2^n := by + simp only [Nat.le_div_iff_mul_le hbpos, one_mul] + apply Nat.le_of_lt at hge + exact hge + + -- upper bound: x / 2^n < 2 + have h2 : x / 2^n < 2 := by + rw [Nat.div_lt_iff_lt_mul hbpos] + + rw [← Nat.pow_add_one'] + exact hlt + + rw [le_antisymm (Nat.lt_succ_iff.mp h2) h1] + +def main (n : ℕ) (hn : 2^(n+1) < p) (input : Expression (F p) × Expression (F p)) := do + let diff := input.1 + (2^n : F p) - input.2 + let bits ← Num2Bits.circuit (n + 1) hn diff + let out <==1 - bits[n] + return out + +def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field where + main := main n hn + localLength _ := n + 2 + localLength_eq := by simp [circuit_norm, main, Num2Bits.circuit] + output _ i := var ⟨ i + n + 1 ⟩ + output_eq := by simp +arith [circuit_norm, main, Num2Bits.circuit] + + Assumptions := fun (x, y) => x.val < 2^n ∧ y.val < 2^n -- TODO: ∧ n <= 252 + + Spec := fun (x, y) output => + output = (if x.val < y.val then 1 else 0) + + soundness := by + intro i₀ env input_var input h_input h_assumptions h_holds + unfold main at * + simp only [circuit_norm, Num2Bits.circuit] at h_holds + simp only [circuit_norm] at * +-- + rw [← h_input] +-- + have h1 : Expression.eval env input_var.1 = input.1 := by + rw [← h_input] + have h2 : Expression.eval env input_var.2 = input.2 := by + rw [← h_input] + set a := input.1 + set b := input.2 + rw [h1, h2] at h_holds + rw [h1, h2] + simp only [id_eq] + set summation := ((ZMod.val a : ℤ) + 2 ^ n + -(ZMod.val b : ℤ)) + rw [← Nat.add_assoc] at h_holds + rw [h_holds.right] + obtain ⟨⟨h_holds1, h_holds2⟩, h_holds3⟩ := h_holds + simp only [Vector.ext_iff] at h_holds2 + + specialize h_holds2 n (Nat.lt_succ_self n) + + rw [Vector.getElem_map, Vector.getElem_mapRange] at h_holds2 + simp only [circuit_norm] at h_holds2 + simp only [fieldToBits, toBits] at h_holds2 + rw [Vector.getElem_map, Vector.getElem_mapRange] at h_holds2 + simp only [Nat.cast_ite, Nat.cast_one, Nat.cast_zero] at h_holds2 + rw [h_holds2] + by_cases hab : ZMod.val a < ZMod.val b + . + have h_assumptions_a' : ZMod.val a < 2 ^ (n+1) := + lt_trans h_assumptions.left (Nat.pow_lt_pow_succ (a:=2) (by decide : 1 < 2) ) + + have hn := a_lt_b_eq_sum_lt_2n n a b hab + + simp only at hn + /- change (a + 2^n - b).val = 2^n -/ + repeat rw [Mathlib.Tactic.RingNF.add_neg] + + have hb := bit_is_clear n (a + 2 ^ n - b) hn + + rw [← zmod_def] at hb + + have hc : (ZMod.val (a + 2 ^ n - b)).testBit n = false := by + exact hb + + + rw [hc] + simp + exact hab + + . + + have h_assumptions_b' : ZMod.val b < 2 ^ (n+1) := + lt_trans h_assumptions.right (Nat.pow_lt_pow_succ (a:=2) (by decide : 1 < 2) ) + have hn := a_lt_b_eq_sum_lt_2n n a b hab + + sorry + + completeness := by + simp only [circuit_norm, main] + sorry + +end LessThan From 0b9e55120d5b6fca538cf3981038df5ad7ad50b2 Mon Sep 17 00:00:00 2001 From: semar Date: Wed, 8 Oct 2025 17:07:48 -0300 Subject: [PATCH 02/18] some progress on lessThan implementation --- Clean/Circomlib/lessComparators.lean | 221 ++++++++++++++++++++------- 1 file changed, 164 insertions(+), 57 deletions(-) diff --git a/Clean/Circomlib/lessComparators.lean b/Clean/Circomlib/lessComparators.lean index 816099625..57d6ab1a5 100644 --- a/Clean/Circomlib/lessComparators.lean +++ b/Clean/Circomlib/lessComparators.lean @@ -29,6 +29,29 @@ template LessThan(n) { } -/ +lemma a_lt_b_eq_sum_lt_2n_nat {p : ℕ} [Fact p.Prime] + (n : ℕ) (a b : F p) (hn : ZMod.val a < ZMod.val b) : + ZMod.val a + 2^n - ZMod.val b < 2^n := by + have h_diff_pos : ZMod.val b - ZMod.val a > 0 := Nat.sub_pos_of_lt hn + have h_eq : ZMod.val a + 2 ^ n - ZMod.val b = 2 ^ n - (ZMod.val b - ZMod.val a) := by + -- rewrite b as a + (b - a), then cancel the common +a on both sides of the subtraction + have hb : ZMod.val b = ZMod.val a + (ZMod.val b - ZMod.val a) := + (Nat.add_sub_of_le (Nat.le_of_lt hn)).symm + calc + ZMod.val a + 2 ^ n - ZMod.val b + = ZMod.val a + 2 ^ n - (ZMod.val a + (ZMod.val b - ZMod.val a)) := by + rw [hb] + simp only [Nat.add_sub_cancel_left] + _ = (2 ^ n + ZMod.val a) - ((ZMod.val b - ZMod.val a) + ZMod.val a) := by + ac_rfl + _ = 2 ^ n - (ZMod.val b - ZMod.val a) := by + simp only [Nat.add_sub_add_right] + -- exact Nat.add_sub_add_right (2 ^ n) (ZMod.val b - ZMod.val a) (ZMod.val a) + + rw [h_eq] + have hpos : 0 < ZMod.val b - ZMod.val a := Nat.sub_pos_of_lt hn + simp_all only [gt_iff_lt, tsub_pos_iff_lt, tsub_lt_self_iff, Nat.ofNat_pos, pow_pos, and_self] + lemma a_lt_b_eq_sum_lt_2n {p : ℕ} [Fact p.Prime] (n : ℕ) (a b : F p) (hn : ZMod.val a < ZMod.val b) : (ZMod.val a : ℤ) + 2^n - (ZMod.val b : ℤ) < 2^n := by @@ -37,58 +60,132 @@ lemma a_lt_b_eq_sum_lt_2n {p : ℕ} [Fact p.Prime] linarith [Int.ofNat_lt.mpr hn] lemma val_pow_of_lt {p n : ℕ} [NeZero p] (h : 2^n < p) : - ((2^n : ℕ): ZMod p).val = 2^n := by + (2^n: ZMod p).val = 2^n := by + rw [ZMod.val_natCast_of_lt (a := 2^n) (n := p)] exact h -lemma test_lemma {p : ℕ} [Fact p.Prime] - (n : ℕ) (a b : F p) (hab : ZMod.val a < ZMod.val b) +-- Helper: no wrap on a + 2^n +omit [Fact (p > 2)] +private lemma add_twoPow_val (a : ZMod p) (n : ℕ) + (ha : a.val < 2 ^ n) (hp : 2 ^ n < p) (hp' : 2 ^ (n+1) < p) : + (a + 2 ^ n).val = a.val + 2 ^ n := by + + have h2n : (2^n: ZMod p).val = 2^n := by + exact val_pow_of_lt hp + + -- symm at h2n + + have hlt : a.val + 2 ^ n < p := lt_trans + (by + have : a.val + 2 ^ n < 2 ^ n + 2 ^ n := Nat.add_lt_add_right ha _ + simp [pow_succ] + rw [Nat.mul_two] + exact this + ) + hp' + have hlt' : a.val + (2^n : ZMod p).val < p := by + rw [← h2n] + + sorry + -- simp only [ZMod.val_natCast] at * + -- exact hlt + + rw [ZMod.val_add_of_lt hlt'] + /- rw [← ZMod.val_add] -/ + /- simp_all only [ZMod.val_le] -/ + + rw [Nat.mod_eq_iff_lt] + rw [ZMod.val_natCast_of_lt (a := 2^n) (n := p) hp] + -- (x + y).val = (x.val + y.val) % p, and mod is identity if < p + simp [Nat.mod_eq_of_lt hlt] + + -- Helper: no wrap on (a + 2^n) - b because b.val ≤ a.val + 2^n +private lemma sub_no_wrap_val (n : ℕ) (a b : ZMod p) + (ha : a.val < 2 ^ n) (hb : b.val < 2 ^ n) (hp : 2^n < p) (hp' : 2 ^ (n+1) < p) : + ((a + 2 ^ n) - b).val = (a.val + 2 ^ n) - b.val := by + -- First compute (a + 2^n).val without wrap + have hadd : (a + 2 ^ n).val = a.val + 2 ^ n := + add_twoPow_val (n:=n) a ha hp hp' + -- b.val ≤ 2^n ≤ 2^n + a.val = (a + 2^n).val + have hb_le_twoPow : b.val ≤ 2 ^ n := Nat.le_of_lt hb + have twoPow_le_sum : 2 ^ n ≤ (a.val + 2 ^ n) := by + simp [Nat.add_comm] + have hble : b.val ≤ (a.val + 2 ^ n) := le_trans hb_le_twoPow twoPow_le_sum + -- For subtraction in ZMod: if x.val ≥ y.val then (x - y).val = x.val - y.val + -- Rewrite x.val using hadd, then finish. + have hres : ((a + 2 ^ n) - b).val = (a + (2 ^ n : ℕ)).val - b.val := by + rw [ZMod.val_sub] + rw [hadd] + exact hble + + rw [hres] + rw [hadd] + +/- set_option maxRecDepth 1_000_000 -/ +/- set_option maxRecDepth 500_000 -/ +/- set_option maxHeartbeats 400_000 -/ +/- set_option diagnostics true -/ + +lemma test_lemma_nocast {p : ℕ} [Fact p.Prime] + (n : ℕ) (a b : F p) (ha : ZMod.val a < 2^n) (hb : ZMod.val b < 2^n) (hp : 2^n < p) (hp' : 2^(n+1) < p) + /- (hab : ZMod.val a < ZMod.val b) -/ (h_val: (ZMod.val a + 2^n - ZMod.val b) < 2^n) : (a + 2 ^ n - b).val < 2 ^ n := by + have h2n : ((2^n : ℕ) : ZMod p).val = 2^n := by + change (((2^n : ℕ) : ZMod p).val) = 2^n + simp only [ZMod.val_natCast] + simp only [Nat.mod_eq_of_lt hp] + + symm at h2n + simp only [ZMod.val_natCast] at h2n + + rw [sub_no_wrap_val n a b ha hb hp hp'] + + exact h_val +lemma test_lemma {p : ℕ} [Fact p.Prime] + (n : ℕ) (a b : F p) + (ha : ZMod.val a < 2^n) + (hb : ZMod.val b < 2^n) + (hp : 2^n < p) + (hp' : 2^(n+1) < p) + /- (hab : ZMod.val a < ZMod.val b) -/ + (h_val: (ZMod.val a + 2^n - ZMod.val b) < 2^n) : + (a + (2 ^ n : ℕ) - b).val < 2 ^ n := by have h2n : ((2^n : ℕ) : ZMod p).val = 2^n := by change (((2^n : ℕ) : ZMod p).val) = 2^n simp only [ZMod.val_natCast] simp only [Nat.mod_eq_of_lt hp] - have h_int : (a.val : ℤ) + 2^n - (b.val : ℤ) < 2^n := by - linarith - - have h_aa : ZMod.val a = a := by - have hap : a.val < p := by - linarith - - sorry - /- rw [ZMod.val_natCast p (a.val: ℕ)] -/ - /- rw [FieldUtils.natToField a hap] -/ - /- simp only [instFieldFOfFactPrime.congr_simp] -/ - /- apply FieldUtils.val_lt_p -/ - /- linarith -/ - - - have h_a2n : (a.val : ℕ) + 2^n < p + 0 := by - have hap : a.val < p := by - linarith - - sorry - /- rw [ZMod.val_cast_of_lt hap] -/ - /- simp_all only [Nat.add_lt_add_of_le_of_lt (a := a.val) (b := p) (c := 2^n) (d := 0) hap ] -/ - /- linarith -/ - - /- ring_nf at h_int -/ - /- ring_nf -/ - /- simp_all only [Nat.cast_pow, Nat.cast_ofNat, ZMod.natCast_val] -/ - simp_all only [Nat.cast_pow, Nat.cast_ofNat, ZMod.natCast_val] - /- simp_all only [ZMod.val_sub, ZMod.cast_sub, Int.cast_sub] -/ - sorry + symm at h2n + + rw [sub_no_wrap_val n a b ha hb hp hp'] + + exact h_val + +lemma solve_if {p : ℕ} [Fact p.Prime] + (n : ℕ) (a b : F p) + : + (1 + -if (ZMod.val (a + 2 ^ n + -b)).testBit n = true then 1 else 0) = if ZMod.val a < ZMod.val b then 1 else 0 := by + sorry + lemma zmod_def {p : ℕ} [Fact p.Prime] (x : ZMod p) : x.val = ZMod.val x := rfl +lemma remove_casting {p : ℕ} [Fact p.Prime] + (n : ℕ) (a b : F p) + (h_val: (((ZMod.val a) : ℕ) + 2^n - ((ZMod.val b): ℕ)) < 2^n) : + (ZMod.val a + 2^n - ZMod.val b) < 2^n := by + + simp_all only + + lemma bit_is_clear {p : ℕ} [Fact p.Prime] (n : ℕ) (a : ZMod p) (hlt : ZMod.val a < 2^n) : @@ -100,10 +197,10 @@ lemma bit_is_clear {p : ℕ} [Fact p.Prime] rw [hdiv, Nat.zero_mod] simp only [zero_ne_one, decide_false] -lemma bit_is_set {p : ℕ} [Fact p.Prime] +lemma bit_is_set {p : ℕ} [Fact p.Prime] /- (n : ℕ) (a b : ℕ) -/ - (n : ℕ) (x : F p) - (hlt: ZMod.val x < 2^(n+1)) + (n : ℕ) (x : F p) + (hlt: ZMod.val x < 2^(n+1)) (hge: ZMod.val x > 2^n) : (ZMod.val x).testBit n = true := by rw [Nat.testBit_eq_decide_div_mod_eq ] @@ -130,7 +227,7 @@ lemma bit_is_set {p : ℕ} [Fact p.Prime] def main (n : ℕ) (hn : 2^(n+1) < p) (input : Expression (F p) × Expression (F p)) := do let diff := input.1 + (2^n : F p) - input.2 let bits ← Num2Bits.circuit (n + 1) hn diff - let out <==1 - bits[n] + let out <== 1 - bits[n] return out def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field where @@ -150,9 +247,14 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w unfold main at * simp only [circuit_norm, Num2Bits.circuit] at h_holds simp only [circuit_norm] at * --- +-- rw [← h_input] --- + + + have hn' : 2 ^ n < p := by + apply lt_trans (Nat.pow_lt_pow_right (by decide) (Nat.lt_succ_self n)) + exact hn + have h1 : Expression.eval env input_var.1 = input.1 := by rw [← h_input] have h2 : Expression.eval env input_var.2 = input.2 := by @@ -162,7 +264,7 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w rw [h1, h2] at h_holds rw [h1, h2] simp only [id_eq] - set summation := ((ZMod.val a : ℤ) + 2 ^ n + -(ZMod.val b : ℤ)) + set summation := ((ZMod.val a : ℤ) + 2 ^ n + -(ZMod.val b : ℤ)) rw [← Nat.add_assoc] at h_holds rw [h_holds.right] obtain ⟨⟨h_holds1, h_holds2⟩, h_holds3⟩ := h_holds @@ -177,30 +279,35 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w simp only [Nat.cast_ite, Nat.cast_one, Nat.cast_zero] at h_holds2 rw [h_holds2] by_cases hab : ZMod.val a < ZMod.val b - . - have h_assumptions_a' : ZMod.val a < 2 ^ (n+1) := - lt_trans h_assumptions.left (Nat.pow_lt_pow_succ (a:=2) (by decide : 1 < 2) ) + . + -- sum is < 2^n, so nth bit is 0 + have h_lt : ZMod.val a + 2 ^ n - ZMod.val b < 2 ^ n := + a_lt_b_eq_sum_lt_2n_nat n a b hab + have h_val_lt : ZMod.val (a + (2 ^ n : F p) - b) < 2 ^ n := by + /- simp only [ZMod.val_natCast] at h_lt -/ + exact test_lemma_nocast n a b h_assumptions.left h_assumptions.right hn' hn h_lt + have h_bit_clear : (ZMod.val (a + (2 ^ n : F p) - b)).testBit n = false := + bit_is_clear n (a + (2 ^ n : F p) - b) h_val_lt + have h_bit_clear' : (ZMod.val (a + 2 ^ n + -b)).testBit n = false := by + simpa [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h_bit_clear + simp [h_bit_clear', hab] - have hn := a_lt_b_eq_sum_lt_2n n a b hab - - simp only at hn - /- change (a + 2^n - b).val = 2^n -/ - repeat rw [Mathlib.Tactic.RingNF.add_neg] + -- have h_lt := a_lt_b_eq_sum_lt_2n_nat n a b hab - have hb := bit_is_clear n (a + 2 ^ n - b) hn + -- simp only [ZMod.val_cast_of_lt] at h_lt - rw [← zmod_def] at hb + -- have h_lt' := test_lemma n a b h_assumptions.left h_assumptions.right hn' hn h_lt - have hc : (ZMod.val (a + 2 ^ n - b)).testBit n = false := by - exact hb + -- /- repeat rw [Mathlib.Tactic.RingNF.add_neg] -/ + -- field_simp at h_lt' + -- have hb := bit_is_clear n (a + 2 ^ n - b) h_lt' + -- /- have hb := bit_is_clear n (a + 2 ^ n - b) h_lt -/ - rw [hc] - simp - exact hab + -- have hc : (ZMod.val (a + 2 ^ n - b)).testBit n = false := by + -- exact hb + . - . - have h_assumptions_b' : ZMod.val b < 2 ^ (n+1) := lt_trans h_assumptions.right (Nat.pow_lt_pow_succ (a:=2) (by decide : 1 < 2) ) have hn := a_lt_b_eq_sum_lt_2n n a b hab From 183cfea35efcbee177f41b38cbc33886f37e53e3 Mon Sep 17 00:00:00 2001 From: semar Date: Wed, 8 Oct 2025 17:22:08 -0300 Subject: [PATCH 03/18] more progress --- Clean/Circomlib/lessComparators.lean | 55 ++++------------------------ 1 file changed, 7 insertions(+), 48 deletions(-) diff --git a/Clean/Circomlib/lessComparators.lean b/Clean/Circomlib/lessComparators.lean index 57d6ab1a5..fbef3f38c 100644 --- a/Clean/Circomlib/lessComparators.lean +++ b/Clean/Circomlib/lessComparators.lean @@ -85,20 +85,10 @@ private lemma add_twoPow_val (a : ZMod p) (n : ℕ) ) hp' have hlt' : a.val + (2^n : ZMod p).val < p := by - rw [← h2n] - - sorry - -- simp only [ZMod.val_natCast] at * - -- exact hlt + simp_all only rw [ZMod.val_add_of_lt hlt'] - /- rw [← ZMod.val_add] -/ - /- simp_all only [ZMod.val_le] -/ - - rw [Nat.mod_eq_iff_lt] - rw [ZMod.val_natCast_of_lt (a := 2^n) (n := p) hp] - -- (x + y).val = (x.val + y.val) % p, and mod is identity if < p - simp [Nat.mod_eq_of_lt hlt] + rw [h2n] -- Helper: no wrap on (a + 2^n) - b because b.val ≤ a.val + 2^n private lemma sub_no_wrap_val (n : ℕ) (a b : ZMod p) @@ -116,12 +106,15 @@ private lemma sub_no_wrap_val (n : ℕ) (a b : ZMod p) -- Rewrite x.val using hadd, then finish. have hres : ((a + 2 ^ n) - b).val = (a + (2 ^ n : ℕ)).val - b.val := by rw [ZMod.val_sub] + /- rw [hadd] -/ + /- aesop? -/ + simp_all only [le_add_iff_nonneg_left, zero_le, Nat.cast_pow, Nat.cast_ofNat] rw [hadd] exact hble rw [hres] - rw [hadd] - + simp_all only [le_add_iff_nonneg_left, zero_le, Nat.cast_pow, Nat.cast_ofNat] + /- set_option maxRecDepth 1_000_000 -/ /- set_option maxRecDepth 500_000 -/ /- set_option maxHeartbeats 400_000 -/ @@ -145,25 +138,6 @@ lemma test_lemma_nocast {p : ℕ} [Fact p.Prime] simp only [ZMod.val_natCast] at h2n - rw [sub_no_wrap_val n a b ha hb hp hp'] - - exact h_val -lemma test_lemma {p : ℕ} [Fact p.Prime] - (n : ℕ) (a b : F p) - (ha : ZMod.val a < 2^n) - (hb : ZMod.val b < 2^n) - (hp : 2^n < p) - (hp' : 2^(n+1) < p) - /- (hab : ZMod.val a < ZMod.val b) -/ - (h_val: (ZMod.val a + 2^n - ZMod.val b) < 2^n) : - (a + (2 ^ n : ℕ) - b).val < 2 ^ n := by - have h2n : ((2^n : ℕ) : ZMod p).val = 2^n := by - change (((2^n : ℕ) : ZMod p).val) = 2^n - simp only [ZMod.val_natCast] - simp only [Nat.mod_eq_of_lt hp] - - symm at h2n - rw [sub_no_wrap_val n a b ha hb hp hp'] exact h_val @@ -291,21 +265,6 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w have h_bit_clear' : (ZMod.val (a + 2 ^ n + -b)).testBit n = false := by simpa [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h_bit_clear simp [h_bit_clear', hab] - - -- have h_lt := a_lt_b_eq_sum_lt_2n_nat n a b hab - - -- simp only [ZMod.val_cast_of_lt] at h_lt - - -- have h_lt' := test_lemma n a b h_assumptions.left h_assumptions.right hn' hn h_lt - - -- /- repeat rw [Mathlib.Tactic.RingNF.add_neg] -/ - -- field_simp at h_lt' - - -- have hb := bit_is_clear n (a + 2 ^ n - b) h_lt' - -- /- have hb := bit_is_clear n (a + 2 ^ n - b) h_lt -/ - - -- have hc : (ZMod.val (a + 2 ^ n - b)).testBit n = false := by - -- exact hb . have h_assumptions_b' : ZMod.val b < 2 ^ (n+1) := From 97308a4fb100601a0159ff4c4f39f316e6614779 Mon Sep 17 00:00:00 2001 From: semar Date: Wed, 8 Oct 2025 18:13:36 -0300 Subject: [PATCH 04/18] even more progress --- Clean/Circomlib/lessComparators.lean | 54 +++++++++++++++++++++++----- 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/Clean/Circomlib/lessComparators.lean b/Clean/Circomlib/lessComparators.lean index fbef3f38c..7a6c36a41 100644 --- a/Clean/Circomlib/lessComparators.lean +++ b/Clean/Circomlib/lessComparators.lean @@ -59,20 +59,58 @@ lemma a_lt_b_eq_sum_lt_2n {p : ℕ} [Fact p.Prime] linarith [Int.ofNat_lt.mpr hn] -lemma val_pow_of_lt {p n : ℕ} [NeZero p] (h : 2^n < p) : - (2^n: ZMod p).val = 2^n := by +/- lemma val_pow_of_lt {p n : ℕ} [NeZero p] (h : 2^n < p) : -/ +/- (2^n: ZMod p).val = 2^n := by -/ +/- rw [Nat.cast_pow, Nat.cast_ofNat] -/ +/- rw [ZMod.val_natCast_of_lt (a := 2^n) (n := p)] -/ +/- exact h -/ - rw [ZMod.val_natCast_of_lt (a := 2^n) (n := p)] - exact h + +@[simp] +lemma val_pow_of_lt' {p n : ℕ} [NeZero p] (h : 2 ^ n < p) : + (((2 ^ n: ℕ) : ZMod p).val) = 2 ^ n := + ZMod.val_natCast_of_lt h + + +lemma val_pow_of_lt_nat {p n : ℕ} [NeZero p] [Fact p.Prime] (h : 2 ^ n < p) (hp : p > 2): + (2 ^ n : ZMod p).val = 2 ^ n := by + + have p_ne_zero := NeZero.ne p + rw [ZMod.val_pow] at * + rw [← Nat.cast_ofNat] + rw [ZMod.val_natCast] + . -- (2 % p) ^ n = 2 ^ n + have h_mod := Nat.mod_eq_iff_lt (m := 2) (n := p) p_ne_zero + + have h_mod' : 2 % p = 2 := by + simp_all only [gt_iff_lt, ne_eq, iff_true] + + rw [h_mod'] + + . + rw [← Nat.cast_ofNat] + rw [ZMod.val_natCast] + + have h_mod := Nat.mod_eq_iff_lt (m := 2) (n := p) p_ne_zero + have h_mod' : 2 % p = 2 := by + simp_all only [gt_iff_lt, ne_eq, iff_true] + rw [h_mod'] + exact h -- Helper: no wrap on a + 2^n -omit [Fact (p > 2)] -private lemma add_twoPow_val (a : ZMod p) (n : ℕ) +/- omit [Fact (p > 2)] -/ +private lemma add_twoPow_val {p : ℕ} [Fact (p > 2)] [Fact p.Prime] (a : ZMod p) (n : ℕ) (ha : a.val < 2 ^ n) (hp : 2 ^ n < p) (hp' : 2 ^ (n+1) < p) : (a + 2 ^ n).val = a.val + 2 ^ n := by + + + /- have h2n : ((↑(2 ^ n) : ZMod p).val) = 2 ^ n := val_pow_of_lt' hp -/ + have hp2 := Fact.out (p := p > 2) + have h2n : (2^n: ZMod p).val = 2^n := by - exact val_pow_of_lt hp + /- rw [← Nat.cast_pow] -- turns (2 ^ n) into (↑2 ^ n : ZMod p) -/ + exact val_pow_of_lt_nat hp hp2 -- symm at h2n @@ -120,7 +158,7 @@ private lemma sub_no_wrap_val (n : ℕ) (a b : ZMod p) /- set_option maxHeartbeats 400_000 -/ /- set_option diagnostics true -/ -lemma test_lemma_nocast {p : ℕ} [Fact p.Prime] +lemma test_lemma_nocast {p : ℕ} [Fact p.Prime] [Fact (p > 2)] (n : ℕ) (a b : F p) (ha : ZMod.val a < 2^n) (hb : ZMod.val b < 2^n) From cb2e848b46e94103a63c05321d280aee15282449 Mon Sep 17 00:00:00 2001 From: semar Date: Wed, 8 Oct 2025 18:48:21 -0300 Subject: [PATCH 05/18] soudness almost done for lessthan --- Clean/Circomlib/lessComparators.lean | 76 +++++++++++++++++++++++----- 1 file changed, 64 insertions(+), 12 deletions(-) diff --git a/Clean/Circomlib/lessComparators.lean b/Clean/Circomlib/lessComparators.lean index 7a6c36a41..6dce127ad 100644 --- a/Clean/Circomlib/lessComparators.lean +++ b/Clean/Circomlib/lessComparators.lean @@ -29,9 +29,20 @@ template LessThan(n) { } -/ +lemma a_ge_b_eq_sum_ge_2n_nat {p : ℕ} [Fact p.Prime] + (n : ℕ) (a b : F p) (hn : ZMod.val a ≥ ZMod.val b) : + /- (n : ℕ) (a b : F p) (hn : ¬(ZMod.val a < ZMod.val b)) : -/ + ZMod.val a + 2^n - ZMod.val b ≥ 2^n := by + + have hab' : ZMod.val a ≥ ZMod.val b := by + simp_all only [not_lt, ge_iff_le] + + sorry + lemma a_lt_b_eq_sum_lt_2n_nat {p : ℕ} [Fact p.Prime] (n : ℕ) (a b : F p) (hn : ZMod.val a < ZMod.val b) : ZMod.val a + 2^n - ZMod.val b < 2^n := by + have h_diff_pos : ZMod.val b - ZMod.val a > 0 := Nat.sub_pos_of_lt hn have h_eq : ZMod.val a + 2 ^ n - ZMod.val b = 2 ^ n - (ZMod.val b - ZMod.val a) := by -- rewrite b as a + (b - a), then cancel the common +a on both sides of the subtraction @@ -158,7 +169,7 @@ private lemma sub_no_wrap_val (n : ℕ) (a b : ZMod p) /- set_option maxHeartbeats 400_000 -/ /- set_option diagnostics true -/ -lemma test_lemma_nocast {p : ℕ} [Fact p.Prime] [Fact (p > 2)] +lemma val_sum_eq_sum_val_lt {p : ℕ} [Fact p.Prime] [Fact (p > 2)] (n : ℕ) (a b : F p) (ha : ZMod.val a < 2^n) (hb : ZMod.val b < 2^n) @@ -167,6 +178,8 @@ lemma test_lemma_nocast {p : ℕ} [Fact p.Prime] [Fact (p > 2)] /- (hab : ZMod.val a < ZMod.val b) -/ (h_val: (ZMod.val a + 2^n - ZMod.val b) < 2^n) : (a + 2 ^ n - b).val < 2 ^ n := by + + have h2n : ((2^n : ℕ) : ZMod p).val = 2^n := by change (((2^n : ℕ) : ZMod p).val) = 2^n simp only [ZMod.val_natCast] @@ -180,11 +193,30 @@ lemma test_lemma_nocast {p : ℕ} [Fact p.Prime] [Fact (p > 2)] exact h_val -lemma solve_if {p : ℕ} [Fact p.Prime] +-- TODO: Remove duplicated code. They look the exact same +lemma val_sum_eq_sum_val_ge {p : ℕ} [Fact p.Prime] [Fact (p > 2)] (n : ℕ) (a b : F p) - : - (1 + -if (ZMod.val (a + 2 ^ n + -b)).testBit n = true then 1 else 0) = if ZMod.val a < ZMod.val b then 1 else 0 := by - sorry + (ha : ZMod.val a < 2^n) + (hb : ZMod.val b < 2^n) + (hp : 2^n < p) + (hp' : 2^(n+1) < p) + /- (hab : ZMod.val a < ZMod.val b) -/ + (h_val: (ZMod.val a + 2^n - ZMod.val b) ≥ 2^n) : + (a + 2 ^ n - b).val ≥ 2 ^ n := by + + + have h2n : ((2^n : ℕ) : ZMod p).val = 2^n := by + change (((2^n : ℕ) : ZMod p).val) = 2^n + simp only [ZMod.val_natCast] + simp only [Nat.mod_eq_of_lt hp] + + symm at h2n + simp only [ZMod.val_natCast] at h2n + + + rw [sub_no_wrap_val n a b ha hb hp hp'] + + exact h_val lemma zmod_def {p : ℕ} [Fact p.Prime] (x : ZMod p) : @@ -213,7 +245,7 @@ lemma bit_is_set {p : ℕ} [Fact p.Prime] /- (n : ℕ) (a b : ℕ) -/ (n : ℕ) (x : F p) (hlt: ZMod.val x < 2^(n+1)) - (hge: ZMod.val x > 2^n) : + (hge: ZMod.val x ≥ 2^n) : (ZMod.val x).testBit n = true := by rw [Nat.testBit_eq_decide_div_mod_eq ] @@ -224,7 +256,7 @@ lemma bit_is_set {p : ℕ} [Fact p.Prime] have hbpos : 0 < 2^n := pow_pos (by decide) n have h1 : 1 ≤ x / 2^n := by simp only [Nat.le_div_iff_mul_le hbpos, one_mul] - apply Nat.le_of_lt at hge + /- apply Nat.le_of_lt at hge -/ exact hge -- upper bound: x / 2^n < 2 @@ -297,7 +329,7 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w a_lt_b_eq_sum_lt_2n_nat n a b hab have h_val_lt : ZMod.val (a + (2 ^ n : F p) - b) < 2 ^ n := by /- simp only [ZMod.val_natCast] at h_lt -/ - exact test_lemma_nocast n a b h_assumptions.left h_assumptions.right hn' hn h_lt + exact val_sum_eq_sum_val_lt n a b h_assumptions.left h_assumptions.right hn' hn h_lt have h_bit_clear : (ZMod.val (a + (2 ^ n : F p) - b)).testBit n = false := bit_is_clear n (a + (2 ^ n : F p) - b) h_val_lt have h_bit_clear' : (ZMod.val (a + 2 ^ n + -b)).testBit n = false := by @@ -305,11 +337,31 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w simp [h_bit_clear', hab] . - have h_assumptions_b' : ZMod.val b < 2 ^ (n+1) := - lt_trans h_assumptions.right (Nat.pow_lt_pow_succ (a:=2) (by decide : 1 < 2) ) - have hn := a_lt_b_eq_sum_lt_2n n a b hab + have hab' : ZMod.val a ≥ ZMod.val b := by + simp_all only [not_lt, ge_iff_le] - sorry + have h_ge : ZMod.val a + 2 ^ n - ZMod.val b ≥ 2 ^ n := + a_ge_b_eq_sum_ge_2n_nat n a b hab' + + have h_val_ge : ZMod.val (a + (2 ^ n : F p) - b) ≥ 2 ^ n := by + /- simp only [ZMod.val_natCast] at h_lt -/ + exact val_sum_eq_sum_val_ge n a b h_assumptions.left h_assumptions.right hn' hn h_ge + + have h_holds1' : ZMod.val (a + 2 ^ n - b) < 2 ^ (n + 1) := by + -- h_holds1 : ZMod.val (a + 2 ^ n + -b) < 2 ^ (n + 1) + rw [← sub_eq_add_neg] at h_holds1 + -- now: h_holds1 : ZMod.val (a + 2 ^ n - b) < 2 ^ (n + 1) + exact h_holds1 + + + have h_bit_set : (ZMod.val (a + (2 ^ n : F p) - b)).testBit n = true := + bit_is_set n (a + (2 ^ n : F p) - b) h_holds1' h_val_ge + /- have hn := a_lt_b_eq_sum_lt_2n n a b hab -/ + + have h_bit_set' : (ZMod.val (a + 2 ^ n + -b)).testBit n = true := by + simpa [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h_bit_set + + simp [h_bit_set', hab] completeness := by simp only [circuit_norm, main] From b72017b921ad5ead2bde769765847fd8ed2532d2 Mon Sep 17 00:00:00 2001 From: semar Date: Wed, 8 Oct 2025 19:14:24 -0300 Subject: [PATCH 06/18] soudness done for lessthan --- Clean/Circomlib/lessComparators.lean | 34 +++++++++++++--------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/Clean/Circomlib/lessComparators.lean b/Clean/Circomlib/lessComparators.lean index 6dce127ad..c8a39fc13 100644 --- a/Clean/Circomlib/lessComparators.lean +++ b/Clean/Circomlib/lessComparators.lean @@ -29,21 +29,29 @@ template LessThan(n) { } -/ +-- TODO: Needs cleanup as well lemma a_ge_b_eq_sum_ge_2n_nat {p : ℕ} [Fact p.Prime] (n : ℕ) (a b : F p) (hn : ZMod.val a ≥ ZMod.val b) : - /- (n : ℕ) (a b : F p) (hn : ¬(ZMod.val a < ZMod.val b)) : -/ ZMod.val a + 2^n - ZMod.val b ≥ 2^n := by - have hab' : ZMod.val a ≥ ZMod.val b := by - simp_all only [not_lt, ge_iff_le] + have hn' : ZMod.val b ≤ ZMod.val a := by + simp_all only [ge_iff_le] - sorry + have h_eq : ZMod.val a + 2 ^ n - ZMod.val b = 2 ^ n + (ZMod.val a - ZMod.val b) := by + calc + ZMod.val a + 2 ^ n - ZMod.val b + = (2 ^ n + ZMod.val a) - ZMod.val b := by ac_rfl + _ = (2 ^ n + ZMod.val a) - (ZMod.val b + 0) := by rfl + _ = 2 ^ n + (ZMod.val a - ZMod.val b) := by + simp only [Nat.add_zero, Nat.add_sub_assoc hn'] + + rw [h_eq] + exact Nat.le_add_right _ _ lemma a_lt_b_eq_sum_lt_2n_nat {p : ℕ} [Fact p.Prime] (n : ℕ) (a b : F p) (hn : ZMod.val a < ZMod.val b) : ZMod.val a + 2^n - ZMod.val b < 2^n := by - have h_diff_pos : ZMod.val b - ZMod.val a > 0 := Nat.sub_pos_of_lt hn have h_eq : ZMod.val a + 2 ^ n - ZMod.val b = 2 ^ n - (ZMod.val b - ZMod.val a) := by -- rewrite b as a + (b - a), then cancel the common +a on both sides of the subtraction have hb : ZMod.val b = ZMod.val a + (ZMod.val b - ZMod.val a) := @@ -328,7 +336,6 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w have h_lt : ZMod.val a + 2 ^ n - ZMod.val b < 2 ^ n := a_lt_b_eq_sum_lt_2n_nat n a b hab have h_val_lt : ZMod.val (a + (2 ^ n : F p) - b) < 2 ^ n := by - /- simp only [ZMod.val_natCast] at h_lt -/ exact val_sum_eq_sum_val_lt n a b h_assumptions.left h_assumptions.right hn' hn h_lt have h_bit_clear : (ZMod.val (a + (2 ^ n : F p) - b)).testBit n = false := bit_is_clear n (a + (2 ^ n : F p) - b) h_val_lt @@ -339,25 +346,16 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w have hab' : ZMod.val a ≥ ZMod.val b := by simp_all only [not_lt, ge_iff_le] + have h_holds1' : ZMod.val (a + 2 ^ n - b) < 2 ^ (n + 1) := by + rw [← sub_eq_add_neg] at h_holds1 + exact h_holds1 have h_ge : ZMod.val a + 2 ^ n - ZMod.val b ≥ 2 ^ n := a_ge_b_eq_sum_ge_2n_nat n a b hab' - have h_val_ge : ZMod.val (a + (2 ^ n : F p) - b) ≥ 2 ^ n := by - /- simp only [ZMod.val_natCast] at h_lt -/ exact val_sum_eq_sum_val_ge n a b h_assumptions.left h_assumptions.right hn' hn h_ge - - have h_holds1' : ZMod.val (a + 2 ^ n - b) < 2 ^ (n + 1) := by - -- h_holds1 : ZMod.val (a + 2 ^ n + -b) < 2 ^ (n + 1) - rw [← sub_eq_add_neg] at h_holds1 - -- now: h_holds1 : ZMod.val (a + 2 ^ n - b) < 2 ^ (n + 1) - exact h_holds1 - - have h_bit_set : (ZMod.val (a + (2 ^ n : F p) - b)).testBit n = true := bit_is_set n (a + (2 ^ n : F p) - b) h_holds1' h_val_ge - /- have hn := a_lt_b_eq_sum_lt_2n n a b hab -/ - have h_bit_set' : (ZMod.val (a + 2 ^ n + -b)).testBit n = true := by simpa [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h_bit_set From 0bfa5d0b91bff18892e35269bbccd980dd7771e4 Mon Sep 17 00:00:00 2001 From: semar Date: Thu, 9 Oct 2025 10:07:10 -0300 Subject: [PATCH 07/18] add completeness for lessThan - code needs cleanup --- Clean/Circomlib/lessComparators.lean | 80 +++++++++++++++++++++++----- 1 file changed, 66 insertions(+), 14 deletions(-) diff --git a/Clean/Circomlib/lessComparators.lean b/Clean/Circomlib/lessComparators.lean index c8a39fc13..e359d0fd9 100644 --- a/Clean/Circomlib/lessComparators.lean +++ b/Clean/Circomlib/lessComparators.lean @@ -78,6 +78,14 @@ lemma a_lt_b_eq_sum_lt_2n {p : ℕ} [Fact p.Prime] linarith [Int.ofNat_lt.mpr hn] +lemma a_lt_b_eq_sum_lt_2n_plus_one_nat {p : ℕ} [Fact p.Prime] + (n : ℕ) (a b : F p) + (ha : ZMod.val a < 2^n) + (hb : ZMod.val b < 2^n) : + ZMod.val a + 2^n - ZMod.val b < 2^(n+1) := by + sorry + + /- lemma val_pow_of_lt {p n : ℕ} [NeZero p] (h : 2^n < p) : -/ /- (2^n: ZMod p).val = 2^n := by -/ /- rw [Nat.cast_pow, Nat.cast_ofNat] -/ @@ -122,17 +130,11 @@ private lemma add_twoPow_val {p : ℕ} [Fact (p > 2)] [Fact p.Prime] (a : ZMod p (ha : a.val < 2 ^ n) (hp : 2 ^ n < p) (hp' : 2 ^ (n+1) < p) : (a + 2 ^ n).val = a.val + 2 ^ n := by - - - /- have h2n : ((↑(2 ^ n) : ZMod p).val) = 2 ^ n := val_pow_of_lt' hp -/ have hp2 := Fact.out (p := p > 2) have h2n : (2^n: ZMod p).val = 2^n := by - /- rw [← Nat.cast_pow] -- turns (2 ^ n) into (↑2 ^ n : ZMod p) -/ exact val_pow_of_lt_nat hp hp2 - -- symm at h2n - have hlt : a.val + 2 ^ n < p := lt_trans (by have : a.val + 2 ^ n < 2 ^ n + 2 ^ n := Nat.add_lt_add_right ha _ @@ -163,8 +165,6 @@ private lemma sub_no_wrap_val (n : ℕ) (a b : ZMod p) -- Rewrite x.val using hadd, then finish. have hres : ((a + 2 ^ n) - b).val = (a + (2 ^ n : ℕ)).val - b.val := by rw [ZMod.val_sub] - /- rw [hadd] -/ - /- aesop? -/ simp_all only [le_add_iff_nonneg_left, zero_le, Nat.cast_pow, Nat.cast_ofNat] rw [hadd] exact hble @@ -172,11 +172,6 @@ private lemma sub_no_wrap_val (n : ℕ) (a b : ZMod p) rw [hres] simp_all only [le_add_iff_nonneg_left, zero_le, Nat.cast_pow, Nat.cast_ofNat] -/- set_option maxRecDepth 1_000_000 -/ -/- set_option maxRecDepth 500_000 -/ -/- set_option maxHeartbeats 400_000 -/ -/- set_option diagnostics true -/ - lemma val_sum_eq_sum_val_lt {p : ℕ} [Fact p.Prime] [Fact (p > 2)] (n : ℕ) (a b : F p) (ha : ZMod.val a < 2^n) @@ -226,6 +221,30 @@ lemma val_sum_eq_sum_val_ge {p : ℕ} [Fact p.Prime] [Fact (p > 2)] exact h_val +-- TODO: Remove duplicated code. They look the exact same +lemma val_sum_eq_sum_val_lt_n_plus_1 {p : ℕ} [Fact p.Prime] [Fact (p > 2)] + (n : ℕ) (a b : F p) + (ha : ZMod.val a < 2^n) + (hb : ZMod.val b < 2^n) + (hp : 2^n < p) + (hp' : 2^(n+1) < p) + /- (hab : ZMod.val a < ZMod.val b) -/ + (h_val: (ZMod.val a + 2^n - ZMod.val b) ≤ 2^(n+1)) : + (a + 2 ^ n - b).val ≤ 2 ^ (n+1) := by + + + have h2n : ((2^n : ℕ) : ZMod p).val = 2^n := by + change (((2^n : ℕ) : ZMod p).val) = 2^n + simp only [ZMod.val_natCast] + simp only [Nat.mod_eq_of_lt hp] + + symm at h2n + simp only [ZMod.val_natCast] at h2n + + + rw [sub_no_wrap_val n a b ha hb hp hp'] + + exact h_val lemma zmod_def {p : ℕ} [Fact p.Prime] (x : ZMod p) : x.val = ZMod.val x := rfl @@ -343,6 +362,7 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w simpa [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h_bit_clear simp [h_bit_clear', hab] . + -- sum is ≥ 2^n, so nth bit is 1 have hab' : ZMod.val a ≥ ZMod.val b := by simp_all only [not_lt, ge_iff_le] @@ -362,7 +382,39 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w simp [h_bit_set', hab] completeness := by - simp only [circuit_norm, main] + /- circuit_proof_all -/ + circuit_proof_start + simp only [Num2Bits.circuit] at * + simp only [circuit_norm] at * + simp_all only [gt_iff_lt, true_and, id_eq, and_true] + + /- obtain ⟨fst, snd⟩ := input_var -/ + obtain ⟨h_envl, h_envr⟩ := h_env + obtain ⟨left, right⟩ := h_assumptions + + have h1 : Expression.eval env input_var.1 = input.1 := by + rw [← h_input] + have h2 : Expression.eval env input_var.2 = input.2 := by + rw [← h_input] + + set a := input.1 + set b := input.2 + rw [h1, h2] + rw [h1, h2] at h_envl + + rw [← sub_eq_add_neg (a:=(a+ 2 ^ n)) ] + + have hn' : 2 ^ n < p := by + apply lt_trans (Nat.pow_lt_pow_right (by decide) (Nat.lt_succ_self n)) + exact hn + + have h_val := val_sum_eq_sum_val_lt n a b left right hn' hn + /- exact h_val -/ + have h_comp := val_sum_eq_sum_val_lt_n_plus_1 n a b left right hn' hn + /- apply? -/ + /- /- exact h_envl -/ -/ + /- aesop? -/ + sorry end LessThan From d92e3bac6584a070ad2427a9d8d76f2100e239ac Mon Sep 17 00:00:00 2001 From: semar Date: Thu, 9 Oct 2025 13:16:54 -0300 Subject: [PATCH 08/18] some simplification --- Clean/Circomlib/lessComparators.lean | 204 ++++++++++----------------- 1 file changed, 76 insertions(+), 128 deletions(-) diff --git a/Clean/Circomlib/lessComparators.lean b/Clean/Circomlib/lessComparators.lean index e359d0fd9..f1bdde61c 100644 --- a/Clean/Circomlib/lessComparators.lean +++ b/Clean/Circomlib/lessComparators.lean @@ -28,6 +28,18 @@ template LessThan(n) { out <== 1-n2b.out[n]; } -/ +structure Input (p : ℕ) [Fact p.Prime] [Fact (p > 2)] where + a : F p + b : F p + +structure Bounds + (p n : ℕ) [Fact p.Prime] [Fact (p > 2)] + (a b : F p) + where + ha : ZMod.val a < 2 ^ n + hb : ZMod.val b < 2 ^ n + hp : 2 ^ (n + 1) < p + hp' : 2 ^ n < p -- TODO: Needs cleanup as well lemma a_ge_b_eq_sum_ge_2n_nat {p : ℕ} [Fact p.Prime] @@ -69,35 +81,7 @@ lemma a_lt_b_eq_sum_lt_2n_nat {p : ℕ} [Fact p.Prime] rw [h_eq] have hpos : 0 < ZMod.val b - ZMod.val a := Nat.sub_pos_of_lt hn - simp_all only [gt_iff_lt, tsub_pos_iff_lt, tsub_lt_self_iff, Nat.ofNat_pos, pow_pos, and_self] - -lemma a_lt_b_eq_sum_lt_2n {p : ℕ} [Fact p.Prime] - (n : ℕ) (a b : F p) (hn : ZMod.val a < ZMod.val b) : - (ZMod.val a : ℤ) + 2^n - (ZMod.val b : ℤ) < 2^n := by - rw [Int.add_sub_assoc, Int.add_comm] - - linarith [Int.ofNat_lt.mpr hn] - -lemma a_lt_b_eq_sum_lt_2n_plus_one_nat {p : ℕ} [Fact p.Prime] - (n : ℕ) (a b : F p) - (ha : ZMod.val a < 2^n) - (hb : ZMod.val b < 2^n) : - ZMod.val a + 2^n - ZMod.val b < 2^(n+1) := by - sorry - - -/- lemma val_pow_of_lt {p n : ℕ} [NeZero p] (h : 2^n < p) : -/ -/- (2^n: ZMod p).val = 2^n := by -/ -/- rw [Nat.cast_pow, Nat.cast_ofNat] -/ -/- rw [ZMod.val_natCast_of_lt (a := 2^n) (n := p)] -/ -/- exact h -/ - - -@[simp] -lemma val_pow_of_lt' {p n : ℕ} [NeZero p] (h : 2 ^ n < p) : - (((2 ^ n: ℕ) : ZMod p).val) = 2 ^ n := - ZMod.val_natCast_of_lt h - + simp_all only [ tsub_pos_iff_lt, tsub_lt_self_iff, Nat.ofNat_pos, pow_pos, and_self] lemma val_pow_of_lt_nat {p n : ℕ} [NeZero p] [Fact p.Prime] (h : 2 ^ n < p) (hp : p > 2): (2 ^ n : ZMod p).val = 2 ^ n := by @@ -171,92 +155,45 @@ private lemma sub_no_wrap_val (n : ℕ) (a b : ZMod p) rw [hres] simp_all only [le_add_iff_nonneg_left, zero_le, Nat.cast_pow, Nat.cast_ofNat] - -lemma val_sum_eq_sum_val_lt {p : ℕ} [Fact p.Prime] [Fact (p > 2)] - (n : ℕ) (a b : F p) - (ha : ZMod.val a < 2^n) - (hb : ZMod.val b < 2^n) - (hp : 2^n < p) - (hp' : 2^(n+1) < p) - /- (hab : ZMod.val a < ZMod.val b) -/ - (h_val: (ZMod.val a + 2^n - ZMod.val b) < 2^n) : - (a + 2 ^ n - b).val < 2 ^ n := by - - - have h2n : ((2^n : ℕ) : ZMod p).val = 2^n := by - change (((2^n : ℕ) : ZMod p).val) = 2^n - simp only [ZMod.val_natCast] - simp only [Nat.mod_eq_of_lt hp] - - symm at h2n - simp only [ZMod.val_natCast] at h2n - - - rw [sub_no_wrap_val n a b ha hb hp hp'] - - exact h_val - --- TODO: Remove duplicated code. They look the exact same -lemma val_sum_eq_sum_val_ge {p : ℕ} [Fact p.Prime] [Fact (p > 2)] - (n : ℕ) (a b : F p) - (ha : ZMod.val a < 2^n) - (hb : ZMod.val b < 2^n) - (hp : 2^n < p) - (hp' : 2^(n+1) < p) - /- (hab : ZMod.val a < ZMod.val b) -/ - (h_val: (ZMod.val a + 2^n - ZMod.val b) ≥ 2^n) : - (a + 2 ^ n - b).val ≥ 2 ^ n := by - - - have h2n : ((2^n : ℕ) : ZMod p).val = 2^n := by - change (((2^n : ℕ) : ZMod p).val) = 2^n - simp only [ZMod.val_natCast] - simp only [Nat.mod_eq_of_lt hp] - - symm at h2n - simp only [ZMod.val_natCast] at h2n - - - rw [sub_no_wrap_val n a b ha hb hp hp'] - - exact h_val - --- TODO: Remove duplicated code. They look the exact same -lemma val_sum_eq_sum_val_lt_n_plus_1 {p : ℕ} [Fact p.Prime] [Fact (p > 2)] - (n : ℕ) (a b : F p) - (ha : ZMod.val a < 2^n) - (hb : ZMod.val b < 2^n) - (hp : 2^n < p) - (hp' : 2^(n+1) < p) - /- (hab : ZMod.val a < ZMod.val b) -/ - (h_val: (ZMod.val a + 2^n - ZMod.val b) ≤ 2^(n+1)) : - (a + 2 ^ n - b).val ≤ 2 ^ (n+1) := by - - - have h2n : ((2^n : ℕ) : ZMod p).val = 2^n := by - change (((2^n : ℕ) : ZMod p).val) = 2^n - simp only [ZMod.val_natCast] - simp only [Nat.mod_eq_of_lt hp] - - symm at h2n - simp only [ZMod.val_natCast] at h2n - - - rw [sub_no_wrap_val n a b ha hb hp hp'] - - exact h_val +lemma val_sum_eq_sum_val_rel_threshold {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] + {a b : F p} + (R : ℕ → ℕ → Prop) (threshold : ℕ) + (h_bounds : Bounds p n a b) + (h_val : R (ZMod.val a + 2 ^ n - ZMod.val b) threshold) : + R (ZMod.val (a + 2 ^ n - b)) threshold := by + + have hp' : 2 ^ n < p := by + apply lt_trans (Nat.pow_lt_pow_right (by decide) (Nat.lt_succ_self n)) + exact h_bounds.hp + + rw [sub_no_wrap_val n a b h_bounds.ha h_bounds.hb h_bounds.hp' h_bounds.hp] + exact h_val + +lemma val_sum_eq_sum_val_lt {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] + {a b : F p} : + (Bounds p n a b) -> + (ZMod.val a + 2 ^ n - ZMod.val b) < (2 ^ n) -> + (ZMod.val (a + 2 ^ n - b)) < (2 ^ n) := + val_sum_eq_sum_val_rel_threshold (· < ·) (2 ^ n) + +lemma val_sum_eq_sum_val_ge {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] + {a b : F p} : + (Bounds p n a b) -> + (ZMod.val a + 2 ^ n - ZMod.val b) ≥ (2 ^ n) -> + (ZMod.val (a + 2 ^ n - b)) ≥ (2 ^ n) := + val_sum_eq_sum_val_rel_threshold (· ≥ ·) (2 ^ n) + +lemma val_sum_no_wrap {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] + {a b : F p} : + (Bounds p n a b) -> + (ZMod.val a + 2 ^ n - ZMod.val b) < (2 ^ (n+1)) -> + (ZMod.val (a + 2 ^ n - b)) < (2 ^ (n+1)) := + val_sum_eq_sum_val_rel_threshold (· < ·) (2 ^ (n+1)) + lemma zmod_def {p : ℕ} [Fact p.Prime] (x : ZMod p) : x.val = ZMod.val x := rfl -lemma remove_casting {p : ℕ} [Fact p.Prime] - (n : ℕ) (a b : F p) - (h_val: (((ZMod.val a) : ℕ) + 2^n - ((ZMod.val b): ℕ)) < 2^n) : - (ZMod.val a + 2^n - ZMod.val b) < 2^n := by - - simp_all only - - lemma bit_is_clear {p : ℕ} [Fact p.Prime] (n : ℕ) (a : ZMod p) (hlt : ZMod.val a < 2^n) : @@ -295,6 +232,7 @@ lemma bit_is_set {p : ℕ} [Fact p.Prime] rw [le_antisymm (Nat.lt_succ_iff.mp h2) h1] + def main (n : ℕ) (hn : 2^(n+1) < p) (input : Expression (F p) × Expression (F p)) := do let diff := input.1 + (2^n : F p) - input.2 let bits ← Num2Bits.circuit (n + 1) hn diff @@ -322,16 +260,20 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w rw [← h_input] - have hn' : 2 ^ n < p := by - apply lt_trans (Nat.pow_lt_pow_right (by decide) (Nat.lt_succ_self n)) - exact hn - have h1 : Expression.eval env input_var.1 = input.1 := by rw [← h_input] have h2 : Expression.eval env input_var.2 = input.2 := by rw [← h_input] set a := input.1 set b := input.2 + set hp := hn + + have hp' : 2 ^ n < p := by + apply lt_trans (Nat.pow_lt_pow_right (by decide) (Nat.lt_succ_self n)) + exact hp + have h_bounds : Bounds p n a b := ⟨h_assumptions.left, h_assumptions.right, hp, hp'⟩ + + rw [h1, h2] at h_holds rw [h1, h2] simp only [id_eq] @@ -355,7 +297,7 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w have h_lt : ZMod.val a + 2 ^ n - ZMod.val b < 2 ^ n := a_lt_b_eq_sum_lt_2n_nat n a b hab have h_val_lt : ZMod.val (a + (2 ^ n : F p) - b) < 2 ^ n := by - exact val_sum_eq_sum_val_lt n a b h_assumptions.left h_assumptions.right hn' hn h_lt + exact val_sum_eq_sum_val_lt h_bounds h_lt have h_bit_clear : (ZMod.val (a + (2 ^ n : F p) - b)).testBit n = false := bit_is_clear n (a + (2 ^ n : F p) - b) h_val_lt have h_bit_clear' : (ZMod.val (a + 2 ^ n + -b)).testBit n = false := by @@ -373,7 +315,7 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w have h_ge : ZMod.val a + 2 ^ n - ZMod.val b ≥ 2 ^ n := a_ge_b_eq_sum_ge_2n_nat n a b hab' have h_val_ge : ZMod.val (a + (2 ^ n : F p) - b) ≥ 2 ^ n := by - exact val_sum_eq_sum_val_ge n a b h_assumptions.left h_assumptions.right hn' hn h_ge + exact val_sum_eq_sum_val_ge h_bounds h_ge have h_bit_set : (ZMod.val (a + (2 ^ n : F p) - b)).testBit n = true := bit_is_set n (a + (2 ^ n : F p) - b) h_holds1' h_val_ge have h_bit_set' : (ZMod.val (a + 2 ^ n + -b)).testBit n = true := by @@ -388,9 +330,10 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w simp only [circuit_norm] at * simp_all only [gt_iff_lt, true_and, id_eq, and_true] - /- obtain ⟨fst, snd⟩ := input_var -/ obtain ⟨h_envl, h_envr⟩ := h_env - obtain ⟨left, right⟩ := h_assumptions + obtain ⟨ha, hb⟩ := h_assumptions + + set hp := hn have h1 : Expression.eval env input_var.1 = input.1 := by rw [← h_input] @@ -402,19 +345,24 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w rw [h1, h2] rw [h1, h2] at h_envl - rw [← sub_eq_add_neg (a:=(a+ 2 ^ n)) ] + rw [← sub_eq_add_neg (a:=(a+ 2 ^ n))] - have hn' : 2 ^ n < p := by + have hp' : 2 ^ n < p := by apply lt_trans (Nat.pow_lt_pow_right (by decide) (Nat.lt_succ_self n)) exact hn - have h_val := val_sum_eq_sum_val_lt n a b left right hn' hn - /- exact h_val -/ - have h_comp := val_sum_eq_sum_val_lt_n_plus_1 n a b left right hn' hn - /- apply? -/ - /- /- exact h_envl -/ -/ - /- aesop? -/ + have h_bounds : Bounds p n a b := ⟨ha, hb, hp, hp'⟩ + have h_comp := val_sum_no_wrap h_bounds + + have h_sum_lt_2n1 : ZMod.val a + 2 ^ n - ZMod.val b < 2 ^ (n + 1) := by + calc + ZMod.val a + 2 ^ n - ZMod.val b + ≤ ZMod.val a + 2 ^ n := Nat.sub_le _ _ + _ < 2 ^ n + 2 ^ n := by + apply Nat.add_lt_add_right ha + _ = 2 ^ (n + 1) := by + rw [Nat.pow_succ, Nat.mul_two] - sorry + exact h_comp h_sum_lt_2n1 end LessThan From 5748175576bbff82b1be3e27fe1128532c4734e4 Mon Sep 17 00:00:00 2001 From: semar Date: Thu, 9 Oct 2025 13:30:44 -0300 Subject: [PATCH 09/18] more simplification --- Clean/Circomlib/lessComparators.lean | 32 ++++++++++++++-------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/Clean/Circomlib/lessComparators.lean b/Clean/Circomlib/lessComparators.lean index f1bdde61c..3da52fb9e 100644 --- a/Clean/Circomlib/lessComparators.lean +++ b/Clean/Circomlib/lessComparators.lean @@ -83,7 +83,7 @@ lemma a_lt_b_eq_sum_lt_2n_nat {p : ℕ} [Fact p.Prime] have hpos : 0 < ZMod.val b - ZMod.val a := Nat.sub_pos_of_lt hn simp_all only [ tsub_pos_iff_lt, tsub_lt_self_iff, Nat.ofNat_pos, pow_pos, and_self] -lemma val_pow_of_lt_nat {p n : ℕ} [NeZero p] [Fact p.Prime] (h : 2 ^ n < p) (hp : p > 2): +lemma ZMod.val_two_pow_of_lt {p n : ℕ} [NeZero p] [Fact p.Prime] (h : 2 ^ n < p) (hp : p > 2): (2 ^ n : ZMod p).val = 2 ^ n := by have p_ne_zero := NeZero.ne p @@ -110,14 +110,14 @@ lemma val_pow_of_lt_nat {p n : ℕ} [NeZero p] [Fact p.Prime] (h : 2 ^ n < p) (h -- Helper: no wrap on a + 2^n /- omit [Fact (p > 2)] -/ -private lemma add_twoPow_val {p : ℕ} [Fact (p > 2)] [Fact p.Prime] (a : ZMod p) (n : ℕ) +private lemma add_two_pow_no_wrap_val {p : ℕ} [Fact (p > 2)] [Fact p.Prime] (a : ZMod p) (n : ℕ) (ha : a.val < 2 ^ n) (hp : 2 ^ n < p) (hp' : 2 ^ (n+1) < p) : (a + 2 ^ n).val = a.val + 2 ^ n := by have hp2 := Fact.out (p := p > 2) have h2n : (2^n: ZMod p).val = 2^n := by - exact val_pow_of_lt_nat hp hp2 + exact ZMod.val_two_pow_of_lt hp hp2 have hlt : a.val + 2 ^ n < p := lt_trans (by @@ -134,12 +134,12 @@ private lemma add_twoPow_val {p : ℕ} [Fact (p > 2)] [Fact p.Prime] (a : ZMod p rw [h2n] -- Helper: no wrap on (a + 2^n) - b because b.val ≤ a.val + 2^n -private lemma sub_no_wrap_val (n : ℕ) (a b : ZMod p) +private lemma ZMod.val_sub_add_two_pow_of_no_wrap (n : ℕ) (a b : ZMod p) (ha : a.val < 2 ^ n) (hb : b.val < 2 ^ n) (hp : 2^n < p) (hp' : 2 ^ (n+1) < p) : ((a + 2 ^ n) - b).val = (a.val + 2 ^ n) - b.val := by -- First compute (a + 2^n).val without wrap have hadd : (a + 2 ^ n).val = a.val + 2 ^ n := - add_twoPow_val (n:=n) a ha hp hp' + add_two_pow_no_wrap_val (n:=n) a ha hp hp' -- b.val ≤ 2^n ≤ 2^n + a.val = (a + 2^n).val have hb_le_twoPow : b.val ≤ 2 ^ n := Nat.le_of_lt hb have twoPow_le_sum : 2 ^ n ≤ (a.val + 2 ^ n) := by @@ -156,7 +156,7 @@ private lemma sub_no_wrap_val (n : ℕ) (a b : ZMod p) rw [hres] simp_all only [le_add_iff_nonneg_left, zero_le, Nat.cast_pow, Nat.cast_ofNat] -lemma val_sum_eq_sum_val_rel_threshold {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] +lemma ZMod.val_sub_add_two_pow_rel {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] {a b : F p} (R : ℕ → ℕ → Prop) (threshold : ℕ) (h_bounds : Bounds p n a b) @@ -167,29 +167,29 @@ lemma val_sum_eq_sum_val_rel_threshold {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] apply lt_trans (Nat.pow_lt_pow_right (by decide) (Nat.lt_succ_self n)) exact h_bounds.hp - rw [sub_no_wrap_val n a b h_bounds.ha h_bounds.hb h_bounds.hp' h_bounds.hp] + rw [ZMod.val_sub_add_two_pow_of_no_wrap n a b h_bounds.ha h_bounds.hb h_bounds.hp' h_bounds.hp] exact h_val -lemma val_sum_eq_sum_val_lt {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] +lemma ZMod.val_sub_add_two_pow_lt {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] {a b : F p} : (Bounds p n a b) -> (ZMod.val a + 2 ^ n - ZMod.val b) < (2 ^ n) -> (ZMod.val (a + 2 ^ n - b)) < (2 ^ n) := - val_sum_eq_sum_val_rel_threshold (· < ·) (2 ^ n) + ZMod.val_sub_add_two_pow_rel (· < ·) (2 ^ n) -lemma val_sum_eq_sum_val_ge {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] +lemma ZMod.val_sub_add_two_pow_ge {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] {a b : F p} : (Bounds p n a b) -> (ZMod.val a + 2 ^ n - ZMod.val b) ≥ (2 ^ n) -> (ZMod.val (a + 2 ^ n - b)) ≥ (2 ^ n) := - val_sum_eq_sum_val_rel_threshold (· ≥ ·) (2 ^ n) + ZMod.val_sub_add_two_pow_rel (· ≥ ·) (2 ^ n) -lemma val_sum_no_wrap {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] +lemma ZMod.val_sub_add_two_pow_no_wrap {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] {a b : F p} : (Bounds p n a b) -> (ZMod.val a + 2 ^ n - ZMod.val b) < (2 ^ (n+1)) -> (ZMod.val (a + 2 ^ n - b)) < (2 ^ (n+1)) := - val_sum_eq_sum_val_rel_threshold (· < ·) (2 ^ (n+1)) + ZMod.val_sub_add_two_pow_rel (· < ·) (2 ^ (n+1)) lemma zmod_def {p : ℕ} [Fact p.Prime] (x : ZMod p) : x.val = ZMod.val x := rfl @@ -297,7 +297,7 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w have h_lt : ZMod.val a + 2 ^ n - ZMod.val b < 2 ^ n := a_lt_b_eq_sum_lt_2n_nat n a b hab have h_val_lt : ZMod.val (a + (2 ^ n : F p) - b) < 2 ^ n := by - exact val_sum_eq_sum_val_lt h_bounds h_lt + exact ZMod.val_sub_add_two_pow_lt h_bounds h_lt have h_bit_clear : (ZMod.val (a + (2 ^ n : F p) - b)).testBit n = false := bit_is_clear n (a + (2 ^ n : F p) - b) h_val_lt have h_bit_clear' : (ZMod.val (a + 2 ^ n + -b)).testBit n = false := by @@ -315,7 +315,7 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w have h_ge : ZMod.val a + 2 ^ n - ZMod.val b ≥ 2 ^ n := a_ge_b_eq_sum_ge_2n_nat n a b hab' have h_val_ge : ZMod.val (a + (2 ^ n : F p) - b) ≥ 2 ^ n := by - exact val_sum_eq_sum_val_ge h_bounds h_ge + exact ZMod.val_sub_add_two_pow_ge h_bounds h_ge have h_bit_set : (ZMod.val (a + (2 ^ n : F p) - b)).testBit n = true := bit_is_set n (a + (2 ^ n : F p) - b) h_holds1' h_val_ge have h_bit_set' : (ZMod.val (a + 2 ^ n + -b)).testBit n = true := by @@ -352,7 +352,7 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w exact hn have h_bounds : Bounds p n a b := ⟨ha, hb, hp, hp'⟩ - have h_comp := val_sum_no_wrap h_bounds + have h_comp := ZMod.val_sub_add_two_pow_no_wrap h_bounds have h_sum_lt_2n1 : ZMod.val a + 2 ^ n - ZMod.val b < 2 ^ (n + 1) := by calc From 1492b9cc1753db81f46c32c7158c4f96e28d0b46 Mon Sep 17 00:00:00 2001 From: semar Date: Thu, 9 Oct 2025 16:56:13 -0300 Subject: [PATCH 10/18] start simplification on soundness --- Clean/Circomlib/lessComparators.lean | 186 +++++++++++++++++---------- 1 file changed, 115 insertions(+), 71 deletions(-) diff --git a/Clean/Circomlib/lessComparators.lean b/Clean/Circomlib/lessComparators.lean index 3da52fb9e..13eede66e 100644 --- a/Clean/Circomlib/lessComparators.lean +++ b/Clean/Circomlib/lessComparators.lean @@ -41,47 +41,72 @@ structure Bounds hp : 2 ^ (n + 1) < p hp' : 2 ^ n < p --- TODO: Needs cleanup as well -lemma a_ge_b_eq_sum_ge_2n_nat {p : ℕ} [Fact p.Prime] - (n : ℕ) (a b : F p) (hn : ZMod.val a ≥ ZMod.val b) : - ZMod.val a + 2^n - ZMod.val b ≥ 2^n := by - - have hn' : ZMod.val b ≤ ZMod.val a := by - simp_all only [ge_iff_le] - - have h_eq : ZMod.val a + 2 ^ n - ZMod.val b = 2 ^ n + (ZMod.val a - ZMod.val b) := by - calc - ZMod.val a + 2 ^ n - ZMod.val b - = (2 ^ n + ZMod.val a) - ZMod.val b := by ac_rfl - _ = (2 ^ n + ZMod.val a) - (ZMod.val b + 0) := by rfl - _ = 2 ^ n + (ZMod.val a - ZMod.val b) := by - simp only [Nat.add_zero, Nat.add_sub_assoc hn'] - - rw [h_eq] - exact Nat.le_add_right _ _ - -lemma a_lt_b_eq_sum_lt_2n_nat {p : ℕ} [Fact p.Prime] - (n : ℕ) (a b : F p) (hn : ZMod.val a < ZMod.val b) : - ZMod.val a + 2^n - ZMod.val b < 2^n := by - - have h_eq : ZMod.val a + 2 ^ n - ZMod.val b = 2 ^ n - (ZMod.val b - ZMod.val a) := by - -- rewrite b as a + (b - a), then cancel the common +a on both sides of the subtraction - have hb : ZMod.val b = ZMod.val a + (ZMod.val b - ZMod.val a) := - (Nat.add_sub_of_le (Nat.le_of_lt hn)).symm - calc - ZMod.val a + 2 ^ n - ZMod.val b - = ZMod.val a + 2 ^ n - (ZMod.val a + (ZMod.val b - ZMod.val a)) := by - rw [hb] - simp only [Nat.add_sub_cancel_left] - _ = (2 ^ n + ZMod.val a) - ((ZMod.val b - ZMod.val a) + ZMod.val a) := by - ac_rfl - _ = 2 ^ n - (ZMod.val b - ZMod.val a) := by - simp only [Nat.add_sub_add_right] - -- exact Nat.add_sub_add_right (2 ^ n) (ZMod.val b - ZMod.val a) (ZMod.val a) - - rw [h_eq] - have hpos : 0 < ZMod.val b - ZMod.val a := Nat.sub_pos_of_lt hn - simp_all only [ tsub_pos_iff_lt, tsub_lt_self_iff, Nat.ofNat_pos, pow_pos, and_self] +/-- From `2^(n+1) < p` get `2^n < p`. -/ +lemma pow_lt_of_succ {n p : ℕ} (hn : 2^(n+1) < p) : 2^n < p := by + exact lt_trans (Nat.pow_lt_pow_right (by decide) (Nat.lt_succ_self n)) hn + +/-- `testBit` is insensitive to rewriting `x - y` as `x + -y`. -/ +lemma testBit_sub_eq_addNeg (x y : ZMod p) (k : ℕ) : + (ZMod.val (x - y)).testBit k = (ZMod.val (x + -y)).testBit k := by + simp [sub_eq_add_neg] + +lemma testBit_sum_sub_eq_sum_addNeg + {p : ℕ} [Fact p.Prime] [Fact (p > 2)] + (a b : ZMod p) (n k : ℕ) : + (ZMod.val (a + 2^n - b)).testBit k + = (ZMod.val (a + 2^n + -b)).testBit k := by + simp [sub_eq_add_neg, add_comm, add_left_comm] + /- using -/ + /- (testBit_sub_eq_addNeg (a + 2^n) b k) -/ + +/-- Pack the repeated bounds (`ha hb hp hp'`) into your structure in one shot. -/ +def Bounds.of_assumptions + {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] + {a b : ZMod p} + (ha : a.val < 2^n) (hb : b.val < 2^n) + (hp_succ : 2^(n+1) < p) : Bounds p n a b := + ⟨ha, hb, hp_succ, pow_lt_of_succ hp_succ⟩ + +lemma add_two_pow_sub_eq_add_diff {n a b : ℕ} (h : b ≤ a) : + a + 2 ^ n - b = 2 ^ n + (a - b) := by + calc + a + 2 ^ n - b + = (2 ^ n + a) - b := by ac_rfl + _ = (2 ^ n + a) - (b + 0) := by rfl + _ = 2 ^ n + (a - b) := by + simp only [Nat.add_zero, Nat.add_sub_assoc h] + +lemma add_two_pow_sub_eq_sub_diff {n a b : ℕ} (h : a < b) : + a + 2 ^ n - b = 2 ^ n - (b - a) := by + have hb : b = a + (b - a) := (Nat.add_sub_of_le (Nat.le_of_lt h)).symm + calc + a + 2 ^ n - b + = a + 2 ^ n - (a + (b - a)) := by rw [hb]; simp only [Nat.add_sub_cancel_left] + _ = (2 ^ n + a) - ((b - a) + a) := by ac_rfl + _ = 2 ^ n - (b - a) := by simp only [Nat.add_sub_add_right] + +lemma add_two_pow_sub_lt_pow_succ {n a b : ℕ} (ha : a < 2 ^ n) : + a + 2 ^ n - b < 2 ^ (n + 1) := by + calc + a + 2 ^ n - b ≤ a + 2 ^ n := Nat.sub_le _ _ + _ < 2 ^ n + 2 ^ n := Nat.add_lt_add_right ha _ + _ = 2 ^ (n + 1) := by rw [Nat.pow_succ, Nat.mul_two] + +lemma ZMod.val_add_two_pow_sub_rel + {p : ℕ} [Fact p.Prime] + (n : ℕ) (a b : F p) : + + if ZMod.val a < ZMod.val b then + ZMod.val a + 2 ^ n - ZMod.val b < 2 ^ n + else + ZMod.val a + 2 ^ n - ZMod.val b ≥ 2 ^ n := by + + split_ifs with h_lt + · rw [add_two_pow_sub_eq_sub_diff h_lt] + have h_pos : 0 < ZMod.val b - ZMod.val a := Nat.sub_pos_of_lt h_lt + simp_all [tsub_lt_self_iff, pow_pos] + · rw [add_two_pow_sub_eq_add_diff (le_of_not_gt h_lt)] + exact Nat.le_add_right _ _ lemma ZMod.val_two_pow_of_lt {p n : ℕ} [NeZero p] [Fact p.Prime] (h : 2 ^ n < p) (hp : p > 2): (2 ^ n : ZMod p).val = 2 ^ n := by @@ -163,9 +188,7 @@ lemma ZMod.val_sub_add_two_pow_rel {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] (h_val : R (ZMod.val a + 2 ^ n - ZMod.val b) threshold) : R (ZMod.val (a + 2 ^ n - b)) threshold := by - have hp' : 2 ^ n < p := by - apply lt_trans (Nat.pow_lt_pow_right (by decide) (Nat.lt_succ_self n)) - exact h_bounds.hp + have hp' : 2 ^ n < p := pow_lt_of_succ h_bounds.hp rw [ZMod.val_sub_add_two_pow_of_no_wrap n a b h_bounds.ha h_bounds.hb h_bounds.hp' h_bounds.hp] exact h_val @@ -191,18 +214,17 @@ lemma ZMod.val_sub_add_two_pow_no_wrap {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] (ZMod.val (a + 2 ^ n - b)) < (2 ^ (n+1)) := ZMod.val_sub_add_two_pow_rel (· < ·) (2 ^ (n+1)) -lemma zmod_def {p : ℕ} [Fact p.Prime] (x : ZMod p) : - x.val = ZMod.val x := rfl - lemma bit_is_clear {p : ℕ} [Fact p.Prime] (n : ℕ) (a : ZMod p) (hlt : ZMod.val a < 2^n) : (ZMod.val a).testBit n = false := by rw [Nat.testBit_eq_decide_div_mod_eq] + -- ⊢ decide (a.val / 2 ^ n % 2 = 1) = false have hbpos : 0 < 2^n := pow_pos (by decide) n have hdiv : ZMod.val a / 2^n = 0 := Nat.div_eq_of_lt hlt rw [hdiv, Nat.zero_mod] + -- ⊢ decide (0 = 1) = false simp only [zero_ne_one, decide_false] lemma bit_is_set {p : ℕ} [Fact p.Prime] @@ -213,14 +235,14 @@ lemma bit_is_set {p : ℕ} [Fact p.Prime] (ZMod.val x).testBit n = true := by rw [Nat.testBit_eq_decide_div_mod_eq ] - /- ⊢ ZMod.val a / 2 ^ n % 2 = 1 -/ simp only [decide_eq_true_eq] + /- ⊢ ZMod.val a / 2 ^ n % 2 = 1 -/ set x := ZMod.val x - -- lower bound: 1 ≤ x / 2^n have hbpos : 0 < 2^n := pow_pos (by decide) n + + -- lower bound: 1 ≤ x / 2^n have h1 : 1 ≤ x / 2^n := by simp only [Nat.le_div_iff_mul_le hbpos, one_mul] - /- apply Nat.le_of_lt at hge -/ exact hge -- upper bound: x / 2^n < 2 @@ -232,6 +254,35 @@ lemma bit_is_set {p : ℕ} [Fact p.Prime] rw [le_antisymm (Nat.lt_succ_iff.mp h2) h1] +-- The nth bit of (a + 2^n - b) is 0 iff a.val < b.val, otherwise 1. +lemma testBit_n_of_add_two_pow_sub + {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] + {a b : F p} + (hB : Bounds p n a b) + (h_sum_lt : ZMod.val (a + (2 ^ n : F p) - b) < 2 ^ (n + 1)) : + (ZMod.val (a + (2 ^ n : F p) - b)).testBit n + = (if ZMod.val a < ZMod.val b then false else true) := by + -- relational split for a.val vs b.val + have h_rel := ZMod.val_add_two_pow_sub_rel n a b + -- branch on (a.val < b.val) + by_cases hlt : ZMod.val a < ZMod.val b + · -- then sum < 2^n, bit is 0 + have h_lt : ZMod.val a + 2 ^ n - ZMod.val b < 2 ^ n := by + simpa [hlt] using h_rel + have hx_lt : ZMod.val (a + (2 ^ n : F p) - b) < 2 ^ n := + ZMod.val_sub_add_two_pow_lt hB h_lt + have h_clear : (ZMod.val (a + (2 ^ n : F p) - b)).testBit n = false := + bit_is_clear n (a + (2 ^ n : F p) - b) hx_lt + simpa [hlt, sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h_clear + · -- else sum ≥ 2^n, bit is 1 + have h_ge : ZMod.val a + 2 ^ n - ZMod.val b ≥ 2 ^ n := by + simpa [hlt] using h_rel + have hx_ge : ZMod.val (a + (2 ^ n : F p) - b) ≥ 2 ^ n := + ZMod.val_sub_add_two_pow_ge hB h_ge + have h_set : (ZMod.val (a + (2 ^ n : F p) - b)).testBit n = true := + bit_is_set n (a + (2 ^ n : F p) - b) h_sum_lt hx_ge + simpa [hlt, sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h_set + def main (n : ℕ) (hn : 2^(n+1) < p) (input : Expression (F p) × Expression (F p)) := do let diff := input.1 + (2^n : F p) - input.2 @@ -267,11 +318,10 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w set a := input.1 set b := input.2 set hp := hn + + have hp' : 2 ^ n < p := pow_lt_of_succ hp - have hp' : 2 ^ n < p := by - apply lt_trans (Nat.pow_lt_pow_right (by decide) (Nat.lt_succ_self n)) - exact hp - have h_bounds : Bounds p n a b := ⟨h_assumptions.left, h_assumptions.right, hp, hp'⟩ + have h_bounds : Bounds p n a b := Bounds.of_assumptions h_assumptions.left h_assumptions.right hp rw [h1, h2] at h_holds @@ -294,8 +344,10 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w by_cases hab : ZMod.val a < ZMod.val b . -- sum is < 2^n, so nth bit is 0 - have h_lt : ZMod.val a + 2 ^ n - ZMod.val b < 2 ^ n := - a_lt_b_eq_sum_lt_2n_nat n a b hab + have h_rel := ZMod.val_add_two_pow_sub_rel n a b + simp [hab] at h_rel + have h_lt := h_rel + have h_val_lt : ZMod.val (a + (2 ^ n : F p) - b) < 2 ^ n := by exact ZMod.val_sub_add_two_pow_lt h_bounds h_lt have h_bit_clear : (ZMod.val (a + (2 ^ n : F p) - b)).testBit n = false := @@ -312,8 +364,9 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w rw [← sub_eq_add_neg] at h_holds1 exact h_holds1 - have h_ge : ZMod.val a + 2 ^ n - ZMod.val b ≥ 2 ^ n := - a_ge_b_eq_sum_ge_2n_nat n a b hab' + have h_rel := ZMod.val_add_two_pow_sub_rel n a b + simp [hab] at h_rel + have h_ge := h_rel have h_val_ge : ZMod.val (a + (2 ^ n : F p) - b) ≥ 2 ^ n := by exact ZMod.val_sub_add_two_pow_ge h_bounds h_ge have h_bit_set : (ZMod.val (a + (2 ^ n : F p) - b)).testBit n = true := @@ -324,7 +377,6 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w simp [h_bit_set', hab] completeness := by - /- circuit_proof_all -/ circuit_proof_start simp only [Num2Bits.circuit] at * simp only [circuit_norm] at * @@ -347,21 +399,13 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w rw [← sub_eq_add_neg (a:=(a+ 2 ^ n))] - have hp' : 2 ^ n < p := by - apply lt_trans (Nat.pow_lt_pow_right (by decide) (Nat.lt_succ_self n)) - exact hn + have hp' : 2 ^ n < p := pow_lt_of_succ hp - have h_bounds : Bounds p n a b := ⟨ha, hb, hp, hp'⟩ + have h_bounds : Bounds p n a b := Bounds.of_assumptions ha hb hp have h_comp := ZMod.val_sub_add_two_pow_no_wrap h_bounds - have h_sum_lt_2n1 : ZMod.val a + 2 ^ n - ZMod.val b < 2 ^ (n + 1) := by - calc - ZMod.val a + 2 ^ n - ZMod.val b - ≤ ZMod.val a + 2 ^ n := Nat.sub_le _ _ - _ < 2 ^ n + 2 ^ n := by - apply Nat.add_lt_add_right ha - _ = 2 ^ (n + 1) := by - rw [Nat.pow_succ, Nat.mul_two] + have h_sum_lt_2n1 : ZMod.val a + 2 ^ n - ZMod.val b < 2 ^ (n + 1) := + add_two_pow_sub_lt_pow_succ ha exact h_comp h_sum_lt_2n1 From 1144f99b31d40f6e5e554bfc7709f4ea684bcfa6 Mon Sep 17 00:00:00 2001 From: semar Date: Thu, 9 Oct 2025 17:27:23 -0300 Subject: [PATCH 11/18] cleaned up by_cases --- Clean/Circomlib/lessComparators.lean | 96 ++++++++++++++++++++-------- 1 file changed, 71 insertions(+), 25 deletions(-) diff --git a/Clean/Circomlib/lessComparators.lean b/Clean/Circomlib/lessComparators.lean index 13eede66e..f90c94346 100644 --- a/Clean/Circomlib/lessComparators.lean +++ b/Clean/Circomlib/lessComparators.lean @@ -56,8 +56,6 @@ lemma testBit_sum_sub_eq_sum_addNeg (ZMod.val (a + 2^n - b)).testBit k = (ZMod.val (a + 2^n + -b)).testBit k := by simp [sub_eq_add_neg, add_comm, add_left_comm] - /- using -/ - /- (testBit_sub_eq_addNeg (a + 2^n) b k) -/ /-- Pack the repeated bounds (`ha hb hp hp'`) into your structure in one shot. -/ def Bounds.of_assumptions @@ -283,6 +281,66 @@ lemma testBit_n_of_add_two_pow_sub bit_is_set n (a + (2 ^ n : F p) - b) h_sum_lt hx_ge simpa [hlt, sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h_set +/-- From `<` on naturals to `<` on `.val (a + 2^n - b)` using `Bounds`. -/ +lemma zval_lt_of_nat_lt + {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] + {a b : ZMod p} + (bd : Bounds p n a b) + (h : a.val + 2^n - b.val < 2^n) : + (a + 2^n - b).val < 2^n := + ZMod.val_sub_add_two_pow_lt bd h + +/-- From `≥` on naturals to `≥` on `.val (a + 2^n - b)` using `Bounds`. -/ +lemma zval_ge_of_nat_ge + {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] + {a b : ZMod p} + (bd : Bounds p n a b) + (h : a.val + 2^n - b.val ≥ 2^n) : + (a + 2^n - b).val ≥ 2^n := + ZMod.val_sub_add_two_pow_ge bd h + +/-- If `a.val < b.val`, the nth bit of `ZMod.val (a + 2^n - b)` is `false`. -/ +lemma nth_bit_clear_of_val_lt + {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] + {a b : F p} + (bd : Bounds p n a b) + (hab : a.val < b.val) : + (ZMod.val (a + 2^n - b)).testBit n = false := by + -- nat-level bound + have hnlt := ZMod.val_add_two_pow_sub_rel n a b + simp [hab] at hnlt + have h_lt := hnlt + -- lift through no-wrap + have hzlt := zval_lt_of_nat_lt bd hnlt + exact bit_is_clear n (a + 2^n - b) hzlt + +/-- If `a.val ≥ b.val` and the sum is `< 2^(n+1)`, the nth bit is `true`. -/ +lemma nth_bit_set_of_val_ge + {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] + {a b : F p} + (bd : Bounds p n a b) + (h_sum_lt : (a + 2^n - b).val < 2^(n+1)) + /- (hab : a.val ≥ b.val) : -/ + (hab : ¬(a.val < b.val)) : + (ZMod.val (a + 2^n - b)).testBit n = true := by + -- nat-level bound + have h_rel := ZMod.val_add_two_pow_sub_rel n a b + simp [hab] at h_rel + have h_ge := h_rel + -- lift through no-wrap + have hzge := zval_ge_of_nat_ge bd h_ge + exact bit_is_set n (a + 2^n - b) h_sum_lt hzge + +/-- If `a.val ≥ b.val` then `a.val + 2^n - b.val ≥ 2^n`. -/ +lemma val_add_two_pow_sub_ge_of_val_ge + {p : ℕ} [Fact p.Prime] (n : ℕ) (a b : ZMod p) + (hge : a.val ≥ b.val) : + a.val + 2^n - b.val ≥ 2^n := by + have hb_le_a : b.val ≤ a.val := hge + -- reuse the ≥ arithmetic lemma + have h := add_two_pow_sub_eq_add_diff (n := n) (a := a.val) (b := b.val) hb_le_a + -- `a.val + 2^n - b.val = 2^n + (a.val - b.val)` ≥ `2^n` + simp [h] def main (n : ℕ) (hn : 2^(n+1) < p) (input : Expression (F p) × Expression (F p)) := do let diff := input.1 + (2^n : F p) - input.2 @@ -344,36 +402,24 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w by_cases hab : ZMod.val a < ZMod.val b . -- sum is < 2^n, so nth bit is 0 - have h_rel := ZMod.val_add_two_pow_sub_rel n a b - simp [hab] at h_rel - have h_lt := h_rel - - have h_val_lt : ZMod.val (a + (2 ^ n : F p) - b) < 2 ^ n := by - exact ZMod.val_sub_add_two_pow_lt h_bounds h_lt - have h_bit_clear : (ZMod.val (a + (2 ^ n : F p) - b)).testBit n = false := - bit_is_clear n (a + (2 ^ n : F p) - b) h_val_lt - have h_bit_clear' : (ZMod.val (a + 2 ^ n + -b)).testBit n = false := by - simpa [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h_bit_clear + have h_bit_clear' : + (ZMod.val (a + 2 ^ n + -b)).testBit n = false := by + have h_bit := nth_bit_clear_of_val_lt (p:=p) (n:=n) (a:=a) (b:=b) h_bounds hab + rw [sub_eq_add_neg] at h_bit + exact h_bit + simp [h_bit_clear', hab] . -- sum is ≥ 2^n, so nth bit is 1 - - have hab' : ZMod.val a ≥ ZMod.val b := by - simp_all only [not_lt, ge_iff_le] have h_holds1' : ZMod.val (a + 2 ^ n - b) < 2 ^ (n + 1) := by rw [← sub_eq_add_neg] at h_holds1 exact h_holds1 - have h_rel := ZMod.val_add_two_pow_sub_rel n a b - simp [hab] at h_rel - have h_ge := h_rel - have h_val_ge : ZMod.val (a + (2 ^ n : F p) - b) ≥ 2 ^ n := by - exact ZMod.val_sub_add_two_pow_ge h_bounds h_ge - have h_bit_set : (ZMod.val (a + (2 ^ n : F p) - b)).testBit n = true := - bit_is_set n (a + (2 ^ n : F p) - b) h_holds1' h_val_ge - have h_bit_set' : (ZMod.val (a + 2 ^ n + -b)).testBit n = true := by - simpa [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h_bit_set - + have h_bit_set' : + (ZMod.val (a + 2 ^ n + -b)).testBit n = true := by + have h_bit := nth_bit_set_of_val_ge (p:=p) (n:=n) (a:=a) (b:=b) h_bounds h_holds1' hab + rw [sub_eq_add_neg] at h_bit + exact h_bit simp [h_bit_set', hab] completeness := by From 866471767480cf779b8697fb3d58dd6d353b08ec Mon Sep 17 00:00:00 2001 From: semar Date: Thu, 9 Oct 2025 18:32:26 -0300 Subject: [PATCH 12/18] clean up --- Clean/Circomlib/lessComparators.lean | 48 +++++++++++++++++++--------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/Clean/Circomlib/lessComparators.lean b/Clean/Circomlib/lessComparators.lean index f90c94346..d50f98619 100644 --- a/Clean/Circomlib/lessComparators.lean +++ b/Clean/Circomlib/lessComparators.lean @@ -341,6 +341,26 @@ lemma val_add_two_pow_sub_ge_of_val_ge have h := add_two_pow_sub_eq_add_diff (n := n) (a := a.val) (b := b.val) hb_le_a -- `a.val + 2^n - b.val = 2^n + (a.val - b.val)` ≥ `2^n` simp [h] + +lemma num2bits_bit_eq_testBit + {p n i₀ : ℕ} [Fact p.Prime] [Fact (p > 2)] + {env : Environment (F p)} + {a b : F p} + (h_holds : + Vector.map (Expression.eval env) + (Vector.mapRange (n + 1) (fun i ↦ var { index := i₀ + i })) + = fieldToBits (n + 1) (a + 2 ^ n + -b)) : + (Vector.map (Expression.eval env) + (Vector.mapRange (n + 1) (fun i ↦ var { index := i₀ + i })))[n]'(Nat.lt_succ_self n) + = + (if (ZMod.val (a + 2 ^ n + -b)).testBit n then (1 : F p) else 0) := by + simp only [Vector.ext_iff] at h_holds + specialize h_holds n (Nat.lt_succ_self n) + simp only [Vector.getElem_map, Vector.getElem_mapRange, + fieldToBits, toBits, Nat.cast_ite, Nat.cast_one, Nat.cast_zero] at h_holds + + simp only [Vector.getElem_map, Vector.getElem_mapRange] + exact h_holds def main (n : ℕ) (hn : 2^(n+1) < p) (input : Expression (F p) × Expression (F p)) := do let diff := input.1 + (2^n : F p) - input.2 @@ -368,7 +388,7 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w -- rw [← h_input] - + -- ① evaluate inputs have h1 : Expression.eval env input_var.1 = input.1 := by rw [← h_input] have h2 : Expression.eval env input_var.2 = input.2 := by @@ -376,29 +396,25 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w set a := input.1 set b := input.2 set hp := hn + rw [h1, h2] at h_holds + rw [h1, h2] + simp only [id_eq] + -- ② prepare bounds have hp' : 2 ^ n < p := pow_lt_of_succ hp have h_bounds : Bounds p n a b := Bounds.of_assumptions h_assumptions.left h_assumptions.right hp - - rw [h1, h2] at h_holds - rw [h1, h2] - simp only [id_eq] - set summation := ((ZMod.val a : ℤ) + 2 ^ n + -(ZMod.val b : ℤ)) + -- ③ extract nth bit rw [← Nat.add_assoc] at h_holds - rw [h_holds.right] obtain ⟨⟨h_holds1, h_holds2⟩, h_holds3⟩ := h_holds - simp only [Vector.ext_iff] at h_holds2 - specialize h_holds2 n (Nat.lt_succ_self n) + rw [h_holds3] - rw [Vector.getElem_map, Vector.getElem_mapRange] at h_holds2 - simp only [circuit_norm] at h_holds2 - simp only [fieldToBits, toBits] at h_holds2 - rw [Vector.getElem_map, Vector.getElem_mapRange] at h_holds2 - simp only [Nat.cast_ite, Nat.cast_one, Nat.cast_zero] at h_holds2 - rw [h_holds2] + have h_nbit := num2bits_bit_eq_testBit h_holds2 + simp only [circuit_norm] at h_nbit + rw [h_nbit] + by_cases hab : ZMod.val a < ZMod.val b . -- sum is < 2^n, so nth bit is 0 @@ -411,6 +427,8 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w simp [h_bit_clear', hab] . -- sum is ≥ 2^n, so nth bit is 1 + have hab' : ZMod.val a ≥ ZMod.val b := by + simpa [not_lt, ge_iff_le] using hab have h_holds1' : ZMod.val (a + 2 ^ n - b) < 2 ^ (n + 1) := by rw [← sub_eq_add_neg] at h_holds1 exact h_holds1 From 55cb0ce40918e1fdb1619eab1f46f14c13e740d7 Mon Sep 17 00:00:00 2001 From: semar Date: Thu, 9 Oct 2025 18:46:35 -0300 Subject: [PATCH 13/18] removing unused lemmas --- Clean/Circomlib/lessComparators.lean | 158 ++++++++++----------------- 1 file changed, 60 insertions(+), 98 deletions(-) diff --git a/Clean/Circomlib/lessComparators.lean b/Clean/Circomlib/lessComparators.lean index d50f98619..51aa3c6cd 100644 --- a/Clean/Circomlib/lessComparators.lean +++ b/Clean/Circomlib/lessComparators.lean @@ -28,10 +28,6 @@ template LessThan(n) { out <== 1-n2b.out[n]; } -/ -structure Input (p : ℕ) [Fact p.Prime] [Fact (p > 2)] where - a : F p - b : F p - structure Bounds (p n : ℕ) [Fact p.Prime] [Fact (p > 2)] (a b : F p) @@ -45,18 +41,6 @@ structure Bounds lemma pow_lt_of_succ {n p : ℕ} (hn : 2^(n+1) < p) : 2^n < p := by exact lt_trans (Nat.pow_lt_pow_right (by decide) (Nat.lt_succ_self n)) hn -/-- `testBit` is insensitive to rewriting `x - y` as `x + -y`. -/ -lemma testBit_sub_eq_addNeg (x y : ZMod p) (k : ℕ) : - (ZMod.val (x - y)).testBit k = (ZMod.val (x + -y)).testBit k := by - simp [sub_eq_add_neg] - -lemma testBit_sum_sub_eq_sum_addNeg - {p : ℕ} [Fact p.Prime] [Fact (p > 2)] - (a b : ZMod p) (n k : ℕ) : - (ZMod.val (a + 2^n - b)).testBit k - = (ZMod.val (a + 2^n + -b)).testBit k := by - simp [sub_eq_add_neg, add_comm, add_left_comm] - /-- Pack the repeated bounds (`ha hb hp hp'`) into your structure in one shot. -/ def Bounds.of_assumptions {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] @@ -158,13 +142,13 @@ private lemma add_two_pow_no_wrap_val {p : ℕ} [Fact (p > 2)] [Fact p.Prime] (a -- Helper: no wrap on (a + 2^n) - b because b.val ≤ a.val + 2^n private lemma ZMod.val_sub_add_two_pow_of_no_wrap (n : ℕ) (a b : ZMod p) - (ha : a.val < 2 ^ n) (hb : b.val < 2 ^ n) (hp : 2^n < p) (hp' : 2 ^ (n+1) < p) : + (bounds: Bounds p n a b) : ((a + 2 ^ n) - b).val = (a.val + 2 ^ n) - b.val := by -- First compute (a + 2^n).val without wrap have hadd : (a + 2 ^ n).val = a.val + 2 ^ n := - add_two_pow_no_wrap_val (n:=n) a ha hp hp' + add_two_pow_no_wrap_val (n:=n) a bounds.ha bounds.hp' bounds.hp -- b.val ≤ 2^n ≤ 2^n + a.val = (a + 2^n).val - have hb_le_twoPow : b.val ≤ 2 ^ n := Nat.le_of_lt hb + have hb_le_twoPow : b.val ≤ 2 ^ n := Nat.le_of_lt bounds.hb have twoPow_le_sum : 2 ^ n ≤ (a.val + 2 ^ n) := by simp [Nat.add_comm] have hble : b.val ≤ (a.val + 2 ^ n) := le_trans hb_le_twoPow twoPow_le_sum @@ -182,13 +166,13 @@ private lemma ZMod.val_sub_add_two_pow_of_no_wrap (n : ℕ) (a b : ZMod p) lemma ZMod.val_sub_add_two_pow_rel {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] {a b : F p} (R : ℕ → ℕ → Prop) (threshold : ℕ) - (h_bounds : Bounds p n a b) + (bounds : Bounds p n a b) (h_val : R (ZMod.val a + 2 ^ n - ZMod.val b) threshold) : R (ZMod.val (a + 2 ^ n - b)) threshold := by - have hp' : 2 ^ n < p := pow_lt_of_succ h_bounds.hp + have hp' : 2 ^ n < p := pow_lt_of_succ bounds.hp - rw [ZMod.val_sub_add_two_pow_of_no_wrap n a b h_bounds.ha h_bounds.hb h_bounds.hp' h_bounds.hp] + rw [ZMod.val_sub_add_two_pow_of_no_wrap n a b bounds] exact h_val lemma ZMod.val_sub_add_two_pow_lt {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] @@ -252,58 +236,29 @@ lemma bit_is_set {p : ℕ} [Fact p.Prime] rw [le_antisymm (Nat.lt_succ_iff.mp h2) h1] --- The nth bit of (a + 2^n - b) is 0 iff a.val < b.val, otherwise 1. -lemma testBit_n_of_add_two_pow_sub - {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] - {a b : F p} - (hB : Bounds p n a b) - (h_sum_lt : ZMod.val (a + (2 ^ n : F p) - b) < 2 ^ (n + 1)) : - (ZMod.val (a + (2 ^ n : F p) - b)).testBit n - = (if ZMod.val a < ZMod.val b then false else true) := by - -- relational split for a.val vs b.val - have h_rel := ZMod.val_add_two_pow_sub_rel n a b - -- branch on (a.val < b.val) - by_cases hlt : ZMod.val a < ZMod.val b - · -- then sum < 2^n, bit is 0 - have h_lt : ZMod.val a + 2 ^ n - ZMod.val b < 2 ^ n := by - simpa [hlt] using h_rel - have hx_lt : ZMod.val (a + (2 ^ n : F p) - b) < 2 ^ n := - ZMod.val_sub_add_two_pow_lt hB h_lt - have h_clear : (ZMod.val (a + (2 ^ n : F p) - b)).testBit n = false := - bit_is_clear n (a + (2 ^ n : F p) - b) hx_lt - simpa [hlt, sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h_clear - · -- else sum ≥ 2^n, bit is 1 - have h_ge : ZMod.val a + 2 ^ n - ZMod.val b ≥ 2 ^ n := by - simpa [hlt] using h_rel - have hx_ge : ZMod.val (a + (2 ^ n : F p) - b) ≥ 2 ^ n := - ZMod.val_sub_add_two_pow_ge hB h_ge - have h_set : (ZMod.val (a + (2 ^ n : F p) - b)).testBit n = true := - bit_is_set n (a + (2 ^ n : F p) - b) h_sum_lt hx_ge - simpa [hlt, sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h_set - /-- From `<` on naturals to `<` on `.val (a + 2^n - b)` using `Bounds`. -/ lemma zval_lt_of_nat_lt {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] {a b : ZMod p} - (bd : Bounds p n a b) + (bounds : Bounds p n a b) (h : a.val + 2^n - b.val < 2^n) : (a + 2^n - b).val < 2^n := - ZMod.val_sub_add_two_pow_lt bd h + ZMod.val_sub_add_two_pow_lt bounds h /-- From `≥` on naturals to `≥` on `.val (a + 2^n - b)` using `Bounds`. -/ lemma zval_ge_of_nat_ge {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] {a b : ZMod p} - (bd : Bounds p n a b) + (bounds : Bounds p n a b) (h : a.val + 2^n - b.val ≥ 2^n) : (a + 2^n - b).val ≥ 2^n := - ZMod.val_sub_add_two_pow_ge bd h + ZMod.val_sub_add_two_pow_ge bounds h /-- If `a.val < b.val`, the nth bit of `ZMod.val (a + 2^n - b)` is `false`. -/ lemma nth_bit_clear_of_val_lt {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] {a b : F p} - (bd : Bounds p n a b) + (bounds : Bounds p n a b) (hab : a.val < b.val) : (ZMod.val (a + 2^n - b)).testBit n = false := by -- nat-level bound @@ -311,14 +266,14 @@ lemma nth_bit_clear_of_val_lt simp [hab] at hnlt have h_lt := hnlt -- lift through no-wrap - have hzlt := zval_lt_of_nat_lt bd hnlt + have hzlt := zval_lt_of_nat_lt bounds hnlt exact bit_is_clear n (a + 2^n - b) hzlt /-- If `a.val ≥ b.val` and the sum is `< 2^(n+1)`, the nth bit is `true`. -/ lemma nth_bit_set_of_val_ge {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] {a b : F p} - (bd : Bounds p n a b) + (bounds : Bounds p n a b) (h_sum_lt : (a + 2^n - b).val < 2^(n+1)) /- (hab : a.val ≥ b.val) : -/ (hab : ¬(a.val < b.val)) : @@ -328,20 +283,9 @@ lemma nth_bit_set_of_val_ge simp [hab] at h_rel have h_ge := h_rel -- lift through no-wrap - have hzge := zval_ge_of_nat_ge bd h_ge + have hzge := zval_ge_of_nat_ge bounds h_ge exact bit_is_set n (a + 2^n - b) h_sum_lt hzge -/-- If `a.val ≥ b.val` then `a.val + 2^n - b.val ≥ 2^n`. -/ -lemma val_add_two_pow_sub_ge_of_val_ge - {p : ℕ} [Fact p.Prime] (n : ℕ) (a b : ZMod p) - (hge : a.val ≥ b.val) : - a.val + 2^n - b.val ≥ 2^n := by - have hb_le_a : b.val ≤ a.val := hge - -- reuse the ≥ arithmetic lemma - have h := add_two_pow_sub_eq_add_diff (n := n) (a := a.val) (b := b.val) hb_le_a - -- `a.val + 2^n - b.val = 2^n + (a.val - b.val)` ≥ `2^n` - simp [h] - lemma num2bits_bit_eq_testBit {p n i₀ : ℕ} [Fact p.Prime] [Fact (p > 2)] {env : Environment (F p)} @@ -362,6 +306,43 @@ lemma num2bits_bit_eq_testBit simp only [Vector.getElem_map, Vector.getElem_mapRange] exact h_holds +lemma nth_bit_from_val + {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] + {a b : F p} + (bounds : Bounds p n a b) + (h_holds1 : ZMod.val (a + 2 ^ n - b) < 2 ^ (n + 1)) + : (ZMod.val (a + 2 ^ n + -b)).testBit n + = decide (¬ (ZMod.val a < ZMod.val b)) := by + by_cases hab : ZMod.val a < ZMod.val b + · -- Case 1: a < b → bit n is 0 + have h_bit_clear : + (ZMod.val (a + 2 ^ n + -b)).testBit n = false := by + have h_bit := nth_bit_clear_of_val_lt (p := p) (n := n) (a := a) (b := b) bounds hab + rw [sub_eq_add_neg] at h_bit + exact h_bit + simp [h_bit_clear, hab] + · -- Case 2: a ≥ b → bit n is 1 + have h_bit_set : + (ZMod.val (a + 2 ^ n + -b)).testBit n = true := by + have h_ge : ZMod.val a ≥ ZMod.val b := by simpa [not_lt, ge_iff_le] using hab + have h_bit := nth_bit_set_of_val_ge (p := p) (n := n) (a := a) (b := b) + bounds h_holds1 hab + rw [sub_eq_add_neg] at h_bit + exact h_bit + simp [h_bit_set, hab] + +/-- In any field, `1 - [y ≤ x] = [x < y]` where brackets mean 1/0-as-a-field. -/ +lemma one_sub_if_le_eq_if_lt {F} [Field F] {x y : ℕ} : + (1 : F) + - (if y ≤ x then 1 else 0) + = (if x < y then 1 else 0) := by + by_cases h : y ≤ x + · -- then `¬ (x < y)` + have hxlt : ¬ x < y := not_lt.mpr h + simp [h, hxlt] + · -- so `x < y` + have hxlt : x < y := lt_of_not_ge h + simp [h, hxlt] + def main (n : ℕ) (hn : 2^(n+1) < p) (input : Expression (F p) × Expression (F p)) := do let diff := input.1 + (2^n : F p) - input.2 let bits ← Num2Bits.circuit (n + 1) hn diff @@ -403,7 +384,7 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w -- ② prepare bounds have hp' : 2 ^ n < p := pow_lt_of_succ hp - have h_bounds : Bounds p n a b := Bounds.of_assumptions h_assumptions.left h_assumptions.right hp + have bounds := Bounds.of_assumptions h_assumptions.left h_assumptions.right hp -- ③ extract nth bit rw [← Nat.add_assoc] at h_holds @@ -415,30 +396,12 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w simp only [circuit_norm] at h_nbit rw [h_nbit] - by_cases hab : ZMod.val a < ZMod.val b - . - -- sum is < 2^n, so nth bit is 0 - have h_bit_clear' : - (ZMod.val (a + 2 ^ n + -b)).testBit n = false := by - have h_bit := nth_bit_clear_of_val_lt (p:=p) (n:=n) (a:=a) (b:=b) h_bounds hab - rw [sub_eq_add_neg] at h_bit - exact h_bit - - simp [h_bit_clear', hab] - . - -- sum is ≥ 2^n, so nth bit is 1 - have hab' : ZMod.val a ≥ ZMod.val b := by - simpa [not_lt, ge_iff_le] using hab - have h_holds1' : ZMod.val (a + 2 ^ n - b) < 2 ^ (n + 1) := by - rw [← sub_eq_add_neg] at h_holds1 - exact h_holds1 - - have h_bit_set' : - (ZMod.val (a + 2 ^ n + -b)).testBit n = true := by - have h_bit := nth_bit_set_of_val_ge (p:=p) (n:=n) (a:=a) (b:=b) h_bounds h_holds1' hab - rw [sub_eq_add_neg] at h_bit - exact h_bit - simp [h_bit_set', hab] + -- ④ core logic: bit is set iff a ≥ b + rw [← sub_eq_add_neg] at h_holds1 + have h_bit := nth_bit_from_val bounds h_holds1 + simp only [h_bit, circuit_norm] + simp only [not_lt, decide_eq_true_eq] + simpa using (one_sub_if_le_eq_if_lt (F := F p) (x := ZMod.val a) (y := ZMod.val b)) completeness := by circuit_proof_start @@ -465,12 +428,11 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w have hp' : 2 ^ n < p := pow_lt_of_succ hp - have h_bounds : Bounds p n a b := Bounds.of_assumptions ha hb hp - have h_comp := ZMod.val_sub_add_two_pow_no_wrap h_bounds + have bounds := Bounds.of_assumptions ha hb hp have h_sum_lt_2n1 : ZMod.val a + 2 ^ n - ZMod.val b < 2 ^ (n + 1) := add_two_pow_sub_lt_pow_succ ha - exact h_comp h_sum_lt_2n1 + exact ZMod.val_sub_add_two_pow_no_wrap bounds h_sum_lt_2n1 end LessThan From 6966532a33a5e75126015b563ac1df307670046a Mon Sep 17 00:00:00 2001 From: semar Date: Fri, 10 Oct 2025 08:26:13 -0300 Subject: [PATCH 14/18] move code out of lessComparators back into original file --- Clean/Circomlib/Comparators.lean | 397 ++++++++++++++++++++++-- Clean/Circomlib/lessComparators.lean | 438 --------------------------- 2 files changed, 370 insertions(+), 465 deletions(-) delete mode 100644 Clean/Circomlib/lessComparators.lean diff --git a/Clean/Circomlib/Comparators.lean b/Clean/Circomlib/Comparators.lean index ebad5a97a..00b1c3a3f 100644 --- a/Clean/Circomlib/Comparators.lean +++ b/Clean/Circomlib/Comparators.lean @@ -1,7 +1,6 @@ import Clean.Circuit import Clean.Utils.Bits import Clean.Circomlib.Bitify -import Mathlib.Tactic /- Original source code: @@ -166,26 +165,8 @@ def circuit : FormalAssertion (F p) Inputs where enabled = 1 → inp.1 = inp.2 soundness := by - circuit_proof_start - intro h_ie - simp_all only [gt_iff_lt, one_ne_zero, or_true, id_eq, one_mul] - cases h_input with - | intro h_enabled h_inp => - rw [← h_inp] - simp only - cases h_holds with - | intro h1 h2 => - rw [h1] at h2 - rw [add_comm] at h2 - simp only [id_eq] at h2 - split_ifs at h2 with h_ifs - . simp_all only [neg_add_cancel] - rw [add_comm, neg_add_eq_zero] at h_ifs - exact h_ifs - . simp_all only [neg_zero, zero_add, one_ne_zero] - rw [add_comm, neg_add_eq_zero] at h2 - rw [h2] at h1 - trivial + simp only [circuit_norm, main] + sorry completeness := by simp only [circuit_norm, main] @@ -195,7 +176,6 @@ end ForceEqualIfEnabled namespace LessThan /- template LessThan(n) { - assert(n <= 252); signal input in[2]; signal output out; @@ -206,6 +186,303 @@ template LessThan(n) { out <== 1-n2b.out[n]; } -/ +structure Bounds + (p n : ℕ) [Fact p.Prime] [Fact (p > 2)] + (a b : F p) + where + ha : ZMod.val a < 2 ^ n + hb : ZMod.val b < 2 ^ n + hp : 2 ^ (n + 1) < p + hp' : 2 ^ n < p + +/- From `2^(n+1) < p` get `2^n < p`. -/ +lemma pow_lt_of_succ {n p : ℕ} (hn : 2^(n+1) < p) : 2^n < p := by + exact lt_trans (Nat.pow_lt_pow_right (by decide) (Nat.lt_succ_self n)) hn + +/- Pack the repeated bounds (`ha hb hp hp'`) into your structure in one shot. -/ +def Bounds.of_assumptions + {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] + {a b : ZMod p} + (ha : a.val < 2^n) (hb : b.val < 2^n) + (hp_succ : 2^(n+1) < p) : Bounds p n a b := + ⟨ha, hb, hp_succ, pow_lt_of_succ hp_succ⟩ + +lemma add_two_pow_sub_eq_add_diff {n a b : ℕ} (h : b ≤ a) : + a + 2 ^ n - b = 2 ^ n + (a - b) := by + calc + a + 2 ^ n - b + = (2 ^ n + a) - b := by ac_rfl + _ = (2 ^ n + a) - (b + 0) := by rfl + _ = 2 ^ n + (a - b) := by + simp only [Nat.add_zero, Nat.add_sub_assoc h] + +lemma add_two_pow_sub_eq_sub_diff {n a b : ℕ} (h : a < b) : + a + 2 ^ n - b = 2 ^ n - (b - a) := by + have hb : b = a + (b - a) := (Nat.add_sub_of_le (Nat.le_of_lt h)).symm + calc + a + 2 ^ n - b + = a + 2 ^ n - (a + (b - a)) := by rw [hb]; simp only [Nat.add_sub_cancel_left] + _ = (2 ^ n + a) - ((b - a) + a) := by ac_rfl + _ = 2 ^ n - (b - a) := by simp only [Nat.add_sub_add_right] + +lemma add_two_pow_sub_lt_pow_succ {n a b : ℕ} (ha : a < 2 ^ n) : + a + 2 ^ n - b < 2 ^ (n + 1) := by + calc + a + 2 ^ n - b ≤ a + 2 ^ n := Nat.sub_le _ _ + _ < 2 ^ n + 2 ^ n := Nat.add_lt_add_right ha _ + _ = 2 ^ (n + 1) := by rw [Nat.pow_succ, Nat.mul_two] + +lemma ZMod.val_add_two_pow_sub_rel + {p : ℕ} [Fact p.Prime] + (n : ℕ) (a b : F p) : + + if ZMod.val a < ZMod.val b then + ZMod.val a + 2 ^ n - ZMod.val b < 2 ^ n + else + ZMod.val a + 2 ^ n - ZMod.val b ≥ 2 ^ n := by + + split_ifs with h_lt + · rw [add_two_pow_sub_eq_sub_diff h_lt] + have h_pos : 0 < ZMod.val b - ZMod.val a := Nat.sub_pos_of_lt h_lt + simp_all [tsub_lt_self_iff, pow_pos] + · rw [add_two_pow_sub_eq_add_diff (le_of_not_gt h_lt)] + exact Nat.le_add_right _ _ + +lemma ZMod.val_two_pow_of_lt {p n : ℕ} [NeZero p] [Fact p.Prime] (h : 2 ^ n < p) (hp : p > 2): + (2 ^ n : ZMod p).val = 2 ^ n := by + + have p_ne_zero := NeZero.ne p + rw [ZMod.val_pow] at * + rw [← Nat.cast_ofNat] + rw [ZMod.val_natCast] + . -- (2 % p) ^ n = 2 ^ n + have h_mod := Nat.mod_eq_iff_lt (m := 2) (n := p) p_ne_zero + + have h_mod' : 2 % p = 2 := by + simp_all only [gt_iff_lt, ne_eq, iff_true] + + rw [h_mod'] + + . + rw [← Nat.cast_ofNat] + rw [ZMod.val_natCast] + + have h_mod := Nat.mod_eq_iff_lt (m := 2) (n := p) p_ne_zero + have h_mod' : 2 % p = 2 := by + simp_all only [gt_iff_lt, ne_eq, iff_true] + rw [h_mod'] + exact h + +-- Helper: no wrap on a + 2^n +private lemma add_two_pow_no_wrap_val {p : ℕ} [Fact (p > 2)] [Fact p.Prime] (a : ZMod p) (n : ℕ) + (ha : a.val < 2 ^ n) (hp : 2 ^ n < p) (hp' : 2 ^ (n+1) < p) : + (a + 2 ^ n).val = a.val + 2 ^ n := by + + have hp2 := Fact.out (p := p > 2) + + have h2n : (2^n: ZMod p).val = 2^n := by + exact ZMod.val_two_pow_of_lt hp hp2 + + have hlt : a.val + 2 ^ n < p := lt_trans + (by + have : a.val + 2 ^ n < 2 ^ n + 2 ^ n := Nat.add_lt_add_right ha _ + simp [pow_succ] + rw [Nat.mul_two] + exact this + ) + hp' + have hlt' : a.val + (2^n : ZMod p).val < p := by + simp_all only + + rw [ZMod.val_add_of_lt hlt'] + rw [h2n] + + -- Helper: no wrap on (a + 2^n) - b because b.val ≤ a.val + 2^n +private lemma ZMod.val_sub_add_two_pow_of_no_wrap (n : ℕ) (a b : ZMod p) + (bounds: Bounds p n a b) : + ((a + 2 ^ n) - b).val = (a.val + 2 ^ n) - b.val := by + -- First compute (a + 2^n).val without wrap + have hadd : (a + 2 ^ n).val = a.val + 2 ^ n := + add_two_pow_no_wrap_val (n:=n) a bounds.ha bounds.hp' bounds.hp + -- b.val ≤ 2^n ≤ 2^n + a.val = (a + 2^n).val + have hb_le_twoPow : b.val ≤ 2 ^ n := Nat.le_of_lt bounds.hb + have twoPow_le_sum : 2 ^ n ≤ (a.val + 2 ^ n) := by + simp [Nat.add_comm] + have hble : b.val ≤ (a.val + 2 ^ n) := le_trans hb_le_twoPow twoPow_le_sum + -- For subtraction in ZMod: if x.val ≥ y.val then (x - y).val = x.val - y.val + -- Rewrite x.val using hadd, then finish. + have hres : ((a + 2 ^ n) - b).val = (a + (2 ^ n : ℕ)).val - b.val := by + rw [ZMod.val_sub] + simp_all only [le_add_iff_nonneg_left, zero_le, Nat.cast_pow, Nat.cast_ofNat] + rw [hadd] + exact hble + + rw [hres] + simp_all only [le_add_iff_nonneg_left, zero_le, Nat.cast_pow, Nat.cast_ofNat] + +lemma ZMod.val_sub_add_two_pow_rel {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] + {a b : F p} + (R : ℕ → ℕ → Prop) (threshold : ℕ) + (bounds : Bounds p n a b) + (h_val : R (ZMod.val a + 2 ^ n - ZMod.val b) threshold) : + R (ZMod.val (a + 2 ^ n - b)) threshold := by + + have hp' : 2 ^ n < p := pow_lt_of_succ bounds.hp + + rw [ZMod.val_sub_add_two_pow_of_no_wrap n a b bounds] + exact h_val + +-- Helper: (ZMod.val (a + 2 ^ n - b)) < (2 ^ n) +lemma ZMod.val_sub_add_two_pow_lt {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] + {a b : F p} : + (Bounds p n a b) -> + (ZMod.val a + 2 ^ n - ZMod.val b) < (2 ^ n) -> + (ZMod.val (a + 2 ^ n - b)) < (2 ^ n) := + ZMod.val_sub_add_two_pow_rel (· < ·) (2 ^ n) + +-- Helper: (ZMod.val (a + 2 ^ n - b)) ≥ (2 ^ n) +lemma ZMod.val_sub_add_two_pow_ge {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] + {a b : F p} : + (Bounds p n a b) -> + (ZMod.val a + 2 ^ n - ZMod.val b) ≥ (2 ^ n) -> + (ZMod.val (a + 2 ^ n - b)) ≥ (2 ^ n) := + ZMod.val_sub_add_two_pow_rel (· ≥ ·) (2 ^ n) + +-- Helper: (ZMod.val (a + 2 ^ n - b)) < (2 ^ (n+1)) +lemma ZMod.val_sub_add_two_pow_no_wrap {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] + {a b : F p} : + (Bounds p n a b) -> + (ZMod.val a + 2 ^ n - ZMod.val b) < (2 ^ (n+1)) -> + (ZMod.val (a + 2 ^ n - b)) < (2 ^ (n+1)) := + ZMod.val_sub_add_two_pow_rel (· < ·) (2 ^ (n+1)) + +lemma bit_is_clear {p : ℕ} [Fact p.Prime] + (n : ℕ) (a : ZMod p) + (hlt : ZMod.val a < 2^n) : + (ZMod.val a).testBit n = false := by + rw [Nat.testBit_eq_decide_div_mod_eq] + -- ⊢ decide (a.val / 2 ^ n % 2 = 1) = false + have hbpos : 0 < 2^n := pow_pos (by decide) n + have hdiv : ZMod.val a / 2^n = 0 := + Nat.div_eq_of_lt hlt + rw [hdiv, Nat.zero_mod] + -- ⊢ decide (0 = 1) = false + simp only [zero_ne_one, decide_false] + +lemma bit_is_set {p : ℕ} [Fact p.Prime] + /- (n : ℕ) (a b : ℕ) -/ + (n : ℕ) (x : F p) + (hlt: ZMod.val x < 2^(n+1)) + (hge: ZMod.val x ≥ 2^n) : + (ZMod.val x).testBit n = true := by + rw [Nat.testBit_eq_decide_div_mod_eq ] + + simp only [decide_eq_true_eq] + /- ⊢ ZMod.val a / 2 ^ n % 2 = 1 -/ + set x := ZMod.val x + have hbpos : 0 < 2^n := pow_pos (by decide) n + + -- lower bound: 1 ≤ x / 2^n + have h1 : 1 ≤ x / 2^n := by + simp only [Nat.le_div_iff_mul_le hbpos, one_mul] + exact hge + + -- upper bound: x / 2^n < 2 + have h2 : x / 2^n < 2 := by + rw [Nat.div_lt_iff_lt_mul hbpos] + + rw [← Nat.pow_add_one'] + exact hlt + + rw [le_antisymm (Nat.lt_succ_iff.mp h2) h1] + +/- If `a.val < b.val`, the nth bit of `ZMod.val (a + 2^n - b)` is `false`. -/ +lemma nth_bit_clear_of_val_lt + {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] + {a b : F p} + (bounds : Bounds p n a b) + (hab : a.val < b.val) : + (ZMod.val (a + 2^n - b)).testBit n = false := by + -- nat-level bound + have hnlt := ZMod.val_add_two_pow_sub_rel n a b + simp [hab] at hnlt + -- lift through no-wrap + have hzlt := ZMod.val_sub_add_two_pow_lt bounds hnlt + exact bit_is_clear n (a + 2^n - b) hzlt + +/- If `a.val ≥ b.val` and the sum is `< 2^(n+1)`, the nth bit is `true`. -/ +lemma nth_bit_set_of_val_ge + {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] + {a b : F p} + (bounds : Bounds p n a b) + (h_sum_lt : (a + 2^n - b).val < 2^(n+1)) + /- (hab : a.val ≥ b.val) : -/ + (hab : ¬(a.val < b.val)) : + (ZMod.val (a + 2^n - b)).testBit n = true := by + -- nat-level bound + have h_rel := ZMod.val_add_two_pow_sub_rel n a b + simp [hab] at h_rel + -- lift through no-wrap + have hzge := ZMod.val_sub_add_two_pow_ge bounds h_rel + exact bit_is_set n (a + 2^n - b) h_sum_lt hzge + +lemma num2bits_bit_eq_testBit + {p n i₀ : ℕ} [Fact p.Prime] [Fact (p > 2)] + {env : Environment (F p)} + {a b : F p} + (h_holds : + Vector.map (Expression.eval env) + (Vector.mapRange (n + 1) (fun i ↦ var { index := i₀ + i })) + = fieldToBits (n + 1) (a + 2 ^ n + -b)) : + (Vector.map (Expression.eval env) + (Vector.mapRange (n + 1) (fun i ↦ var { index := i₀ + i })))[n]'(Nat.lt_succ_self n) + = + (if (ZMod.val (a + 2 ^ n + -b)).testBit n then (1 : F p) else 0) := by + simp only [Vector.ext_iff] at h_holds + specialize h_holds n (Nat.lt_succ_self n) + simp only [Vector.getElem_map, Vector.getElem_mapRange, + fieldToBits, toBits, Nat.cast_ite, Nat.cast_one, Nat.cast_zero] at h_holds + + simp only [Vector.getElem_map, Vector.getElem_mapRange] + exact h_holds + +lemma nth_bit_from_val + {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] + {a b : F p} + (bounds : Bounds p n a b) + (h_holds1 : ZMod.val (a + 2 ^ n - b) < 2 ^ (n + 1)) + : (ZMod.val (a + 2 ^ n + -b)).testBit n + = decide (¬ (ZMod.val a < ZMod.val b)) := by + by_cases hab : ZMod.val a < ZMod.val b + · -- Case 1: a < b → bit n is 0 + have h_bit_clear : + (ZMod.val (a + 2 ^ n + -b)).testBit n = false := by + have h_bit := nth_bit_clear_of_val_lt (p := p) (n := n) (a := a) (b := b) bounds hab + rw [sub_eq_add_neg] at h_bit + exact h_bit + simp [h_bit_clear, hab] + · -- Case 2: a ≥ b → bit n is 1 + have h_bit_set : + (ZMod.val (a + 2 ^ n + -b)).testBit n = true := by + have h_ge : ZMod.val a ≥ ZMod.val b := by simpa [not_lt, ge_iff_le] using hab + have h_bit := nth_bit_set_of_val_ge (p := p) (n := n) (a := a) (b := b) + bounds h_holds1 hab + rw [sub_eq_add_neg] at h_bit + exact h_bit + simp [h_bit_set, hab] + +/- In any field, `1 - [y ≤ x] = [x < y]` where brackets mean 1/0-as-a-field. -/ +lemma one_sub_if_le_eq_if_lt {F} [Field F] {x y : ℕ} : + (1 : F) + - (if y ≤ x then 1 else 0) + = (if x < y then 1 else 0) := by + by_cases h : y ≤ x + · -- then `¬ (x < y)` + have hxlt : ¬ x < y := not_lt.mpr h + simp [h, hxlt] + · -- so `x < y` + have hxlt : x < y := lt_of_not_ge h + simp [h, hxlt] + def main (n : ℕ) (hn : 2^(n+1) < p) (input : Expression (F p) × Expression (F p)) := do let diff := input.1 + (2^n : F p) - input.2 let bits ← Num2Bits.circuit (n + 1) hn diff @@ -219,18 +496,84 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w output _ i := var ⟨ i + n + 1 ⟩ output_eq := by simp +arith [circuit_norm, main, Num2Bits.circuit] - Assumptions := fun (x, y) => x.val < 2^n ∧ y.val < 2^n + Assumptions := fun (x, y) => x.val < 2^n ∧ y.val < 2^n -- TODO: ∧ n <= 252 Spec := fun (x, y) output => output = (if x.val < y.val then 1 else 0) soundness := by - simp only [circuit_norm, main] - sorry + intro i₀ env input_var input h_input h_assumptions h_holds + unfold main at * + simp only [circuit_norm, Num2Bits.circuit] at h_holds + simp only [circuit_norm] at * +-- + rw [← h_input] + + -- ① evaluate inputs + have h1 : Expression.eval env input_var.1 = input.1 := by + rw [← h_input] + have h2 : Expression.eval env input_var.2 = input.2 := by + rw [← h_input] + set a := input.1 + set b := input.2 + set hp := hn + rw [h1, h2] at h_holds + rw [h1, h2] + simp only [id_eq] + + -- ② prepare bounds + have hp' : 2 ^ n < p := pow_lt_of_succ hp + + have bounds := Bounds.of_assumptions h_assumptions.left h_assumptions.right hp + + -- ③ extract nth bit + rw [← Nat.add_assoc] at h_holds + obtain ⟨⟨h_holds1, h_holds2⟩, h_holds3⟩ := h_holds + + rw [h_holds3] + + have h_nbit := num2bits_bit_eq_testBit h_holds2 + simp only [circuit_norm] at h_nbit + rw [h_nbit] + + -- ④ core logic: bit is set iff a ≥ b + rw [← sub_eq_add_neg] at h_holds1 + have h_bit := nth_bit_from_val bounds h_holds1 + simp only [h_bit, circuit_norm] + simp only [not_lt, decide_eq_true_eq] + simpa using (one_sub_if_le_eq_if_lt (F := F p) (x := ZMod.val a) (y := ZMod.val b)) completeness := by - simp only [circuit_norm, main] - sorry + circuit_proof_start + simp only [Num2Bits.circuit] at * + simp only [circuit_norm] at * + simp_all only [gt_iff_lt, true_and, id_eq, and_true] + + obtain ⟨h_envl, h_envr⟩ := h_env + obtain ⟨ha, hb⟩ := h_assumptions + + set hp := hn + + -- ① evaluate inputs + have h1 : Expression.eval env input_var.1 = input.1 := by + rw [← h_input] + have h2 : Expression.eval env input_var.2 = input.2 := by + rw [← h_input] + set a := input.1 + set b := input.2 + rw [h1, h2] + rw [h1, h2] at h_envl + + -- ② prepare bounds + rw [← sub_eq_add_neg (a:=(a+ 2 ^ n))] + have bounds := Bounds.of_assumptions ha hb hp + + -- ③ core logic: ⊢ ZMod.val (a + 2 ^ n - b) < 2 ^ (n + 1) + have h_sum_lt_2n1 : ZMod.val a + 2 ^ n - ZMod.val b < 2 ^ (n + 1) := + add_two_pow_sub_lt_pow_succ ha + + exact ZMod.val_sub_add_two_pow_no_wrap bounds h_sum_lt_2n1 + end LessThan namespace LessEqThan diff --git a/Clean/Circomlib/lessComparators.lean b/Clean/Circomlib/lessComparators.lean deleted file mode 100644 index 51aa3c6cd..000000000 --- a/Clean/Circomlib/lessComparators.lean +++ /dev/null @@ -1,438 +0,0 @@ -/- import Std.Data.Vector.Basic -/ -import Clean.Circuit -import Clean.Utils.Bits -import Clean.Circomlib.Bitify -import Mathlib.Data.Nat.Bitwise -import Mathlib.Data.ZMod.Basic - -/- -Original source code: -https://github.com/iden3/circomlib/blob/35e54ea21da3e8762557234298dbb553c175ea8d/circuits/comparators.circom --/ - -namespace Circomlib -open Utils.Bits -variable {p : ℕ} [Fact p.Prime] [Fact (p > 2)] - - -namespace LessThan -/- -template LessThan(n) { - signal input in[2]; - signal output out; - - component n2b = Num2Bits(n+1); - - n2b.in <== in[0]+ (1 << n) - in[1]; - - out <== 1-n2b.out[n]; -} --/ -structure Bounds - (p n : ℕ) [Fact p.Prime] [Fact (p > 2)] - (a b : F p) - where - ha : ZMod.val a < 2 ^ n - hb : ZMod.val b < 2 ^ n - hp : 2 ^ (n + 1) < p - hp' : 2 ^ n < p - -/-- From `2^(n+1) < p` get `2^n < p`. -/ -lemma pow_lt_of_succ {n p : ℕ} (hn : 2^(n+1) < p) : 2^n < p := by - exact lt_trans (Nat.pow_lt_pow_right (by decide) (Nat.lt_succ_self n)) hn - -/-- Pack the repeated bounds (`ha hb hp hp'`) into your structure in one shot. -/ -def Bounds.of_assumptions - {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] - {a b : ZMod p} - (ha : a.val < 2^n) (hb : b.val < 2^n) - (hp_succ : 2^(n+1) < p) : Bounds p n a b := - ⟨ha, hb, hp_succ, pow_lt_of_succ hp_succ⟩ - -lemma add_two_pow_sub_eq_add_diff {n a b : ℕ} (h : b ≤ a) : - a + 2 ^ n - b = 2 ^ n + (a - b) := by - calc - a + 2 ^ n - b - = (2 ^ n + a) - b := by ac_rfl - _ = (2 ^ n + a) - (b + 0) := by rfl - _ = 2 ^ n + (a - b) := by - simp only [Nat.add_zero, Nat.add_sub_assoc h] - -lemma add_two_pow_sub_eq_sub_diff {n a b : ℕ} (h : a < b) : - a + 2 ^ n - b = 2 ^ n - (b - a) := by - have hb : b = a + (b - a) := (Nat.add_sub_of_le (Nat.le_of_lt h)).symm - calc - a + 2 ^ n - b - = a + 2 ^ n - (a + (b - a)) := by rw [hb]; simp only [Nat.add_sub_cancel_left] - _ = (2 ^ n + a) - ((b - a) + a) := by ac_rfl - _ = 2 ^ n - (b - a) := by simp only [Nat.add_sub_add_right] - -lemma add_two_pow_sub_lt_pow_succ {n a b : ℕ} (ha : a < 2 ^ n) : - a + 2 ^ n - b < 2 ^ (n + 1) := by - calc - a + 2 ^ n - b ≤ a + 2 ^ n := Nat.sub_le _ _ - _ < 2 ^ n + 2 ^ n := Nat.add_lt_add_right ha _ - _ = 2 ^ (n + 1) := by rw [Nat.pow_succ, Nat.mul_two] - -lemma ZMod.val_add_two_pow_sub_rel - {p : ℕ} [Fact p.Prime] - (n : ℕ) (a b : F p) : - - if ZMod.val a < ZMod.val b then - ZMod.val a + 2 ^ n - ZMod.val b < 2 ^ n - else - ZMod.val a + 2 ^ n - ZMod.val b ≥ 2 ^ n := by - - split_ifs with h_lt - · rw [add_two_pow_sub_eq_sub_diff h_lt] - have h_pos : 0 < ZMod.val b - ZMod.val a := Nat.sub_pos_of_lt h_lt - simp_all [tsub_lt_self_iff, pow_pos] - · rw [add_two_pow_sub_eq_add_diff (le_of_not_gt h_lt)] - exact Nat.le_add_right _ _ - -lemma ZMod.val_two_pow_of_lt {p n : ℕ} [NeZero p] [Fact p.Prime] (h : 2 ^ n < p) (hp : p > 2): - (2 ^ n : ZMod p).val = 2 ^ n := by - - have p_ne_zero := NeZero.ne p - rw [ZMod.val_pow] at * - rw [← Nat.cast_ofNat] - rw [ZMod.val_natCast] - . -- (2 % p) ^ n = 2 ^ n - have h_mod := Nat.mod_eq_iff_lt (m := 2) (n := p) p_ne_zero - - have h_mod' : 2 % p = 2 := by - simp_all only [gt_iff_lt, ne_eq, iff_true] - - rw [h_mod'] - - . - rw [← Nat.cast_ofNat] - rw [ZMod.val_natCast] - - have h_mod := Nat.mod_eq_iff_lt (m := 2) (n := p) p_ne_zero - have h_mod' : 2 % p = 2 := by - simp_all only [gt_iff_lt, ne_eq, iff_true] - rw [h_mod'] - exact h - --- Helper: no wrap on a + 2^n -/- omit [Fact (p > 2)] -/ -private lemma add_two_pow_no_wrap_val {p : ℕ} [Fact (p > 2)] [Fact p.Prime] (a : ZMod p) (n : ℕ) - (ha : a.val < 2 ^ n) (hp : 2 ^ n < p) (hp' : 2 ^ (n+1) < p) : - (a + 2 ^ n).val = a.val + 2 ^ n := by - - have hp2 := Fact.out (p := p > 2) - - have h2n : (2^n: ZMod p).val = 2^n := by - exact ZMod.val_two_pow_of_lt hp hp2 - - have hlt : a.val + 2 ^ n < p := lt_trans - (by - have : a.val + 2 ^ n < 2 ^ n + 2 ^ n := Nat.add_lt_add_right ha _ - simp [pow_succ] - rw [Nat.mul_two] - exact this - ) - hp' - have hlt' : a.val + (2^n : ZMod p).val < p := by - simp_all only - - rw [ZMod.val_add_of_lt hlt'] - rw [h2n] - - -- Helper: no wrap on (a + 2^n) - b because b.val ≤ a.val + 2^n -private lemma ZMod.val_sub_add_two_pow_of_no_wrap (n : ℕ) (a b : ZMod p) - (bounds: Bounds p n a b) : - ((a + 2 ^ n) - b).val = (a.val + 2 ^ n) - b.val := by - -- First compute (a + 2^n).val without wrap - have hadd : (a + 2 ^ n).val = a.val + 2 ^ n := - add_two_pow_no_wrap_val (n:=n) a bounds.ha bounds.hp' bounds.hp - -- b.val ≤ 2^n ≤ 2^n + a.val = (a + 2^n).val - have hb_le_twoPow : b.val ≤ 2 ^ n := Nat.le_of_lt bounds.hb - have twoPow_le_sum : 2 ^ n ≤ (a.val + 2 ^ n) := by - simp [Nat.add_comm] - have hble : b.val ≤ (a.val + 2 ^ n) := le_trans hb_le_twoPow twoPow_le_sum - -- For subtraction in ZMod: if x.val ≥ y.val then (x - y).val = x.val - y.val - -- Rewrite x.val using hadd, then finish. - have hres : ((a + 2 ^ n) - b).val = (a + (2 ^ n : ℕ)).val - b.val := by - rw [ZMod.val_sub] - simp_all only [le_add_iff_nonneg_left, zero_le, Nat.cast_pow, Nat.cast_ofNat] - rw [hadd] - exact hble - - rw [hres] - simp_all only [le_add_iff_nonneg_left, zero_le, Nat.cast_pow, Nat.cast_ofNat] - -lemma ZMod.val_sub_add_two_pow_rel {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] - {a b : F p} - (R : ℕ → ℕ → Prop) (threshold : ℕ) - (bounds : Bounds p n a b) - (h_val : R (ZMod.val a + 2 ^ n - ZMod.val b) threshold) : - R (ZMod.val (a + 2 ^ n - b)) threshold := by - - have hp' : 2 ^ n < p := pow_lt_of_succ bounds.hp - - rw [ZMod.val_sub_add_two_pow_of_no_wrap n a b bounds] - exact h_val - -lemma ZMod.val_sub_add_two_pow_lt {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] - {a b : F p} : - (Bounds p n a b) -> - (ZMod.val a + 2 ^ n - ZMod.val b) < (2 ^ n) -> - (ZMod.val (a + 2 ^ n - b)) < (2 ^ n) := - ZMod.val_sub_add_two_pow_rel (· < ·) (2 ^ n) - -lemma ZMod.val_sub_add_two_pow_ge {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] - {a b : F p} : - (Bounds p n a b) -> - (ZMod.val a + 2 ^ n - ZMod.val b) ≥ (2 ^ n) -> - (ZMod.val (a + 2 ^ n - b)) ≥ (2 ^ n) := - ZMod.val_sub_add_two_pow_rel (· ≥ ·) (2 ^ n) - -lemma ZMod.val_sub_add_two_pow_no_wrap {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] - {a b : F p} : - (Bounds p n a b) -> - (ZMod.val a + 2 ^ n - ZMod.val b) < (2 ^ (n+1)) -> - (ZMod.val (a + 2 ^ n - b)) < (2 ^ (n+1)) := - ZMod.val_sub_add_two_pow_rel (· < ·) (2 ^ (n+1)) - -lemma bit_is_clear {p : ℕ} [Fact p.Prime] - (n : ℕ) (a : ZMod p) - (hlt : ZMod.val a < 2^n) : - (ZMod.val a).testBit n = false := by - rw [Nat.testBit_eq_decide_div_mod_eq] - -- ⊢ decide (a.val / 2 ^ n % 2 = 1) = false - have hbpos : 0 < 2^n := pow_pos (by decide) n - have hdiv : ZMod.val a / 2^n = 0 := - Nat.div_eq_of_lt hlt - rw [hdiv, Nat.zero_mod] - -- ⊢ decide (0 = 1) = false - simp only [zero_ne_one, decide_false] - -lemma bit_is_set {p : ℕ} [Fact p.Prime] - /- (n : ℕ) (a b : ℕ) -/ - (n : ℕ) (x : F p) - (hlt: ZMod.val x < 2^(n+1)) - (hge: ZMod.val x ≥ 2^n) : - (ZMod.val x).testBit n = true := by - rw [Nat.testBit_eq_decide_div_mod_eq ] - - simp only [decide_eq_true_eq] - /- ⊢ ZMod.val a / 2 ^ n % 2 = 1 -/ - set x := ZMod.val x - have hbpos : 0 < 2^n := pow_pos (by decide) n - - -- lower bound: 1 ≤ x / 2^n - have h1 : 1 ≤ x / 2^n := by - simp only [Nat.le_div_iff_mul_le hbpos, one_mul] - exact hge - - -- upper bound: x / 2^n < 2 - have h2 : x / 2^n < 2 := by - rw [Nat.div_lt_iff_lt_mul hbpos] - - rw [← Nat.pow_add_one'] - exact hlt - - rw [le_antisymm (Nat.lt_succ_iff.mp h2) h1] - -/-- From `<` on naturals to `<` on `.val (a + 2^n - b)` using `Bounds`. -/ -lemma zval_lt_of_nat_lt - {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] - {a b : ZMod p} - (bounds : Bounds p n a b) - (h : a.val + 2^n - b.val < 2^n) : - (a + 2^n - b).val < 2^n := - ZMod.val_sub_add_two_pow_lt bounds h - -/-- From `≥` on naturals to `≥` on `.val (a + 2^n - b)` using `Bounds`. -/ -lemma zval_ge_of_nat_ge - {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] - {a b : ZMod p} - (bounds : Bounds p n a b) - (h : a.val + 2^n - b.val ≥ 2^n) : - (a + 2^n - b).val ≥ 2^n := - ZMod.val_sub_add_two_pow_ge bounds h - -/-- If `a.val < b.val`, the nth bit of `ZMod.val (a + 2^n - b)` is `false`. -/ -lemma nth_bit_clear_of_val_lt - {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] - {a b : F p} - (bounds : Bounds p n a b) - (hab : a.val < b.val) : - (ZMod.val (a + 2^n - b)).testBit n = false := by - -- nat-level bound - have hnlt := ZMod.val_add_two_pow_sub_rel n a b - simp [hab] at hnlt - have h_lt := hnlt - -- lift through no-wrap - have hzlt := zval_lt_of_nat_lt bounds hnlt - exact bit_is_clear n (a + 2^n - b) hzlt - -/-- If `a.val ≥ b.val` and the sum is `< 2^(n+1)`, the nth bit is `true`. -/ -lemma nth_bit_set_of_val_ge - {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] - {a b : F p} - (bounds : Bounds p n a b) - (h_sum_lt : (a + 2^n - b).val < 2^(n+1)) - /- (hab : a.val ≥ b.val) : -/ - (hab : ¬(a.val < b.val)) : - (ZMod.val (a + 2^n - b)).testBit n = true := by - -- nat-level bound - have h_rel := ZMod.val_add_two_pow_sub_rel n a b - simp [hab] at h_rel - have h_ge := h_rel - -- lift through no-wrap - have hzge := zval_ge_of_nat_ge bounds h_ge - exact bit_is_set n (a + 2^n - b) h_sum_lt hzge - -lemma num2bits_bit_eq_testBit - {p n i₀ : ℕ} [Fact p.Prime] [Fact (p > 2)] - {env : Environment (F p)} - {a b : F p} - (h_holds : - Vector.map (Expression.eval env) - (Vector.mapRange (n + 1) (fun i ↦ var { index := i₀ + i })) - = fieldToBits (n + 1) (a + 2 ^ n + -b)) : - (Vector.map (Expression.eval env) - (Vector.mapRange (n + 1) (fun i ↦ var { index := i₀ + i })))[n]'(Nat.lt_succ_self n) - = - (if (ZMod.val (a + 2 ^ n + -b)).testBit n then (1 : F p) else 0) := by - simp only [Vector.ext_iff] at h_holds - specialize h_holds n (Nat.lt_succ_self n) - simp only [Vector.getElem_map, Vector.getElem_mapRange, - fieldToBits, toBits, Nat.cast_ite, Nat.cast_one, Nat.cast_zero] at h_holds - - simp only [Vector.getElem_map, Vector.getElem_mapRange] - exact h_holds - -lemma nth_bit_from_val - {p n : ℕ} [Fact p.Prime] [Fact (p > 2)] - {a b : F p} - (bounds : Bounds p n a b) - (h_holds1 : ZMod.val (a + 2 ^ n - b) < 2 ^ (n + 1)) - : (ZMod.val (a + 2 ^ n + -b)).testBit n - = decide (¬ (ZMod.val a < ZMod.val b)) := by - by_cases hab : ZMod.val a < ZMod.val b - · -- Case 1: a < b → bit n is 0 - have h_bit_clear : - (ZMod.val (a + 2 ^ n + -b)).testBit n = false := by - have h_bit := nth_bit_clear_of_val_lt (p := p) (n := n) (a := a) (b := b) bounds hab - rw [sub_eq_add_neg] at h_bit - exact h_bit - simp [h_bit_clear, hab] - · -- Case 2: a ≥ b → bit n is 1 - have h_bit_set : - (ZMod.val (a + 2 ^ n + -b)).testBit n = true := by - have h_ge : ZMod.val a ≥ ZMod.val b := by simpa [not_lt, ge_iff_le] using hab - have h_bit := nth_bit_set_of_val_ge (p := p) (n := n) (a := a) (b := b) - bounds h_holds1 hab - rw [sub_eq_add_neg] at h_bit - exact h_bit - simp [h_bit_set, hab] - -/-- In any field, `1 - [y ≤ x] = [x < y]` where brackets mean 1/0-as-a-field. -/ -lemma one_sub_if_le_eq_if_lt {F} [Field F] {x y : ℕ} : - (1 : F) + - (if y ≤ x then 1 else 0) - = (if x < y then 1 else 0) := by - by_cases h : y ≤ x - · -- then `¬ (x < y)` - have hxlt : ¬ x < y := not_lt.mpr h - simp [h, hxlt] - · -- so `x < y` - have hxlt : x < y := lt_of_not_ge h - simp [h, hxlt] - -def main (n : ℕ) (hn : 2^(n+1) < p) (input : Expression (F p) × Expression (F p)) := do - let diff := input.1 + (2^n : F p) - input.2 - let bits ← Num2Bits.circuit (n + 1) hn diff - let out <== 1 - bits[n] - return out - -def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field where - main := main n hn - localLength _ := n + 2 - localLength_eq := by simp [circuit_norm, main, Num2Bits.circuit] - output _ i := var ⟨ i + n + 1 ⟩ - output_eq := by simp +arith [circuit_norm, main, Num2Bits.circuit] - - Assumptions := fun (x, y) => x.val < 2^n ∧ y.val < 2^n -- TODO: ∧ n <= 252 - - Spec := fun (x, y) output => - output = (if x.val < y.val then 1 else 0) - - soundness := by - intro i₀ env input_var input h_input h_assumptions h_holds - unfold main at * - simp only [circuit_norm, Num2Bits.circuit] at h_holds - simp only [circuit_norm] at * --- - rw [← h_input] - - -- ① evaluate inputs - have h1 : Expression.eval env input_var.1 = input.1 := by - rw [← h_input] - have h2 : Expression.eval env input_var.2 = input.2 := by - rw [← h_input] - set a := input.1 - set b := input.2 - set hp := hn - rw [h1, h2] at h_holds - rw [h1, h2] - simp only [id_eq] - - -- ② prepare bounds - have hp' : 2 ^ n < p := pow_lt_of_succ hp - - have bounds := Bounds.of_assumptions h_assumptions.left h_assumptions.right hp - - -- ③ extract nth bit - rw [← Nat.add_assoc] at h_holds - obtain ⟨⟨h_holds1, h_holds2⟩, h_holds3⟩ := h_holds - - rw [h_holds3] - - have h_nbit := num2bits_bit_eq_testBit h_holds2 - simp only [circuit_norm] at h_nbit - rw [h_nbit] - - -- ④ core logic: bit is set iff a ≥ b - rw [← sub_eq_add_neg] at h_holds1 - have h_bit := nth_bit_from_val bounds h_holds1 - simp only [h_bit, circuit_norm] - simp only [not_lt, decide_eq_true_eq] - simpa using (one_sub_if_le_eq_if_lt (F := F p) (x := ZMod.val a) (y := ZMod.val b)) - - completeness := by - circuit_proof_start - simp only [Num2Bits.circuit] at * - simp only [circuit_norm] at * - simp_all only [gt_iff_lt, true_and, id_eq, and_true] - - obtain ⟨h_envl, h_envr⟩ := h_env - obtain ⟨ha, hb⟩ := h_assumptions - - set hp := hn - - have h1 : Expression.eval env input_var.1 = input.1 := by - rw [← h_input] - have h2 : Expression.eval env input_var.2 = input.2 := by - rw [← h_input] - - set a := input.1 - set b := input.2 - rw [h1, h2] - rw [h1, h2] at h_envl - - rw [← sub_eq_add_neg (a:=(a+ 2 ^ n))] - - have hp' : 2 ^ n < p := pow_lt_of_succ hp - - have bounds := Bounds.of_assumptions ha hb hp - - have h_sum_lt_2n1 : ZMod.val a + 2 ^ n - ZMod.val b < 2 ^ (n + 1) := - add_two_pow_sub_lt_pow_succ ha - - exact ZMod.val_sub_add_two_pow_no_wrap bounds h_sum_lt_2n1 - -end LessThan From 2e00b97c76f9c240cfe68853c967cde4b512ea86 Mon Sep 17 00:00:00 2001 From: semar Date: Fri, 10 Oct 2025 08:32:58 -0300 Subject: [PATCH 15/18] adds ForceEqualIfEnabled soundness back --- Clean/Circomlib/Comparators.lean | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/Clean/Circomlib/Comparators.lean b/Clean/Circomlib/Comparators.lean index 00b1c3a3f..6ae609c7e 100644 --- a/Clean/Circomlib/Comparators.lean +++ b/Clean/Circomlib/Comparators.lean @@ -165,8 +165,26 @@ def circuit : FormalAssertion (F p) Inputs where enabled = 1 → inp.1 = inp.2 soundness := by - simp only [circuit_norm, main] - sorry + circuit_proof_start + intro h_ie + simp_all only [gt_iff_lt, one_ne_zero, or_true, id_eq, one_mul] + cases h_input with + | intro h_enabled h_inp => + rw [← h_inp] + simp only + cases h_holds with + | intro h1 h2 => + rw [h1] at h2 + rw [add_comm] at h2 + simp only [id_eq] at h2 + split_ifs at h2 with h_ifs + . simp_all only [neg_add_cancel] + rw [add_comm, neg_add_eq_zero] at h_ifs + exact h_ifs + . simp_all only [neg_zero, zero_add, one_ne_zero] + rw [add_comm, neg_add_eq_zero] at h2 + rw [h2] at h1 + trivial completeness := by simp only [circuit_norm, main] From d7fe0505fbabc1ee55a5f845d3f5aeba7cc9aae3 Mon Sep 17 00:00:00 2001 From: semar Date: Fri, 10 Oct 2025 08:39:23 -0300 Subject: [PATCH 16/18] adds assert comment back --- Clean/Circomlib/Comparators.lean | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Clean/Circomlib/Comparators.lean b/Clean/Circomlib/Comparators.lean index 6ae609c7e..fb8d79d26 100644 --- a/Clean/Circomlib/Comparators.lean +++ b/Clean/Circomlib/Comparators.lean @@ -194,6 +194,7 @@ end ForceEqualIfEnabled namespace LessThan /- template LessThan(n) { + assert(n <= 252); signal input in[2]; signal output out; @@ -514,7 +515,7 @@ def circuit (n : ℕ) (hn : 2^(n+1) < p) : FormalCircuit (F p) fieldPair field w output _ i := var ⟨ i + n + 1 ⟩ output_eq := by simp +arith [circuit_norm, main, Num2Bits.circuit] - Assumptions := fun (x, y) => x.val < 2^n ∧ y.val < 2^n -- TODO: ∧ n <= 252 + Assumptions := fun (x, y) => x.val < 2^n ∧ y.val < 2^n Spec := fun (x, y) output => output = (if x.val < y.val then 1 else 0) From af155302773eedc5a75edd2b4cfba5b09924b5f5 Mon Sep 17 00:00:00 2001 From: semar Date: Fri, 10 Oct 2025 08:51:30 -0300 Subject: [PATCH 17/18] reverts changes to CircuitProofStart --- Clean/Utils/Tactics/CircuitProofStart.lean | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/Clean/Utils/Tactics/CircuitProofStart.lean b/Clean/Utils/Tactics/CircuitProofStart.lean index d744f6324..dadac0a65 100644 --- a/Clean/Utils/Tactics/CircuitProofStart.lean +++ b/Clean/Utils/Tactics/CircuitProofStart.lean @@ -107,11 +107,10 @@ elab_rules : tactic circuitProofStartCore -- try to unfold main, Assumptions and Spec as local definitions - evalTactic (← `(tactic| try dsimp only [$(mkIdent `Assumptions):ident] at *)) - evalTactic (← `(tactic| try dsimp only [$(mkIdent `Spec):ident] at *)) - evalTactic (← `(tactic| try dsimp only [$(mkIdent `elaborated):ident] at *)) -- sometimes `main` is hidden behind `elaborated` - evalTactic (← `(tactic| try dsimp only [$(mkIdent `main):ident] at *)) - + try (evalTactic (← `(tactic| unfold $(mkIdent `Assumptions):ident at *))) catch _ => pure () + try (evalTactic (← `(tactic| unfold $(mkIdent `Spec):ident at *))) catch _ => pure () + try (evalTactic (← `(tactic| unfold $(mkIdent `elaborated):ident at *))) catch _ => pure () -- sometimes `main` is hidden behind `elaborated` + try (evalTactic (← `(tactic| unfold $(mkIdent `main):ident at *))) catch _ => pure () -- simplify structs / eval first try (evalTactic (← `(tactic| provable_struct_simp))) catch _ => pure () From 1750271bdc9b4824f23fbd65ffc20a57e55545f8 Mon Sep 17 00:00:00 2001 From: semar Date: Fri, 10 Oct 2025 10:37:42 -0300 Subject: [PATCH 18/18] reverts changes to CircuitProofStart --- Clean/Utils/Tactics/CircuitProofStart.lean | 1 + 1 file changed, 1 insertion(+) diff --git a/Clean/Utils/Tactics/CircuitProofStart.lean b/Clean/Utils/Tactics/CircuitProofStart.lean index dadac0a65..c993138a3 100644 --- a/Clean/Utils/Tactics/CircuitProofStart.lean +++ b/Clean/Utils/Tactics/CircuitProofStart.lean @@ -111,6 +111,7 @@ elab_rules : tactic try (evalTactic (← `(tactic| unfold $(mkIdent `Spec):ident at *))) catch _ => pure () try (evalTactic (← `(tactic| unfold $(mkIdent `elaborated):ident at *))) catch _ => pure () -- sometimes `main` is hidden behind `elaborated` try (evalTactic (← `(tactic| unfold $(mkIdent `main):ident at *))) catch _ => pure () + -- simplify structs / eval first try (evalTactic (← `(tactic| provable_struct_simp))) catch _ => pure ()