Skip to content

Speculative Decoding Guide

Behnam Ebrahimi edited this page Mar 29, 2026 · 1 revision

Speculative Decoding Guide

Speculative decoding is an experimental feature that can provide an additional 2-3x speedup on top of batched decoding.

Concept

The idea is simple: use a small, fast model to "draft" tokens, then verify them with a large, accurate model in a single forward pass.

Traditional decoding:
  Token 1 → Forward Pass → Token 2 → Forward Pass → Token 3 → ...
  (1 forward pass per token)

Speculative decoding:
  Draft model: Token 1, 2, 3, 4, 5 (very fast)
  Target model: Verify all 5 in ONE forward pass
  (1 forward pass per 5 tokens, if all accepted)

When the draft model's predictions match the target model's, you get multiple tokens for the cost of one forward pass. The output is identical to running the target model alone — speculative decoding never sacrifices accuracy.

Basic Usage

from whisper_mlx import SpeculativeDecoder

decoder = SpeculativeDecoder(
    draft_model_path="mlx-community/whisper-tiny-mlx",
    target_model_path="mlx-community/whisper-large-v3-mlx",
)

result = decoder.decode_segment(mel_spectrogram, language="en")
print(result["text"])

Using speculative_transcribe()

For a complete transcription pipeline with speculative decoding:

from whisper_mlx.speculative import speculative_transcribe

result = speculative_transcribe(
    "audio.mp3",
    draft_model="tiny",
    target_model="large-v3",
    language="en",
    verbose=True,
)

print(result["text"])
print(f"Speedup: {result['stats']['speedup_factor']:.1f}x")
print(f"Acceptance rate: {result['stats']['acceptance_rate']:.0%}")
print(f"Elapsed: {result['elapsed']:.1f}s")

Parameters

SpeculativeDecoder

Parameter Default Description
draft_model_path "mlx-community/whisper-tiny-mlx" Fast model for drafting tokens
target_model_path "mlx-community/whisper-large-v3-mlx" Accurate model for verification
num_draft_tokens 5 Tokens to draft before verification
dtype mx.float16 Data type for computation

speculative_transcribe()

Parameter Default Description
audio File path or audio array
draft_model "tiny" Draft model name
target_model "large-v3" Target model name
language "en" Language code
verbose True Print progress

How Verification Works

  1. The draft model generates num_draft_tokens tokens autoregressively (fast, since it's a tiny model)
  2. The target model receives ALL draft tokens and produces logits in one forward pass
  3. At each position, if the target model agrees with the draft → token is accepted
  4. On first disagreement → the target model's token is used instead, and drafting resumes
  5. Statistics track acceptance rate and speedup factor

Performance Statistics

stats = decoder.get_stats()
print(stats)
# {
#     "draft_tokens": 1250,        # Total tokens drafted
#     "accepted_tokens": 1050,     # Tokens verified correct
#     "total_forward_passes": 280, # Target model forward passes
#     "acceptance_rate": 0.84,     # 84% of drafts accepted
#     "speedup_factor": 4.46       # ~4.5 tokens per forward pass
# }
  • Acceptance rate: Higher is better. Depends on how well the draft model matches the target for your audio
  • Speedup factor: draft_tokens / total_forward_passes — theoretical speedup over sequential decoding

Model Pairing Recommendations

Draft Model Target Model Expected Acceptance Rate Notes
tiny large-v3 70-85% Best overall speedup
tiny turbo 75-90% Good balance
base large-v3 80-90% Higher acceptance, slower draft
tiny distil-large-v3 75-90% Fast target + fast draft

Tip: The tiny model is the best draft model because it's extremely fast. A slightly larger draft model (e.g., base) improves acceptance rate but adds drafting overhead.

Voice Activity Detection (VAD)

Vayu includes a VADProcessor for skipping silence before speculative decoding:

from whisper_mlx.speculative import VADProcessor

vad = VADProcessor(energy_threshold=0.01, min_speech_duration=0.5)

# Detect speech regions
regions = vad.detect_speech_regions(audio_array, sample_rate=16000)
# Returns: [(start_sample, end_sample), ...]

# Check how much silence can be skipped
skip_ratio = vad.get_skip_ratio(audio_array)
print(f"Can skip {skip_ratio:.0%} of audio (silence)")
Parameter Default Description
energy_threshold 0.01 Energy level below which frames are considered silent
min_speech_duration 0.5 Minimum duration (seconds) to count as speech

Parallel Chunk Transcription

For even more throughput, process audio chunks in parallel:

from whisper_mlx.speculative import parallel_chunk_transcribe

result = parallel_chunk_transcribe(
    "long_audio.mp3",
    model_path="mlx-community/whisper-turbo",
    chunk_duration=30.0,
    overlap_duration=2.0,
    language="en",
)
Parameter Default Description
chunk_duration 30.0 Length of each chunk in seconds
overlap_duration 2.0 Overlap between chunks for seamless merging
model_path "mlx-community/whisper-turbo" Model to use
language "en" Language code

Overlapping chunks are automatically merged by keeping the longer text in overlap regions.

Limitations

  • Experimental: The API may change in future versions
  • No word timestamps: Speculative decoding does not currently support word-level timestamps
  • No batched decoding: Speculative decoding processes segments sequentially (the speedup comes from token-level parallelism instead)
  • Memory: Two models are loaded simultaneously (draft + target), which uses more memory
  • Language: Currently optimized for single-language transcription

Clone this wiki locally