Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions CONTRIBUTION.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Contribution: True O(1) Stateful Loop Engine

## Architectural Change

The existing latent loop implementation rebuilds SSM state from scratch every iteration:

```python
# Original — O(n) per loop, n grows each step
for lp in range(MAX_LOOPS):
toks = tok(prompt + "=" * lp, ...) # re-tokenize expanding string
h = model(**toks, ...).hidden_states[-1] # full forward pass on entire sequence
```

This is functionally equivalent to pause tokens with a growing prompt. The SSM state
is rebuilt from scratch each time — no recurrent state is carried forward.

The new `stateful_engine.py` uses MambaCache for true O(1) recurrent iteration:

```python
# Stateful — O(1) per loop, constant regardless of history
out = model(input_ids=prompt_ids, use_cache=True, ...) # prefill once
cache = out.cache_params

for lp in range(MAX_LOOPS):
step = model(input_ids=spacer, cache_params=cache, # single-token step
cache_position=pos, use_cache=True, ...)
h = step.hidden_states[-1][0, -1, :] # read from cache
```

Each loop is a single-token recurrent step. Sequence length never grows.
Memory usage is constant. This is the correct way to use an SSM recurrently.

## Key API Finding

The plan assumed the standard `past_key_values` transformer API. Mamba uses a
different interface:

- `cache_params` (not `past_key_values`) — passes MambaCache to model
- `cache_position` — **required** when passing cache manually; shape determines
prefill (shape=conv_kernel) vs decode (shape=1) mode
- Cache is updated **in-place** — same object, mutated

See `docs/cache_api_findings.md` for full details.

## Results

### Mamba-2.8B (CPU, 64 layers, 2560 hidden)

| Approach | Avg Loop ms | LLPS | Speedup |
|----------|-------------|------|---------|
| Original (re-tokenize) | 1100.71 | 0.9 | — |
| Stateful cache | 468.79 | 2.1 | **2.35x** |

### Mamba-130M (CPU, 24 layers, 768 hidden)

| Approach | Avg Loop ms | LLPS | Speedup |
|----------|-------------|------|---------|
| Original (re-tokenize) | 103.43 | 9.7 | — |
| Stateful cache | 32.64 | 30.6 | **3.17x** |

### Correctness

- **Prefill match**: Loop 0 hidden states identical (cosine sim = 1.0000)
- **Generate from cache**: Works without fallback (no kill switch triggered)
- **ACT proportionality**: Hard prompt h-delta (50.6) > Easy (19.2) — 2.6x ratio

On GPU with CUDA kernels, the speedup should be significantly higher because
the original approach's total cost is O(loops^2) while the stateful approach
is O(loops). The CUDA selective scan kernel also specifically optimizes the
`seq_len=1` decode path.

## Files Changed

| File | Status | Description |
|------|--------|-------------|
| `stateful_engine.py` | **NEW** | O(1) StatefulLoopEngine implementation |
| `validate_stateful.py` | **NEW** | Phase 2 correctness validation script |
| `benchmark_llps.py` | **NEW** | Phase 3 LLPS benchmark script |
| `session_memory.py` | **MODIFIED** | `latent_turn()` upgraded to O(1) cache iteration |
| `mamba_engine.py` | **UNCHANGED** | Original training engine preserved |
| `docs/cache_api_findings.md` | **NEW** | Phase 0 MambaCache API documentation |
| `docs/correctness_validation.md` | **NEW** | Phase 2 comparison results |
| `docs/llps_benchmark.md` | **NEW** | Phase 3 latency measurements |
| `docs/blockers.md` | **NEW** | Kill switch status and environment blockers |

## What Remains

### Requires Fine-Tuned Checkpoint + GPU

- [ ] Proof 3 validation: variable tracking W=8
- [ ] ACT proportionality with HaltingHead loop counts
- [ ] Kill-shot ablation (full run vs 2-loop lobotomy)
- [ ] GPU LLPS benchmark with 2.8B model (expected >10x speedup)

### Future Work (from plan)

- [ ] **NPU HaltingHead dispatch**: Move halting decision to NPU for latency hiding
- [ ] **State delta encoding**: Save only the delta between cache states for more
compact session cartridges (currently ~5MB full cache, could be <1KB deltas)
- [ ] **Batched iteration**: Process multiple conversations simultaneously using
`max_batch_size > 1` in MambaCache
- [ ] **torch.compile integration**: MambaCache marks tensors as static addresses;
should be compatible with `torch.compile` for additional speedup
222 changes: 222 additions & 0 deletions benchmark_llps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
"""
benchmark_llps.py — Latent Loops Per Second (LLPS) Benchmark
=============================================================
Phase 3: Measure loop latency: stateful cache vs original re-tokenize.

Usage:
python benchmark_llps.py [engine_dir] [--runs N] [--loops N]
"""

import torch
import time
import sys
import os
import statistics
from transformers import AutoTokenizer, AutoModelForCausalLM


