You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
116 lines
3.4 KiB
116 lines
3.4 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import numpy as np
|
|
import torch
|
|
|
|
from mmpose.models.detectors import MultiTask
|
|
|
|
|
|
def test_multitask_forward():
|
|
"""Test multitask forward."""
|
|
|
|
# build MultiTask detector
|
|
model_cfg = dict(
|
|
backbone=dict(type='ResNet', depth=50),
|
|
heads=[
|
|
dict(
|
|
type='DeepposeRegressionHead',
|
|
in_channels=2048,
|
|
num_joints=17,
|
|
loss_keypoint=dict(
|
|
type='SmoothL1Loss', use_target_weight=False)),
|
|
],
|
|
necks=[dict(type='GlobalAveragePooling')],
|
|
head2neck={0: 0},
|
|
pretrained=None,
|
|
)
|
|
model = MultiTask(**model_cfg)
|
|
|
|
# build inputs and target
|
|
mm_inputs = _demo_mm_inputs()
|
|
inputs = mm_inputs['img']
|
|
target = [mm_inputs['target_keypoints']]
|
|
target_weight = [mm_inputs['target_weight']]
|
|
img_metas = mm_inputs['img_metas']
|
|
|
|
# Test forward train
|
|
losses = model(inputs, target, target_weight, return_loss=True)
|
|
assert 'reg_loss' in losses and 'acc_pose' in losses
|
|
|
|
# Test forward test
|
|
outputs = model(inputs, img_metas=img_metas, return_loss=False)
|
|
assert 'preds' in outputs
|
|
|
|
# Test dummy forward
|
|
outputs = model.forward_dummy(inputs)
|
|
assert outputs[0].shape == torch.Size([1, 17, 2])
|
|
|
|
# Build multitask detector with no neck
|
|
model_cfg = dict(
|
|
backbone=dict(type='ResNet', depth=50),
|
|
heads=[
|
|
dict(
|
|
type='TopdownHeatmapSimpleHead',
|
|
in_channels=2048,
|
|
out_channels=17,
|
|
num_deconv_layers=3,
|
|
num_deconv_filters=(256, 256, 256),
|
|
num_deconv_kernels=(4, 4, 4),
|
|
loss_keypoint=dict(
|
|
type='JointsMSELoss', use_target_weight=True))
|
|
],
|
|
pretrained=None,
|
|
)
|
|
model = MultiTask(**model_cfg)
|
|
|
|
# build inputs and target
|
|
target = [mm_inputs['target_heatmap']]
|
|
|
|
# Test forward train
|
|
losses = model(inputs, target, target_weight, return_loss=True)
|
|
assert 'heatmap_loss' in losses and 'acc_pose' in losses
|
|
|
|
# Test forward test
|
|
outputs = model(inputs, img_metas=img_metas, return_loss=False)
|
|
assert 'preds' in outputs
|
|
|
|
# Test dummy forward
|
|
outputs = model.forward_dummy(inputs)
|
|
assert outputs[0].shape == torch.Size([1, 17, 64, 64])
|
|
|
|
|
|
def _demo_mm_inputs(input_shape=(1, 3, 256, 256)):
|
|
"""Create a superset of inputs needed to run test or train.
|
|
|
|
Args:
|
|
input_shape (tuple):
|
|
input batch dimensions
|
|
"""
|
|
(N, C, H, W) = input_shape
|
|
|
|
rng = np.random.RandomState(0)
|
|
|
|
imgs = rng.rand(*input_shape)
|
|
|
|
target_keypoints = np.zeros([N, 17, 2])
|
|
target_heatmap = np.zeros([N, 17, H // 4, W // 4])
|
|
target_weight = np.ones([N, 17, 1])
|
|
|
|
img_metas = [{
|
|
'img_shape': (H, W, C),
|
|
'center': np.array([W / 2, H / 2]),
|
|
'scale': np.array([0.5, 0.5]),
|
|
'bbox_score': 1.0,
|
|
'bbox_id': 0,
|
|
'flip_pairs': [],
|
|
'inference_channel': np.arange(17),
|
|
'image_file': '<demo>.png',
|
|
} for _ in range(N)]
|
|
|
|
mm_inputs = {
|
|
'img': torch.FloatTensor(imgs).requires_grad_(True),
|
|
'target_keypoints': torch.FloatTensor(target_keypoints),
|
|
'target_heatmap': torch.FloatTensor(target_heatmap),
|
|
'target_weight': torch.FloatTensor(target_weight),
|
|
'img_metas': img_metas,
|
|
}
|
|
return mm_inputs
|
|
|