You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
72 lines
2.4 KiB
72 lines
2.4 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
|
|
from mmpose.core.post_processing.group import HeatmapParser
|
|
|
|
|
|
def test_group():
|
|
cfg = {}
|
|
cfg['num_joints'] = 17
|
|
cfg['detection_threshold'] = 0.1
|
|
cfg['tag_threshold'] = 1
|
|
cfg['use_detection_val'] = True
|
|
cfg['ignore_too_much'] = False
|
|
cfg['nms_kernel'] = 5
|
|
cfg['nms_padding'] = 2
|
|
cfg['tag_per_joint'] = True
|
|
cfg['max_num_people'] = 1
|
|
parser = HeatmapParser(cfg)
|
|
fake_heatmap = torch.zeros(1, 1, 5, 5)
|
|
fake_heatmap[0, 0, 3, 3] = 1
|
|
fake_heatmap[0, 0, 3, 2] = 0.8
|
|
assert parser.nms(fake_heatmap)[0, 0, 3, 2] == 0
|
|
fake_heatmap = torch.zeros(1, 17, 32, 32)
|
|
fake_tag = torch.zeros(1, 17, 32, 32, 1)
|
|
fake_heatmap[0, 0, 10, 10] = 0.8
|
|
fake_heatmap[0, 1, 12, 12] = 0.9
|
|
fake_heatmap[0, 4, 8, 8] = 0.8
|
|
fake_heatmap[0, 8, 6, 6] = 0.9
|
|
fake_tag[0, 0, 10, 10] = 0.8
|
|
fake_tag[0, 1, 12, 12] = 0.9
|
|
fake_tag[0, 4, 8, 8] = 0.8
|
|
fake_tag[0, 8, 6, 6] = 0.9
|
|
grouped, scores = parser.parse(fake_heatmap, fake_tag, True, True)
|
|
assert grouped[0][0, 0, 0] == 10.25
|
|
assert abs(scores[0] - 0.2) < 0.001
|
|
cfg['tag_per_joint'] = False
|
|
parser = HeatmapParser(cfg)
|
|
grouped, scores = parser.parse(fake_heatmap, fake_tag, False, False)
|
|
assert grouped[0][0, 0, 0] == 10.
|
|
grouped, scores = parser.parse(fake_heatmap, fake_tag, False, True)
|
|
assert grouped[0][0, 0, 0] == 10.
|
|
|
|
|
|
def test_group_score_per_joint():
|
|
cfg = {}
|
|
cfg['num_joints'] = 17
|
|
cfg['detection_threshold'] = 0.1
|
|
cfg['tag_threshold'] = 1
|
|
cfg['use_detection_val'] = True
|
|
cfg['ignore_too_much'] = False
|
|
cfg['nms_kernel'] = 5
|
|
cfg['nms_padding'] = 2
|
|
cfg['tag_per_joint'] = True
|
|
cfg['max_num_people'] = 1
|
|
cfg['score_per_joint'] = True
|
|
parser = HeatmapParser(cfg)
|
|
fake_heatmap = torch.zeros(1, 1, 5, 5)
|
|
fake_heatmap[0, 0, 3, 3] = 1
|
|
fake_heatmap[0, 0, 3, 2] = 0.8
|
|
assert parser.nms(fake_heatmap)[0, 0, 3, 2] == 0
|
|
fake_heatmap = torch.zeros(1, 17, 32, 32)
|
|
fake_tag = torch.zeros(1, 17, 32, 32, 1)
|
|
fake_heatmap[0, 0, 10, 10] = 0.8
|
|
fake_heatmap[0, 1, 12, 12] = 0.9
|
|
fake_heatmap[0, 4, 8, 8] = 0.8
|
|
fake_heatmap[0, 8, 6, 6] = 0.9
|
|
fake_tag[0, 0, 10, 10] = 0.8
|
|
fake_tag[0, 1, 12, 12] = 0.9
|
|
fake_tag[0, 4, 8, 8] = 0.8
|
|
fake_tag[0, 8, 6, 6] = 0.9
|
|
grouped, scores = parser.parse(fake_heatmap, fake_tag, True, True)
|
|
assert len(scores[0]) == 17
|
|
|