You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
185 lines
6.1 KiB
185 lines
6.1 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
|
|
from mmpose.models import build_loss
|
|
|
|
|
|
def test_smooth_l1_loss():
|
|
# test SmoothL1Loss without target weight(default None)
|
|
loss_cfg = dict(type='SmoothL1Loss')
|
|
loss = build_loss(loss_cfg)
|
|
|
|
fake_pred = torch.zeros((1, 3, 2))
|
|
fake_label = torch.zeros((1, 3, 2))
|
|
assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.))
|
|
|
|
fake_pred = torch.ones((1, 3, 2))
|
|
fake_label = torch.zeros((1, 3, 2))
|
|
assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(.5))
|
|
|
|
# test SmoothL1Loss with target weight
|
|
loss_cfg = dict(type='SmoothL1Loss', use_target_weight=True)
|
|
loss = build_loss(loss_cfg)
|
|
|
|
fake_pred = torch.zeros((1, 3, 2))
|
|
fake_label = torch.zeros((1, 3, 2))
|
|
assert torch.allclose(
|
|
loss(fake_pred, fake_label, torch.ones_like(fake_label)),
|
|
torch.tensor(0.))
|
|
|
|
fake_pred = torch.ones((1, 3, 2))
|
|
fake_label = torch.zeros((1, 3, 2))
|
|
assert torch.allclose(
|
|
loss(fake_pred, fake_label, torch.ones_like(fake_label)),
|
|
torch.tensor(.5))
|
|
|
|
|
|
def test_wing_loss():
|
|
# test WingLoss without target weight(default None)
|
|
loss_cfg = dict(type='WingLoss')
|
|
loss = build_loss(loss_cfg)
|
|
|
|
fake_pred = torch.zeros((1, 3, 2))
|
|
fake_label = torch.zeros((1, 3, 2))
|
|
assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.))
|
|
|
|
fake_pred = torch.ones((1, 3, 2))
|
|
fake_label = torch.zeros((1, 3, 2))
|
|
assert torch.gt(loss(fake_pred, fake_label), torch.tensor(.5))
|
|
|
|
# test WingLoss with target weight
|
|
loss_cfg = dict(type='WingLoss', use_target_weight=True)
|
|
loss = build_loss(loss_cfg)
|
|
|
|
fake_pred = torch.zeros((1, 3, 2))
|
|
fake_label = torch.zeros((1, 3, 2))
|
|
assert torch.allclose(
|
|
loss(fake_pred, fake_label, torch.ones_like(fake_label)),
|
|
torch.tensor(0.))
|
|
|
|
fake_pred = torch.ones((1, 3, 2))
|
|
fake_label = torch.zeros((1, 3, 2))
|
|
assert torch.gt(
|
|
loss(fake_pred, fake_label, torch.ones_like(fake_label)),
|
|
torch.tensor(.5))
|
|
|
|
|
|
def test_soft_wing_loss():
|
|
# test SoftWingLoss without target weight(default None)
|
|
loss_cfg = dict(type='SoftWingLoss')
|
|
loss = build_loss(loss_cfg)
|
|
|
|
fake_pred = torch.zeros((1, 3, 2))
|
|
fake_label = torch.zeros((1, 3, 2))
|
|
assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.))
|
|
|
|
fake_pred = torch.ones((1, 3, 2))
|
|
fake_label = torch.zeros((1, 3, 2))
|
|
assert torch.gt(loss(fake_pred, fake_label), torch.tensor(.5))
|
|
|
|
# test SoftWingLoss with target weight
|
|
loss_cfg = dict(type='SoftWingLoss', use_target_weight=True)
|
|
loss = build_loss(loss_cfg)
|
|
|
|
fake_pred = torch.zeros((1, 3, 2))
|
|
fake_label = torch.zeros((1, 3, 2))
|
|
assert torch.allclose(
|
|
loss(fake_pred, fake_label, torch.ones_like(fake_label)),
|
|
torch.tensor(0.))
|
|
|
|
fake_pred = torch.ones((1, 3, 2))
|
|
fake_label = torch.zeros((1, 3, 2))
|
|
assert torch.gt(
|
|
loss(fake_pred, fake_label, torch.ones_like(fake_label)),
|
|
torch.tensor(.5))
|
|
|
|
|
|
def test_mse_regression_loss():
|
|
# w/o target weight(default None)
|
|
loss_cfg = dict(type='MSELoss')
|
|
loss = build_loss(loss_cfg)
|
|
fake_pred = torch.zeros((1, 3, 3))
|
|
fake_label = torch.zeros((1, 3, 3))
|
|
assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.))
|
|
|
|
fake_pred = torch.ones((1, 3, 3))
|
|
fake_label = torch.zeros((1, 3, 3))
|
|
assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(1.))
|
|
|
|
# w/ target weight
|
|
loss_cfg = dict(type='MSELoss', use_target_weight=True)
|
|
loss = build_loss(loss_cfg)
|
|
fake_pred = torch.zeros((1, 3, 3))
|
|
fake_label = torch.zeros((1, 3, 3))
|
|
assert torch.allclose(
|
|
loss(fake_pred, fake_label, torch.ones_like(fake_label)),
|
|
torch.tensor(0.))
|
|
|
|
fake_pred = torch.ones((1, 3, 3))
|
|
fake_label = torch.zeros((1, 3, 3))
|
|
assert torch.allclose(
|
|
loss(fake_pred, fake_label, torch.ones_like(fake_label)),
|
|
torch.tensor(1.))
|
|
|
|
|
|
def test_bone_loss():
|
|
# w/o target weight(default None)
|
|
loss_cfg = dict(type='BoneLoss', joint_parents=[0, 0, 1])
|
|
loss = build_loss(loss_cfg)
|
|
fake_pred = torch.zeros((1, 3, 3))
|
|
fake_label = torch.zeros((1, 3, 3))
|
|
assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.))
|
|
|
|
fake_pred = torch.tensor([[[0, 0, 0], [1, 1, 1], [2, 2, 2]]],
|
|
dtype=torch.float32)
|
|
fake_label = fake_pred * 2
|
|
assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(3**0.5))
|
|
|
|
# w/ target weight
|
|
loss_cfg = dict(
|
|
type='BoneLoss', joint_parents=[0, 0, 1], use_target_weight=True)
|
|
loss = build_loss(loss_cfg)
|
|
fake_pred = torch.zeros((1, 3, 3))
|
|
fake_label = torch.zeros((1, 3, 3))
|
|
fake_weight = torch.ones((1, 2))
|
|
assert torch.allclose(
|
|
loss(fake_pred, fake_label, fake_weight), torch.tensor(0.))
|
|
|
|
fake_pred = torch.tensor([[[0, 0, 0], [1, 1, 1], [2, 2, 2]]],
|
|
dtype=torch.float32)
|
|
fake_label = fake_pred * 2
|
|
fake_weight = torch.ones((1, 2))
|
|
assert torch.allclose(
|
|
loss(fake_pred, fake_label, fake_weight), torch.tensor(3**0.5))
|
|
|
|
|
|
def test_semi_supervision_loss():
|
|
loss_cfg = dict(
|
|
type='SemiSupervisionLoss',
|
|
joint_parents=[0, 0, 1],
|
|
warmup_iterations=1)
|
|
loss = build_loss(loss_cfg)
|
|
|
|
unlabeled_pose = torch.rand((1, 3, 3))
|
|
unlabeled_traj = torch.ones((1, 1, 3))
|
|
labeled_pose = unlabeled_pose.clone()
|
|
fake_pred = dict(
|
|
labeled_pose=labeled_pose,
|
|
unlabeled_pose=unlabeled_pose,
|
|
unlabeled_traj=unlabeled_traj)
|
|
|
|
intrinsics = torch.tensor([[1, 1, 1, 1, 0.1, 0.1, 0.1, 0, 0]],
|
|
dtype=torch.float32)
|
|
unlabled_target_2d = loss.project_joints(unlabeled_pose + unlabeled_traj,
|
|
intrinsics)
|
|
fake_label = dict(
|
|
unlabeled_target_2d=unlabled_target_2d, intrinsics=intrinsics)
|
|
|
|
# test warmup
|
|
losses = loss(fake_pred, fake_label)
|
|
assert not losses
|
|
|
|
# test semi-supervised loss
|
|
losses = loss(fake_pred, fake_label)
|
|
assert torch.allclose(losses['proj_loss'], torch.tensor(0.))
|
|
assert torch.allclose(losses['bone_loss'], torch.tensor(0.))
|
|
|