You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
102 lines
2.9 KiB
102 lines
2.9 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import numpy as np
|
|
import torch
|
|
|
|
from mmpose.core import (aggregate_scale, aggregate_stage_flip,
|
|
flip_feature_maps, get_group_preds, split_ae_outputs)
|
|
|
|
|
|
def test_split_ae_outputs():
|
|
fake_outputs = [torch.zeros((1, 4, 2, 2))]
|
|
heatmaps, tags = split_ae_outputs(
|
|
fake_outputs,
|
|
num_joints=4,
|
|
with_heatmaps=[False],
|
|
with_ae=[True],
|
|
select_output_index=[0])
|
|
|
|
|
|
def test_flip_feature_maps():
|
|
fake_outputs = [torch.zeros((1, 4, 2, 2))]
|
|
_ = flip_feature_maps(fake_outputs, None)
|
|
_ = flip_feature_maps(fake_outputs, flip_index=[1, 0])
|
|
|
|
|
|
def test_aggregate_stage_flip():
|
|
fake_outputs = [torch.zeros((1, 4, 2, 2))]
|
|
fake_flip_outputs = [torch.ones((1, 4, 2, 2))]
|
|
output = aggregate_stage_flip(
|
|
fake_outputs,
|
|
fake_flip_outputs,
|
|
index=-1,
|
|
project2image=True,
|
|
size_projected=(4, 4),
|
|
align_corners=False,
|
|
aggregate_stage='concat',
|
|
aggregate_flip='average')
|
|
assert isinstance(output, list)
|
|
|
|
output = aggregate_stage_flip(
|
|
fake_outputs,
|
|
fake_flip_outputs,
|
|
index=-1,
|
|
project2image=True,
|
|
size_projected=(4, 4),
|
|
align_corners=False,
|
|
aggregate_stage='average',
|
|
aggregate_flip='average')
|
|
assert isinstance(output, list)
|
|
|
|
output = aggregate_stage_flip(
|
|
fake_outputs,
|
|
fake_flip_outputs,
|
|
index=-1,
|
|
project2image=True,
|
|
size_projected=(4, 4),
|
|
align_corners=False,
|
|
aggregate_stage='average',
|
|
aggregate_flip='concat')
|
|
assert isinstance(output, list)
|
|
|
|
output = aggregate_stage_flip(
|
|
fake_outputs,
|
|
fake_flip_outputs,
|
|
index=-1,
|
|
project2image=True,
|
|
size_projected=(4, 4),
|
|
align_corners=False,
|
|
aggregate_stage='concat',
|
|
aggregate_flip='concat')
|
|
assert isinstance(output, list)
|
|
|
|
|
|
def test_aggregate_scale():
|
|
fake_outputs = [torch.zeros((1, 4, 2, 2)), torch.zeros((1, 4, 2, 2))]
|
|
output = aggregate_scale(
|
|
fake_outputs, align_corners=False, aggregate_scale='average')
|
|
assert isinstance(output, torch.Tensor)
|
|
assert output.shape == fake_outputs[0].shape
|
|
|
|
output = aggregate_scale(
|
|
fake_outputs, align_corners=False, aggregate_scale='unsqueeze_concat')
|
|
|
|
assert isinstance(output, torch.Tensor)
|
|
assert len(output.shape) == len(fake_outputs[0].shape) + 1
|
|
|
|
|
|
def test_get_group_preds():
|
|
fake_grouped_joints = [np.array([[[0, 0], [1, 1]]])]
|
|
results = get_group_preds(
|
|
fake_grouped_joints,
|
|
center=np.array([0, 0]),
|
|
scale=np.array([1, 1]),
|
|
heatmap_size=np.array([2, 2]))
|
|
assert not results == []
|
|
|
|
results = get_group_preds(
|
|
fake_grouped_joints,
|
|
center=np.array([0, 0]),
|
|
scale=np.array([1, 1]),
|
|
heatmap_size=np.array([2, 2]),
|
|
use_udp=True)
|
|
assert not results == []
|
|
|