-
Notifications
You must be signed in to change notification settings - Fork 0
Speculative Decoding Guide
Speculative decoding is an experimental feature that can provide an additional 2-3x speedup on top of batched decoding.
The idea is simple: use a small, fast model to "draft" tokens, then verify them with a large, accurate model in a single forward pass.
Traditional decoding:
Token 1 → Forward Pass → Token 2 → Forward Pass → Token 3 → ...
(1 forward pass per token)
Speculative decoding:
Draft model: Token 1, 2, 3, 4, 5 (very fast)
Target model: Verify all 5 in ONE forward pass
(1 forward pass per 5 tokens, if all accepted)
When the draft model's predictions match the target model's, you get multiple tokens for the cost of one forward pass. The output is identical to running the target model alone — speculative decoding never sacrifices accuracy.
from whisper_mlx import SpeculativeDecoder
decoder = SpeculativeDecoder(
draft_model_path="mlx-community/whisper-tiny-mlx",
target_model_path="mlx-community/whisper-large-v3-mlx",
)
result = decoder.decode_segment(mel_spectrogram, language="en")
print(result["text"])For a complete transcription pipeline with speculative decoding:
from whisper_mlx.speculative import speculative_transcribe
result = speculative_transcribe(
"audio.mp3",
draft_model="tiny",
target_model="large-v3",
language="en",
verbose=True,
)
print(result["text"])
print(f"Speedup: {result['stats']['speedup_factor']:.1f}x")
print(f"Acceptance rate: {result['stats']['acceptance_rate']:.0%}")
print(f"Elapsed: {result['elapsed']:.1f}s")| Parameter | Default | Description |
|---|---|---|
draft_model_path |
"mlx-community/whisper-tiny-mlx" |
Fast model for drafting tokens |
target_model_path |
"mlx-community/whisper-large-v3-mlx" |
Accurate model for verification |
num_draft_tokens |
5 |
Tokens to draft before verification |
dtype |
mx.float16 |
Data type for computation |
| Parameter | Default | Description |
|---|---|---|
audio |
— | File path or audio array |
draft_model |
"tiny" |
Draft model name |
target_model |
"large-v3" |
Target model name |
language |
"en" |
Language code |
verbose |
True |
Print progress |
- The draft model generates
num_draft_tokenstokens autoregressively (fast, since it's a tiny model) - The target model receives ALL draft tokens and produces logits in one forward pass
- At each position, if the target model agrees with the draft → token is accepted
- On first disagreement → the target model's token is used instead, and drafting resumes
- Statistics track acceptance rate and speedup factor
stats = decoder.get_stats()
print(stats)
# {
# "draft_tokens": 1250, # Total tokens drafted
# "accepted_tokens": 1050, # Tokens verified correct
# "total_forward_passes": 280, # Target model forward passes
# "acceptance_rate": 0.84, # 84% of drafts accepted
# "speedup_factor": 4.46 # ~4.5 tokens per forward pass
# }- Acceptance rate: Higher is better. Depends on how well the draft model matches the target for your audio
-
Speedup factor:
draft_tokens / total_forward_passes— theoretical speedup over sequential decoding
| Draft Model | Target Model | Expected Acceptance Rate | Notes |
|---|---|---|---|
tiny |
large-v3 |
70-85% | Best overall speedup |
tiny |
turbo |
75-90% | Good balance |
base |
large-v3 |
80-90% | Higher acceptance, slower draft |
tiny |
distil-large-v3 |
75-90% | Fast target + fast draft |
Tip: The tiny model is the best draft model because it's extremely fast. A slightly larger draft model (e.g., base) improves acceptance rate but adds drafting overhead.
Vayu includes a VADProcessor for skipping silence before speculative decoding:
from whisper_mlx.speculative import VADProcessor
vad = VADProcessor(energy_threshold=0.01, min_speech_duration=0.5)
# Detect speech regions
regions = vad.detect_speech_regions(audio_array, sample_rate=16000)
# Returns: [(start_sample, end_sample), ...]
# Check how much silence can be skipped
skip_ratio = vad.get_skip_ratio(audio_array)
print(f"Can skip {skip_ratio:.0%} of audio (silence)")| Parameter | Default | Description |
|---|---|---|
energy_threshold |
0.01 |
Energy level below which frames are considered silent |
min_speech_duration |
0.5 |
Minimum duration (seconds) to count as speech |
For even more throughput, process audio chunks in parallel:
from whisper_mlx.speculative import parallel_chunk_transcribe
result = parallel_chunk_transcribe(
"long_audio.mp3",
model_path="mlx-community/whisper-turbo",
chunk_duration=30.0,
overlap_duration=2.0,
language="en",
)| Parameter | Default | Description |
|---|---|---|
chunk_duration |
30.0 |
Length of each chunk in seconds |
overlap_duration |
2.0 |
Overlap between chunks for seamless merging |
model_path |
"mlx-community/whisper-turbo" |
Model to use |
language |
"en" |
Language code |
Overlapping chunks are automatically merged by keeping the longer text in overlap regions.
- Experimental: The API may change in future versions
- No word timestamps: Speculative decoding does not currently support word-level timestamps
- No batched decoding: Speculative decoding processes segments sequentially (the speedup comes from token-level parallelism instead)
- Memory: Two models are loaded simultaneously (draft + target), which uses more memory
- Language: Currently optimized for single-language transcription