You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
410 lines
13 KiB
410 lines
13 KiB
# ------------------------------------------------------------------------------
|
|
# Adapted from https://github.com/princeton-vl/pose-ae-train/
|
|
# Original licence: Copyright (c) 2017, umich-vl, under BSD 3-Clause License.
|
|
# ------------------------------------------------------------------------------
|
|
|
|
import numpy as np
|
|
import torch
|
|
from munkres import Munkres
|
|
|
|
from mmpose.core.evaluation import post_dark_udp
|
|
|
|
|
|
def _py_max_match(scores):
|
|
"""Apply munkres algorithm to get the best match.
|
|
|
|
Args:
|
|
scores(np.ndarray): cost matrix.
|
|
|
|
Returns:
|
|
np.ndarray: best match.
|
|
"""
|
|
m = Munkres()
|
|
tmp = m.compute(scores)
|
|
tmp = np.array(tmp).astype(int)
|
|
return tmp
|
|
|
|
|
|
def _match_by_tag(inp, params):
|
|
"""Match joints by tags. Use Munkres algorithm to calculate the best match
|
|
for keypoints grouping.
|
|
|
|
Note:
|
|
number of keypoints: K
|
|
max number of people in an image: M (M=30 by default)
|
|
dim of tags: L
|
|
If use flip testing, L=2; else L=1.
|
|
|
|
Args:
|
|
inp(tuple):
|
|
tag_k (np.ndarray[KxMxL]): tag corresponding to the
|
|
top k values of feature map per keypoint.
|
|
loc_k (np.ndarray[KxMx2]): top k locations of the
|
|
feature maps for keypoint.
|
|
val_k (np.ndarray[KxM]): top k value of the
|
|
feature maps per keypoint.
|
|
params(Params): class Params().
|
|
|
|
Returns:
|
|
np.ndarray: result of pose groups.
|
|
"""
|
|
assert isinstance(params, _Params), 'params should be class _Params()'
|
|
|
|
tag_k, loc_k, val_k = inp
|
|
|
|
default_ = np.zeros((params.num_joints, 3 + tag_k.shape[2]),
|
|
dtype=np.float32)
|
|
|
|
joint_dict = {}
|
|
tag_dict = {}
|
|
for i in range(params.num_joints):
|
|
idx = params.joint_order[i]
|
|
|
|
tags = tag_k[idx]
|
|
joints = np.concatenate((loc_k[idx], val_k[idx, :, None], tags), 1)
|
|
mask = joints[:, 2] > params.detection_threshold
|
|
tags = tags[mask]
|
|
joints = joints[mask]
|
|
|
|
if joints.shape[0] == 0:
|
|
continue
|
|
|
|
if i == 0 or len(joint_dict) == 0:
|
|
for tag, joint in zip(tags, joints):
|
|
key = tag[0]
|
|
joint_dict.setdefault(key, np.copy(default_))[idx] = joint
|
|
tag_dict[key] = [tag]
|
|
else:
|
|
grouped_keys = list(joint_dict.keys())[:params.max_num_people]
|
|
grouped_tags = [np.mean(tag_dict[i], axis=0) for i in grouped_keys]
|
|
|
|
if (params.ignore_too_much
|
|
and len(grouped_keys) == params.max_num_people):
|
|
continue
|
|
|
|
diff = joints[:, None, 3:] - np.array(grouped_tags)[None, :, :]
|
|
diff_normed = np.linalg.norm(diff, ord=2, axis=2)
|
|
diff_saved = np.copy(diff_normed)
|
|
|
|
if params.use_detection_val:
|
|
diff_normed = np.round(diff_normed) * 100 - joints[:, 2:3]
|
|
|
|
num_added = diff.shape[0]
|
|
num_grouped = diff.shape[1]
|
|
|
|
if num_added > num_grouped:
|
|
diff_normed = np.concatenate(
|
|
(diff_normed,
|
|
np.zeros((num_added, num_added - num_grouped),
|
|
dtype=np.float32) + 1e10),
|
|
axis=1)
|
|
|
|
pairs = _py_max_match(diff_normed)
|
|
for row, col in pairs:
|
|
if (row < num_added and col < num_grouped
|
|
and diff_saved[row][col] < params.tag_threshold):
|
|
key = grouped_keys[col]
|
|
joint_dict[key][idx] = joints[row]
|
|
tag_dict[key].append(tags[row])
|
|
else:
|
|
key = tags[row][0]
|
|
joint_dict.setdefault(key, np.copy(default_))[idx] = \
|
|
joints[row]
|
|
tag_dict[key] = [tags[row]]
|
|
|
|
results = np.array([joint_dict[i] for i in joint_dict]).astype(np.float32)
|
|
return results
|
|
|
|
|
|
class _Params:
|
|
"""A class of parameter.
|
|
|
|
Args:
|
|
cfg(Config): config.
|
|
"""
|
|
|
|
def __init__(self, cfg):
|
|
self.num_joints = cfg['num_joints']
|
|
self.max_num_people = cfg['max_num_people']
|
|
|
|
self.detection_threshold = cfg['detection_threshold']
|
|
self.tag_threshold = cfg['tag_threshold']
|
|
self.use_detection_val = cfg['use_detection_val']
|
|
self.ignore_too_much = cfg['ignore_too_much']
|
|
|
|
if self.num_joints == 17:
|
|
self.joint_order = [
|
|
i - 1 for i in
|
|
[1, 2, 3, 4, 5, 6, 7, 12, 13, 8, 9, 10, 11, 14, 15, 16, 17]
|
|
]
|
|
else:
|
|
self.joint_order = list(np.arange(self.num_joints))
|
|
|
|
|
|
class HeatmapParser:
|
|
"""The heatmap parser for post processing."""
|
|
|
|
def __init__(self, cfg):
|
|
self.params = _Params(cfg)
|
|
self.tag_per_joint = cfg['tag_per_joint']
|
|
self.pool = torch.nn.MaxPool2d(cfg['nms_kernel'], 1,
|
|
cfg['nms_padding'])
|
|
self.use_udp = cfg.get('use_udp', False)
|
|
self.score_per_joint = cfg.get('score_per_joint', False)
|
|
|
|
def nms(self, heatmaps):
|
|
"""Non-Maximum Suppression for heatmaps.
|
|
|
|
Args:
|
|
heatmap(torch.Tensor): Heatmaps before nms.
|
|
|
|
Returns:
|
|
torch.Tensor: Heatmaps after nms.
|
|
"""
|
|
|
|
maxm = self.pool(heatmaps)
|
|
maxm = torch.eq(maxm, heatmaps).float()
|
|
heatmaps = heatmaps * maxm
|
|
|
|
return heatmaps
|
|
|
|
def match(self, tag_k, loc_k, val_k):
|
|
"""Group keypoints to human poses in a batch.
|
|
|
|
Args:
|
|
tag_k (np.ndarray[NxKxMxL]): tag corresponding to the
|
|
top k values of feature map per keypoint.
|
|
loc_k (np.ndarray[NxKxMx2]): top k locations of the
|
|
feature maps for keypoint.
|
|
val_k (np.ndarray[NxKxM]): top k value of the
|
|
feature maps per keypoint.
|
|
|
|
Returns:
|
|
list
|
|
"""
|
|
|
|
def _match(x):
|
|
return _match_by_tag(x, self.params)
|
|
|
|
return list(map(_match, zip(tag_k, loc_k, val_k)))
|
|
|
|
def top_k(self, heatmaps, tags):
|
|
"""Find top_k values in an image.
|
|
|
|
Note:
|
|
batch size: N
|
|
number of keypoints: K
|
|
heatmap height: H
|
|
heatmap width: W
|
|
max number of people: M
|
|
dim of tags: L
|
|
If use flip testing, L=2; else L=1.
|
|
|
|
Args:
|
|
heatmaps (torch.Tensor[NxKxHxW])
|
|
tags (torch.Tensor[NxKxHxWxL])
|
|
|
|
Returns:
|
|
dict: A dict containing top_k values.
|
|
|
|
- tag_k (np.ndarray[NxKxMxL]):
|
|
tag corresponding to the top k values of
|
|
feature map per keypoint.
|
|
- loc_k (np.ndarray[NxKxMx2]):
|
|
top k location of feature map per keypoint.
|
|
- val_k (np.ndarray[NxKxM]):
|
|
top k value of feature map per keypoint.
|
|
"""
|
|
heatmaps = self.nms(heatmaps)
|
|
N, K, H, W = heatmaps.size()
|
|
heatmaps = heatmaps.view(N, K, -1)
|
|
val_k, ind = heatmaps.topk(self.params.max_num_people, dim=2)
|
|
|
|
tags = tags.view(tags.size(0), tags.size(1), W * H, -1)
|
|
if not self.tag_per_joint:
|
|
tags = tags.expand(-1, self.params.num_joints, -1, -1)
|
|
|
|
tag_k = torch.stack(
|
|
[torch.gather(tags[..., i], 2, ind) for i in range(tags.size(3))],
|
|
dim=3)
|
|
|
|
x = ind % W
|
|
y = ind // W
|
|
|
|
ind_k = torch.stack((x, y), dim=3)
|
|
|
|
results = {
|
|
'tag_k': tag_k.cpu().numpy(),
|
|
'loc_k': ind_k.cpu().numpy(),
|
|
'val_k': val_k.cpu().numpy()
|
|
}
|
|
|
|
return results
|
|
|
|
@staticmethod
|
|
def adjust(results, heatmaps):
|
|
"""Adjust the coordinates for better accuracy.
|
|
|
|
Note:
|
|
batch size: N
|
|
number of keypoints: K
|
|
heatmap height: H
|
|
heatmap width: W
|
|
|
|
Args:
|
|
results (list(np.ndarray)): Keypoint predictions.
|
|
heatmaps (torch.Tensor[NxKxHxW]): Heatmaps.
|
|
"""
|
|
_, _, H, W = heatmaps.shape
|
|
for batch_id, people in enumerate(results):
|
|
for people_id, people_i in enumerate(people):
|
|
for joint_id, joint in enumerate(people_i):
|
|
if joint[2] > 0:
|
|
x, y = joint[0:2]
|
|
xx, yy = int(x), int(y)
|
|
tmp = heatmaps[batch_id][joint_id]
|
|
if tmp[min(H - 1, yy + 1), xx] > tmp[max(0, yy - 1),
|
|
xx]:
|
|
y += 0.25
|
|
else:
|
|
y -= 0.25
|
|
|
|
if tmp[yy, min(W - 1, xx + 1)] > tmp[yy,
|
|
max(0, xx - 1)]:
|
|
x += 0.25
|
|
else:
|
|
x -= 0.25
|
|
results[batch_id][people_id, joint_id,
|
|
0:2] = (x + 0.5, y + 0.5)
|
|
return results
|
|
|
|
@staticmethod
|
|
def refine(heatmap, tag, keypoints, use_udp=False):
|
|
"""Given initial keypoint predictions, we identify missing joints.
|
|
|
|
Note:
|
|
number of keypoints: K
|
|
heatmap height: H
|
|
heatmap width: W
|
|
dim of tags: L
|
|
If use flip testing, L=2; else L=1.
|
|
|
|
Args:
|
|
heatmap: np.ndarray(K, H, W).
|
|
tag: np.ndarray(K, H, W) | np.ndarray(K, H, W, L)
|
|
keypoints: np.ndarray of size (K, 3 + L)
|
|
last dim is (x, y, score, tag).
|
|
use_udp: bool-unbiased data processing
|
|
|
|
Returns:
|
|
np.ndarray: The refined keypoints.
|
|
"""
|
|
|
|
K, H, W = heatmap.shape
|
|
if len(tag.shape) == 3:
|
|
tag = tag[..., None]
|
|
|
|
tags = []
|
|
for i in range(K):
|
|
if keypoints[i, 2] > 0:
|
|
# save tag value of detected keypoint
|
|
x, y = keypoints[i][:2].astype(int)
|
|
x = np.clip(x, 0, W - 1)
|
|
y = np.clip(y, 0, H - 1)
|
|
tags.append(tag[i, y, x])
|
|
|
|
# mean tag of current detected people
|
|
prev_tag = np.mean(tags, axis=0)
|
|
results = []
|
|
|
|
for _heatmap, _tag in zip(heatmap, tag):
|
|
# distance of all tag values with mean tag of
|
|
# current detected people
|
|
distance_tag = (((_tag -
|
|
prev_tag[None, None, :])**2).sum(axis=2)**0.5)
|
|
norm_heatmap = _heatmap - np.round(distance_tag)
|
|
|
|
# find maximum position
|
|
y, x = np.unravel_index(np.argmax(norm_heatmap), _heatmap.shape)
|
|
xx = x.copy()
|
|
yy = y.copy()
|
|
# detection score at maximum position
|
|
val = _heatmap[y, x]
|
|
if not use_udp:
|
|
# offset by 0.5
|
|
x += 0.5
|
|
y += 0.5
|
|
|
|
# add a quarter offset
|
|
if _heatmap[yy, min(W - 1, xx + 1)] > _heatmap[yy, max(0, xx - 1)]:
|
|
x += 0.25
|
|
else:
|
|
x -= 0.25
|
|
|
|
if _heatmap[min(H - 1, yy + 1), xx] > _heatmap[max(0, yy - 1), xx]:
|
|
y += 0.25
|
|
else:
|
|
y -= 0.25
|
|
|
|
results.append((x, y, val))
|
|
results = np.array(results)
|
|
|
|
if results is not None:
|
|
for i in range(K):
|
|
# add keypoint if it is not detected
|
|
if results[i, 2] > 0 and keypoints[i, 2] == 0:
|
|
keypoints[i, :3] = results[i, :3]
|
|
|
|
return keypoints
|
|
|
|
def parse(self, heatmaps, tags, adjust=True, refine=True):
|
|
"""Group keypoints into poses given heatmap and tag.
|
|
|
|
Note:
|
|
batch size: N
|
|
number of keypoints: K
|
|
heatmap height: H
|
|
heatmap width: W
|
|
dim of tags: L
|
|
If use flip testing, L=2; else L=1.
|
|
|
|
Args:
|
|
heatmaps (torch.Tensor[NxKxHxW]): model output heatmaps.
|
|
tags (torch.Tensor[NxKxHxWxL]): model output tagmaps.
|
|
|
|
Returns:
|
|
tuple: A tuple containing keypoint grouping results.
|
|
|
|
- results (list(np.ndarray)): Pose results.
|
|
- scores (list/list(np.ndarray)): Score of people.
|
|
"""
|
|
results = self.match(**self.top_k(heatmaps, tags))
|
|
|
|
if adjust:
|
|
if self.use_udp:
|
|
for i in range(len(results)):
|
|
if results[i].shape[0] > 0:
|
|
results[i][..., :2] = post_dark_udp(
|
|
results[i][..., :2].copy(), heatmaps[i:i + 1, :])
|
|
else:
|
|
results = self.adjust(results, heatmaps)
|
|
|
|
if self.score_per_joint:
|
|
scores = [i[:, 2] for i in results[0]]
|
|
else:
|
|
scores = [i[:, 2].mean() for i in results[0]]
|
|
|
|
if refine:
|
|
results = results[0]
|
|
# for every detected person
|
|
for i in range(len(results)):
|
|
heatmap_numpy = heatmaps[0].cpu().numpy()
|
|
tag_numpy = tags[0].cpu().numpy()
|
|
if not self.tag_per_joint:
|
|
tag_numpy = np.tile(tag_numpy,
|
|
(self.params.num_joints, 1, 1, 1))
|
|
results[i] = self.refine(
|
|
heatmap_numpy, tag_numpy, results[i], use_udp=self.use_udp)
|
|
results = [results]
|
|
|
|
return results, scores
|
|
|