diff --git a/src/Lean/Meta/Tactic/Grind/Canon.lean b/src/Lean/Meta/Tactic/Grind/Canon.lean index 9cfe6e3960f8..582331b73b9c 100644 --- a/src/Lean/Meta/Tactic/Grind/Canon.lean +++ b/src/Lean/Meta/Tactic/Grind/Canon.lean @@ -242,7 +242,9 @@ private def normOfNatArgs? (args : Array Expr) : MetaM (Option (Array Expr)) := @[export lean_grind_canon] partial def canonImpl (e : Expr) : GoalM Expr := do profileitM Exception "grind canon" (← getOptions) do trace_goal[grind.debug.canon] "{e}" - visit e |>.run' {} + let (r, cache') ← (visit e).run (← get').visitCache + modify' fun s => { s with visitCache := cache' } + return r where visit (e : Expr) : StateRefT (Std.HashMap ExprPtr Expr) GoalM Expr := do unless e.isApp || e.isForall do return e diff --git a/src/Lean/Meta/Tactic/Grind/MarkNestedSubsingletons.lean b/src/Lean/Meta/Tactic/Grind/MarkNestedSubsingletons.lean index b2b4d323417a..66be423cabf3 100644 --- a/src/Lean/Meta/Tactic/Grind/MarkNestedSubsingletons.lean +++ b/src/Lean/Meta/Tactic/Grind/MarkNestedSubsingletons.lean @@ -12,6 +12,14 @@ import Lean.Meta.Tactic.Grind.Util public section namespace Lean.Meta.Grind +/-- Cached variant of `Sym.unfoldReducible` that persists the `transformWithCache` cache across calls, +ensuring pointer stability for shared sub-expressions. -/ +def unfoldReducibleCached (e : Expr) : GrindM Expr := do + let cache := (← get).unfoldReducibleCache + let (e', cache') ← Meta.transformWithCache e cache (pre := fun e => Sym.unfoldReducibleStep e) + modify fun s => { s with unfoldReducibleCache := cache' } + return e' + private abbrev M := StateRefT (Std.HashMap ExprPtr Expr) GrindM def isMarkedSubsingletonConst (e : Expr) : Bool := Id.run do @@ -41,7 +49,9 @@ Recall that the congruence closure module has special support for them. -- TODO: consider other subsingletons in the future? We decided to not support them to avoid the overhead of -- synthesizing `Subsingleton` instances. partial def markNestedSubsingletons (e : Expr) : GrindM Expr := do profileitM Exception "grind mark subsingleton" (← getOptions) do - visit e |>.run' {} + let (r, cache') ← (visit e).run (← get).markSubsingletonCache + modify fun s => { s with markSubsingletonCache := cache' } + return r where visit (e : Expr) : M Expr := do if isMarkedSubsingletonApp e then @@ -103,7 +113,7 @@ where -/ /- We must also apply beta-reduction to improve the effectiveness of the congruence closure procedure. -/ let e ← Core.betaReduce e - let e ← Sym.unfoldReducible e + let e ← unfoldReducibleCached e /- We must mask proofs occurring in `prop` too. -/ let e ← visit e let e ← eraseIrrelevantMData e @@ -123,6 +133,8 @@ def markProof (e : Expr) : GrindM Expr := do if e.isAppOf ``Grind.nestedProof then return e -- `e` is already marked else - markNestedProof e |>.run' {} + let (r, cache') ← (markNestedProof e).run (← get).markSubsingletonCache + modify fun s => { s with markSubsingletonCache := cache' } + return r end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Simp.lean b/src/Lean/Meta/Tactic/Grind/Simp.lean index 1c8f2f86b806..e940826f9908 100644 --- a/src/Lean/Meta/Tactic/Grind/Simp.lean +++ b/src/Lean/Meta/Tactic/Grind/Simp.lean @@ -58,8 +58,9 @@ def preprocessImpl (e : Expr) : GoalM Simp.Result := do let e' ← instantiateMVars r.expr -- Remark: `simpCore` unfolds reducible constants, but it does not consistently visit all possible subterms. -- So, we must use the following `unfoldReducible` step. It is non-op in most cases - let e' ← Sym.unfoldReducible e' + let e' ← unfoldReducibleCached e' let e' ← abstractNestedProofs e' + let e' ← shareCommon e' let e' ← markNestedSubsingletons e' let e' ← eraseIrrelevantMData e' let e' ← foldProjs e' @@ -70,6 +71,7 @@ def preprocessImpl (e : Expr) : GoalM Simp.Result := do let r' ← replacePreMatchCond e' let r ← r.mkEqTrans r' let e' := r'.expr + let e' ← shareCommon e' let e' ← canon e' let e' ← shareCommon e' trace_goal[grind.simp] "{e}\n===>\n{e'}" @@ -98,6 +100,6 @@ but ensures assumptions made by `grind` are satisfied. -/ def preprocessLight (e : Expr) : GoalM Expr := do let e ← instantiateMVars e - shareCommon (← canon (← normalizeLevels (← foldProjs (← eraseIrrelevantMData (← markNestedSubsingletons (← Sym.unfoldReducible e)))))) + shareCommon (← canon (← normalizeLevels (← foldProjs (← eraseIrrelevantMData (← markNestedSubsingletons (← unfoldReducibleCached e)))))) end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index 6f33c9117487..78d644b2d9b7 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -232,6 +232,10 @@ structure State where Cached anchors (aka stable hash codes) for terms in the `grind` state. -/ anchors : PHashMap ExprPtr UInt64 := {} + /-- Persistent cache for `markNestedSubsingletons` and `markProof` traversals. -/ + markSubsingletonCache : Std.HashMap ExprPtr Expr := {} + /-- Persistent cache for `unfoldReducible` via `transformWithCache`, ensuring pointer stability. -/ + unfoldReducibleCache : Std.HashMap ExprStructEq Expr := {} instance : Nonempty State := .intro {} @@ -712,6 +716,8 @@ structure Canon.State where canon : PHashMap Expr Expr := {} proofCanon : PHashMap Expr Expr := {} canonArg : PHashMap CanonArgKey Expr := {} + /-- Persistent cache for `canonImpl` visit traversals. -/ + visitCache : Std.HashMap ExprPtr Expr := {} deriving Inhabited /-- Trace information for a case split. -/ diff --git a/tests/lean/run/grind_add_sub_cancel_fvar.lean b/tests/lean/run/grind_add_sub_cancel_fvar.lean new file mode 100644 index 000000000000..6a6ef7362e46 --- /dev/null +++ b/tests/lean/run/grind_add_sub_cancel_fvar.lean @@ -0,0 +1,45 @@ +import Lean + +/-! +Regression test: `grind` on nested Nat subtraction chains. + +After `simp only [Goal, loop]`, the goal becomes: +``` +post (s₁ + s₂ - s₂ + s₂ - s₂ + ⋯ + s₂ - s₂) s₂ +``` +with `n` nested `(+ s₂ - s₂)` operations. Grind handles each Nat subtraction by +generating a `natCast_sub` fact, processed inside-out. The i-th fact has expression +size O(i). Before persistent caches were added, the preprocessing steps +`markNestedSubsingletons` and `canonImpl` used fresh per-call caches (traversing +all O(i) subexpressions each time), giving O(n²) total work. + +This test checks that the scaling from n=25 to n=100 (4x) is at most 10x in wall-clock +time, which passes for ≤ O(n^1.7) but fails for O(n²) (which would give ~16x). +-/ + +def loop : Nat → (Nat → Nat → Prop) → (Nat → Nat → Prop) + | 0, post => post + | n+1, post => fun s₁ s₂ => loop n post (s₁ + s₂ - s₂) s₂ +def Goal (n : Nat) : Prop := ∀ post s₁ s₂, post s₁ s₂ → loop n post s₁ s₂ + +set_option maxRecDepth 10000 +set_option maxHeartbeats 10000000 + +open Lean Elab Command in +elab "#test_grind_scaling" : command => do + let solveAt (n : Nat) : CommandElabM Nat := do + let start ← IO.monoNanosNow + elabCommand (← `(command| + example : Goal $(Syntax.mkNumLit (toString n)) := by intros; simp only [Goal, loop]; grind + )) + let stop ← IO.monoNanosNow + return stop - start + let t_small ← solveAt 25 + let t_large ← solveAt 100 + let ratio := t_large.toFloat / t_small.toFloat + -- Linear: expect ~4x for 4x problem size. Quadratic: would be ~16x. + -- Use 10x as threshold (generous for noise, catches quadratic). + if ratio > 10.0 then + throwError "grind preprocessing scaling regression: 100/25 time ratio is {ratio}x (expected < 10x for linear scaling)" + +#test_grind_scaling