You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
86 lines
2.7 KiB
86 lines
2.7 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from ..builder import LOSSES
|
|
|
|
|
|
@LOSSES.register_module()
|
|
class AdaptiveWingLoss(nn.Module):
|
|
"""Adaptive wing loss. paper ref: 'Adaptive Wing Loss for Robust Face
|
|
Alignment via Heatmap Regression' Wang et al. ICCV'2019.
|
|
|
|
Args:
|
|
alpha (float), omega (float), epsilon (float), theta (float)
|
|
are hyper-parameters.
|
|
use_target_weight (bool): Option to use weighted MSE loss.
|
|
Different joint types may have different target weights.
|
|
loss_weight (float): Weight of the loss. Default: 1.0.
|
|
"""
|
|
|
|
def __init__(self,
|
|
alpha=2.1,
|
|
omega=14,
|
|
epsilon=1,
|
|
theta=0.5,
|
|
use_target_weight=False,
|
|
loss_weight=1.):
|
|
super().__init__()
|
|
self.alpha = float(alpha)
|
|
self.omega = float(omega)
|
|
self.epsilon = float(epsilon)
|
|
self.theta = float(theta)
|
|
self.use_target_weight = use_target_weight
|
|
self.loss_weight = loss_weight
|
|
|
|
def criterion(self, pred, target):
|
|
"""Criterion of wingloss.
|
|
|
|
Note:
|
|
batch_size: N
|
|
num_keypoints: K
|
|
|
|
Args:
|
|
pred (torch.Tensor[NxKxHxW]): Predicted heatmaps.
|
|
target (torch.Tensor[NxKxHxW]): Target heatmaps.
|
|
"""
|
|
H, W = pred.shape[2:4]
|
|
delta = (target - pred).abs()
|
|
|
|
A = self.omega * (
|
|
1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - target))
|
|
) * (self.alpha - target) * (torch.pow(
|
|
self.theta / self.epsilon,
|
|
self.alpha - target - 1)) * (1 / self.epsilon)
|
|
C = self.theta * A - self.omega * torch.log(
|
|
1 + torch.pow(self.theta / self.epsilon, self.alpha - target))
|
|
|
|
losses = torch.where(
|
|
delta < self.theta,
|
|
self.omega *
|
|
torch.log(1 +
|
|
torch.pow(delta / self.epsilon, self.alpha - target)),
|
|
A * delta - C)
|
|
|
|
return torch.mean(losses)
|
|
|
|
def forward(self, output, target, target_weight):
|
|
"""Forward function.
|
|
|
|
Note:
|
|
batch_size: N
|
|
num_keypoints: K
|
|
|
|
Args:
|
|
output (torch.Tensor[NxKxHxW]): Output heatmaps.
|
|
target (torch.Tensor[NxKxHxW]): Target heatmaps.
|
|
target_weight (torch.Tensor[NxKx1]):
|
|
Weights across different joint types.
|
|
"""
|
|
if self.use_target_weight:
|
|
loss = self.criterion(output * target_weight.unsqueeze(-1),
|
|
target * target_weight.unsqueeze(-1))
|
|
else:
|
|
loss = self.criterion(output, target)
|
|
|
|
return loss * self.loss_weight
|
|
|