You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
76 lines
2.5 KiB
76 lines
2.5 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from mmpose.models import HMRMeshHead
|
|
from mmpose.models.misc.discriminator import SMPLDiscriminator
|
|
|
|
|
|
def test_mesh_hmr_head():
|
|
"""Test hmr mesh head."""
|
|
head = HMRMeshHead(in_channels=512)
|
|
head.init_weights()
|
|
|
|
input_shape = (1, 512, 8, 8)
|
|
inputs = _demo_inputs(input_shape)
|
|
out = head(inputs)
|
|
smpl_rotmat, smpl_shape, camera = out
|
|
assert smpl_rotmat.shape == torch.Size([1, 24, 3, 3])
|
|
assert smpl_shape.shape == torch.Size([1, 10])
|
|
assert camera.shape == torch.Size([1, 3])
|
|
"""Test hmr mesh head with assigned mean parameters and n_iter """
|
|
head = HMRMeshHead(
|
|
in_channels=512,
|
|
smpl_mean_params='tests/data/smpl/smpl_mean_params.npz',
|
|
n_iter=3)
|
|
head.init_weights()
|
|
input_shape = (1, 512, 8, 8)
|
|
inputs = _demo_inputs(input_shape)
|
|
out = head(inputs)
|
|
smpl_rotmat, smpl_shape, camera = out
|
|
assert smpl_rotmat.shape == torch.Size([1, 24, 3, 3])
|
|
assert smpl_shape.shape == torch.Size([1, 10])
|
|
assert camera.shape == torch.Size([1, 3])
|
|
|
|
# test discriminator with SMPL pose parameters
|
|
# in rotation matrix representation
|
|
disc = SMPLDiscriminator(
|
|
beta_channel=(10, 10, 5, 1),
|
|
per_joint_channel=(9, 32, 32, 16, 1),
|
|
full_pose_channel=(23 * 16, 256, 1))
|
|
pred_theta = (camera, smpl_rotmat, smpl_shape)
|
|
pred_score = disc(pred_theta)
|
|
assert pred_score.shape[1] == 25
|
|
|
|
# test discriminator with SMPL pose parameters
|
|
# in axis-angle representation
|
|
pred_theta = (camera, camera.new_zeros([1, 72]), smpl_shape)
|
|
pred_score = disc(pred_theta)
|
|
assert pred_score.shape[1] == 25
|
|
|
|
with pytest.raises(TypeError):
|
|
_ = SMPLDiscriminator(
|
|
beta_channel=[10, 10, 5, 1],
|
|
per_joint_channel=(9, 32, 32, 16, 1),
|
|
full_pose_channel=(23 * 16, 256, 1))
|
|
|
|
with pytest.raises(ValueError):
|
|
_ = SMPLDiscriminator(
|
|
beta_channel=(10, ),
|
|
per_joint_channel=(9, 32, 32, 16, 1),
|
|
full_pose_channel=(23 * 16, 256, 1))
|
|
|
|
|
|
def _demo_inputs(input_shape=(1, 3, 64, 64)):
|
|
"""Create a superset of inputs needed to run mesh head.
|
|
|
|
Args:
|
|
input_shape (tuple): input batch dimensions.
|
|
Default: (1, 3, 64, 64).
|
|
Returns:
|
|
Random input tensor with the size of input_shape.
|
|
"""
|
|
inps = np.random.random(input_shape)
|
|
inps = torch.FloatTensor(inps)
|
|
return inps
|
|
|