You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
168 lines
5.7 KiB
168 lines
5.7 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
|
|
def test_multi_loss_factory():
|
|
from mmpose.models import build_loss
|
|
|
|
# test heatmap loss
|
|
loss_cfg = dict(type='HeatmapLoss')
|
|
loss = build_loss(loss_cfg)
|
|
|
|
with pytest.raises(AssertionError):
|
|
fake_pred = torch.zeros((2, 3, 64, 64))
|
|
fake_label = torch.zeros((1, 3, 64, 64))
|
|
fake_mask = torch.zeros((1, 64, 64))
|
|
loss(fake_pred, fake_label, fake_mask)
|
|
|
|
fake_pred = torch.zeros((1, 3, 64, 64))
|
|
fake_label = torch.zeros((1, 3, 64, 64))
|
|
fake_mask = torch.zeros((1, 64, 64))
|
|
assert torch.allclose(
|
|
loss(fake_pred, fake_label, fake_mask), torch.tensor(0.))
|
|
|
|
fake_pred = torch.ones((1, 3, 64, 64))
|
|
fake_label = torch.zeros((1, 3, 64, 64))
|
|
fake_mask = torch.zeros((1, 64, 64))
|
|
assert torch.allclose(
|
|
loss(fake_pred, fake_label, fake_mask), torch.tensor(0.))
|
|
|
|
fake_pred = torch.ones((1, 3, 64, 64))
|
|
fake_label = torch.zeros((1, 3, 64, 64))
|
|
fake_mask = torch.ones((1, 64, 64))
|
|
assert torch.allclose(
|
|
loss(fake_pred, fake_label, fake_mask), torch.tensor(1.))
|
|
|
|
# test AE loss
|
|
fake_tags = torch.zeros((1, 18, 1))
|
|
fake_joints = torch.zeros((1, 3, 2, 2), dtype=torch.int)
|
|
|
|
loss_cfg = dict(type='AELoss', loss_type='exp')
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(loss(fake_tags, fake_joints)[0], torch.tensor(0.))
|
|
assert torch.allclose(loss(fake_tags, fake_joints)[1], torch.tensor(0.))
|
|
|
|
fake_tags[0, 0, 0] = 1.
|
|
fake_tags[0, 10, 0] = 0.
|
|
fake_joints[0, 0, 0, :] = torch.IntTensor((0, 1))
|
|
fake_joints[0, 0, 1, :] = torch.IntTensor((10, 1))
|
|
loss_cfg = dict(type='AELoss', loss_type='exp')
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(loss(fake_tags, fake_joints)[0], torch.tensor(0.))
|
|
assert torch.allclose(loss(fake_tags, fake_joints)[1], torch.tensor(0.25))
|
|
|
|
fake_tags[0, 0, 0] = 0
|
|
fake_tags[0, 7, 0] = 1.
|
|
fake_tags[0, 17, 0] = 1.
|
|
fake_joints[0, 1, 0, :] = torch.IntTensor((7, 1))
|
|
fake_joints[0, 1, 1, :] = torch.IntTensor((17, 1))
|
|
|
|
loss_cfg = dict(type='AELoss', loss_type='exp')
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(loss(fake_tags, fake_joints)[1], torch.tensor(0.))
|
|
|
|
loss_cfg = dict(type='AELoss', loss_type='max')
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(loss(fake_tags, fake_joints)[0], torch.tensor(0.))
|
|
|
|
with pytest.raises(ValueError):
|
|
loss_cfg = dict(type='AELoss', loss_type='min')
|
|
loss = build_loss(loss_cfg)
|
|
loss(fake_tags, fake_joints)
|
|
|
|
# test MultiLossFactory
|
|
with pytest.raises(AssertionError):
|
|
loss_cfg = dict(
|
|
type='MultiLossFactory',
|
|
num_joints=2,
|
|
num_stages=1,
|
|
ae_loss_type='exp',
|
|
with_ae_loss=True,
|
|
push_loss_factor=[0.001],
|
|
pull_loss_factor=[0.001],
|
|
with_heatmaps_loss=[True],
|
|
heatmaps_loss_factor=[1.0])
|
|
loss = build_loss(loss_cfg)
|
|
with pytest.raises(AssertionError):
|
|
loss_cfg = dict(
|
|
type='MultiLossFactory',
|
|
num_joints=2,
|
|
num_stages=1,
|
|
ae_loss_type='exp',
|
|
with_ae_loss=[True],
|
|
push_loss_factor=0.001,
|
|
pull_loss_factor=[0.001],
|
|
with_heatmaps_loss=[True],
|
|
heatmaps_loss_factor=[1.0])
|
|
loss = build_loss(loss_cfg)
|
|
with pytest.raises(AssertionError):
|
|
loss_cfg = dict(
|
|
type='MultiLossFactory',
|
|
num_joints=2,
|
|
num_stages=1,
|
|
ae_loss_type='exp',
|
|
with_ae_loss=[True],
|
|
push_loss_factor=[0.001],
|
|
pull_loss_factor=0.001,
|
|
with_heatmaps_loss=[True],
|
|
heatmaps_loss_factor=[1.0])
|
|
loss = build_loss(loss_cfg)
|
|
with pytest.raises(AssertionError):
|
|
loss_cfg = dict(
|
|
type='MultiLossFactory',
|
|
num_joints=2,
|
|
num_stages=1,
|
|
ae_loss_type='exp',
|
|
with_ae_loss=[True],
|
|
push_loss_factor=[0.001],
|
|
pull_loss_factor=[0.001],
|
|
with_heatmaps_loss=True,
|
|
heatmaps_loss_factor=[1.0])
|
|
loss = build_loss(loss_cfg)
|
|
with pytest.raises(AssertionError):
|
|
loss_cfg = dict(
|
|
type='MultiLossFactory',
|
|
num_joints=2,
|
|
num_stages=1,
|
|
ae_loss_type='exp',
|
|
with_ae_loss=[True],
|
|
push_loss_factor=[0.001],
|
|
pull_loss_factor=[0.001],
|
|
with_heatmaps_loss=[True],
|
|
heatmaps_loss_factor=1.0)
|
|
loss = build_loss(loss_cfg)
|
|
loss_cfg = dict(
|
|
type='MultiLossFactory',
|
|
num_joints=17,
|
|
num_stages=1,
|
|
ae_loss_type='exp',
|
|
with_ae_loss=[False],
|
|
push_loss_factor=[0.001],
|
|
pull_loss_factor=[0.001],
|
|
with_heatmaps_loss=[False],
|
|
heatmaps_loss_factor=[1.0])
|
|
loss = build_loss(loss_cfg)
|
|
fake_outputs = [torch.zeros((1, 34, 64, 64))]
|
|
fake_heatmaps = [torch.zeros((1, 17, 64, 64))]
|
|
fake_masks = [torch.ones((1, 64, 64))]
|
|
fake_joints = [torch.zeros((1, 30, 17, 2))]
|
|
heatmaps_losses, push_losses, pull_losses = \
|
|
loss(fake_outputs, fake_heatmaps, fake_masks, fake_joints)
|
|
assert heatmaps_losses == [None]
|
|
assert pull_losses == [None]
|
|
assert push_losses == [None]
|
|
loss_cfg = dict(
|
|
type='MultiLossFactory',
|
|
num_joints=17,
|
|
num_stages=1,
|
|
ae_loss_type='exp',
|
|
with_ae_loss=[True],
|
|
push_loss_factor=[0.001],
|
|
pull_loss_factor=[0.001],
|
|
with_heatmaps_loss=[True],
|
|
heatmaps_loss_factor=[1.0])
|
|
loss = build_loss(loss_cfg)
|
|
heatmaps_losses, push_losses, pull_losses = \
|
|
loss(fake_outputs, fake_heatmaps, fake_masks, fake_joints)
|
|
assert len(heatmaps_losses) == 1
|
|
|