Transformer architektura — kompletní průvodce
Transformer architektura představuje přelomovou technologii v oblasti umělé inteligence, která stojí za nejmodernějšími jazykovými modely jako GPT či BERT. Tento průvodce vám jednoduše vysvětlí, jak fungují klíčové mechanismy jako attention a self-attention, a proč jsou Transformery tak efektivní.
Co je Transformer architektura
Transformer architektura představuje převratný přístup k zpracování sekvencí, který od roku 2017 dominuje oblasti Natural Language Processing. Klíčovou inovací je mechanismus self-attention, který umožňuje modelu sledovat vztahy mezi všemi pozicemi v sekvenci současně, na rozdíl od sekvenčního zpracování v RNN nebo LSTM sítích.
Základní princip spočívá v transformaci vstupní sekvence tokenů na vektory pomocí attention mechanismu, který váží důležitost jednotlivých pozic pro každý prvek sekvence. Tím se model naučí kontextové reprezentace slov, kde význam závisí na celém kontextu věty.
Architektura Transformeru
Encoder-Decoder struktura
Originální Transformer se skládá ze dvou hlavních částí:
- Encoder - zpracovává vstupní sekvenci a vytváří kontextové reprezentace
- Decoder - generuje výstupní sekvenci na základě enkodovaných reprezentací
Každá část obsahuje stack několika identických vrstev (typicky 6), přičemž každá vrstva má dva hlavní komponenty: multi-head attention a feed-forward síť.
Self-Attention mechanismus
Srdcem Transformeru je scaled dot-product attention:
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q: Query matice [batch_size, seq_len, d_model]
K: Key matice [batch_size, seq_len, d_model]
V: Value matice [batch_size, seq_len, d_model]
"""
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
return output, attention_weights
Mechanismus funguje tak, že pro každý token vytvoří tři vektory - Query (co hledám), Key (co nabízím) a Value (co předávám). Attention score se počítá jako dot product mezi Query a Key vektory, normalizovaný délkou vektoru.
Multi-Head Attention
Multi-head attention umožňuje modelu sledovat různé typy vztahů současně:
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# Rozdělení do hlav
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Attention pro každou hlavu
attention, _ = scaled_dot_product_attention(Q, K, V, mask)
# Spojení hlav
attention = attention.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
)
return self.W_o(attention)
Pozičního kódování (Positional Encoding)
Protože Transformer nemá inherentní pojem pořadí, musíme pozici tokenů kódovat explicitně. Používají se sinusoidní funkce:
def positional_encoding(seq_len, d_model):
pe = torch.zeros(seq_len, d_model)
position = torch.arange(0, seq_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe
Toto kódování umožňuje modelu rozlišovat pozice a učit se vztahy závislé na vzdálenosti mezi tokeny.
Varianty Transformer architektury
BERT (Bidirectional Encoder Representations)
BERT používá pouze encoder část a trénuje se bidirectionálně pomocí masked language modeling:
- Maskuje náhodně 15% tokenů ve vstupní sekvenci
- Učí se předpovídat zamaskované tokeny na základě celého kontextu
- Výborný pro úlohy porozumění textu (classification, NER, QA)
GPT (Generative Pre-trained Transformer)
GPT používá pouze decoder část s causal masking:
def create_causal_mask(seq_len):
"""Vytvoří masku pro autoregresivní generování"""
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]
Tato maska zajišťuje, že při předpovídání tokenu na pozici i model vidí pouze tokeny na pozicích 0 až i-1.
Implementace základního Transformeru
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, num_heads)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention s residual connection
attn_output = self.attention(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# Feed-forward s residual connection
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
Výhody a nevýhody
Výhody:
- Paralelizace - na rozdíl od RNN lze trénovat všechny pozice současně
- Dlouhé závislosti - attention mechanismus umožňuje přímé propojení vzdálených pozic
- Interpretovatelnost - attention weights poskytují insight do toho, na co se model zaměřuje
- Transfer learning - pre-trained modely se dají fine-tunovat pro specifické úlohy
Nevýhody:
- Výpočetní složitost - O(n²) vzhledem k délce sekvence
- Paměťové nároky - attention matice roste kvadraticky
- Množství dat - vyžaduje velké množství trénovacích dat pro dobrou výkonnost
Shrnutí
Transformer architektura představuje zásadní pokrok v oblasti zpracování sekvencí. Její klíčové inovace - self-attention mechanismus a paralelizovatelnost - umožnily vznik pokročilých modelů jako GPT, BERT a jejich následovníků. Pro praxi je důležité pochopit, že různé varianty (encoder-only, decoder-only, encoder-decoder) se hodí pro různé typy úloh. Zatímco implementace může být komplexní, principy jsou elegantní a poskytují pevný základ pro pochopení moderních AI systémů.