You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
122 lines
3.7 KiB
122 lines
3.7 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import numpy as np
|
|
import torch
|
|
|
|
from mmpose.models.detectors import AssociativeEmbedding
|
|
|
|
|
|
def test_ae_forward():
|
|
model_cfg = dict(
|
|
type='AssociativeEmbedding',
|
|
pretrained=None,
|
|
backbone=dict(type='ResNet', depth=18),
|
|
keypoint_head=dict(
|
|
type='AESimpleHead',
|
|
in_channels=512,
|
|
num_joints=17,
|
|
num_deconv_layers=0,
|
|
tag_per_joint=True,
|
|
with_ae_loss=[True],
|
|
extra=dict(final_conv_kernel=1, ),
|
|
loss_keypoint=dict(
|
|
type='MultiLossFactory',
|
|
num_joints=17,
|
|
num_stages=1,
|
|
ae_loss_type='exp',
|
|
with_ae_loss=[True],
|
|
push_loss_factor=[0.001],
|
|
pull_loss_factor=[0.001],
|
|
with_heatmaps_loss=[True],
|
|
heatmaps_loss_factor=[1.0])),
|
|
train_cfg=dict(),
|
|
test_cfg=dict(
|
|
num_joints=17,
|
|
max_num_people=30,
|
|
scale_factor=[1],
|
|
with_heatmaps=[True],
|
|
with_ae=[True],
|
|
project2image=True,
|
|
nms_kernel=5,
|
|
nms_padding=2,
|
|
tag_per_joint=True,
|
|
detection_threshold=0.1,
|
|
tag_threshold=1,
|
|
use_detection_val=True,
|
|
ignore_too_much=False,
|
|
adjust=True,
|
|
refine=True,
|
|
soft_nms=False,
|
|
flip_test=True,
|
|
post_process=True,
|
|
shift_heatmap=True,
|
|
use_gt_bbox=True,
|
|
flip_pairs=[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12],
|
|
[13, 14], [15, 16]],
|
|
))
|
|
|
|
detector = AssociativeEmbedding(model_cfg['backbone'],
|
|
model_cfg['keypoint_head'],
|
|
model_cfg['train_cfg'],
|
|
model_cfg['test_cfg'],
|
|
model_cfg['pretrained'])
|
|
|
|
detector.init_weights()
|
|
|
|
input_shape = (1, 3, 256, 256)
|
|
mm_inputs = _demo_mm_inputs(input_shape)
|
|
|
|
imgs = mm_inputs.pop('imgs')
|
|
target = mm_inputs.pop('target')
|
|
mask = mm_inputs.pop('mask')
|
|
joints = mm_inputs.pop('joints')
|
|
img_metas = mm_inputs.pop('img_metas')
|
|
|
|
# Test forward train
|
|
losses = detector.forward(
|
|
imgs, target, mask, joints, img_metas, return_loss=True)
|
|
assert isinstance(losses, dict)
|
|
|
|
# Test forward test
|
|
with torch.no_grad():
|
|
_ = detector.forward(imgs, img_metas=img_metas, return_loss=False)
|
|
_ = detector.forward_dummy(imgs)
|
|
|
|
|
|
def _demo_mm_inputs(input_shape=(1, 3, 256, 256)):
|
|
"""Create a superset of inputs needed to run test or train batches.
|
|
|
|
Args:
|
|
input_shape (tuple):
|
|
input batch dimensions
|
|
"""
|
|
(N, C, H, W) = input_shape
|
|
|
|
rng = np.random.RandomState(0)
|
|
|
|
imgs = rng.rand(*input_shape)
|
|
target = np.zeros([N, 17, H // 32, W // 32], dtype=np.float32)
|
|
mask = np.ones([N, H // 32, W // 32], dtype=np.float32)
|
|
joints = np.zeros([N, 30, 17, 2], dtype=np.float32)
|
|
|
|
img_metas = [{
|
|
'image_file':
|
|
'test.jpg',
|
|
'aug_data': [torch.zeros(1, 3, 256, 256)],
|
|
'test_scale_factor': [1],
|
|
'base_size': (256, 256),
|
|
'center':
|
|
np.array([128, 128]),
|
|
'scale':
|
|
np.array([1.28, 1.28]),
|
|
'flip_index':
|
|
[0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
|
|
} for _ in range(N)]
|
|
|
|
mm_inputs = {
|
|
'imgs': torch.FloatTensor(imgs).requires_grad_(True),
|
|
'target': [torch.FloatTensor(target)],
|
|
'mask': [torch.FloatTensor(mask)],
|
|
'joints': [torch.FloatTensor(joints)],
|
|
'img_metas': img_metas
|
|
}
|
|
return mm_inputs
|
|
|