Přeskočit na obsah
AI Základy

Speculative Decoding — rychlejší LLM

4 min čtení
Speculative DecodingInferenceOptimalizace

Velké jazykové modely jsou mocné, ale pomalé při generování textu. Speculative Decoding představuje elegantní řešení tohoto problému pomocí menšího "draft" modelu, který navrhuje tokeny předem. Tato technika dokáže výrazně zrychlit inference bez ztráty kvality výstupu.

Co je Speculative Decoding

Speculative Decoding je pokročilá optimalizační technika, která dramaticky zrychluje inference velkých jazykových modelů (LLM) bez ztráty kvality výstupu. Princip spočívá v použití menšího, rychlejšího "draft" modelu, který předpovídá několik tokenů najednou, a větší "target" model tyto předpovědi následně verifikuje a případně opravuje.

Klíčová myšlenka je jednoduchá: místo sekvenčního generování token po tokenu využíváme paralelizaci - menší model rychle navrhne sekvenci tokenů a větší model je všechny najednou zkontroluje. Pokud jsou předpovědi správné, ušetříme významné množství času. Pokud ne, pokračujeme od místa, kde došlo k divergenci.

Jak Speculative Decoding funguje

Proces probíhá ve dvou hlavních krocích:

  1. Draft fáze: Menší model (např. 7B parametrů) rychle generuje K kandidátních tokenů
  2. Verification fáze: Větší model (např. 70B parametrů) paralelně vyhodnotí všechny kandidáty

Matematicky můžeme acceptance probability pro token na pozici i vyjádřit jako:

α_i = min(1, P_target(x_i | x_<i) / P_draft(x_i | x_<i))

Kde P_target a P_draft jsou pravděpodobnostní distribuce target a draft modelu. Token je přijat s pravděpodobností α_i, jinak se vzorkuje nový token z upravené distribuce.

Implementace v PyTorch

Základní implementace speculative decoding algoritmu:

import torch
import torch.nn.functional as F

