You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
104 lines
3.0 KiB
104 lines
3.0 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from mmpose.models import TemporalRegressionHead
|
|
|
|
|
|
def test_temporal_regression_head():
|
|
"""Test temporal head."""
|
|
|
|
# w/o global position restoration
|
|
head = TemporalRegressionHead(
|
|
in_channels=1024,
|
|
num_joints=17,
|
|
loss_keypoint=dict(type='MPJPELoss', use_target_weight=True),
|
|
test_cfg=dict(restore_global_position=False))
|
|
|
|
head.init_weights()
|
|
|
|
with pytest.raises(AssertionError):
|
|
# ndim of the input tensor should be 3
|
|
input_shape = (1, 1024, 1, 1)
|
|
inputs = _demo_inputs(input_shape)
|
|
_ = head(inputs)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# size of the last dim should be 1
|
|
input_shape = (1, 1024, 3)
|
|
inputs = _demo_inputs(input_shape)
|
|
_ = head(inputs)
|
|
|
|
input_shape = (1, 1024, 1)
|
|
inputs = _demo_inputs(input_shape)
|
|
out = head(inputs)
|
|
assert out.shape == torch.Size([1, 17, 3])
|
|
|
|
loss = head.get_loss(out, out, None)
|
|
assert torch.allclose(loss['reg_loss'], torch.tensor(0.))
|
|
|
|
_ = head.inference_model(inputs)
|
|
_ = head.inference_model(inputs, [(0, 1), (2, 3)])
|
|
metas = [{}]
|
|
|
|
acc = head.get_accuracy(out, out, None, metas=metas)
|
|
assert acc['mpjpe'] == 0.
|
|
np.testing.assert_almost_equal(acc['p_mpjpe'], 0., decimal=6)
|
|
|
|
# w/ global position restoration
|
|
head = TemporalRegressionHead(
|
|
in_channels=1024,
|
|
num_joints=16,
|
|
loss_keypoint=dict(type='MPJPELoss', use_target_weight=True),
|
|
test_cfg=dict(restore_global_position=True))
|
|
head.init_weights()
|
|
|
|
input_shape = (1, 1024, 1)
|
|
inputs = _demo_inputs(input_shape)
|
|
metas = [{
|
|
'root_position': np.zeros((1, 3)),
|
|
'root_position_index': 0,
|
|
'root_weight': 1.
|
|
}]
|
|
out = head(inputs)
|
|
assert out.shape == torch.Size([1, 16, 3])
|
|
|
|
inference_out = head.inference_model(inputs)
|
|
acc = head.get_accuracy(out, out, torch.ones_like(out), metas)
|
|
assert acc['mpjpe'] == 0.
|
|
np.testing.assert_almost_equal(acc['p_mpjpe'], 0.)
|
|
|
|
_ = head.decode(metas, inference_out)
|
|
|
|
# trajectory model (only predict root position)
|
|
head = TemporalRegressionHead(
|
|
in_channels=1024,
|
|
num_joints=1,
|
|
loss_keypoint=dict(type='MPJPELoss', use_target_weight=True),
|
|
is_trajectory=True,
|
|
test_cfg=dict(restore_global_position=False))
|
|
|
|
head.init_weights()
|
|
|
|
input_shape = (1, 1024, 1)
|
|
inputs = _demo_inputs(input_shape)
|
|
out = head(inputs)
|
|
assert out.shape == torch.Size([1, 1, 3])
|
|
|
|
loss = head.get_loss(out, out.squeeze(1), torch.ones_like(out))
|
|
assert torch.allclose(loss['traj_loss'], torch.tensor(0.))
|
|
|
|
|
|
def _demo_inputs(input_shape=(1, 1024, 1)):
|
|
"""Create a superset of inputs needed to run head.
|
|
|
|
Args:
|
|
input_shape (tuple): input batch dimensions.
|
|
Default: (1, 1024, 1).
|
|
Returns:
|
|
Random input tensor with the size of input_shape.
|
|
"""
|
|
inps = np.random.random(input_shape)
|
|
inps = torch.FloatTensor(inps)
|
|
return inps
|
|
|