Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/Lean/Meta/Tactic/Grind/Canon.lean
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ private def normOfNatArgs? (args : Array Expr) : MetaM (Option (Array Expr)) :=
@[export lean_grind_canon]
partial def canonImpl (e : Expr) : GoalM Expr := do profileitM Exception "grind canon" (← getOptions) do
trace_goal[grind.debug.canon] "{e}"
visit e |>.run' {}
let (r, cache') ← (visit e).run (← get').visitCache
modify' fun s => { s with visitCache := cache' }
return r
where
visit (e : Expr) : StateRefT (Std.HashMap ExprPtr Expr) GoalM Expr := do
unless e.isApp || e.isForall do return e
Expand Down
18 changes: 15 additions & 3 deletions src/Lean/Meta/Tactic/Grind/MarkNestedSubsingletons.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ import Lean.Meta.Tactic.Grind.Util
public section
namespace Lean.Meta.Grind

/-- Cached variant of `Sym.unfoldReducible` that persists the `transformWithCache` cache across calls,
ensuring pointer stability for shared sub-expressions. -/
def unfoldReducibleCached (e : Expr) : GrindM Expr := do
let cache := (← get).unfoldReducibleCache
let (e', cache') ← Meta.transformWithCache e cache (pre := fun e => Sym.unfoldReducibleStep e)
modify fun s => { s with unfoldReducibleCache := cache' }
return e'

private abbrev M := StateRefT (Std.HashMap ExprPtr Expr) GrindM

def isMarkedSubsingletonConst (e : Expr) : Bool := Id.run do
Expand Down Expand Up @@ -41,7 +49,9 @@ Recall that the congruence closure module has special support for them.
-- TODO: consider other subsingletons in the future? We decided to not support them to avoid the overhead of
-- synthesizing `Subsingleton` instances.
partial def markNestedSubsingletons (e : Expr) : GrindM Expr := do profileitM Exception "grind mark subsingleton" (← getOptions) do
visit e |>.run' {}
let (r, cache') ← (visit e).run (← get).markSubsingletonCache
modify fun s => { s with markSubsingletonCache := cache' }
return r
where
visit (e : Expr) : M Expr := do
if isMarkedSubsingletonApp e then
Expand Down Expand Up @@ -103,7 +113,7 @@ where
-/
/- We must also apply beta-reduction to improve the effectiveness of the congruence closure procedure. -/
let e ← Core.betaReduce e
let e ← Sym.unfoldReducible e
let e ← unfoldReducibleCached e
/- We must mask proofs occurring in `prop` too. -/
let e ← visit e
let e ← eraseIrrelevantMData e
Expand All @@ -123,6 +133,8 @@ def markProof (e : Expr) : GrindM Expr := do
if e.isAppOf ``Grind.nestedProof then
return e -- `e` is already marked
else
markNestedProof e |>.run' {}
let (r, cache') ← (markNestedProof e).run (← get).markSubsingletonCache
modify fun s => { s with markSubsingletonCache := cache' }
return r

end Lean.Meta.Grind
6 changes: 4 additions & 2 deletions src/Lean/Meta/Tactic/Grind/Simp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ def preprocessImpl (e : Expr) : GoalM Simp.Result := do
let e' ← instantiateMVars r.expr
-- Remark: `simpCore` unfolds reducible constants, but it does not consistently visit all possible subterms.
-- So, we must use the following `unfoldReducible` step. It is non-op in most cases
let e' ← Sym.unfoldReducible e'
let e' ← unfoldReducibleCached e'
let e' ← abstractNestedProofs e'
let e' ← shareCommon e'
let e' ← markNestedSubsingletons e'
let e' ← eraseIrrelevantMData e'
let e' ← foldProjs e'
Expand All @@ -70,6 +71,7 @@ def preprocessImpl (e : Expr) : GoalM Simp.Result := do
let r' ← replacePreMatchCond e'
let r ← r.mkEqTrans r'
let e' := r'.expr
let e' ← shareCommon e'
let e' ← canon e'
let e' ← shareCommon e'
trace_goal[grind.simp] "{e}\n===>\n{e'}"
Expand Down Expand Up @@ -98,6 +100,6 @@ but ensures assumptions made by `grind` are satisfied.
-/
def preprocessLight (e : Expr) : GoalM Expr := do
let e ← instantiateMVars e
shareCommon (← canon (← normalizeLevels (← foldProjs (← eraseIrrelevantMData (← markNestedSubsingletons (← Sym.unfoldReducible e))))))
shareCommon (← canon (← normalizeLevels (← foldProjs (← eraseIrrelevantMData (← markNestedSubsingletons (← unfoldReducibleCached e))))))

end Lean.Meta.Grind
6 changes: 6 additions & 0 deletions src/Lean/Meta/Tactic/Grind/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ structure State where
Cached anchors (aka stable hash codes) for terms in the `grind` state.
-/
anchors : PHashMap ExprPtr UInt64 := {}
/-- Persistent cache for `markNestedSubsingletons` and `markProof` traversals. -/
markSubsingletonCache : Std.HashMap ExprPtr Expr := {}
/-- Persistent cache for `unfoldReducible` via `transformWithCache`, ensuring pointer stability. -/
unfoldReducibleCache : Std.HashMap ExprStructEq Expr := {}

instance : Nonempty State :=
.intro {}
Expand Down Expand Up @@ -712,6 +716,8 @@ structure Canon.State where
canon : PHashMap Expr Expr := {}
proofCanon : PHashMap Expr Expr := {}
canonArg : PHashMap CanonArgKey Expr := {}
/-- Persistent cache for `canonImpl` visit traversals. -/
visitCache : Std.HashMap ExprPtr Expr := {}
deriving Inhabited

/-- Trace information for a case split. -/
Expand Down
45 changes: 45 additions & 0 deletions tests/lean/run/grind_add_sub_cancel_fvar.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import Lean

/-!
Regression test: `grind` on nested Nat subtraction chains.

After `simp only [Goal, loop]`, the goal becomes:
```
post (s₁ + s₂ - s₂ + s₂ - s₂ + ⋯ + s₂ - s₂) s₂
```
with `n` nested `(+ s₂ - s₂)` operations. Grind handles each Nat subtraction by
generating a `natCast_sub` fact, processed inside-out. The i-th fact has expression
size O(i). Before persistent caches were added, the preprocessing steps
`markNestedSubsingletons` and `canonImpl` used fresh per-call caches (traversing
all O(i) subexpressions each time), giving O(n²) total work.

This test checks that the scaling from n=25 to n=100 (4x) is at most 10x in wall-clock
time, which passes for ≤ O(n^1.7) but fails for O(n²) (which would give ~16x).
-/

def loop : Nat → (Nat → Nat → Prop) → (Nat → Nat → Prop)
| 0, post => post
| n+1, post => fun s₁ s₂ => loop n post (s₁ + s₂ - s₂) s₂
def Goal (n : Nat) : Prop := ∀ post s₁ s₂, post s₁ s₂ → loop n post s₁ s₂

set_option maxRecDepth 10000
set_option maxHeartbeats 10000000

open Lean Elab Command in
elab "#test_grind_scaling" : command => do
let solveAt (n : Nat) : CommandElabM Nat := do
let start ← IO.monoNanosNow
elabCommand (← `(command|
example : Goal $(Syntax.mkNumLit (toString n)) := by intros; simp only [Goal, loop]; grind
))
let stop ← IO.monoNanosNow
return stop - start
let t_small ← solveAt 25
let t_large ← solveAt 100
let ratio := t_large.toFloat / t_small.toFloat
-- Linear: expect ~4x for 4x problem size. Quadratic: would be ~16x.
-- Use 10x as threshold (generous for noise, catches quadratic).
if ratio > 10.0 then
throwError "grind preprocessing scaling regression: 100/25 time ratio is {ratio}x (expected < 10x for linear scaling)"

#test_grind_scaling
Loading