You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
448 lines
15 KiB
448 lines
15 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from ..builder import LOSSES
|
|
|
|
|
|
@LOSSES.register_module()
|
|
class SmoothL1Loss(nn.Module):
|
|
"""SmoothL1Loss loss.
|
|
|
|
Args:
|
|
use_target_weight (bool): Option to use weighted MSE loss.
|
|
Different joint types may have different target weights.
|
|
loss_weight (float): Weight of the loss. Default: 1.0.
|
|
"""
|
|
|
|
def __init__(self, use_target_weight=False, loss_weight=1.):
|
|
super().__init__()
|
|
self.criterion = F.smooth_l1_loss
|
|
self.use_target_weight = use_target_weight
|
|
self.loss_weight = loss_weight
|
|
|
|
def forward(self, output, target, target_weight=None):
|
|
"""Forward function.
|
|
|
|
Note:
|
|
- batch_size: N
|
|
- num_keypoints: K
|
|
- dimension of keypoints: D (D=2 or D=3)
|
|
|
|
Args:
|
|
output (torch.Tensor[N, K, D]): Output regression.
|
|
target (torch.Tensor[N, K, D]): Target regression.
|
|
target_weight (torch.Tensor[N, K, D]):
|
|
Weights across different joint types.
|
|
"""
|
|
if self.use_target_weight:
|
|
assert target_weight is not None
|
|
loss = self.criterion(output * target_weight,
|
|
target * target_weight)
|
|
else:
|
|
loss = self.criterion(output, target)
|
|
|
|
return loss * self.loss_weight
|
|
|
|
|
|
@LOSSES.register_module()
|
|
class WingLoss(nn.Module):
|
|
"""Wing Loss. paper ref: 'Wing Loss for Robust Facial Landmark Localisation
|
|
with Convolutional Neural Networks' Feng et al. CVPR'2018.
|
|
|
|
Args:
|
|
omega (float): Also referred to as width.
|
|
epsilon (float): Also referred to as curvature.
|
|
use_target_weight (bool): Option to use weighted MSE loss.
|
|
Different joint types may have different target weights.
|
|
loss_weight (float): Weight of the loss. Default: 1.0.
|
|
"""
|
|
|
|
def __init__(self,
|
|
omega=10.0,
|
|
epsilon=2.0,
|
|
use_target_weight=False,
|
|
loss_weight=1.):
|
|
super().__init__()
|
|
self.omega = omega
|
|
self.epsilon = epsilon
|
|
self.use_target_weight = use_target_weight
|
|
self.loss_weight = loss_weight
|
|
|
|
# constant that smoothly links the piecewise-defined linear
|
|
# and nonlinear parts
|
|
self.C = self.omega * (1.0 - math.log(1.0 + self.omega / self.epsilon))
|
|
|
|
def criterion(self, pred, target):
|
|
"""Criterion of wingloss.
|
|
|
|
Note:
|
|
- batch_size: N
|
|
- num_keypoints: K
|
|
- dimension of keypoints: D (D=2 or D=3)
|
|
|
|
Args:
|
|
pred (torch.Tensor[N, K, D]): Output regression.
|
|
target (torch.Tensor[N, K, D]): Target regression.
|
|
"""
|
|
delta = (target - pred).abs()
|
|
losses = torch.where(
|
|
delta < self.omega,
|
|
self.omega * torch.log(1.0 + delta / self.epsilon), delta - self.C)
|
|
return torch.mean(torch.sum(losses, dim=[1, 2]), dim=0)
|
|
|
|
def forward(self, output, target, target_weight=None):
|
|
"""Forward function.
|
|
|
|
Note:
|
|
- batch_size: N
|
|
- num_keypoints: K
|
|
- dimension of keypoints: D (D=2 or D=3)
|
|
|
|
Args:
|
|
output (torch.Tensor[N, K, D]): Output regression.
|
|
target (torch.Tensor[N, K, D]): Target regression.
|
|
target_weight (torch.Tensor[N,K,D]):
|
|
Weights across different joint types.
|
|
"""
|
|
if self.use_target_weight:
|
|
assert target_weight is not None
|
|
loss = self.criterion(output * target_weight,
|
|
target * target_weight)
|
|
else:
|
|
loss = self.criterion(output, target)
|
|
|
|
return loss * self.loss_weight
|
|
|
|
|
|
@LOSSES.register_module()
|
|
class SoftWingLoss(nn.Module):
|
|
"""Soft Wing Loss 'Structure-Coherent Deep Feature Learning for Robust Face
|
|
Alignment' Lin et al. TIP'2021.
|
|
|
|
loss =
|
|
1. |x| , if |x| < omega1
|
|
2. omega2*ln(1+|x|/epsilon) + B, if |x| >= omega1
|
|
|
|
Args:
|
|
omega1 (float): The first threshold.
|
|
omega2 (float): The second threshold.
|
|
epsilon (float): Also referred to as curvature.
|
|
use_target_weight (bool): Option to use weighted MSE loss.
|
|
Different joint types may have different target weights.
|
|
loss_weight (float): Weight of the loss. Default: 1.0.
|
|
"""
|
|
|
|
def __init__(self,
|
|
omega1=2.0,
|
|
omega2=20.0,
|
|
epsilon=0.5,
|
|
use_target_weight=False,
|
|
loss_weight=1.):
|
|
super().__init__()
|
|
self.omega1 = omega1
|
|
self.omega2 = omega2
|
|
self.epsilon = epsilon
|
|
self.use_target_weight = use_target_weight
|
|
self.loss_weight = loss_weight
|
|
|
|
# constant that smoothly links the piecewise-defined linear
|
|
# and nonlinear parts
|
|
self.B = self.omega1 - self.omega2 * math.log(1.0 + self.omega1 /
|
|
self.epsilon)
|
|
|
|
def criterion(self, pred, target):
|
|
"""Criterion of wingloss.
|
|
|
|
Note:
|
|
batch_size: N
|
|
num_keypoints: K
|
|
dimension of keypoints: D (D=2 or D=3)
|
|
|
|
Args:
|
|
pred (torch.Tensor[N, K, D]): Output regression.
|
|
target (torch.Tensor[N, K, D]): Target regression.
|
|
"""
|
|
delta = (target - pred).abs()
|
|
losses = torch.where(
|
|
delta < self.omega1, delta,
|
|
self.omega2 * torch.log(1.0 + delta / self.epsilon) + self.B)
|
|
return torch.mean(torch.sum(losses, dim=[1, 2]), dim=0)
|
|
|
|
def forward(self, output, target, target_weight=None):
|
|
"""Forward function.
|
|
|
|
Note:
|
|
batch_size: N
|
|
num_keypoints: K
|
|
dimension of keypoints: D (D=2 or D=3)
|
|
|
|
Args:
|
|
output (torch.Tensor[N, K, D]): Output regression.
|
|
target (torch.Tensor[N, K, D]): Target regression.
|
|
target_weight (torch.Tensor[N, K, D]):
|
|
Weights across different joint types.
|
|
"""
|
|
if self.use_target_weight:
|
|
assert target_weight is not None
|
|
loss = self.criterion(output * target_weight,
|
|
target * target_weight)
|
|
else:
|
|
loss = self.criterion(output, target)
|
|
|
|
return loss * self.loss_weight
|
|
|
|
|
|
@LOSSES.register_module()
|
|
class MPJPELoss(nn.Module):
|
|
"""MPJPE (Mean Per Joint Position Error) loss.
|
|
|
|
Args:
|
|
use_target_weight (bool): Option to use weighted MSE loss.
|
|
Different joint types may have different target weights.
|
|
loss_weight (float): Weight of the loss. Default: 1.0.
|
|
"""
|
|
|
|
def __init__(self, use_target_weight=False, loss_weight=1.):
|
|
super().__init__()
|
|
self.use_target_weight = use_target_weight
|
|
self.loss_weight = loss_weight
|
|
|
|
def forward(self, output, target, target_weight=None):
|
|
"""Forward function.
|
|
|
|
Note:
|
|
- batch_size: N
|
|
- num_keypoints: K
|
|
- dimension of keypoints: D (D=2 or D=3)
|
|
|
|
Args:
|
|
output (torch.Tensor[N, K, D]): Output regression.
|
|
target (torch.Tensor[N, K, D]): Target regression.
|
|
target_weight (torch.Tensor[N,K,D]):
|
|
Weights across different joint types.
|
|
"""
|
|
|
|
if self.use_target_weight:
|
|
assert target_weight is not None
|
|
loss = torch.mean(
|
|
torch.norm((output - target) * target_weight, dim=-1))
|
|
else:
|
|
loss = torch.mean(torch.norm(output - target, dim=-1))
|
|
|
|
return loss * self.loss_weight
|
|
|
|
|
|
@LOSSES.register_module()
|
|
class L1Loss(nn.Module):
|
|
"""L1Loss loss ."""
|
|
|
|
def __init__(self, use_target_weight=False, loss_weight=1.):
|
|
super().__init__()
|
|
self.criterion = F.l1_loss
|
|
self.use_target_weight = use_target_weight
|
|
self.loss_weight = loss_weight
|
|
|
|
def forward(self, output, target, target_weight=None):
|
|
"""Forward function.
|
|
|
|
Note:
|
|
- batch_size: N
|
|
- num_keypoints: K
|
|
|
|
Args:
|
|
output (torch.Tensor[N, K, 2]): Output regression.
|
|
target (torch.Tensor[N, K, 2]): Target regression.
|
|
target_weight (torch.Tensor[N, K, 2]):
|
|
Weights across different joint types.
|
|
"""
|
|
if self.use_target_weight:
|
|
assert target_weight is not None
|
|
loss = self.criterion(output * target_weight,
|
|
target * target_weight)
|
|
else:
|
|
loss = self.criterion(output, target)
|
|
|
|
return loss * self.loss_weight
|
|
|
|
|
|
@LOSSES.register_module()
|
|
class MSELoss(nn.Module):
|
|
"""MSE loss for coordinate regression."""
|
|
|
|
def __init__(self, use_target_weight=False, loss_weight=1.):
|
|
super().__init__()
|
|
self.criterion = F.mse_loss
|
|
self.use_target_weight = use_target_weight
|
|
self.loss_weight = loss_weight
|
|
|
|
def forward(self, output, target, target_weight=None):
|
|
"""Forward function.
|
|
|
|
Note:
|
|
- batch_size: N
|
|
- num_keypoints: K
|
|
|
|
Args:
|
|
output (torch.Tensor[N, K, 2]): Output regression.
|
|
target (torch.Tensor[N, K, 2]): Target regression.
|
|
target_weight (torch.Tensor[N, K, 2]):
|
|
Weights across different joint types.
|
|
"""
|
|
if self.use_target_weight:
|
|
assert target_weight is not None
|
|
loss = self.criterion(output * target_weight,
|
|
target * target_weight)
|
|
else:
|
|
loss = self.criterion(output, target)
|
|
|
|
return loss * self.loss_weight
|
|
|
|
|
|
@LOSSES.register_module()
|
|
class BoneLoss(nn.Module):
|
|
"""Bone length loss.
|
|
|
|
Args:
|
|
joint_parents (list): Indices of each joint's parent joint.
|
|
use_target_weight (bool): Option to use weighted bone loss.
|
|
Different bone types may have different target weights.
|
|
loss_weight (float): Weight of the loss. Default: 1.0.
|
|
"""
|
|
|
|
def __init__(self, joint_parents, use_target_weight=False, loss_weight=1.):
|
|
super().__init__()
|
|
self.joint_parents = joint_parents
|
|
self.use_target_weight = use_target_weight
|
|
self.loss_weight = loss_weight
|
|
|
|
self.non_root_indices = []
|
|
for i in range(len(self.joint_parents)):
|
|
if i != self.joint_parents[i]:
|
|
self.non_root_indices.append(i)
|
|
|
|
def forward(self, output, target, target_weight=None):
|
|
"""Forward function.
|
|
|
|
Note:
|
|
- batch_size: N
|
|
- num_keypoints: K
|
|
- dimension of keypoints: D (D=2 or D=3)
|
|
|
|
Args:
|
|
output (torch.Tensor[N, K, D]): Output regression.
|
|
target (torch.Tensor[N, K, D]): Target regression.
|
|
target_weight (torch.Tensor[N, K-1]):
|
|
Weights across different bone types.
|
|
"""
|
|
output_bone = torch.norm(
|
|
output - output[:, self.joint_parents, :],
|
|
dim=-1)[:, self.non_root_indices]
|
|
target_bone = torch.norm(
|
|
target - target[:, self.joint_parents, :],
|
|
dim=-1)[:, self.non_root_indices]
|
|
if self.use_target_weight:
|
|
assert target_weight is not None
|
|
loss = torch.mean(
|
|
torch.abs((output_bone * target_weight).mean(dim=0) -
|
|
(target_bone * target_weight).mean(dim=0)))
|
|
else:
|
|
loss = torch.mean(
|
|
torch.abs(output_bone.mean(dim=0) - target_bone.mean(dim=0)))
|
|
|
|
return loss * self.loss_weight
|
|
|
|
|
|
@LOSSES.register_module()
|
|
class SemiSupervisionLoss(nn.Module):
|
|
"""Semi-supervision loss for unlabeled data. It is composed of projection
|
|
loss and bone loss.
|
|
|
|
Paper ref: `3D human pose estimation in video with temporal convolutions
|
|
and semi-supervised training` Dario Pavllo et al. CVPR'2019.
|
|
|
|
Args:
|
|
joint_parents (list): Indices of each joint's parent joint.
|
|
projection_loss_weight (float): Weight for projection loss.
|
|
bone_loss_weight (float): Weight for bone loss.
|
|
warmup_iterations (int): Number of warmup iterations. In the first
|
|
`warmup_iterations` iterations, the model is trained only on
|
|
labeled data, and semi-supervision loss will be 0.
|
|
This is a workaround since currently we cannot access
|
|
epoch number in loss functions. Note that the iteration number in
|
|
an epoch can be changed due to different GPU numbers in multi-GPU
|
|
settings. So please set this parameter carefully.
|
|
warmup_iterations = dataset_size // samples_per_gpu // gpu_num
|
|
* warmup_epochs
|
|
"""
|
|
|
|
def __init__(self,
|
|
joint_parents,
|
|
projection_loss_weight=1.,
|
|
bone_loss_weight=1.,
|
|
warmup_iterations=0):
|
|
super().__init__()
|
|
self.criterion_projection = MPJPELoss(
|
|
loss_weight=projection_loss_weight)
|
|
self.criterion_bone = BoneLoss(
|
|
joint_parents, loss_weight=bone_loss_weight)
|
|
self.warmup_iterations = warmup_iterations
|
|
self.num_iterations = 0
|
|
|
|
@staticmethod
|
|
def project_joints(x, intrinsics):
|
|
"""Project 3D joint coordinates to 2D image plane using camera
|
|
intrinsic parameters.
|
|
|
|
Args:
|
|
x (torch.Tensor[N, K, 3]): 3D joint coordinates.
|
|
intrinsics (torch.Tensor[N, 4] | torch.Tensor[N, 9]): Camera
|
|
intrinsics: f (2), c (2), k (3), p (2).
|
|
"""
|
|
while intrinsics.dim() < x.dim():
|
|
intrinsics.unsqueeze_(1)
|
|
f = intrinsics[..., :2]
|
|
c = intrinsics[..., 2:4]
|
|
_x = torch.clamp(x[:, :, :2] / x[:, :, 2:], -1, 1)
|
|
if intrinsics.shape[-1] == 9:
|
|
k = intrinsics[..., 4:7]
|
|
p = intrinsics[..., 7:9]
|
|
|
|
r2 = torch.sum(_x[:, :, :2]**2, dim=-1, keepdim=True)
|
|
radial = 1 + torch.sum(
|
|
k * torch.cat((r2, r2**2, r2**3), dim=-1),
|
|
dim=-1,
|
|
keepdim=True)
|
|
tan = torch.sum(p * _x, dim=-1, keepdim=True)
|
|
_x = _x * (radial + tan) + p * r2
|
|
_x = f * _x + c
|
|
return _x
|
|
|
|
def forward(self, output, target):
|
|
losses = dict()
|
|
|
|
self.num_iterations += 1
|
|
if self.num_iterations <= self.warmup_iterations:
|
|
return losses
|
|
|
|
labeled_pose = output['labeled_pose']
|
|
unlabeled_pose = output['unlabeled_pose']
|
|
unlabeled_traj = output['unlabeled_traj']
|
|
unlabeled_target_2d = target['unlabeled_target_2d']
|
|
intrinsics = target['intrinsics']
|
|
|
|
# projection loss
|
|
unlabeled_output = unlabeled_pose + unlabeled_traj
|
|
unlabeled_output_2d = self.project_joints(unlabeled_output, intrinsics)
|
|
loss_proj = self.criterion_projection(unlabeled_output_2d,
|
|
unlabeled_target_2d, None)
|
|
losses['proj_loss'] = loss_proj
|
|
|
|
# bone loss
|
|
loss_bone = self.criterion_bone(unlabeled_pose, labeled_pose, None)
|
|
losses['bone_loss'] = loss_bone
|
|
|
|
return losses
|
|
|