You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
99 lines
3.5 KiB
99 lines
3.5 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from ..builder import HEADS
|
|
from .deconv_head import DeconvHead
|
|
|
|
|
|
@HEADS.register_module()
|
|
class AESimpleHead(DeconvHead):
|
|
"""Associative embedding simple head.
|
|
paper ref: Alejandro Newell et al. "Associative
|
|
Embedding: End-to-end Learning for Joint Detection
|
|
and Grouping"
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
num_joints (int): Number of joints.
|
|
num_deconv_layers (int): Number of deconv layers.
|
|
num_deconv_layers should >= 0. Note that 0 means
|
|
no deconv layers.
|
|
num_deconv_filters (list|tuple): Number of filters.
|
|
If num_deconv_layers > 0, the length of
|
|
num_deconv_kernels (list|tuple): Kernel sizes.
|
|
tag_per_joint (bool): If tag_per_joint is True,
|
|
the dimension of tags equals to num_joints,
|
|
else the dimension of tags is 1. Default: True
|
|
with_ae_loss (list[bool]): Option to use ae loss or not.
|
|
loss_keypoint (dict): Config for loss. Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
num_joints,
|
|
num_deconv_layers=3,
|
|
num_deconv_filters=(256, 256, 256),
|
|
num_deconv_kernels=(4, 4, 4),
|
|
tag_per_joint=True,
|
|
with_ae_loss=None,
|
|
extra=None,
|
|
loss_keypoint=None):
|
|
|
|
dim_tag = num_joints if tag_per_joint else 1
|
|
if with_ae_loss[0]:
|
|
out_channels = num_joints + dim_tag
|
|
else:
|
|
out_channels = num_joints
|
|
|
|
super().__init__(
|
|
in_channels,
|
|
out_channels,
|
|
num_deconv_layers=num_deconv_layers,
|
|
num_deconv_filters=num_deconv_filters,
|
|
num_deconv_kernels=num_deconv_kernels,
|
|
extra=extra,
|
|
loss_keypoint=loss_keypoint)
|
|
|
|
def get_loss(self, outputs, targets, masks, joints):
|
|
"""Calculate bottom-up keypoint loss.
|
|
|
|
Note:
|
|
- batch_size: N
|
|
- num_keypoints: K
|
|
- num_outputs: O
|
|
- heatmaps height: H
|
|
- heatmaps weight: W
|
|
|
|
Args:
|
|
outputs (list(torch.Tensor[N,K,H,W])): Multi-scale output heatmaps.
|
|
targets (List(torch.Tensor[N,K,H,W])): Multi-scale target heatmaps.
|
|
masks (List(torch.Tensor[N,H,W])): Masks of multi-scale target
|
|
heatmaps
|
|
joints(List(torch.Tensor[N,M,K,2])): Joints of multi-scale target
|
|
heatmaps for ae loss
|
|
"""
|
|
|
|
losses = dict()
|
|
|
|
heatmaps_losses, push_losses, pull_losses = self.loss(
|
|
outputs, targets, masks, joints)
|
|
|
|
for idx in range(len(targets)):
|
|
if heatmaps_losses[idx] is not None:
|
|
heatmaps_loss = heatmaps_losses[idx].mean(dim=0)
|
|
if 'heatmap_loss' not in losses:
|
|
losses['heatmap_loss'] = heatmaps_loss
|
|
else:
|
|
losses['heatmap_loss'] += heatmaps_loss
|
|
if push_losses[idx] is not None:
|
|
push_loss = push_losses[idx].mean(dim=0)
|
|
if 'push_loss' not in losses:
|
|
losses['push_loss'] = push_loss
|
|
else:
|
|
losses['push_loss'] += push_loss
|
|
if pull_losses[idx] is not None:
|
|
pull_loss = pull_losses[idx].mean(dim=0)
|
|
if 'pull_loss' not in losses:
|
|
losses['pull_loss'] = pull_loss
|
|
else:
|
|
losses['pull_loss'] += pull_loss
|
|
|
|
return losses
|
|
|