def benchmark_original(model, tok, prompt, max_loops, device, runs=20):
"""Benchmark original re-tokenize approach."""
all_loop_times = []

for run in range(runs):
with torch.no_grad():
for lp in range(max_loops):
text = prompt + "=" * lp
toks = tok(text, return_tensors="pt",
truncation=True, max_length=256)
input_ids = toks.input_ids.to(device)

t0 = time.perf_counter()
out = model(input_ids=input_ids, output_hidden_states=True)
_ = out.hidden_states[-1][0, -1, :]
elapsed = (time.perf_counter() - t0) * 1000
all_loop_times.append(elapsed)

return all_loop_times


def benchmark_stateful(model, tok, prompt, spacer_id, max_loops, device, runs=20):
"""Benchmark stateful cache approach."""
all_loop_times = []

for run in range(runs):
with torch.no_grad():
# Prefill (not counted in loop latency)
toks = tok(prompt, return_tensors="pt",
truncation=True, max_length=256)
input_ids = toks.input_ids.to(device)
seq_len = input_ids.shape[1]

out = model(input_ids=input_ids, use_cache=True,
output_hidden_states=True)
cache = out.cache_params

# Measure loop iterations only
spacer = torch.tensor([[spacer_id]], device=device)
for lp in range(max_loops):
cache_pos = torch.tensor([seq_len + lp], device=device)

t0 = time.perf_counter()
step_out = model(
input_ids=spacer,
cache_params=cache,
cache_position=cache_pos,
use_cache=True,
output_hidden_states=True
)
_ = step_out.hidden_states[-1][0, -1, :]
elapsed = (time.perf_counter() - t0) * 1000
all_loop_times.append(elapsed)

return all_loop_times


def compute_stats(times):
"""Compute benchmark statistics."""
times_sorted = sorted(times)
n = len(times_sorted)
return {
"avg": statistics.mean(times),
"median": statistics.median(times),
"p95": times_sorted[int(n * 0.95)] if n > 20 else times_sorted[-1],
"p99": times_sorted[int(n * 0.99)] if n > 100 else times_sorted[-1],
"min": min(times),
"max": max(times),
"stdev": statistics.stdev(times) if n > 1 else 0,
"n": n,
}


def main():
engine_dir = "checkpoints/mamba-2.8b-latent"
runs = 20
max_loops = 7

args = sys.argv[1:]
skip_next = False
for i, arg in enumerate(args):
if skip_next:
skip_next = False
continue
if arg == "--runs" and i + 1 < len(args):
runs = int(args[i + 1])
skip_next = True
elif arg == "--loops" and i + 1 < len(args):
max_loops = int(args[i + 1])
skip_next = True
elif not arg.startswith("--"):
engine_dir = arg

if not os.path.isdir(engine_dir):
engine_dir = "state-spaces/mamba-130m-hf"
print(f"[INFO] Checkpoint not found, using base model: {engine_dir}")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INIT] Loading {engine_dir} on {device}...")

