You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

91 lines
2.6 KiB

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmpose.models import Interhand3DHead
def test_interhand_3d_head():
"""Test interhand 3d head."""
N = 4
input_shape = (N, 2048, 8, 8)
inputs = torch.rand(input_shape, dtype=torch.float32)
target = [
inputs.new_zeros(N, 42, 64, 64, 64),
inputs.new_zeros(N, 1),
inputs.new_zeros(N, 2),
]
target_weight = [
inputs.new_ones(N, 42, 1),
inputs.new_ones(N, 1),
inputs.new_ones(N),
]
img_metas = [{
'img_shape': (256, 256, 3),
'center': np.array([112, 112]),
'scale': np.array([0.5, 0.5]),
'bbox_score': 1.0,
'bbox_id': 0,
'flip_pairs': [],
'inference_channel': np.arange(42),
'image_file': '<demo>.png',
'heatmap3d_depth_bound': 400.0,
'root_depth_bound': 400.0,
} for _ in range(N)]
head = Interhand3DHead(
keypoint_head_cfg=dict(
in_channels=2048,
out_channels=21 * 64,
depth_size=64,
num_deconv_layers=3,
num_deconv_filters=(256, 256, 256),
num_deconv_kernels=(4, 4, 4),
),
root_head_cfg=dict(
in_channels=2048,
heatmap_size=64,
hidden_dims=(512, ),
),
hand_type_head_cfg=dict(
in_channels=2048,
num_labels=2,
hidden_dims=(512, ),
),
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True),
loss_root_depth=dict(type='L1Loss'),
loss_hand_type=dict(type='BCELoss', use_target_weight=True),
train_cfg={},
test_cfg={},
)
head.init_weights()
# test forward
output = head(inputs)
assert isinstance(output, list)
assert len(output) == 3
assert output[0].shape == (N, 42, 64, 64, 64)
assert output[1].shape == (N, 1)
assert output[2].shape == (N, 2)
# test loss computation
losses = head.get_loss(output, target, target_weight)
assert 'hand_loss' in losses
assert 'rel_root_loss' in losses
assert 'hand_type_loss' in losses
# test inference model
flip_pairs = [[i, 21 + i] for i in range(21)]
output = head.inference_model(inputs, flip_pairs)
assert isinstance(output, list)
assert len(output) == 3
assert output[0].shape == (N, 42, 64, 64, 64)
assert output[1].shape == (N, 1)
assert output[2].shape == (N, 2)
# test decode
result = head.decode(img_metas, output)
assert 'preds' in result
assert 'rel_root_depth' in result
assert 'hand_type' in result