You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
163 lines
5.7 KiB
163 lines
5.7 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
from numpy.testing import assert_almost_equal
|
|
|
|
from mmpose.models import build_loss
|
|
from mmpose.models.utils.geometry import batch_rodrigues
|
|
|
|
|
|
def test_mesh_loss():
|
|
"""test mesh loss."""
|
|
loss_cfg = dict(
|
|
type='MeshLoss',
|
|
joints_2d_loss_weight=1,
|
|
joints_3d_loss_weight=1,
|
|
vertex_loss_weight=1,
|
|
smpl_pose_loss_weight=1,
|
|
smpl_beta_loss_weight=1,
|
|
img_res=256,
|
|
focal_length=5000)
|
|
|
|
loss = build_loss(loss_cfg)
|
|
|
|
smpl_pose = torch.zeros([1, 72], dtype=torch.float32)
|
|
smpl_rotmat = batch_rodrigues(smpl_pose.view(-1, 3)).view(-1, 24, 3, 3)
|
|
smpl_beta = torch.zeros([1, 10], dtype=torch.float32)
|
|
camera = torch.tensor([[1, 0, 0]], dtype=torch.float32)
|
|
vertices = torch.rand([1, 6890, 3], dtype=torch.float32)
|
|
joints_3d = torch.ones([1, 24, 3], dtype=torch.float32)
|
|
joints_2d = loss.project_points(joints_3d, camera) + (256 - 1) / 2
|
|
|
|
fake_pred = {}
|
|
fake_pred['pose'] = smpl_rotmat
|
|
fake_pred['beta'] = smpl_beta
|
|
fake_pred['camera'] = camera
|
|
fake_pred['vertices'] = vertices
|
|
fake_pred['joints_3d'] = joints_3d
|
|
|
|
fake_gt = {}
|
|
fake_gt['pose'] = smpl_pose
|
|
fake_gt['beta'] = smpl_beta
|
|
fake_gt['vertices'] = vertices
|
|
fake_gt['has_smpl'] = torch.ones(1, dtype=torch.float32)
|
|
fake_gt['joints_3d'] = joints_3d
|
|
fake_gt['joints_3d_visible'] = torch.ones([1, 24, 1], dtype=torch.float32)
|
|
fake_gt['joints_2d'] = joints_2d
|
|
fake_gt['joints_2d_visible'] = torch.ones([1, 24, 1], dtype=torch.float32)
|
|
|
|
losses = loss(fake_pred, fake_gt)
|
|
assert torch.allclose(losses['vertex_loss'], torch.tensor(0.))
|
|
assert torch.allclose(losses['smpl_pose_loss'], torch.tensor(0.))
|
|
assert torch.allclose(losses['smpl_beta_loss'], torch.tensor(0.))
|
|
assert torch.allclose(losses['joints_3d_loss'], torch.tensor(0.))
|
|
assert torch.allclose(losses['joints_2d_loss'], torch.tensor(0.))
|
|
|
|
fake_pred = {}
|
|
fake_pred['pose'] = smpl_rotmat + 1
|
|
fake_pred['beta'] = smpl_beta + 1
|
|
fake_pred['camera'] = camera
|
|
fake_pred['vertices'] = vertices + 1
|
|
fake_pred['joints_3d'] = joints_3d.clone()
|
|
|
|
joints_3d_t = joints_3d.clone()
|
|
joints_3d_t[:, 0] = joints_3d_t[:, 0] + 1
|
|
fake_gt = {}
|
|
fake_gt['pose'] = smpl_pose
|
|
fake_gt['beta'] = smpl_beta
|
|
fake_gt['vertices'] = vertices
|
|
fake_gt['has_smpl'] = torch.ones(1, dtype=torch.float32)
|
|
fake_gt['joints_3d'] = joints_3d_t
|
|
fake_gt['joints_3d_visible'] = torch.ones([1, 24, 1], dtype=torch.float32)
|
|
fake_gt['joints_2d'] = joints_2d + (256 - 1) / 2
|
|
fake_gt['joints_2d_visible'] = torch.ones([1, 24, 1], dtype=torch.float32)
|
|
|
|
losses = loss(fake_pred, fake_gt)
|
|
assert torch.allclose(losses['vertex_loss'], torch.tensor(1.))
|
|
assert torch.allclose(losses['smpl_pose_loss'], torch.tensor(1.))
|
|
assert torch.allclose(losses['smpl_beta_loss'], torch.tensor(1.))
|
|
assert torch.allclose(losses['joints_3d_loss'], torch.tensor(0.5 / 24))
|
|
assert torch.allclose(losses['joints_2d_loss'], torch.tensor(0.5))
|
|
|
|
|
|
def test_gan_loss():
|
|
"""test gan loss."""
|
|
with pytest.raises(NotImplementedError):
|
|
loss_cfg = dict(
|
|
type='GANLoss',
|
|
gan_type='test',
|
|
real_label_val=1.0,
|
|
fake_label_val=0.0,
|
|
loss_weight=1)
|
|
_ = build_loss(loss_cfg)
|
|
|
|
input_1 = torch.ones(1, 1)
|
|
input_2 = torch.ones(1, 3, 6, 6) * 2
|
|
|
|
# vanilla
|
|
loss_cfg = dict(
|
|
type='GANLoss',
|
|
gan_type='vanilla',
|
|
real_label_val=1.0,
|
|
fake_label_val=0.0,
|
|
loss_weight=2.0)
|
|
gan_loss = build_loss(loss_cfg)
|
|
loss = gan_loss(input_1, True, is_disc=False)
|
|
assert_almost_equal(loss.item(), 0.6265233)
|
|
loss = gan_loss(input_1, False, is_disc=False)
|
|
assert_almost_equal(loss.item(), 2.6265232)
|
|
loss = gan_loss(input_1, True, is_disc=True)
|
|
assert_almost_equal(loss.item(), 0.3132616)
|
|
loss = gan_loss(input_1, False, is_disc=True)
|
|
assert_almost_equal(loss.item(), 1.3132616)
|
|
|
|
# lsgan
|
|
loss_cfg = dict(
|
|
type='GANLoss',
|
|
gan_type='lsgan',
|
|
real_label_val=1.0,
|
|
fake_label_val=0.0,
|
|
loss_weight=2.0)
|
|
gan_loss = build_loss(loss_cfg)
|
|
loss = gan_loss(input_2, True, is_disc=False)
|
|
assert_almost_equal(loss.item(), 2.0)
|
|
loss = gan_loss(input_2, False, is_disc=False)
|
|
assert_almost_equal(loss.item(), 8.0)
|
|
loss = gan_loss(input_2, True, is_disc=True)
|
|
assert_almost_equal(loss.item(), 1.0)
|
|
loss = gan_loss(input_2, False, is_disc=True)
|
|
assert_almost_equal(loss.item(), 4.0)
|
|
|
|
# wgan
|
|
loss_cfg = dict(
|
|
type='GANLoss',
|
|
gan_type='wgan',
|
|
real_label_val=1.0,
|
|
fake_label_val=0.0,
|
|
loss_weight=2.0)
|
|
gan_loss = build_loss(loss_cfg)
|
|
loss = gan_loss(input_2, True, is_disc=False)
|
|
assert_almost_equal(loss.item(), -4.0)
|
|
loss = gan_loss(input_2, False, is_disc=False)
|
|
assert_almost_equal(loss.item(), 4)
|
|
loss = gan_loss(input_2, True, is_disc=True)
|
|
assert_almost_equal(loss.item(), -2.0)
|
|
loss = gan_loss(input_2, False, is_disc=True)
|
|
assert_almost_equal(loss.item(), 2.0)
|
|
|
|
# hinge
|
|
loss_cfg = dict(
|
|
type='GANLoss',
|
|
gan_type='hinge',
|
|
real_label_val=1.0,
|
|
fake_label_val=0.0,
|
|
loss_weight=2.0)
|
|
gan_loss = build_loss(loss_cfg)
|
|
loss = gan_loss(input_2, True, is_disc=False)
|
|
assert_almost_equal(loss.item(), -4.0)
|
|
loss = gan_loss(input_2, False, is_disc=False)
|
|
assert_almost_equal(loss.item(), -4.0)
|
|
loss = gan_loss(input_2, True, is_disc=True)
|
|
assert_almost_equal(loss.item(), 0.0)
|
|
loss = gan_loss(input_2, False, is_disc=True)
|
|
assert_almost_equal(loss.item(), 3.0)
|
|
|