tok = AutoTokenizer.from_pretrained(engine_dir, trust_remote_code=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token

model = AutoModelForCausalLM.from_pretrained(
engine_dir,
dtype=torch.bfloat16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None,
trust_remote_code=True
)
model.eval()

spacer_id = tok.convert_tokens_to_ids("=")
prompt = "[LOGIC] X=5. Y=X*2. Z=Y+3. W=Z-X. Output W. ===="

print(f"[CONFIG] runs={runs}, max_loops={max_loops}, device={device}")
print(f"[CONFIG] prompt tokens: {len(tok(prompt).input_ids)}")

# Warmup
print("\n[WARMUP] Running warmup passes...")
with torch.no_grad():
toks = tok(prompt, return_tensors="pt").to(device)
for _ in range(3):
model(input_ids=toks.input_ids, output_hidden_states=True)

# Benchmark original
print(f"\n[BENCH] Original (re-tokenize): {runs} runs x {max_loops} loops...")
orig_times = benchmark_original(model, tok, prompt, max_loops, device, runs)
orig_stats = compute_stats(orig_times)

# Benchmark stateful
print(f"[BENCH] Stateful (cache step): {runs} runs x {max_loops} loops...")
stat_times = benchmark_stateful(model, tok, prompt, spacer_id,
max_loops, device, runs)
stat_stats = compute_stats(stat_times)

# Results
banner = "=" * 70
print(f"\n{banner}")
print(f" LLPS BENCHMARK RESULTS")
print(f" Model: {engine_dir}")
print(f" Device: {device}")
print(f" Runs: {runs}, Loops per run: {max_loops}")
print(f"{banner}\n")

print(f" {'Metric':<20} | {'Original (ms)':<16} | {'Stateful (ms)':<16} | Speedup")
print(f" {'-'*20}-+-{'-'*16}-+-{'-'*16}-+--------")
for metric in ["avg", "median", "p95", "min", "max", "stdev"]:
o = orig_stats[metric]
s = stat_stats[metric]
speedup = o / s if s > 0 else float('inf')
print(f" {metric:<20} | {o:>14.2f} | {s:>14.2f} | {speedup:>5.2f}x")

orig_llps = 1000 / orig_stats["avg"] if orig_stats["avg"] > 0 else 0
stat_llps = 1000 / stat_stats["avg"] if stat_stats["avg"] > 0 else 0
print(f"\n Original LLPS: {orig_llps:>8.1f} loops/sec")
print(f" Stateful LLPS: {stat_llps:>8.1f} loops/sec")
print(f" Throughput gain: {stat_llps/orig_llps:.2f}x" if orig_llps > 0 else "")

print(f"\n Samples: original={orig_stats['n']}, stateful={stat_stats['n']}")
print(f"{banner}\n")

# Write results to file
results_path = "docs/llps_benchmark.md"
with open(results_path, "w") as f:
f.write("# Phase 3: LLPS Benchmark Results\n\n")
f.write(f"## Environment\n\n")
f.write(f"- **Model**: {engine_dir}\n")
f.write(f"- **Device**: {device}\n")
f.write(f"- **Runs**: {runs}\n")
f.write(f"- **Loops per run**: {max_loops}\n")
f.write(f"- **Prompt tokens**: {len(tok(prompt).input_ids)}\n\n")
f.write(f"## Results\n\n")
f.write(f"| Approach | Avg Loop ms | Median ms | p95 ms | LLPS | Notes |\n")
f.write(f"|----------|------------|-----------|--------|------|-------|\n")
f.write(f"| Original (re-tokenize) | {orig_stats['avg']:.2f} | {orig_stats['median']:.2f} | {orig_stats['p95']:.2f} | {orig_llps:.1f} | Sequence grows each loop |\n")
f.write(f"| Stateful cache | {stat_stats['avg']:.2f} | {stat_stats['median']:.2f} | {stat_stats['p95']:.2f} | {stat_llps:.1f} | Single-token recurrent step |\n\n")
f.write(f"**Speedup: {orig_stats['avg']/stat_stats['avg']:.2f}x** (avg latency)\n\n")
f.write(f"**Throughput gain: {stat_llps/orig_llps:.2f}x** (LLPS)\n\n")
f.write(f"## Detailed Statistics\n\n")
f.write(f"| Metric | Original (ms) | Stateful (ms) | Speedup |\n")
f.write(f"|--------|--------------|---------------|--------|\n")
for metric in ["avg", "median", "p95", "min", "max", "stdev"]:
o = orig_stats[metric]
s = stat_stats[metric]
sp = o / s if s > 0 else float('inf')
f.write(f"| {metric} | {o:.2f} | {s:.2f} | {sp:.2f}x |\n")
f.write(f"| n (samples) | {orig_stats['n']} | {stat_stats['n']} | — |\n")
f.write(f"\n## Analysis\n\n")
f.write(f"The stateful approach processes a single token per iteration "
f"(O(1) per step), while the original re-tokenizes the entire "
f"prompt + spacers each loop (O(n) per step where n grows).\n\n")
f.write(f"On GPU with the 2.8B model, the speedup should be significantly "
f"larger because:\n")
f.write(f"1. The original approach's cost scales with prompt length "
f"(more tokens = more compute)\n")
f.write(f"2. The stateful approach is constant regardless of prompt length\n")
f.write(f"3. GPU kernel launch overhead is amortized better with "
f"single-token steps\n")

print(f"[DONE] Results written to {results_path}")


if __name__ == "__main__":
main()
29 changes: 29 additions & 0 deletions docs/blockers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Blockers

## No GPU Available

- **Phase**: All
- **Impact**: Cannot measure GPU-specific latency; no CUDA fast-path kernels
- **Mitigation**: Both 2.8B and 130M models tested on CPU (slow path). API correctness confirmed, benchmarks captured.
- **Resolution**: Re-run benchmarks on GPU for production-representative numbers

## No Fine-Tuned Checkpoint

- **Phase**: 2, 3
- **Impact**: Cannot validate Proof 3 (W=8), ACT proportionality with HaltingHead
- **Mitigation**: Structural correctness confirmed; checkpoint-dependent tests documented as pending
- **Resolution**: Run `validate_stateful.py` and `benchmark_llps.py` with `checkpoints/mamba-2.8b-latent`

## transformers 5.3.0 API Change

- **Phase**: 0
- **Impact**: `MambaCache` moved from `transformers.cache_utils` to `transformers.models.mamba.modeling_mamba`
- **Resolution**: Import via `from transformers import MambaCache` (auto-import works)
- **Note**: The plan assumed `past_key_values` API; actual Mamba API uses `cache_params` + `cache_position`

## No Kill Switches Triggered

All kill switch conditions in the plan were avoided:
- `use_cache=True` works correctly
- `generate()` accepts `cache_params` with pre-built cache
- MambaCache exposes `conv_states` and `ssm_states` as public attributes
Loading
Loading