Speculative Decoding — rychlejší LLM
Velké jazykové modely jsou mocné, ale pomalé při generování textu. Speculative Decoding představuje elegantní řešení tohoto problému pomocí menšího "draft" modelu, který navrhuje tokeny předem. Tato technika dokáže výrazně zrychlit inference bez ztráty kvality výstupu.
Co je Speculative Decoding
Speculative Decoding je pokročilá optimalizační technika, která dramaticky zrychluje inference velkých jazykových modelů (LLM) bez ztráty kvality výstupu. Princip spočívá v použití menšího, rychlejšího "draft" modelu, který předpovídá několik tokenů najednou, a větší "target" model tyto předpovědi následně verifikuje a případně opravuje.
Klíčová myšlenka je jednoduchá: místo sekvenčního generování token po tokenu využíváme paralelizaci - menší model rychle navrhne sekvenci tokenů a větší model je všechny najednou zkontroluje. Pokud jsou předpovědi správné, ušetříme významné množství času. Pokud ne, pokračujeme od místa, kde došlo k divergenci.
Jak Speculative Decoding funguje
Proces probíhá ve dvou hlavních krocích:
- Draft fáze: Menší model (např. 7B parametrů) rychle generuje K kandidátních tokenů
- Verification fáze: Větší model (např. 70B parametrů) paralelně vyhodnotí všechny kandidáty
Matematicky můžeme acceptance probability pro token na pozici i vyjádřit jako:
α_i = min(1, P_target(x_i | x_<i) / P_draft(x_i | x_<i))
Kde P_target a P_draft jsou pravděpodobnostní distribuce target a draft modelu. Token je přijat s pravděpodobností α_i, jinak se vzorkuje nový token z upravené distribuce.
Implementace v PyTorch
Základní implementace speculative decoding algoritmu:
import torch
import torch.nn.functional as F
def speculative_decode(draft_model, target_model, prompt_tokens,
max_new_tokens=50, draft_k=5, temperature=1.0):
"""
Implementace Speculative Decoding
Args:
draft_model: Menší, rychlejší model
target_model: Větší, přesnější model
prompt_tokens: Vstupní sekvence tokenů
max_new_tokens: Maximum nových tokenů
draft_k: Počet kandidátních tokenů z draft modelu
temperature: Teplota pro sampling
"""
device = next(target_model.parameters()).device
generated_tokens = prompt_tokens.clone()
for _ in range(max_new_tokens // draft_k):
# Draft fáze - generuj K kandidátů
draft_tokens = []
current_seq = generated_tokens
with torch.no_grad():
for _ in range(draft_k):
draft_logits = draft_model(current_seq)[:, -1, :]
draft_probs = F.softmax(draft_logits / temperature, dim=-1)
next_token = torch.multinomial(draft_probs, 1)
draft_tokens.append(next_token.item())
current_seq = torch.cat([current_seq, next_token], dim=-1)
# Verification fáze
candidate_seq = torch.cat([
generated_tokens,
torch.tensor(draft_tokens, device=device).unsqueeze(0)
], dim=-1)
with torch.no_grad():
target_logits = target_model(candidate_seq)
target_probs = F.softmax(target_logits / temperature, dim=-1)
# Pro původní sekvenci
draft_logits_full = draft_model(candidate_seq)
draft_probs_full = F.softmax(draft_logits_full / temperature, dim=-1)
# Acceptance/rejection sampling
accepted_tokens = []
for i, token in enumerate(draft_tokens):
pos = generated_tokens.size(-1) + i
target_prob = target_probs[0, pos - 1, token]
draft_prob = draft_probs_full[0, pos - 1, token]
acceptance_prob = min(1.0, target_prob / draft_prob)
if torch.rand(1).item() < acceptance_prob:
accepted_tokens.append(token)
else:
# Reject a vzorkuj nový token z residual distribution
residual_probs = torch.clamp(
target_probs[0, pos - 1] - draft_probs_full[0, pos - 1],
min=0
)
residual_probs = residual_probs / residual_probs.sum()
new_token = torch.multinomial(residual_probs, 1).item()
accepted_tokens.append(new_token)
break
# Přidej přijaté tokeny
if accepted_tokens:
new_tokens = torch.tensor(accepted_tokens, device=device).unsqueeze(0)
generated_tokens = torch.cat([generated_tokens, new_tokens], dim=-1)
if len(accepted_tokens) < draft_k:
break
return generated_tokens
Optimalizace a praktické aspekty
Výběr draft modelu
Úspěch speculative decoding silně závisí na kvalitě draft modelu. Ideální draft model by měl:
- Mít podobnou architekturu jako target model
- Být trénován na podobných datech
- Dosahovat rozumné alignment s target modelem
- Být alespoň 4-8× rychlejší než target model
V praxi se osvědčuje použití menších verzí stejné modelové rodiny nebo specializovaných "distilled" modelů:
# Příklad konfigurace pro různé kombinace modelů
model_pairs = [
{
"draft": "microsoft/DialoGPT-small", # 117M parametrů
"target": "microsoft/DialoGPT-large", # 762M parametrů
"speedup": "3-5x"
},
{
"draft": "distilbert-base-uncased", # 66M parametrů
"target": "bert-large-uncased", # 340M parametrů
"speedup": "2-4x"
}
]
Batch processing a memory management
Pro produkční nasazení je kritické efektivní využití paměti a batch processing:
class SpeculativeDecodingEngine:
def __init__(self, draft_model, target_model, max_batch_size=8):
self.draft_model = draft_model
self.target_model = target_model
self.max_batch_size = max_batch_size
# Předalokace bufferů pro lepší memory management
self.draft_cache = {}
self.target_cache = {}
def decode_batch(self, prompts, max_length=512, draft_k=4):
"""Batch inference s memory-efficient implementací"""
batch_size = min(len(prompts), self.max_batch_size)
results = []
for i in range(0, len(prompts), batch_size):
batch = prompts[i:i+batch_size]
# Použij gradient checkpointing pro úsporu paměti
with torch.cuda.amp.autocast():
batch_results = self._process_batch(batch, max_length, draft_k)
results.extend(batch_results)
# Explicitní garbage collection
torch.cuda.empty_cache()
return results
def _process_batch(self, batch, max_length, draft_k):
# Implementace batch speculative decoding
# s KV-cache optimalizacemi
pass
Měření výkonnosti a benchmarky
Pro správné vyhodnocení efektivity speculative decoding je potřeba měřit několik klíčových metrik:
import time
from collections import defaultdict
class SpeculativeDecodingProfiler:
def __init__(self):
self.metrics = defaultdict(list)
def benchmark_comparison(self, models, test_prompts, runs=10):
"""Porovnání standardního a speculative decoding"""
results = {
'standard': {'time': [], 'tokens_per_sec': [], 'acceptance_rate': []},
'speculative': {'time': [], 'tokens_per_sec': [], 'acceptance_rate': []}
}
for run in range(runs):
for prompt in test_prompts:
# Standard decoding
start_time = time.time()
standard_output = models['target'].generate(
prompt, max_length=100, do_sample=True
)
standard_time = time.time() - start_time
# Speculative decoding
start_time = time.time()
spec_output, acceptance_rate = speculative_decode(
models['draft'], models['target'], prompt
)
spec_time = time.time() - start_time
# Compute metrics
standard_tps = len(standard_output[0]) / standard_time
spec_tps = len(spec_output[0]) / spec_time
results['standard']['time'].append(standard_time)
results['standard']['tokens_per_sec'].append(standard_tps)
results['speculative']['time'].append(spec_time)
results['speculative']['tokens_per_sec'].append(spec_tps)
results['speculative']['acceptance_rate'].append(acceptance_rate)
return self._compute_statistics(results)
def _compute_statistics(self, results):
stats = {}
for method in results:
stats[method] = {
'avg_time': np.mean(results[method]['time']),
'avg_tps': np.mean(results[method]['tokens_per_sec']),
'speedup': np.mean(results['speculative']['tokens_per_sec']) /
np.mean(results['standard']['tokens_per_sec'])
}
return stats
Produkční nasazení
V produkčním prostředí je potřeba zvážit několik praktických aspektů implementace speculative decoding:
Infrastrukturní požadavky
- Paměťové nároky: Potřebujete dostatek GPU paměti pro oba modely současně
- Latence: Draft model běží na rychlejším hardware, target model může běžet na specializovaných akcelerátorech
- Load balancing: Různé typy požadavků mohou využívat různé kombinace modelů
# Docker Compose konfigurace pro produkční deployment
version: '3.8'
services:
draft-service:
image: your-registry/draft-model:latest
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
environment:
- MODEL_SIZE=7b
- BATCH_SIZE=32
target-service:
image: your-registry/target-model:latest
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 2
capabilities: [gpu]
environment:
- MODEL_SIZE=70b
- BATCH_SIZE=8
speculative-coordinator:
image: your-registry/speculative-engine:latest
depends_on:
- draft-service
- target-service
environment:
- DRAFT_ENDPOINT=http://draft-service:8000
- TARGET_ENDPOINT=http://target-service:8000
- MAX_DRAFT_TOKENS=6
Shrnutí
Speculative Decoding představuje výrazný pokrok v optimalizaci LLM inference s typickými zrychlením 2-4× při zachování kvality výstupu. Klíčem k úspěchu je správný výběr draft modelu, efektivní implementace acceptance/rejection sampling a pečlivé ladění hyperparametrů. V produkčním nasazení je důležité zvážit memory management, batch processing a infrastrukturní požadavky. Tato technika se stává standardem pro high-throughput aplikace využívající velké jazykové modely.