# Copyright (c) OpenMMLab. All rights reserved. import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer, constant_init, normal_init) from mmpose.core.evaluation.top_down_eval import ( keypoints_from_heatmaps3d, multilabel_classification_accuracy) from mmpose.core.post_processing import flip_back from mmpose.models.builder import build_loss from mmpose.models.necks import GlobalAveragePooling from ..builder import HEADS class Heatmap3DHead(nn.Module): """Heatmap3DHead is a sub-module of Interhand3DHead, and outputs 3D heatmaps. Heatmap3DHead is composed of (>=0) number of deconv layers and a simple conv2d layer. Args: in_channels (int): Number of input channels out_channels (int): Number of output channels depth_size (int): Number of depth discretization size num_deconv_layers (int): Number of deconv layers. num_deconv_layers should >= 0. Note that 0 means no deconv layers. num_deconv_filters (list|tuple): Number of filters. num_deconv_kernels (list|tuple): Kernel sizes. extra (dict): Configs for extra conv layers. Default: None """ def __init__(self, in_channels, out_channels, depth_size=64, num_deconv_layers=3, num_deconv_filters=(256, 256, 256), num_deconv_kernels=(4, 4, 4), extra=None): super().__init__() assert out_channels % depth_size == 0 self.depth_size = depth_size self.in_channels = in_channels if extra is not None and not isinstance(extra, dict): raise TypeError('extra should be dict or None.') if num_deconv_layers > 0: self.deconv_layers = self._make_deconv_layer( num_deconv_layers, num_deconv_filters, num_deconv_kernels, ) elif num_deconv_layers == 0: self.deconv_layers = nn.Identity() else: raise ValueError( f'num_deconv_layers ({num_deconv_layers}) should >= 0.') identity_final_layer = False if extra is not None and 'final_conv_kernel' in extra: assert extra['final_conv_kernel'] in [0, 1, 3] if extra['final_conv_kernel'] == 3: padding = 1 elif extra['final_conv_kernel'] == 1: padding = 0 else: # 0 for Identity mapping. identity_final_layer = True kernel_size = extra['final_conv_kernel'] else: kernel_size = 1 padding = 0 if identity_final_layer: self.final_layer = nn.Identity() else: conv_channels = num_deconv_filters[ -1] if num_deconv_layers > 0 else self.in_channels layers = [] if extra is not None: num_conv_layers = extra.get('num_conv_layers', 0) num_conv_kernels = extra.get('num_conv_kernels', [1] * num_conv_layers) for i in range(num_conv_layers): layers.append( build_conv_layer( dict(type='Conv2d'), in_channels=conv_channels, out_channels=conv_channels, kernel_size=num_conv_kernels[i], stride=1, padding=(num_conv_kernels[i] - 1) // 2)) layers.append( build_norm_layer(dict(type='BN'), conv_channels)[1]) layers.append(nn.ReLU(inplace=True)) layers.append( build_conv_layer( cfg=dict(type='Conv2d'), in_channels=conv_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1, padding=padding)) if len(layers) > 1: self.final_layer = nn.Sequential(*layers) else: self.final_layer = layers[0] def _make_deconv_layer(self, num_layers, num_filters, num_kernels): """Make deconv layers.""" if num_layers != len(num_filters): error_msg = f'num_layers({num_layers}) ' \ f'!= length of num_filters({len(num_filters)})' raise ValueError(error_msg) if num_layers != len(num_kernels): error_msg = f'num_layers({num_layers}) ' \ f'!= length of num_kernels({len(num_kernels)})' raise ValueError(error_msg) layers = [] for i in range(num_layers): kernel, padding, output_padding = \ self._get_deconv_cfg(num_kernels[i]) planes = num_filters[i] layers.append( build_upsample_layer( dict(type='deconv'), in_channels=self.in_channels, out_channels=planes, kernel_size=kernel, stride=2, padding=padding, output_padding=output_padding, bias=False)) layers.append(nn.BatchNorm2d(planes)) layers.append(nn.ReLU(inplace=True)) self.in_channels = planes return nn.Sequential(*layers) @staticmethod def _get_deconv_cfg(deconv_kernel): """Get configurations for deconv layers.""" if deconv_kernel == 4: padding = 1 output_padding = 0 elif deconv_kernel == 3: padding = 1 output_padding = 1 elif deconv_kernel == 2: padding = 0 output_padding = 0 else: raise ValueError(f'Not supported num_kernels ({deconv_kernel}).') return deconv_kernel, padding, output_padding def forward(self, x): """Forward function.""" x = self.deconv_layers(x) x = self.final_layer(x) N, C, H, W = x.shape # reshape the 2D heatmap to 3D heatmap x = x.reshape(N, C // self.depth_size, self.depth_size, H, W) return x def init_weights(self): """Initialize model weights.""" for _, m in self.deconv_layers.named_modules(): if isinstance(m, nn.ConvTranspose2d): normal_init(m, std=0.001) elif isinstance(m, nn.BatchNorm2d): constant_init(m, 1) for m in self.final_layer.modules(): if isinstance(m, nn.Conv2d): normal_init(m, std=0.001, bias=0) elif isinstance(m, nn.BatchNorm2d): constant_init(m, 1) class Heatmap1DHead(nn.Module): """Heatmap1DHead is a sub-module of Interhand3DHead, and outputs 1D heatmaps. Args: in_channels (int): Number of input channels heatmap_size (int): Heatmap size hidden_dims (list|tuple): Number of feature dimension of FC layers. """ def __init__(self, in_channels=2048, heatmap_size=64, hidden_dims=(512, )): super().__init__() self.in_channels = in_channels self.heatmap_size = heatmap_size feature_dims = [in_channels, *hidden_dims, heatmap_size] self.fc = self._make_linear_layers(feature_dims, relu_final=False) def soft_argmax_1d(self, heatmap1d): heatmap1d = F.softmax(heatmap1d, 1) accu = heatmap1d * torch.arange( self.heatmap_size, dtype=heatmap1d.dtype, device=heatmap1d.device)[None, :] coord = accu.sum(dim=1) return coord def _make_linear_layers(self, feat_dims, relu_final=False): """Make linear layers.""" layers = [] for i in range(len(feat_dims) - 1): layers.append(nn.Linear(feat_dims[i], feat_dims[i + 1])) if i < len(feat_dims) - 2 or \ (i == len(feat_dims) - 2 and relu_final): layers.append(nn.ReLU(inplace=True)) return nn.Sequential(*layers) def forward(self, x): """Forward function.""" heatmap1d = self.fc(x) value = self.soft_argmax_1d(heatmap1d).view(-1, 1) return value def init_weights(self): """Initialize model weights.""" for m in self.fc.modules(): if isinstance(m, nn.Linear): normal_init(m, mean=0, std=0.01, bias=0) class MultilabelClassificationHead(nn.Module): """MultilabelClassificationHead is a sub-module of Interhand3DHead, and outputs hand type classification. Args: in_channels (int): Number of input channels num_labels (int): Number of labels hidden_dims (list|tuple): Number of hidden dimension of FC layers. """ def __init__(self, in_channels=2048, num_labels=2, hidden_dims=(512, )): super().__init__() self.in_channels = in_channels self.num_labesl = num_labels feature_dims = [in_channels, *hidden_dims, num_labels] self.fc = self._make_linear_layers(feature_dims, relu_final=False) def _make_linear_layers(self, feat_dims, relu_final=False): """Make linear layers.""" layers = [] for i in range(len(feat_dims) - 1): layers.append(nn.Linear(feat_dims[i], feat_dims[i + 1])) if i < len(feat_dims) - 2 or \ (i == len(feat_dims) - 2 and relu_final): layers.append(nn.ReLU(inplace=True)) return nn.Sequential(*layers) def forward(self, x): """Forward function.""" labels = torch.sigmoid(self.fc(x)) return labels def init_weights(self): for m in self.fc.modules(): if isinstance(m, nn.Linear): normal_init(m, mean=0, std=0.01, bias=0) @HEADS.register_module() class Interhand3DHead(nn.Module): """Interhand 3D head of paper ref: Gyeongsik Moon. "InterHand2.6M: A Dataset and Baseline for 3D Interacting Hand Pose Estimation from a Single RGB Image". Args: keypoint_head_cfg (dict): Configs of Heatmap3DHead for hand keypoint estimation. root_head_cfg (dict): Configs of Heatmap1DHead for relative hand root depth estimation. hand_type_head_cfg (dict): Configs of MultilabelClassificationHead for hand type classification. loss_keypoint (dict): Config for keypoint loss. Default: None. loss_root_depth (dict): Config for relative root depth loss. Default: None. loss_hand_type (dict): Config for hand type classification loss. Default: None. """ def __init__(self, keypoint_head_cfg, root_head_cfg, hand_type_head_cfg, loss_keypoint=None, loss_root_depth=None, loss_hand_type=None, train_cfg=None, test_cfg=None): super().__init__() # build sub-module heads self.right_hand_head = Heatmap3DHead(**keypoint_head_cfg) self.left_hand_head = Heatmap3DHead(**keypoint_head_cfg) self.root_head = Heatmap1DHead(**root_head_cfg) self.hand_type_head = MultilabelClassificationHead( **hand_type_head_cfg) self.neck = GlobalAveragePooling() # build losses self.keypoint_loss = build_loss(loss_keypoint) self.root_depth_loss = build_loss(loss_root_depth) self.hand_type_loss = build_loss(loss_hand_type) self.train_cfg = {} if train_cfg is None else train_cfg self.test_cfg = {} if test_cfg is None else test_cfg self.target_type = self.test_cfg.get('target_type', 'GaussianHeatmap') def init_weights(self): self.left_hand_head.init_weights() self.right_hand_head.init_weights() self.root_head.init_weights() self.hand_type_head.init_weights() def get_loss(self, output, target, target_weight): """Calculate loss for hand keypoint heatmaps, relative root depth and hand type. Args: output (list[Tensor]): a list of outputs from multiple heads. target (list[Tensor]): a list of targets for multiple heads. target_weight (list[Tensor]): a list of targets weight for multiple heads. """ losses = dict() # hand keypoint loss assert not isinstance(self.keypoint_loss, nn.Sequential) out, tar, tar_weight = output[0], target[0], target_weight[0] assert tar.dim() == 5 and tar_weight.dim() == 3 losses['hand_loss'] = self.keypoint_loss(out, tar, tar_weight) # relative root depth loss assert not isinstance(self.root_depth_loss, nn.Sequential) out, tar, tar_weight = output[1], target[1], target_weight[1] assert tar.dim() == 2 and tar_weight.dim() == 2 losses['rel_root_loss'] = self.root_depth_loss(out, tar, tar_weight) # hand type loss assert not isinstance(self.hand_type_loss, nn.Sequential) out, tar, tar_weight = output[2], target[2], target_weight[2] assert tar.dim() == 2 and tar_weight.dim() in [1, 2] losses['hand_type_loss'] = self.hand_type_loss(out, tar, tar_weight) return losses def get_accuracy(self, output, target, target_weight): """Calculate accuracy for hand type. Args: output (list[Tensor]): a list of outputs from multiple heads. target (list[Tensor]): a list of targets for multiple heads. target_weight (list[Tensor]): a list of targets weight for multiple heads. """ accuracy = dict() avg_acc = multilabel_classification_accuracy( output[2].detach().cpu().numpy(), target[2].detach().cpu().numpy(), target_weight[2].detach().cpu().numpy(), ) accuracy['acc_classification'] = float(avg_acc) return accuracy def forward(self, x): """Forward function.""" outputs = [] outputs.append( torch.cat([self.right_hand_head(x), self.left_hand_head(x)], dim=1)) x = self.neck(x) outputs.append(self.root_head(x)) outputs.append(self.hand_type_head(x)) return outputs def inference_model(self, x, flip_pairs=None): """Inference function. Returns: output (list[np.ndarray]): list of output hand keypoint heatmaps, relative root depth and hand type. Args: x (torch.Tensor[N,K,H,W]): Input features. flip_pairs (None | list[tuple()): Pairs of keypoints which are mirrored. """ output = self.forward(x) if flip_pairs is not None: # flip 3D heatmap heatmap_3d = output[0] N, K, D, H, W = heatmap_3d.shape # reshape 3D heatmap to 2D heatmap heatmap_3d = heatmap_3d.reshape(N, K * D, H, W) # 2D heatmap flip heatmap_3d_flipped_back = flip_back( heatmap_3d.detach().cpu().numpy(), flip_pairs, target_type=self.target_type) # reshape back to 3D heatmap heatmap_3d_flipped_back = heatmap_3d_flipped_back.reshape( N, K, D, H, W) # feature is not aligned, shift flipped heatmap for higher accuracy if self.test_cfg.get('shift_heatmap', False): heatmap_3d_flipped_back[..., 1:] = heatmap_3d_flipped_back[..., :-1] output[0] = heatmap_3d_flipped_back # flip relative hand root depth output[1] = -output[1].detach().cpu().numpy() # flip hand type hand_type = output[2].detach().cpu().numpy() hand_type_flipped_back = hand_type.copy() hand_type_flipped_back[:, 0] = hand_type[:, 1] hand_type_flipped_back[:, 1] = hand_type[:, 0] output[2] = hand_type_flipped_back else: output = [out.detach().cpu().numpy() for out in output] return output def decode(self, img_metas, output, **kwargs): """Decode hand keypoint, relative root depth and hand type. Args: img_metas (list(dict)): Information about data augmentation By default this includes: - "image_file: path to the image file - "center": center of the bbox - "scale": scale of the bbox - "rotation": rotation of the bbox - "bbox_score": score of bbox - "heatmap3d_depth_bound": depth bound of hand keypoint 3D heatmap - "root_depth_bound": depth bound of relative root depth 1D heatmap output (list[np.ndarray]): model predicted 3D heatmaps, relative root depth and hand type. """ batch_size = len(img_metas) result = {} heatmap3d_depth_bound = np.ones(batch_size, dtype=np.float32) root_depth_bound = np.ones(batch_size, dtype=np.float32) center = np.zeros((batch_size, 2), dtype=np.float32) scale = np.zeros((batch_size, 2), dtype=np.float32) image_paths = [] score = np.ones(batch_size, dtype=np.float32) if 'bbox_id' in img_metas[0]: bbox_ids = [] else: bbox_ids = None for i in range(batch_size): heatmap3d_depth_bound[i] = img_metas[i]['heatmap3d_depth_bound'] root_depth_bound[i] = img_metas[i]['root_depth_bound'] center[i, :] = img_metas[i]['center'] scale[i, :] = img_metas[i]['scale'] image_paths.append(img_metas[i]['image_file']) if 'bbox_score' in img_metas[i]: score[i] = np.array(img_metas[i]['bbox_score']).reshape(-1) if bbox_ids is not None: bbox_ids.append(img_metas[i]['bbox_id']) all_boxes = np.zeros((batch_size, 6), dtype=np.float32) all_boxes[:, 0:2] = center[:, 0:2] all_boxes[:, 2:4] = scale[:, 0:2] # scale is defined as: bbox_size / 200.0, so we # need multiply 200.0 to get bbox size all_boxes[:, 4] = np.prod(scale * 200.0, axis=1) all_boxes[:, 5] = score result['boxes'] = all_boxes result['image_paths'] = image_paths result['bbox_ids'] = bbox_ids # decode 3D heatmaps of hand keypoints heatmap3d = output[0] preds, maxvals = keypoints_from_heatmaps3d(heatmap3d, center, scale) keypoints_3d = np.zeros((batch_size, preds.shape[1], 4), dtype=np.float32) keypoints_3d[:, :, 0:3] = preds[:, :, 0:3] keypoints_3d[:, :, 3:4] = maxvals # transform keypoint depth to camera space keypoints_3d[:, :, 2] = \ (keypoints_3d[:, :, 2] / self.right_hand_head.depth_size - 0.5) \ * heatmap3d_depth_bound[:, np.newaxis] result['preds'] = keypoints_3d # decode relative hand root depth # transform relative root depth to camera space result['rel_root_depth'] = (output[1] / self.root_head.heatmap_size - 0.5) * root_depth_bound # decode hand type result['hand_type'] = output[2] > 0.5 return result