Attention mechanismus — klíč k moderním AI
Attention mechanismus představuje jeden z nejvýznamnějších průlomů v oblasti umělé inteligence posledních let. Tato technologie umožňuje AI systémům selektivně se zaměřovat na důležité části informací, podobně jako lidský mozek. Díky attention mechanismu vznikly pokročilé modely jako GPT a BERT, které dnes pohánějí ChatGPT a další moderní AI aplikace.
Co je Attention mechanismus?
Attention mechanismus představuje revoluční přístup v oblasti neuronových sítí, který umožňuje modelům dynamicky se zaměřovat na relevantní části vstupních dat. Místo postupného zpracování sekvence jako u tradičních RNN či LSTM, attention umožňuje modelu "nahlédnout" na všechny pozice současně a vybrat ty nejdůležitější.
Základní myšlenka je jednoduchá: při zpracování každého prvku sekvence vypočítáme skóre důležitosti pro všechny ostatní prvky. Tyto skóre následně použijeme jako váhy pro vytvoření kontextově bohaté reprezentace.
Matematické základy
Attention mechanismus lze formálně popsat jako funkci, která mapuje dotaz (query) a množinu párů klíč-hodnota (key-value) na výstup. Pro pozici i a kontext C vypočítáme attention váhy α následovně:
# Základní attention výpočet
def attention_weights(query, keys):
scores = torch.matmul(query, keys.transpose(-2, -1))
weights = torch.softmax(scores / math.sqrt(keys.size(-1)), dim=-1)
return weights
def attention(query, keys, values):
weights = attention_weights(query, keys)
context = torch.matmul(weights, values)
return context, weights
Škálování faktorem √d_k (kde d_k je dimenze klíčů) je kritické pro stabilitu gradientů při větších dimenzích.
Self-Attention: Revoluce v NLP
Self-attention představuje speciální případ, kde query, keys i values pocházejí ze stejné vstupní sekvence. Tento přístup umožňuje každému tokenu "komunikovat" se všemi ostatními tokeny v sekvenci současně.
import torch
import torch.nn as nn
import math
class SelfAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len = x.size(0), x.size(1)
# Projekce na Q, K, V
Q = self.w_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
K = self.w_k(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
V = self.w_v(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
# Transpose pro multi-head attention
Q = Q.transpose(1, 2) # [batch, heads, seq_len, head_dim]
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
attention_weights = torch.softmax(scores, dim=-1)
context = torch.matmul(attention_weights, V)
# Concatenate heads
context = context.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.d_model
)
return self.w_o(context)
Multi-Head Attention
Multi-head attention rozšiřuje základní mechanismus tím, že paralelně spouští několik attention "hlav" s různými naučenými projekcemi. Každá hlava se může zaměřit na jiné aspekty vstupních dat - syntaktické vztahy, sémantické podobnosti, či pozční informace.
Klíčové výhody multi-head přístupu:
- Paralelizace: Výpočty napříč hlavami jsou nezávislé
- Diverzifikace: Různé hlavy zachycují odlišné vzorce
- Kapacita: Zvýšená expresivita modelu bez dramatického nárůstu parametrů
Praktická implementace s optimalizacemi
V produkčních systémech je kritické optimalizovat attention výpočty. Zde je rozšířená implementace s podporou masking a dropout:
class OptimizedMultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.qkv = nn.Linear(d_model, 3 * d_model) # Spojená projekce
self.output = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
batch_size, seq_len = x.shape[:2]
# Efektivní výpočet Q, K, V najednou
qkv = self.qkv(x).chunk(3, dim=-1)
q, k, v = [tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
for tensor in qkv]
# Attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
# Aplikace masky (pro padding, causal attention apod.)
if mask is not None:
scores.masked_fill_(mask == 0, float('-inf'))
attention_weights = torch.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# Aplikace attention na values
out = torch.matmul(attention_weights, v)
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.output(out), attention_weights
Attention v Transformer architektuře
Transformer architektura využívá attention mechanismus trojím způsobem:
- Encoder Self-Attention: Umožňuje každému slovu v sekvenci vztahovat se ke všem ostatním slovům
- Decoder Self-Attention: S causal masking pro autoregresivní generování
- Cross-Attention: Propojuje encoder a decoder reprezentace
# Příklad causal masking pro decoder
def create_causal_mask(seq_len):
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
return mask == 0 # True pro povolené pozice
# Použití v decoder layeru
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.self_attention = OptimizedMultiHeadAttention(d_model, num_heads)
self.cross_attention = OptimizedMultiHeadAttention(d_model, num_heads)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.ReLU(),
nn.Linear(4 * d_model, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, x, encoder_output, causal_mask=None):
# Self-attention s causal masking
attn_out, _ = self.self_attention(x, causal_mask)
x = self.norm1(x + attn_out)
# Cross-attention s encoder výstupem
cross_out, _ = self.cross_attention(x) # Implementace zkrácena
x = self.norm2(x + cross_out)
# Feed-forward
ff_out = self.feed_forward(x)
return self.norm3(x + ff_out)
Výkonnostní aspekty a optimalizace
Attention mechanismus má kvadratickou složitost O(n²) vzhledem k délce sekvence, což představuje výzvu pro dlouhé sekvence. Moderní přístupy zahrnují:
- Flash Attention: Memory-efficient implementace s kernel fusion
- Sparse Attention: Omezení na lokální či strukturované vzorce
- Linear Attention: Aproximace s lineární složitostí
- Gradient Checkpointing: Trade-off mezi pamětí a výpočetním časem
# Příklad sparse attention pattern
def create_local_attention_mask(seq_len, window_size=128):
mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
for i in range(seq_len):
start = max(0, i - window_size // 2)
end = min(seq_len, i + window_size // 2 + 1)
mask[i, start:end] = True
return mask
# Využití při dlouhých sekvencích
local_mask = create_local_attention_mask(4096, 256)
attention_output, weights = model.attention(x, mask=local_mask)
Shrnutí
Attention mechanismus představuje fundamentální průlom v architektuře neuronových sítí, který umožnil vznik moderních jazykových modelů. Jeho schopnost dynamicky se zaměřovat na relevantní části vstupu při zachování paralelizovatelnosti činí z něj klíčový stavební kámen současných AI systémů. Pro praktické nasazení je důležité porozumět jak základním principům, tak optimalizačním technikám pro efektivní škálování na produkční úlohy.