def speculative_decode(draft_model, target_model, prompt_tokens, 
                      max_new_tokens=50, draft_k=5, temperature=1.0):
    """
    Implementace Speculative Decoding
    
    Args:
        draft_model: Menší, rychlejší model
        target_model: Větší, přesnější model  
        prompt_tokens: Vstupní sekvence tokenů
        max_new_tokens: Maximum nových tokenů
        draft_k: Počet kandidátních tokenů z draft modelu
        temperature: Teplota pro sampling
    """
    
    device = next(target_model.parameters()).device
    generated_tokens = prompt_tokens.clone()
    
    for _ in range(max_new_tokens // draft_k):
        # Draft fáze - generuj K kandidátů
        draft_tokens = []
        current_seq = generated_tokens
        
        with torch.no_grad():
            for _ in range(draft_k):
                draft_logits = draft_model(current_seq)[:, -1, :]
                draft_probs = F.softmax(draft_logits / temperature, dim=-1)
                next_token = torch.multinomial(draft_probs, 1)
                draft_tokens.append(next_token.item())
                current_seq = torch.cat([current_seq, next_token], dim=-1)
        
        # Verification fáze
        candidate_seq = torch.cat([
            generated_tokens, 
            torch.tensor(draft_tokens, device=device).unsqueeze(0)
        ], dim=-1)
        
        with torch.no_grad():
            target_logits = target_model(candidate_seq)
            target_probs = F.softmax(target_logits / temperature, dim=-1)
            
            # Pro původní sekvenci
            draft_logits_full = draft_model(candidate_seq)
            draft_probs_full = F.softmax(draft_logits_full / temperature, dim=-1)
        
        # Acceptance/rejection sampling
        accepted_tokens = []
        for i, token in enumerate(draft_tokens):
            pos = generated_tokens.size(-1) + i
            
            target_prob = target_probs[0, pos - 1, token]
            draft_prob = draft_probs_full[0, pos - 1, token]
            
            acceptance_prob = min(1.0, target_prob / draft_prob)
            
            if torch.rand(1).item() < acceptance_prob:
                accepted_tokens.append(token)
            else:
                # Reject a vzorkuj nový token z residual distribution
                residual_probs = torch.clamp(
                    target_probs[0, pos - 1] - draft_probs_full[0, pos - 1], 
                    min=0
                )
                residual_probs = residual_probs / residual_probs.sum()
                new_token = torch.multinomial(residual_probs, 1).item()
                accepted_tokens.append(new_token)
                break
        
        # Přidej přijaté tokeny
        if accepted_tokens:
            new_tokens = torch.tensor(accepted_tokens, device=device).unsqueeze(0)
            generated_tokens = torch.cat([generated_tokens, new_tokens], dim=-1)
        
        if len(accepted_tokens) < draft_k:
            break
            
    return generated_tokens

Optimalizace a praktické aspekty

Výběr draft modelu

Úspěch speculative decoding silně závisí na kvalitě draft modelu. Ideální draft model by měl:

  • Mít podobnou architekturu jako target model
  • Být trénován na podobných datech
  • Dosahovat rozumné alignment s target modelem
  • Být alespoň 4-8× rychlejší než target model

V praxi se osvědčuje použití menších verzí stejné modelové rodiny nebo specializovaných "distilled" modelů:

# Příklad konfigurace pro různé kombinace modelů
model_pairs = [
    {
        "draft": "microsoft/DialoGPT-small",      # 117M parametrů
        "target": "microsoft/DialoGPT-large",    # 762M parametrů
        "speedup": "3-5x"
    },
    {
        "draft": "distilbert-base-uncased",      # 66M parametrů  
        "target": "bert-large-uncased",          # 340M parametrů
        "speedup": "2-4x"
    }
]

Batch processing a memory management

Pro produkční nasazení je kritické efektivní využití paměti a batch processing:

class SpeculativeDecodingEngine:
    def __init__(self, draft_model, target_model, max_batch_size=8):
        self.draft_model = draft_model
        self.target_model = target_model
        self.max_batch_size = max_batch_size
        
        # Předalokace bufferů pro lepší memory management
        self.draft_cache = {}
        self.target_cache = {}
    
    def decode_batch(self, prompts, max_length=512, draft_k=4):
        """Batch inference s memory-efficient implementací"""
        batch_size = min(len(prompts), self.max_batch_size)
        results = []
        
        for i in range(0, len(prompts), batch_size):
            batch = prompts[i:i+batch_size]
            
            # Použij gradient checkpointing pro úsporu paměti
            with torch.cuda.amp.autocast():
                batch_results = self._process_batch(batch, max_length, draft_k)
            
            results.extend(batch_results)
            
            # Explicitní garbage collection
            torch.cuda.empty_cache()
            
        return results
    
    def _process_batch(self, batch, max_length, draft_k):
        # Implementace batch speculative decoding
        # s KV-cache optimalizacemi
        pass

Měření výkonnosti a benchmarky

Pro správné vyhodnocení efektivity speculative decoding je potřeba měřit několik klíčových metrik:

import time
from collections import defaultdict

class SpeculativeDecodingProfiler:
    def __init__(self):
        self.metrics = defaultdict(list)
    
    def benchmark_comparison(self, models, test_prompts, runs=10):
        """Porovnání standardního a speculative decoding"""
        
        results = {
            'standard': {'time': [], 'tokens_per_sec': [], 'acceptance_rate': []},
            'speculative': {'time': [], 'tokens_per_sec': [], 'acceptance_rate': []}
        }
        
        for run in range(runs):
            for prompt in test_prompts:
                # Standard decoding
                start_time = time.time()
                standard_output = models['target'].generate(
                    prompt, max_length=100, do_sample=True
                )
                standard_time = time.time() - start_time
                
                # Speculative decoding  
                start_time = time.time()
                spec_output, acceptance_rate = speculative_decode(
                    models['draft'], models['target'], prompt
                )
                spec_time = time.time() - start_time
                
                # Compute metrics
                standard_tps = len(standard_output[0]) / standard_time
                spec_tps = len(spec_output[0]) / spec_time
                
                results['standard']['time'].append(standard_time)
                results['standard']['tokens_per_sec'].append(standard_tps)
                
                results['speculative']['time'].append(spec_time)  
                results['speculative']['tokens_per_sec'].append(spec_tps)
                results['speculative']['acceptance_rate'].append(acceptance_rate)
        
        return self._compute_statistics(results)
    
    def _compute_statistics(self, results):
        stats = {}
        for method in results:
            stats[method] = {
                'avg_time': np.mean(results[method]['time']),
                'avg_tps': np.mean(results[method]['tokens_per_sec']),
                'speedup': np.mean(results['speculative']['tokens_per_sec']) / 
                          np.mean(results['standard']['tokens_per_sec'])
            }
        return stats

Produkční nasazení

V produkčním prostředí je potřeba zvážit několik praktických aspektů implementace speculative decoding:

Infrastrukturní požadavky

  • Paměťové nároky: Potřebujete dostatek GPU paměti pro oba modely současně
  • Latence: Draft model běží na rychlejším hardware, target model může běžet na specializovaných akcelerátorech
  • Load balancing: Různé typy požadavků mohou využívat různé kombinace modelů
# Docker Compose konfigurace pro produkční deployment
version: '3.8'
services:
  draft-service:
    image: your-registry/draft-model:latest
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              count: 1
              capabilities: [gpu]
    environment:
      - MODEL_SIZE=7b
      - BATCH_SIZE=32
      
  target-service:  
    image: your-registry/target-model:latest
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia  
              count: 2
              capabilities: [gpu]
    environment:
      - MODEL_SIZE=70b
      - BATCH_SIZE=8
      
  speculative-coordinator:
    image: your-registry/speculative-engine:latest
    depends_on:
      - draft-service
      - target-service
    environment:
      - DRAFT_ENDPOINT=http://draft-service:8000
      - TARGET_ENDPOINT=http://target-service:8000
      - MAX_DRAFT_TOKENS=6

Shrnutí

Speculative Decoding představuje výrazný pokrok v optimalizaci LLM inference s typickými zrychlením 2-4× při zachování kvality výstupu. Klíčem k úspěchu je správný výběr draft modelu, efektivní implementace acceptance/rejection sampling a pečlivé ladění hyperparametrů. V produkčním nasazení je důležité zvážit memory management, batch processing a infrastrukturní požadavky. Tato technika se stává standardem pro high-throughput aplikace využívající velké jazykové modely.

CORE SYSTEMS tým

Enterprise architekti a AI inženýři.