Large language models are powerful but slow at text generation. Speculative Decoding presents an elegant solution to this problem using a smaller “draft” model that proposes tokens in advance. This technique can significantly speed up inference without losing output quality.
What is Speculative Decoding¶
Speculative Decoding is an advanced optimization technique that dramatically speeds up large language model (LLM) inference without losing output quality. The principle involves using a smaller, faster “draft” model that predicts multiple tokens at once, and a larger “target” model subsequently verifies and potentially corrects these predictions.
The key idea is simple: instead of sequential token-by-token generation, we use parallelization - a smaller model quickly proposes a sequence of tokens and a larger model checks them all at once. If predictions are correct, we save significant time. If not, we continue from where divergence occurred.
How Speculative Decoding Works¶
The process occurs in two main steps:
- Draft phase: Smaller model (e.g., 7B parameters) quickly generates K candidate tokens
- Verification phase: Larger model (e.g., 70B parameters) evaluates all candidates in parallel
Mathematically, we can express the acceptance probability for token at position i as:
α_i = min(1, P_target(x_i | x_<i) / P_draft(x_i | x_<i))
Where P_target and P_draft are probability distributions of target and draft models. Token is accepted with probability α_i, otherwise a new token is sampled from the adjusted distribution.
Implementation in PyTorch¶
Basic implementation of the speculative decoding algorithm:
import torch
import torch.nn.functional as F
def speculative_decode(draft_model, target_model, prompt_tokens,
max_new_tokens=50, draft_k=5, temperature=1.0):
"""
Speculative Decoding implementation
Args:
draft_model: Smaller, faster model
target_model: Larger, more accurate model
prompt_tokens: Input token sequence
max_new_tokens: Maximum new tokens
draft_k: Number of candidate tokens from draft model
temperature: Temperature for sampling
"""
device = next(target_model.parameters()).device
generated_tokens = prompt_tokens.clone()
for _ in range(max_new_tokens // draft_k):
# Draft phase - generate K candidates
draft_tokens = []
current_seq = generated_tokens
with torch.no_grad():
for _ in range(draft_k):
draft_logits = draft_model(current_seq)[:, -1, :]
draft_probs = F.softmax(draft_logits / temperature, dim=-1)
next_token = torch.multinomial(draft_probs, 1)
draft_tokens.append(next_token.item())
current_seq = torch.cat([current_seq, next_token], dim=-1)
# Verification phase
candidate_seq = torch.cat([
generated_tokens,
torch.tensor(draft_tokens, device=device).unsqueeze(0)
], dim=-1)
with torch.no_grad():
target_logits = target_model(candidate_seq)
target_probs = F.softmax(target_logits / temperature, dim=-1)
# For original sequence
draft_logits_full = draft_model(candidate_seq)
draft_probs_full = F.softmax(draft_logits_full / temperature, dim=-1)
# Acceptance/rejection sampling
accepted_tokens = []
for i, token in enumerate(draft_tokens):
pos = generated_tokens.size(-1) + i
target_prob = target_probs[0, pos - 1, token]
draft_prob = draft_probs_full[0, pos - 1, token]
acceptance_prob = min(1.0, target_prob / draft_prob)
if torch.rand(1).item() < acceptance_prob:
accepted_tokens.append(token)
else:
# Reject and sample new token from residual distribution
residual_probs = torch.clamp(
target_probs[0, pos - 1] - draft_probs_full[0, pos - 1],
min=0
)
residual_probs = residual_probs / residual_probs.sum()
new_token = torch.multinomial(residual_probs, 1).item()
accepted_tokens.append(new_token)
break
# Add accepted tokens
if accepted_tokens:
new_tokens = torch.tensor(accepted_tokens, device=device).unsqueeze(0)
generated_tokens = torch.cat([generated_tokens, new_tokens], dim=-1)
if len(accepted_tokens) < draft_k:
break
return generated_tokens
Optimization and Practical Aspects¶
Draft Model Selection¶
Speculative decoding success strongly depends on draft model quality. Ideal draft model should:
- Have similar architecture to target model
- Be trained on similar data
- Achieve reasonable alignment with target model
- Be at least 4-8× faster than target model
In practice, using smaller versions of the same model family or specialized “distilled” models works well:
# Example configuration for various model combinations
model_pairs = [
{
"draft": "microsoft/DialoGPT-small", # 117M parameters
"target": "microsoft/DialoGPT-large", # 762M parameters
"speedup": "3-5x"
},
{
"draft": "distilbert-base-uncased", # 66M parameters
"target": "bert-large-uncased", # 340M parameters
"speedup": "2-4x"
}
]
Batch Processing and Memory Management¶
For production deployment, efficient memory utilization and batch processing are critical:
class SpeculativeDecodingEngine:
def __init__(self, draft_model, target_model, max_batch_size=8):
self.draft_model = draft_model
self.target_model = target_model
self.max_batch_size = max_batch_size
# Pre-allocate buffers for better memory management
self.draft_cache = {}
self.target_cache = {}
def decode_batch(self, prompts, max_length=512, draft_k=4):
"""Batch inference with memory-efficient implementation"""
batch_size = min(len(prompts), self.max_batch_size)
results = []
for i in range(0, len(prompts), batch_size):
batch = prompts[i:i+batch_size]
# Use gradient checkpointing for memory savings
with torch.cuda.amp.autocast():
batch_results = self._process_batch(batch, max_length, draft_k)
results.extend(batch_results)
# Explicit garbage collection
torch.cuda.empty_cache()
return results
def _process_batch(self, batch, max_length, draft_k):
# Implementation of batch speculative decoding
# with KV-cache optimizations
pass
Performance Measurement and Benchmarks¶
To properly evaluate speculative decoding effectiveness, several key metrics need to be measured:
import time
from collections import defaultdict
class SpeculativeDecodingProfiler:
def __init__(self):
self.metrics = defaultdict(list)
def benchmark_comparison(self, models, test_prompts, runs=10):
"""Compare standard and speculative decoding"""
results = {
'standard': {'time': [], 'tokens_per_sec': [], 'acceptance_rate': []},
'speculative': {'time': [], 'tokens_per_sec': [], 'acceptance_rate': []}
}
for run in range(runs):
for prompt in test_prompts:
# Standard decoding
start_time = time.time()
standard_output = models['target'].generate(
prompt, max_length=100, do_sample=True
)
standard_time = time.time() - start_time
# Speculative decoding
start_time = time.time()
spec_output, acceptance_rate = speculative_decode(
models['draft'], models['target'], prompt
)
spec_time = time.time() - start_time
# Compute metrics
standard_tps = len(standard_output[0]) / standard_time
spec_tps = len(spec_output[0]) / spec_time
results['standard']['time'].append(standard_time)
results['standard']['tokens_per_sec'].append(standard_tps)
results['speculative']['time'].append(spec_time)
results['speculative']['tokens_per_sec'].append(spec_tps)
results['speculative']['acceptance_rate'].append(acceptance_rate)
return self._compute_statistics(results)
def _compute_statistics(self, results):
stats = {}
for method in results:
stats[method] = {
'avg_time': np.mean(results[method]['time']),
'avg_tps': np.mean(results[method]['tokens_per_sec']),
'speedup': np.mean(results['speculative']['tokens_per_sec']) /
np.mean(results['standard']['tokens_per_sec'])
}
return stats
Production Deployment¶
In production environments, several practical aspects of speculative decoding implementation need consideration:
Infrastructure Requirements¶
- Memory requirements: You need sufficient GPU memory for both models simultaneously
- Latency: Draft model runs on faster hardware, target model can run on specialized accelerators
- Load balancing: Different types of requests may use different model combinations
# Docker Compose configuration for production deployment
version: '3.8'
services:
draft-service:
image: your-registry/draft-model:latest
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
environment:
- MODEL_SIZE=7b
- BATCH_SIZE=32
target-service:
image: your-registry/target-model:latest
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 2
capabilities: [gpu]
environment:
- MODEL_SIZE=70b
- BATCH_SIZE=8
speculative-coordinator:
image: your-registry/speculative-engine:latest
depends_on:
- draft-service
- target-service
environment:
- DRAFT_ENDPOINT=http://draft-service:8000
- TARGET_ENDPOINT=http://target-service:8000
- MAX_DRAFT_TOKENS=6
Summary¶
Speculative Decoding represents significant progress in LLM inference optimization with typical 2-4× speedups while maintaining output quality. The key to success is proper draft model selection, efficient acceptance/rejection sampling implementation, and careful hyperparameter tuning. In production deployment, it’s important to consider memory management, batch processing, and infrastructure requirements. This technique is becoming standard for high-throughput applications using large language models.