Batch Normalization — stabilní trénink sítí
Batch Normalization patří mezi nejdůležitější inovace v oblasti hlubokého učení posledních let. Tato technika normalizuje vstupy do každé vrstvy během tréninku, čímž výrazně stabilizuje a zrychluje celý proces učení neuronových sítí.
Batch Normalization — stabilní trénink sítí
Batch Normalization (BatchNorm) patří mezi nejdůležitější techniky moderního deep learningu. Vyřešila zásadní problémy s tréninkem hlubokých sítí a umožnila stavět modely s desítkami či stovkami vrstev. Pojďme si ukázat, jak funguje a proč je tak efektivní.
Problém nestability gradientů
Při tréninku hlubokých neuronových sítí narážíme na fenomén zvaný internal covariate shift. Aktivace v hlubších vrstvách se během tréninku dramaticky mění, což způsobuje:
- Exploding/vanishing gradients
- Pomalou konvergenci
- Nutnost velmi opatrné inicializace vah
- Citlivost na volbu learning rate
BatchNorm tyto problémy řeší normalizací aktivací na úrovni mini-batchů, což stabilizuje distribuci dat procházejících sítí.
Jak BatchNorm funguje
Batch Normalization aplikuje následující transformaci na každý feature v mini-batchi:
# Pro každý feature x_i v batchi μ = mean(x_i) # Průměr přes batch σ² = variance(x_i) # Rozptyl přes batch x̂_i = (x_i - μ) / √(σ² + ε) # Normalizace y_i = γ * x̂_i + β # Škálování a posun
Parametry γ (gamma) a β (beta) jsou učitelné - síť si může "naučit" optimální distribuci pro každou vrstvu.
Implementace v PyTorch
Použití BatchNorm je v PyTorch velmi jednoduché:
import torch
import torch.nn as nn
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x) # BatchNorm po konvoluci, před aktivací
x = self.relu(x)
return x
# Použití
model = nn.Sequential(
ConvBlock(3, 64),
ConvBlock(64, 128),
ConvBlock(128, 256),
)
Kdy a kde aplikovat BatchNorm
Standardní umístění je po lineární/konvoluční vrstvě, ale před aktivační funkcí. Tento pořádek se osvědčil v praxi, ačkoliv původní paper doporučoval opačné pořadí.
# Doporučené pořadí
x = conv_layer(x)
x = batch_norm(x)
x = activation(x)
# Pro fully connected vrstvy
class MLPBlock(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = nn.Linear(in_features, out_features)
self.bn = nn.BatchNorm1d(out_features)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.bn(self.linear(x)))
Training vs. Inference režim
BatchNorm se chová odlišně během tréninku a inference. Během tréninku používá statistiky aktuálního batche, při inferenci používá running averages:
model = YourModel()
# Training režim
model.train()
for batch in dataloader:
# BatchNorm používá statistiky aktuálního batche
output = model(batch)
# Inference režim
model.eval()
with torch.no_grad():
# BatchNorm používá uložené running averages
prediction = model(test_input)
BatchNorm vs. Layer Normalization
Layer Normalization (LayerNorm) je alternativa užitečná zejména pro sekvence a transformery:
# BatchNorm - normalizace přes batch dimension batch_norm = nn.BatchNorm1d(features) # LayerNorm - normalizace přes feature dimension layer_norm = nn.LayerNorm(features) # Rozdíl v chování: x = torch.randn(batch_size=32, features=128) # BatchNorm: μ, σ počítáno přes batch_size (32 vzorků) bn_out = batch_norm(x) # LayerNorm: μ, σ počítáno přes features (128 dimenzí) ln_out = layer_norm(x)
Praktické výhody BatchNorm
Implementace BatchNorm přináší několik konkrétních benefitů:
- Vyšší learning rates - síť je stabilnější, můžeme trénovat rychleji
- Menší závislost na inicializaci - BatchNorm "napravuje" špatnou inicializaci
- Regularizační efekt - šum z mini-batchů částečně nahrazuje dropout
- Rychlejší konvergence - typicky 2-3x méně epoch k dosažení stejné přesnosti
Možné problémy a řešení
BatchNorm má i svá úskalí. Při velmi malých batch size (< 4) může být nestabilní:
# Pro malé batche použijte GroupNorm nebo LayerNorm import torch.nn as nn # GroupNorm - rozdělí kanály do skupin group_norm = nn.GroupNorm(num_groups=8, num_channels=64) # Pro batch_size = 1 (inference) # BatchNorm automaticky přepne na stored statistiky model.eval() # Důležité!
Moderní varianty
Vývoj pokračuje novými technikami jako Batch Renormalization nebo FRN (Filter Response Normalization), které řeší specifické problémy BatchNorm v různých scénářích.
Shrnutí
Batch Normalization je klíčová technika pro trénink hlubokých sítí. Stabilizuje gradient flow, umožňuje vyšší learning rates a zrychluje konvergenci. V CNN umisťujeme BatchNorm za konvoluční vrstvu před aktivací, v RNN a transformerech preferujeme LayerNorm. Nezapomeňte přepnout model do eval() režimu při inferenci, jinak BatchNorm použije statistiky z posledního batche místo uložených running averages.