Přeskočit na obsah
AI Základy

Batch Normalization — stabilní trénink sítí

8 min čtení
BatchNormLayerNormDeep Learning

Batch Normalization patří mezi nejdůležitější inovace v oblasti hlubokého učení posledních let. Tato technika normalizuje vstupy do každé vrstvy během tréninku, čímž výrazně stabilizuje a zrychluje celý proces učení neuronových sítí.

Batch Normalization — stabilní trénink sítí

Batch Normalization (BatchNorm) patří mezi nejdůležitější techniky moderního deep learningu. Vyřešila zásadní problémy s tréninkem hlubokých sítí a umožnila stavět modely s desítkami či stovkami vrstev. Pojďme si ukázat, jak funguje a proč je tak efektivní.

Problém nestability gradientů

Při tréninku hlubokých neuronových sítí narážíme na fenomén zvaný internal covariate shift. Aktivace v hlubších vrstvách se během tréninku dramaticky mění, což způsobuje:

  • Exploding/vanishing gradients
  • Pomalou konvergenci
  • Nutnost velmi opatrné inicializace vah
  • Citlivost na volbu learning rate

BatchNorm tyto problémy řeší normalizací aktivací na úrovni mini-batchů, což stabilizuje distribuci dat procházejících sítí.

Jak BatchNorm funguje

Batch Normalization aplikuje následující transformaci na každý feature v mini-batchi:

# Pro každý feature x_i v batchi
μ = mean(x_i)           # Průměr přes batch
σ² = variance(x_i)      # Rozptyl přes batch
x̂_i = (x_i - μ) / √(σ² + ε)  # Normalizace
y_i = γ * x̂_i + β      # Škálování a posun

Parametry γ (gamma) a β (beta) jsou učitelné - síť si může "naučit" optimální distribuci pro každou vrstvu.

Implementace v PyTorch

Použití BatchNorm je v PyTorch velmi jednoduché:

import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)      # BatchNorm po konvoluci, před aktivací
        x = self.relu(x)
        return x

# Použití
model = nn.Sequential(
    ConvBlock(3, 64),
    ConvBlock(64, 128),
    ConvBlock(128, 256),
)

Kdy a kde aplikovat BatchNorm

Standardní umístění je po lineární/konvoluční vrstvě, ale před aktivační funkcí. Tento pořádek se osvědčil v praxi, ačkoliv původní paper doporučoval opačné pořadí.

# Doporučené pořadí
x = conv_layer(x)
x = batch_norm(x)
x = activation(x)

# Pro fully connected vrstvy
class MLPBlock(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.bn = nn.BatchNorm1d(out_features)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        return self.relu(self.bn(self.linear(x)))

Training vs. Inference režim

BatchNorm se chová odlišně během tréninku a inference. Během tréninku používá statistiky aktuálního batche, při inferenci používá running averages:

model = YourModel()

# Training režim
model.train()
for batch in dataloader:
    # BatchNorm používá statistiky aktuálního batche
    output = model(batch)

# Inference režim
model.eval()
with torch.no_grad():
    # BatchNorm používá uložené running averages
    prediction = model(test_input)

BatchNorm vs. Layer Normalization

Layer Normalization (LayerNorm) je alternativa užitečná zejména pro sekvence a transformery:

# BatchNorm - normalizace přes batch dimension
batch_norm = nn.BatchNorm1d(features)

# LayerNorm - normalizace přes feature dimension
layer_norm = nn.LayerNorm(features)

# Rozdíl v chování:
x = torch.randn(batch_size=32, features=128)

# BatchNorm: μ, σ počítáno přes batch_size (32 vzorků)
bn_out = batch_norm(x)

# LayerNorm: μ, σ počítáno přes features (128 dimenzí)
ln_out = layer_norm(x)

Praktické výhody BatchNorm

Implementace BatchNorm přináší několik konkrétních benefitů:

  • Vyšší learning rates - síť je stabilnější, můžeme trénovat rychleji
  • Menší závislost na inicializaci - BatchNorm "napravuje" špatnou inicializaci
  • Regularizační efekt - šum z mini-batchů částečně nahrazuje dropout
  • Rychlejší konvergence - typicky 2-3x méně epoch k dosažení stejné přesnosti

Možné problémy a řešení

BatchNorm má i svá úskalí. Při velmi malých batch size (< 4) může být nestabilní:

# Pro malé batche použijte GroupNorm nebo LayerNorm
import torch.nn as nn

# GroupNorm - rozdělí kanály do skupin
group_norm = nn.GroupNorm(num_groups=8, num_channels=64)

# Pro batch_size = 1 (inference)
# BatchNorm automaticky přepne na stored statistiky
model.eval()  # Důležité!

Moderní varianty

Vývoj pokračuje novými technikami jako Batch Renormalization nebo FRN (Filter Response Normalization), které řeší specifické problémy BatchNorm v různých scénářích.

Shrnutí

Batch Normalization je klíčová technika pro trénink hlubokých sítí. Stabilizuje gradient flow, umožňuje vyšší learning rates a zrychluje konvergenci. V CNN umisťujeme BatchNorm za konvoluční vrstvu před aktivací, v RNN a transformerech preferujeme LayerNorm. Nezapomeňte přepnout model do eval() režimu při inferenci, jinak BatchNorm použije statistiky z posledního batche místo uložených running averages.

CORE SYSTEMS tým

Enterprise architekti a AI inženýři.