Loss funkce — přehled a výběr pro váš model
Loss funkce jsou klíčovým prvkem tréninku machine learning modelů - určují, jak model měří své chyby a učí se je opravovat. Správný výběr loss funkce může dramaticky ovlivnit výkon vašeho modelu. Projdeme si nejpoužívanější typy a ukážeme, kdy kterou použít.
Co jsou loss funkce a proč jsou klíčové
Loss funkce (ztrátové funkce) představují srdce každého machine learning modelu. Definují, jak měříme "vzdálenost" mezi predikovanými a skutečnými hodnotami, a tím přímo ovlivňují, jak se model učí. Správný výběr loss funkce může znamenat rozdíl mezi modelem, který funguje výborně, a tím, který nikdy nekonverguje.
V podstatě jde o matematickou formulaci toho, co považujeme za "chybu". Model během trénování minimalizuje tuto chybu pomocí optimalizačních algoritmů jako SGD nebo Adam. Různé typy problémů vyžadují různé loss funkce – to, co funguje pro klasifikaci, nemusí být vhodné pro regresi.
Regression loss funkce
Mean Squared Error (MSE)
MSE je nejpopulárnější loss funkcí pro regresní úlohy. Počítá průměr čtverců rozdílů mezi predikovanými a skutečnými hodnotami:
import torch
import torch.nn as nn
# PyTorch implementace
mse_loss = nn.MSELoss()
predictions = torch.tensor([2.5, 0.0, 2.1])
targets = torch.tensor([3.0, -0.5, 2.0])
loss = mse_loss(predictions, targets)
print(f"MSE Loss: {loss.item()}")
# Manuální implementace
def mse_manual(y_pred, y_true):
return torch.mean((y_pred - y_true) ** 2)
Výhody MSE: Jednoduché na implementaci, silně penalizuje velké chyby díky umocnění na druhou. Nevýhody: Citlivé na outliery, které mohou výrazně zkreslit trénování.
Mean Absolute Error (MAE)
MAE počítá průměr absolutních hodnot rozdílů. Je robustnější vůči outlierům než MSE:
mae_loss = nn.L1Loss() # L1Loss = MAE v PyTorch
loss = mae_loss(predictions, targets)
# Manuální implementace
def mae_manual(y_pred, y_true):
return torch.mean(torch.abs(y_pred - y_true))
Huber Loss
Huber Loss kombinuje výhody MSE a MAE. Chová se jako MSE pro malé chyby a jako MAE pro velké chyby:
huber_loss = nn.HuberLoss(delta=1.0)
loss = huber_loss(predictions, targets)
# Manuální implementace
def huber_loss_manual(y_pred, y_true, delta=1.0):
error = torch.abs(y_pred - y_true)
is_small_error = error <= delta
squared_loss = 0.5 * error ** 2
linear_loss = delta * error - 0.5 * delta ** 2
return torch.mean(torch.where(is_small_error, squared_loss, linear_loss))
Classification loss funkce
Cross-Entropy Loss
Cross-entropy je standardní volbou pro klasifikační úlohy. Měří "vzdálenost" mezi pravděpodobnostními distribucemi:
# Binární klasifikace binary_ce = nn.BCELoss() sigmoid_output = torch.sigmoid(torch.tensor([0.8, -1.2, 2.1])) binary_targets = torch.tensor([1.0, 0.0, 1.0]) loss = binary_ce(sigmoid_output, binary_targets) # Multi-class klasifikace ce_loss = nn.CrossEntropyLoss() logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.0, 0.2]]) targets = torch.tensor([0, 1]) # Class indices loss = ce_loss(logits, targets)
Cross-entropy má důležitou vlastnost – rychle konverguje, když je model velmi špatný, ale zpomaluje, když se blíží k optimu. To vede k stabilnímu učení.
Focal Loss
Focal Loss řeší problém nevyvážených datasetů tím, že snižuje váhu "jednoduchých" příkladů a zaměřuje se na složité případy:
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs, targets):
ce_loss = nn.functional.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
return focal_loss.mean()
# Použití
focal_loss = FocalLoss(alpha=1, gamma=2)
loss = focal_loss(logits, targets)
Pokročilé loss funkce
Contrastive Loss
Používá se v metric learning, kde chceme naučit model rozpoznávat podobnost mezi objekty:
def contrastive_loss(output1, output2, label, margin=1.0):
euclidean_distance = nn.functional.pairwise_distance(output1, output2)
loss_contrastive = torch.mean(
(1-label) * torch.pow(euclidean_distance, 2) +
label * torch.pow(torch.clamp(margin - euclidean_distance, min=0.0), 2)
)
return loss_contrastive
Dice Loss
Populární v segmentačních úlohách, zejména v medicínském imaging:
def dice_loss(inputs, targets, smooth=1):
inputs = torch.sigmoid(inputs)
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
return 1 - dice
Praktické tipy pro výběr loss funkce
Výběr správné loss funkce závisí na několika faktorech:
- Typ problému: Regrese vs. klasifikace vs. ranking
- Distribuce dat: Vyvážené vs. nevyvážené třídy
- Citlivost na outliery: MSE vs. MAE pro regresi
- Interpretabilita: Některé loss funkce mají jasný statistický význam
Pro debugging a monitoring doporučuji sledovat více metrik současně:
# Kombinace více loss funkcí pro lepší insight
class CombinedLoss(nn.Module):
def __init__(self, weights={'mse': 0.7, 'mae': 0.3}):
super().__init__()
self.weights = weights
self.mse = nn.MSELoss()
self.mae = nn.L1Loss()
def forward(self, predictions, targets):
mse_loss = self.mse(predictions, targets)
mae_loss = self.mae(predictions, targets)
total_loss = (self.weights['mse'] * mse_loss +
self.weights['mae'] * mae_loss)
return total_loss, {'mse': mse_loss.item(), 'mae': mae_loss.item()}
Nezapomeňte experimentovat! Často se vyplatí začít se standardními funkcemi (MSE pro regresi, Cross-Entropy pro klasifikaci) a postupně optimalizovat podle specifických potřeb vašeho problému.
Shrnutí
Loss funkce jsou fundamentálním stavebním kamenem machine learning modelů. Pro regresní úlohy začněte s MSE nebo MAE, pro klasifikaci s Cross-Entropy. Pokročilé funkce jako Focal Loss nebo Dice Loss řeší specifické problémy jako nevyvážené datasety nebo segmentační úlohy. Klíčem je experimentování a porozumění tomu, jak různé funkce ovlivňují chování vašeho modelu. Sledujte více metrik současně a nezapomeňte na validační data při vyhodnocování výkonu.