You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
98 lines
3.5 KiB
98 lines
3.5 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmpose.models import build_loss
|
|
|
|
|
|
def test_adaptive_wing_loss():
|
|
# test Adaptive WingLoss without target weight
|
|
loss_cfg = dict(type='AdaptiveWingLoss')
|
|
loss = build_loss(loss_cfg)
|
|
|
|
fake_pred = torch.zeros((1, 3, 64, 64))
|
|
fake_label = torch.zeros((1, 3, 64, 64))
|
|
assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.))
|
|
|
|
# test WingLoss with target weight
|
|
loss_cfg = dict(type='AdaptiveWingLoss', use_target_weight=True)
|
|
loss = build_loss(loss_cfg)
|
|
|
|
fake_pred = torch.ones((1, 3, 64, 64))
|
|
fake_label = torch.ones((1, 3, 64, 64))
|
|
assert torch.allclose(
|
|
loss(fake_pred, fake_label, torch.ones([1, 3, 1])), torch.tensor(0.))
|
|
|
|
|
|
def test_mse_loss():
|
|
# test MSE loss without target weight
|
|
loss_cfg = dict(type='JointsMSELoss')
|
|
loss = build_loss(loss_cfg)
|
|
|
|
fake_pred = torch.zeros((1, 3, 64, 64))
|
|
fake_label = torch.zeros((1, 3, 64, 64))
|
|
assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.))
|
|
|
|
fake_pred = torch.ones((1, 3, 64, 64))
|
|
fake_label = torch.zeros((1, 3, 64, 64))
|
|
assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(1.))
|
|
|
|
fake_pred = torch.zeros((1, 2, 64, 64))
|
|
fake_pred[0, 0] += 1
|
|
fake_label = torch.zeros((1, 2, 64, 64))
|
|
assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.5))
|
|
|
|
with pytest.raises(ValueError):
|
|
loss_cfg = dict(type='JointsOHKMMSELoss')
|
|
loss = build_loss(loss_cfg)
|
|
fake_pred = torch.zeros((1, 3, 64, 64))
|
|
fake_label = torch.zeros((1, 3, 64, 64))
|
|
assert torch.allclose(
|
|
loss(fake_pred, fake_label, None), torch.tensor(0.))
|
|
|
|
with pytest.raises(AssertionError):
|
|
loss_cfg = dict(type='JointsOHKMMSELoss', topk=-1)
|
|
loss = build_loss(loss_cfg)
|
|
fake_pred = torch.zeros((1, 3, 64, 64))
|
|
fake_label = torch.zeros((1, 3, 64, 64))
|
|
assert torch.allclose(
|
|
loss(fake_pred, fake_label, None), torch.tensor(0.))
|
|
|
|
loss_cfg = dict(type='JointsOHKMMSELoss', topk=2)
|
|
loss = build_loss(loss_cfg)
|
|
fake_pred = torch.ones((1, 3, 64, 64))
|
|
fake_label = torch.zeros((1, 3, 64, 64))
|
|
assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(1.))
|
|
|
|
loss_cfg = dict(type='JointsOHKMMSELoss', topk=2)
|
|
loss = build_loss(loss_cfg)
|
|
fake_pred = torch.zeros((1, 3, 64, 64))
|
|
fake_pred[0, 0] += 1
|
|
fake_label = torch.zeros((1, 3, 64, 64))
|
|
assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.5))
|
|
|
|
loss_cfg = dict(type='CombinedTargetMSELoss', use_target_weight=True)
|
|
loss = build_loss(loss_cfg)
|
|
fake_pred = torch.ones((1, 3, 64, 64))
|
|
fake_label = torch.zeros((1, 3, 64, 64))
|
|
target_weight = torch.ones((1, 1, 1))
|
|
assert torch.allclose(
|
|
loss(fake_pred, fake_label, target_weight), torch.tensor(0.5))
|
|
|
|
loss_cfg = dict(type='CombinedTargetMSELoss', use_target_weight=True)
|
|
loss = build_loss(loss_cfg)
|
|
fake_pred = torch.ones((1, 3, 64, 64))
|
|
fake_label = torch.zeros((1, 3, 64, 64))
|
|
target_weight = torch.zeros((1, 1, 1))
|
|
assert torch.allclose(
|
|
loss(fake_pred, fake_label, target_weight), torch.tensor(0.))
|
|
|
|
|
|
def test_smoothl1_loss():
|
|
# test MSE loss without target weight
|
|
loss_cfg = dict(type='SmoothL1Loss')
|
|
loss = build_loss(loss_cfg)
|
|
|
|
fake_pred = torch.zeros((1, 3))
|
|
fake_label = torch.zeros((1, 3))
|
|
assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.))
|
|
|