You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
91 lines
2.6 KiB
91 lines
2.6 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import numpy as np
|
|
import torch
|
|
|
|
from mmpose.models import Interhand3DHead
|
|
|
|
|
|
def test_interhand_3d_head():
|
|
"""Test interhand 3d head."""
|
|
N = 4
|
|
input_shape = (N, 2048, 8, 8)
|
|
inputs = torch.rand(input_shape, dtype=torch.float32)
|
|
target = [
|
|
inputs.new_zeros(N, 42, 64, 64, 64),
|
|
inputs.new_zeros(N, 1),
|
|
inputs.new_zeros(N, 2),
|
|
]
|
|
target_weight = [
|
|
inputs.new_ones(N, 42, 1),
|
|
inputs.new_ones(N, 1),
|
|
inputs.new_ones(N),
|
|
]
|
|
|
|
img_metas = [{
|
|
'img_shape': (256, 256, 3),
|
|
'center': np.array([112, 112]),
|
|
'scale': np.array([0.5, 0.5]),
|
|
'bbox_score': 1.0,
|
|
'bbox_id': 0,
|
|
'flip_pairs': [],
|
|
'inference_channel': np.arange(42),
|
|
'image_file': '<demo>.png',
|
|
'heatmap3d_depth_bound': 400.0,
|
|
'root_depth_bound': 400.0,
|
|
} for _ in range(N)]
|
|
|
|
head = Interhand3DHead(
|
|
keypoint_head_cfg=dict(
|
|
in_channels=2048,
|
|
out_channels=21 * 64,
|
|
depth_size=64,
|
|
num_deconv_layers=3,
|
|
num_deconv_filters=(256, 256, 256),
|
|
num_deconv_kernels=(4, 4, 4),
|
|
),
|
|
root_head_cfg=dict(
|
|
in_channels=2048,
|
|
heatmap_size=64,
|
|
hidden_dims=(512, ),
|
|
),
|
|
hand_type_head_cfg=dict(
|
|
in_channels=2048,
|
|
num_labels=2,
|
|
hidden_dims=(512, ),
|
|
),
|
|
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True),
|
|
loss_root_depth=dict(type='L1Loss'),
|
|
loss_hand_type=dict(type='BCELoss', use_target_weight=True),
|
|
train_cfg={},
|
|
test_cfg={},
|
|
)
|
|
head.init_weights()
|
|
|
|
# test forward
|
|
output = head(inputs)
|
|
assert isinstance(output, list)
|
|
assert len(output) == 3
|
|
assert output[0].shape == (N, 42, 64, 64, 64)
|
|
assert output[1].shape == (N, 1)
|
|
assert output[2].shape == (N, 2)
|
|
|
|
# test loss computation
|
|
losses = head.get_loss(output, target, target_weight)
|
|
assert 'hand_loss' in losses
|
|
assert 'rel_root_loss' in losses
|
|
assert 'hand_type_loss' in losses
|
|
|
|
# test inference model
|
|
flip_pairs = [[i, 21 + i] for i in range(21)]
|
|
output = head.inference_model(inputs, flip_pairs)
|
|
assert isinstance(output, list)
|
|
assert len(output) == 3
|
|
assert output[0].shape == (N, 42, 64, 64, 64)
|
|
assert output[1].shape == (N, 1)
|
|
assert output[2].shape == (N, 2)
|
|
|
|
# test decode
|
|
result = head.decode(img_metas, output)
|
|
assert 'preds' in result
|
|
assert 'rel_root_depth' in result
|
|
assert 'hand_type' in result
|
|
|