You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
889 lines
33 KiB
889 lines
33 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.runner import load_checkpoint
|
|
|
|
from mmpose.core.camera import SimpleCameraTorch
|
|
from mmpose.core.post_processing.post_transforms import (
|
|
affine_transform_torch, get_affine_transform)
|
|
from .. import builder
|
|
from ..builder import POSENETS
|
|
from .base import BasePose
|
|
|
|
|
|
class ProjectLayer(nn.Module):
|
|
|
|
def __init__(self, image_size, heatmap_size):
|
|
"""Project layer to get voxel feature. Adapted from
|
|
https://github.com/microsoft/voxelpose-
|
|
pytorch/blob/main/lib/models/project_layer.py.
|
|
|
|
Args:
|
|
image_size (int or list): input size of the 2D model
|
|
heatmap_size (int or list): output size of the 2D model
|
|
"""
|
|
super(ProjectLayer, self).__init__()
|
|
self.image_size = image_size
|
|
self.heatmap_size = heatmap_size
|
|
if isinstance(self.image_size, int):
|
|
self.image_size = [self.image_size, self.image_size]
|
|
if isinstance(self.heatmap_size, int):
|
|
self.heatmap_size = [self.heatmap_size, self.heatmap_size]
|
|
|
|
def compute_grid(self, box_size, box_center, num_bins, device=None):
|
|
if isinstance(box_size, int) or isinstance(box_size, float):
|
|
box_size = [box_size, box_size, box_size]
|
|
if isinstance(num_bins, int):
|
|
num_bins = [num_bins, num_bins, num_bins]
|
|
|
|
grid_1D_x = torch.linspace(
|
|
-box_size[0] / 2, box_size[0] / 2, num_bins[0], device=device)
|
|
grid_1D_y = torch.linspace(
|
|
-box_size[1] / 2, box_size[1] / 2, num_bins[1], device=device)
|
|
grid_1D_z = torch.linspace(
|
|
-box_size[2] / 2, box_size[2] / 2, num_bins[2], device=device)
|
|
grid_x, grid_y, grid_z = torch.meshgrid(
|
|
grid_1D_x + box_center[0],
|
|
grid_1D_y + box_center[1],
|
|
grid_1D_z + box_center[2],
|
|
)
|
|
grid_x = grid_x.contiguous().view(-1, 1)
|
|
grid_y = grid_y.contiguous().view(-1, 1)
|
|
grid_z = grid_z.contiguous().view(-1, 1)
|
|
grid = torch.cat([grid_x, grid_y, grid_z], dim=1)
|
|
|
|
return grid
|
|
|
|
def get_voxel(self, feature_maps, meta, grid_size, grid_center, cube_size):
|
|
device = feature_maps[0].device
|
|
batch_size = feature_maps[0].shape[0]
|
|
num_channels = feature_maps[0].shape[1]
|
|
num_bins = cube_size[0] * cube_size[1] * cube_size[2]
|
|
n = len(feature_maps)
|
|
cubes = torch.zeros(
|
|
batch_size, num_channels, 1, num_bins, n, device=device)
|
|
w, h = self.heatmap_size
|
|
grids = torch.zeros(batch_size, num_bins, 3, device=device)
|
|
bounding = torch.zeros(batch_size, 1, 1, num_bins, n, device=device)
|
|
for i in range(batch_size):
|
|
if len(grid_center[0]) == 3 or grid_center[i][3] >= 0:
|
|
if len(grid_center) == 1:
|
|
grid = self.compute_grid(
|
|
grid_size, grid_center[0], cube_size, device=device)
|
|
else:
|
|
grid = self.compute_grid(
|
|
grid_size, grid_center[i], cube_size, device=device)
|
|
grids[i:i + 1] = grid
|
|
for c in range(n):
|
|
center = meta[i]['center'][c]
|
|
scale = meta[i]['scale'][c]
|
|
|
|
width, height = center * 2
|
|
trans = torch.as_tensor(
|
|
get_affine_transform(center, scale / 200.0, 0,
|
|
self.image_size),
|
|
dtype=torch.float,
|
|
device=device)
|
|
|
|
cam_param = meta[i]['camera'][c].copy()
|
|
|
|
single_view_camera = SimpleCameraTorch(
|
|
param=cam_param, device=device)
|
|
xy = single_view_camera.world_to_pixel(grid)
|
|
|
|
bounding[i, 0, 0, :, c] = (xy[:, 0] >= 0) & (
|
|
xy[:, 1] >= 0) & (xy[:, 0] < width) & (
|
|
xy[:, 1] < height)
|
|
xy = torch.clamp(xy, -1.0, max(width, height))
|
|
xy = affine_transform_torch(xy, trans)
|
|
xy = xy * torch.tensor(
|
|
[w, h], dtype=torch.float,
|
|
device=device) / torch.tensor(
|
|
self.image_size, dtype=torch.float, device=device)
|
|
sample_grid = xy / torch.tensor([w - 1, h - 1],
|
|
dtype=torch.float,
|
|
device=device) * 2.0 - 1.0
|
|
sample_grid = torch.clamp(
|
|
sample_grid.view(1, 1, num_bins, 2), -1.1, 1.1)
|
|
|
|
cubes[i:i + 1, :, :, :, c] += F.grid_sample(
|
|
feature_maps[c][i:i + 1, :, :, :],
|
|
sample_grid,
|
|
align_corners=True)
|
|
|
|
cubes = torch.sum(
|
|
torch.mul(cubes, bounding), dim=-1) / (
|
|
torch.sum(bounding, dim=-1) + 1e-6)
|
|
cubes[cubes != cubes] = 0.0
|
|
cubes = cubes.clamp(0.0, 1.0)
|
|
|
|
cubes = cubes.view(batch_size, num_channels, cube_size[0],
|
|
cube_size[1], cube_size[2])
|
|
return cubes, grids
|
|
|
|
def forward(self, feature_maps, meta, grid_size, grid_center, cube_size):
|
|
cubes, grids = self.get_voxel(feature_maps, meta, grid_size,
|
|
grid_center, cube_size)
|
|
return cubes, grids
|
|
|
|
|
|
@POSENETS.register_module()
|
|
class DetectAndRegress(BasePose):
|
|
"""DetectAndRegress approach for multiview human pose detection.
|
|
|
|
Args:
|
|
backbone (ConfigDict): Dictionary to construct the 2D pose detector
|
|
human_detector (ConfigDict): dictionary to construct human detector
|
|
pose_regressor (ConfigDict): dictionary to construct pose regressor
|
|
train_cfg (ConfigDict): Config for training. Default: None.
|
|
test_cfg (ConfigDict): Config for testing. Default: None.
|
|
pretrained (str): Path to the pretrained 2D model. Default: None.
|
|
freeze_2d (bool): Whether to freeze the 2D model in training.
|
|
Default: True.
|
|
"""
|
|
|
|
def __init__(self,
|
|
backbone,
|
|
human_detector,
|
|
pose_regressor,
|
|
train_cfg=None,
|
|
test_cfg=None,
|
|
pretrained=None,
|
|
freeze_2d=True):
|
|
super(DetectAndRegress, self).__init__()
|
|
if backbone is not None:
|
|
self.backbone = builder.build_posenet(backbone)
|
|
if self.training and pretrained is not None:
|
|
load_checkpoint(self.backbone, pretrained)
|
|
else:
|
|
self.backbone = None
|
|
|
|
self.freeze_2d = freeze_2d
|
|
self.human_detector = builder.MODELS.build(human_detector)
|
|
self.pose_regressor = builder.MODELS.build(pose_regressor)
|
|
|
|
self.train_cfg = train_cfg
|
|
self.test_cfg = test_cfg
|
|
|
|
@staticmethod
|
|
def _freeze(model):
|
|
"""Freeze parameters."""
|
|
model.eval()
|
|
for param in model.parameters():
|
|
param.requires_grad = False
|
|
|
|
def train(self, mode=True):
|
|
"""Sets the module in training mode.
|
|
Args:
|
|
mode (bool): whether to set training mode (``True``)
|
|
or evaluation mode (``False``). Default: ``True``.
|
|
|
|
Returns:
|
|
Module: self
|
|
"""
|
|
super().train(mode)
|
|
if mode and self.freeze_2d and self.backbone is not None:
|
|
self._freeze(self.backbone)
|
|
|
|
return self
|
|
|
|
def forward(self,
|
|
img=None,
|
|
img_metas=None,
|
|
return_loss=True,
|
|
targets=None,
|
|
masks=None,
|
|
targets_3d=None,
|
|
input_heatmaps=None,
|
|
**kwargs):
|
|
"""
|
|
Note:
|
|
batch_size: N
|
|
num_keypoints: K
|
|
num_img_channel: C
|
|
img_width: imgW
|
|
img_height: imgH
|
|
feature_maps width: W
|
|
feature_maps height: H
|
|
volume_length: cubeL
|
|
volume_width: cubeW
|
|
volume_height: cubeH
|
|
|
|
Args:
|
|
img (list(torch.Tensor[NxCximgHximgW])):
|
|
Multi-camera input images to the 2D model.
|
|
img_metas (list(dict)):
|
|
Information about image, 3D groundtruth and camera parameters.
|
|
return_loss: Option to `return loss`. `return loss=True`
|
|
for training, `return loss=False` for validation & test.
|
|
targets (list(torch.Tensor[NxKxHxW])):
|
|
Multi-camera target feature_maps of the 2D model.
|
|
masks (list(torch.Tensor[NxHxW])):
|
|
Multi-camera masks of the input to the 2D model.
|
|
targets_3d (torch.Tensor[NxcubeLxcubeWxcubeH]):
|
|
Ground-truth 3D heatmap of human centers.
|
|
input_heatmaps (list(torch.Tensor[NxKxHxW])):
|
|
Multi-camera feature_maps when the 2D model is not available.
|
|
Default: None.
|
|
**kwargs:
|
|
|
|
Returns:
|
|
dict: if 'return_loss' is true, then return losses.
|
|
Otherwise, return predicted poses, human centers and sample_id
|
|
|
|
"""
|
|
if return_loss:
|
|
return self.forward_train(img, img_metas, targets, masks,
|
|
targets_3d, input_heatmaps)
|
|
else:
|
|
return self.forward_test(img, img_metas, input_heatmaps)
|
|
|
|
def train_step(self, data_batch, optimizer, **kwargs):
|
|
"""The iteration step during training.
|
|
|
|
This method defines an iteration step during training, except for the
|
|
back propagation and optimizer updating, which are done in an optimizer
|
|
hook. Note that in some complicated cases or models, the whole process
|
|
including back propagation and optimizer updating is also defined in
|
|
this method, such as GAN.
|
|
|
|
Args:
|
|
data_batch (dict): The output of dataloader.
|
|
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
|
|
runner is passed to ``train_step()``. This argument is unused
|
|
and reserved.
|
|
|
|
Returns:
|
|
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
|
|
``num_samples``.
|
|
``loss`` is a tensor for back propagation, which can be a
|
|
weighted sum of multiple losses.
|
|
``log_vars`` contains all the variables to be sent to the
|
|
logger.
|
|
``num_samples`` indicates the batch size (when the model is
|
|
DDP, it means the batch size on each GPU), which is used for
|
|
averaging the logs.
|
|
"""
|
|
losses = self.forward(**data_batch)
|
|
|
|
loss, log_vars = self._parse_losses(losses)
|
|
if 'img' in data_batch:
|
|
batch_size = data_batch['img'][0].shape[0]
|
|
else:
|
|
assert 'input_heatmaps' in data_batch
|
|
batch_size = data_batch['input_heatmaps'][0][0].shape[0]
|
|
|
|
outputs = dict(loss=loss, log_vars=log_vars, num_samples=batch_size)
|
|
|
|
return outputs
|
|
|
|
def forward_train(self,
|
|
img,
|
|
img_metas,
|
|
targets=None,
|
|
masks=None,
|
|
targets_3d=None,
|
|
input_heatmaps=None):
|
|
"""
|
|
Note:
|
|
batch_size: N
|
|
num_keypoints: K
|
|
num_img_channel: C
|
|
img_width: imgW
|
|
img_height: imgH
|
|
feature_maps width: W
|
|
feature_maps height: H
|
|
volume_length: cubeL
|
|
volume_width: cubeW
|
|
volume_height: cubeH
|
|
|
|
Args:
|
|
img (list(torch.Tensor[NxCximgHximgW])):
|
|
Multi-camera input images to the 2D model.
|
|
img_metas (list(dict)):
|
|
Information about image, 3D groundtruth and camera parameters.
|
|
targets (list(torch.Tensor[NxKxHxW])):
|
|
Multi-camera target feature_maps of the 2D model.
|
|
masks (list(torch.Tensor[NxHxW])):
|
|
Multi-camera masks of the input to the 2D model.
|
|
targets_3d (torch.Tensor[NxcubeLxcubeWxcubeH]):
|
|
Ground-truth 3D heatmap of human centers.
|
|
input_heatmaps (list(torch.Tensor[NxKxHxW])):
|
|
Multi-camera feature_maps when the 2D model is not available.
|
|
Default: None.
|
|
|
|
Returns:
|
|
dict: losses.
|
|
|
|
"""
|
|
if self.backbone is None:
|
|
assert input_heatmaps is not None
|
|
feature_maps = []
|
|
for input_heatmap in input_heatmaps:
|
|
feature_maps.append(input_heatmap[0])
|
|
else:
|
|
feature_maps = []
|
|
assert isinstance(img, list)
|
|
for img_ in img:
|
|
feature_maps.append(self.backbone.forward_dummy(img_)[0])
|
|
|
|
losses = dict()
|
|
human_candidates, human_loss = self.human_detector.forward_train(
|
|
None, img_metas, feature_maps, targets_3d, return_preds=True)
|
|
losses.update(human_loss)
|
|
|
|
pose_loss = self.pose_regressor(
|
|
None,
|
|
img_metas,
|
|
return_loss=True,
|
|
feature_maps=feature_maps,
|
|
human_candidates=human_candidates)
|
|
losses.update(pose_loss)
|
|
|
|
if not self.freeze_2d:
|
|
losses_2d = {}
|
|
heatmaps_tensor = torch.cat(feature_maps, dim=0)
|
|
targets_tensor = torch.cat(targets, dim=0)
|
|
masks_tensor = torch.cat(masks, dim=0)
|
|
losses_2d_ = self.backbone.get_loss(heatmaps_tensor,
|
|
targets_tensor, masks_tensor)
|
|
for k, v in losses_2d_.items():
|
|
losses_2d[k + '_2d'] = v
|
|
losses.update(losses_2d)
|
|
|
|
return losses
|
|
|
|
def forward_test(
|
|
self,
|
|
img,
|
|
img_metas,
|
|
input_heatmaps=None,
|
|
):
|
|
"""
|
|
Note:
|
|
batch_size: N
|
|
num_keypoints: K
|
|
num_img_channel: C
|
|
img_width: imgW
|
|
img_height: imgH
|
|
feature_maps width: W
|
|
feature_maps height: H
|
|
volume_length: cubeL
|
|
volume_width: cubeW
|
|
volume_height: cubeH
|
|
|
|
Args:
|
|
img (list(torch.Tensor[NxCximgHximgW])):
|
|
Multi-camera input images to the 2D model.
|
|
img_metas (list(dict)):
|
|
Information about image, 3D groundtruth and camera parameters.
|
|
input_heatmaps (list(torch.Tensor[NxKxHxW])):
|
|
Multi-camera feature_maps when the 2D model is not available.
|
|
Default: None.
|
|
|
|
Returns:
|
|
dict: predicted poses, human centers and sample_id
|
|
|
|
"""
|
|
if self.backbone is None:
|
|
assert input_heatmaps is not None
|
|
feature_maps = []
|
|
for input_heatmap in input_heatmaps:
|
|
feature_maps.append(input_heatmap[0])
|
|
else:
|
|
feature_maps = []
|
|
assert isinstance(img, list)
|
|
for img_ in img:
|
|
feature_maps.append(self.backbone.forward_dummy(img_)[0])
|
|
|
|
human_candidates = self.human_detector.forward_test(
|
|
None, img_metas, feature_maps)
|
|
|
|
human_poses = self.pose_regressor(
|
|
None,
|
|
img_metas,
|
|
return_loss=False,
|
|
feature_maps=feature_maps,
|
|
human_candidates=human_candidates)
|
|
|
|
result = {}
|
|
result['pose_3d'] = human_poses.cpu().numpy()
|
|
result['human_detection_3d'] = human_candidates.cpu().numpy()
|
|
result['sample_id'] = [img_meta['sample_id'] for img_meta in img_metas]
|
|
|
|
return result
|
|
|
|
def show_result(self, **kwargs):
|
|
"""Visualize the results."""
|
|
raise NotImplementedError
|
|
|
|
def forward_dummy(self, img, input_heatmaps=None, num_candidates=5):
|
|
"""Used for computing network FLOPs."""
|
|
if self.backbone is None:
|
|
assert input_heatmaps is not None
|
|
feature_maps = []
|
|
for input_heatmap in input_heatmaps:
|
|
feature_maps.append(input_heatmap[0])
|
|
else:
|
|
feature_maps = []
|
|
assert isinstance(img, list)
|
|
for img_ in img:
|
|
feature_maps.append(self.backbone.forward_dummy(img_)[0])
|
|
|
|
_ = self.human_detector.forward_dummy(feature_maps)
|
|
|
|
_ = self.pose_regressor.forward_dummy(feature_maps, num_candidates)
|
|
|
|
|
|
@POSENETS.register_module()
|
|
class VoxelSinglePose(BasePose):
|
|
"""VoxelPose Please refer to the `paper <https://arxiv.org/abs/2004.06239>`
|
|
for details.
|
|
|
|
Args:
|
|
image_size (list): input size of the 2D model.
|
|
heatmap_size (list): output size of the 2D model.
|
|
sub_space_size (list): Size of the cuboid human proposal.
|
|
sub_cube_size (list): Size of the input volume to the pose net.
|
|
pose_net (ConfigDict): Dictionary to construct the pose net.
|
|
pose_head (ConfigDict): Dictionary to construct the pose head.
|
|
train_cfg (ConfigDict): Config for training. Default: None.
|
|
test_cfg (ConfigDict): Config for testing. Default: None.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
image_size,
|
|
heatmap_size,
|
|
sub_space_size,
|
|
sub_cube_size,
|
|
num_joints,
|
|
pose_net,
|
|
pose_head,
|
|
train_cfg=None,
|
|
test_cfg=None,
|
|
):
|
|
super(VoxelSinglePose, self).__init__()
|
|
self.project_layer = ProjectLayer(image_size, heatmap_size)
|
|
self.pose_net = builder.build_backbone(pose_net)
|
|
self.pose_head = builder.build_head(pose_head)
|
|
|
|
self.sub_space_size = sub_space_size
|
|
self.sub_cube_size = sub_cube_size
|
|
|
|
self.num_joints = num_joints
|
|
self.train_cfg = train_cfg
|
|
self.test_cfg = test_cfg
|
|
|
|
def forward(self,
|
|
img,
|
|
img_metas,
|
|
return_loss=True,
|
|
feature_maps=None,
|
|
human_candidates=None,
|
|
**kwargs):
|
|
"""
|
|
Note:
|
|
batch_size: N
|
|
num_keypoints: K
|
|
num_img_channel: C
|
|
img_width: imgW
|
|
img_height: imgH
|
|
feature_maps width: W
|
|
feature_maps height: H
|
|
volume_length: cubeL
|
|
volume_width: cubeW
|
|
volume_height: cubeH
|
|
|
|
Args:
|
|
img (list(torch.Tensor[NxCximgHximgW])):
|
|
Multi-camera input images to the 2D model.
|
|
feature_maps (list(torch.Tensor[NxCxHxW])):
|
|
Multi-camera input feature_maps.
|
|
img_metas (list(dict)):
|
|
Information about image, 3D groundtruth and camera parameters.
|
|
human_candidates (torch.Tensor[NxPx5]):
|
|
Human candidates.
|
|
return_loss: Option to `return loss`. `return loss=True`
|
|
for training, `return loss=False` for validation & test.
|
|
|
|
"""
|
|
if return_loss:
|
|
return self.forward_train(img, img_metas, feature_maps,
|
|
human_candidates)
|
|
else:
|
|
return self.forward_test(img, img_metas, feature_maps,
|
|
human_candidates)
|
|
|
|
def forward_train(self,
|
|
img,
|
|
img_metas,
|
|
feature_maps=None,
|
|
human_candidates=None,
|
|
return_preds=False,
|
|
**kwargs):
|
|
"""Defines the computation performed at training.
|
|
Note:
|
|
batch_size: N
|
|
num_keypoints: K
|
|
num_img_channel: C
|
|
img_width: imgW
|
|
img_height: imgH
|
|
feature_maps width: W
|
|
feature_maps height: H
|
|
volume_length: cubeL
|
|
volume_width: cubeW
|
|
volume_height: cubeH
|
|
|
|
Args:
|
|
img (list(torch.Tensor[NxCximgHximgW])):
|
|
Multi-camera input images to the 2D model.
|
|
feature_maps (list(torch.Tensor[NxCxHxW])):
|
|
Multi-camera input feature_maps.
|
|
img_metas (list(dict)):
|
|
Information about image, 3D groundtruth and camera parameters.
|
|
human_candidates (torch.Tensor[NxPx5]):
|
|
Human candidates.
|
|
return_preds (bool): Whether to return prediction results
|
|
|
|
Returns:
|
|
dict: losses.
|
|
|
|
"""
|
|
batch_size, num_candidates, _ = human_candidates.shape
|
|
pred = human_candidates.new_zeros(batch_size, num_candidates,
|
|
self.num_joints, 5)
|
|
pred[:, :, :, 3:] = human_candidates[:, :, None, 3:]
|
|
|
|
device = feature_maps[0].device
|
|
gt_3d = torch.stack([
|
|
torch.tensor(img_meta['joints_3d'], device=device)
|
|
for img_meta in img_metas
|
|
])
|
|
gt_3d_vis = torch.stack([
|
|
torch.tensor(img_meta['joints_3d_visible'], device=device)
|
|
for img_meta in img_metas
|
|
])
|
|
valid_preds = []
|
|
valid_targets = []
|
|
valid_weights = []
|
|
|
|
for n in range(num_candidates):
|
|
index = pred[:, n, 0, 3] >= 0
|
|
num_valid = index.sum()
|
|
if num_valid > 0:
|
|
pose_input_cube, coordinates \
|
|
= self.project_layer(feature_maps,
|
|
img_metas,
|
|
self.sub_space_size,
|
|
human_candidates[:, n, :3],
|
|
self.sub_cube_size)
|
|
pose_heatmaps_3d = self.pose_net(pose_input_cube)
|
|
pose_3d = self.pose_head(pose_heatmaps_3d[index],
|
|
coordinates[index])
|
|
|
|
pred[index, n, :, 0:3] = pose_3d.detach()
|
|
valid_targets.append(gt_3d[index, pred[index, n, 0, 3].long()])
|
|
valid_weights.append(gt_3d_vis[index, pred[index, n, 0,
|
|
3].long(), :,
|
|
0:1].float())
|
|
valid_preds.append(pose_3d)
|
|
|
|
losses = dict()
|
|
if len(valid_preds) > 0:
|
|
valid_targets = torch.cat(valid_targets, dim=0)
|
|
valid_weights = torch.cat(valid_weights, dim=0)
|
|
valid_preds = torch.cat(valid_preds, dim=0)
|
|
losses.update(
|
|
self.pose_head.get_loss(valid_preds, valid_targets,
|
|
valid_weights))
|
|
else:
|
|
pose_input_cube = feature_maps[0].new_zeros(
|
|
batch_size, self.num_joints, *self.sub_cube_size)
|
|
coordinates = feature_maps[0].new_zeros(batch_size,
|
|
*self.sub_cube_size,
|
|
3).view(batch_size, -1, 3)
|
|
pseudo_targets = feature_maps[0].new_zeros(batch_size,
|
|
self.num_joints, 3)
|
|
pseudo_weights = feature_maps[0].new_zeros(batch_size,
|
|
self.num_joints, 1)
|
|
pose_heatmaps_3d = self.pose_net(pose_input_cube)
|
|
pose_3d = self.pose_head(pose_heatmaps_3d, coordinates)
|
|
losses.update(
|
|
self.pose_head.get_loss(pose_3d, pseudo_targets,
|
|
pseudo_weights))
|
|
if return_preds:
|
|
return pred, losses
|
|
else:
|
|
return losses
|
|
|
|
def forward_test(self,
|
|
img,
|
|
img_metas,
|
|
feature_maps=None,
|
|
human_candidates=None,
|
|
**kwargs):
|
|
"""Defines the computation performed at training.
|
|
Note:
|
|
batch_size: N
|
|
num_keypoints: K
|
|
num_img_channel: C
|
|
img_width: imgW
|
|
img_height: imgH
|
|
feature_maps width: W
|
|
feature_maps height: H
|
|
volume_length: cubeL
|
|
volume_width: cubeW
|
|
volume_height: cubeH
|
|
|
|
Args:
|
|
img (list(torch.Tensor[NxCximgHximgW])):
|
|
Multi-camera input images to the 2D model.
|
|
feature_maps (list(torch.Tensor[NxCxHxW])):
|
|
Multi-camera input feature_maps.
|
|
img_metas (list(dict)):
|
|
Information about image, 3D groundtruth and camera parameters.
|
|
human_candidates (torch.Tensor[NxPx5]):
|
|
Human candidates.
|
|
|
|
Returns:
|
|
dict: predicted poses, human centers and sample_id
|
|
|
|
"""
|
|
batch_size, num_candidates, _ = human_candidates.shape
|
|
pred = human_candidates.new_zeros(batch_size, num_candidates,
|
|
self.num_joints, 5)
|
|
pred[:, :, :, 3:] = human_candidates[:, :, None, 3:]
|
|
|
|
for n in range(num_candidates):
|
|
index = pred[:, n, 0, 3] >= 0
|
|
num_valid = index.sum()
|
|
if num_valid > 0:
|
|
pose_input_cube, coordinates \
|
|
= self.project_layer(feature_maps,
|
|
img_metas,
|
|
self.sub_space_size,
|
|
human_candidates[:, n, :3],
|
|
self.sub_cube_size)
|
|
pose_heatmaps_3d = self.pose_net(pose_input_cube)
|
|
pose_3d = self.pose_head(pose_heatmaps_3d[index],
|
|
coordinates[index])
|
|
|
|
pred[index, n, :, 0:3] = pose_3d.detach()
|
|
|
|
return pred
|
|
|
|
def show_result(self, **kwargs):
|
|
"""Visualize the results."""
|
|
raise NotImplementedError
|
|
|
|
def forward_dummy(self, feature_maps, num_candidates=5):
|
|
"""Used for computing network FLOPs."""
|
|
batch_size, num_channels = feature_maps[0].shape
|
|
pose_input_cube = feature_maps[0].new_zeros(batch_size, num_channels,
|
|
*self.sub_cube_size)
|
|
for n in range(num_candidates):
|
|
_ = self.pose_net(pose_input_cube)
|
|
|
|
|
|
@POSENETS.register_module()
|
|
class VoxelCenterDetector(BasePose):
|
|
"""Detect human center by 3D CNN on voxels.
|
|
|
|
Please refer to the
|
|
`paper <https://arxiv.org/abs/2004.06239>` for details.
|
|
Args:
|
|
image_size (list): input size of the 2D model.
|
|
heatmap_size (list): output size of the 2D model.
|
|
space_size (list): Size of the 3D space.
|
|
cube_size (list): Size of the input volume to the 3D CNN.
|
|
space_center (list): Coordinate of the center of the 3D space.
|
|
center_net (ConfigDict): Dictionary to construct the center net.
|
|
center_head (ConfigDict): Dictionary to construct the center head.
|
|
train_cfg (ConfigDict): Config for training. Default: None.
|
|
test_cfg (ConfigDict): Config for testing. Default: None.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
image_size,
|
|
heatmap_size,
|
|
space_size,
|
|
cube_size,
|
|
space_center,
|
|
center_net,
|
|
center_head,
|
|
train_cfg=None,
|
|
test_cfg=None,
|
|
):
|
|
super(VoxelCenterDetector, self).__init__()
|
|
self.project_layer = ProjectLayer(image_size, heatmap_size)
|
|
self.center_net = builder.build_backbone(center_net)
|
|
self.center_head = builder.build_head(center_head)
|
|
|
|
self.space_size = space_size
|
|
self.cube_size = cube_size
|
|
self.space_center = space_center
|
|
|
|
self.train_cfg = train_cfg
|
|
self.test_cfg = test_cfg
|
|
|
|
def assign2gt(self, center_candidates, gt_centers, gt_num_persons):
|
|
""""Assign gt id to each valid human center candidate."""
|
|
det_centers = center_candidates[..., :3]
|
|
batch_size = center_candidates.shape[0]
|
|
cand_num = center_candidates.shape[1]
|
|
cand2gt = torch.zeros(batch_size, cand_num)
|
|
|
|
for i in range(batch_size):
|
|
cand = det_centers[i].view(cand_num, 1, -1)
|
|
gt = gt_centers[None, i, :gt_num_persons[i]]
|
|
|
|
dist = torch.sqrt(torch.sum((cand - gt)**2, dim=-1))
|
|
min_dist, min_gt = torch.min(dist, dim=-1)
|
|
|
|
cand2gt[i] = min_gt
|
|
cand2gt[i][min_dist > self.train_cfg['dist_threshold']] = -1.0
|
|
|
|
center_candidates[:, :, 3] = cand2gt
|
|
|
|
return center_candidates
|
|
|
|
def forward(self,
|
|
img,
|
|
img_metas,
|
|
return_loss=True,
|
|
feature_maps=None,
|
|
targets_3d=None):
|
|
"""
|
|
Note:
|
|
batch_size: N
|
|
num_keypoints: K
|
|
num_img_channel: C
|
|
img_width: imgW
|
|
img_height: imgH
|
|
heatmaps width: W
|
|
heatmaps height: H
|
|
Args:
|
|
img (list(torch.Tensor[NxCximgHximgW])):
|
|
Multi-camera input images to the 2D model.
|
|
img_metas (list(dict)):
|
|
Information about image, 3D groundtruth and camera parameters.
|
|
return_loss: Option to `return loss`. `return loss=True`
|
|
for training, `return loss=False` for validation & test.
|
|
targets_3d (torch.Tensor[NxcubeLxcubeWxcubeH]):
|
|
Ground-truth 3D heatmap of human centers.
|
|
feature_maps (list(torch.Tensor[NxKxHxW])):
|
|
Multi-camera feature_maps.
|
|
Returns:
|
|
dict: if 'return_loss' is true, then return losses.
|
|
Otherwise, return predicted poses
|
|
"""
|
|
if return_loss:
|
|
return self.forward_train(img, img_metas, feature_maps, targets_3d)
|
|
else:
|
|
return self.forward_test(img, img_metas, feature_maps)
|
|
|
|
def forward_train(self,
|
|
img,
|
|
img_metas,
|
|
feature_maps=None,
|
|
targets_3d=None,
|
|
return_preds=False):
|
|
"""
|
|
Note:
|
|
batch_size: N
|
|
num_keypoints: K
|
|
num_img_channel: C
|
|
img_width: imgW
|
|
img_height: imgH
|
|
heatmaps width: W
|
|
heatmaps height: H
|
|
Args:
|
|
img (list(torch.Tensor[NxCximgHximgW])):
|
|
Multi-camera input images to the 2D model.
|
|
img_metas (list(dict)):
|
|
Information about image, 3D groundtruth and camera parameters.
|
|
targets_3d (torch.Tensor[NxcubeLxcubeWxcubeH]):
|
|
Ground-truth 3D heatmap of human centers.
|
|
feature_maps (list(torch.Tensor[NxKxHxW])):
|
|
Multi-camera feature_maps.
|
|
return_preds (bool): Whether to return prediction results
|
|
Returns:
|
|
dict: if 'return_pred' is true, then return losses
|
|
and human centers. Otherwise, return losses only
|
|
"""
|
|
initial_cubes, _ = self.project_layer(feature_maps, img_metas,
|
|
self.space_size,
|
|
[self.space_center],
|
|
self.cube_size)
|
|
center_heatmaps_3d = self.center_net(initial_cubes)
|
|
center_heatmaps_3d = center_heatmaps_3d.squeeze(1)
|
|
center_candidates = self.center_head(center_heatmaps_3d)
|
|
|
|
device = center_candidates.device
|
|
|
|
gt_centers = torch.stack([
|
|
torch.tensor(img_meta['roots_3d'], device=device)
|
|
for img_meta in img_metas
|
|
])
|
|
gt_num_persons = torch.stack([
|
|
torch.tensor(img_meta['num_persons'], device=device)
|
|
for img_meta in img_metas
|
|
])
|
|
center_candidates = self.assign2gt(center_candidates, gt_centers,
|
|
gt_num_persons)
|
|
|
|
losses = dict()
|
|
losses.update(
|
|
self.center_head.get_loss(center_heatmaps_3d, targets_3d))
|
|
|
|
if return_preds:
|
|
return center_candidates, losses
|
|
else:
|
|
return losses
|
|
|
|
def forward_test(self, img, img_metas, feature_maps=None):
|
|
"""
|
|
Note:
|
|
batch_size: N
|
|
num_keypoints: K
|
|
num_img_channel: C
|
|
img_width: imgW
|
|
img_height: imgH
|
|
heatmaps width: W
|
|
heatmaps height: H
|
|
Args:
|
|
img (list(torch.Tensor[NxCximgHximgW])):
|
|
Multi-camera input images to the 2D model.
|
|
img_metas (list(dict)):
|
|
Information about image, 3D groundtruth and camera parameters.
|
|
feature_maps (list(torch.Tensor[NxKxHxW])):
|
|
Multi-camera feature_maps.
|
|
Returns:
|
|
human centers
|
|
"""
|
|
initial_cubes, _ = self.project_layer(feature_maps, img_metas,
|
|
self.space_size,
|
|
[self.space_center],
|
|
self.cube_size)
|
|
center_heatmaps_3d = self.center_net(initial_cubes)
|
|
center_heatmaps_3d = center_heatmaps_3d.squeeze(1)
|
|
center_candidates = self.center_head(center_heatmaps_3d)
|
|
center_candidates[..., 3] = \
|
|
(center_candidates[..., 4] >
|
|
self.test_cfg['center_threshold']).float() - 1.0
|
|
|
|
return center_candidates
|
|
|
|
def show_result(self, **kwargs):
|
|
"""Visualize the results."""
|
|
raise NotImplementedError
|
|
|
|
def forward_dummy(self, feature_maps):
|
|
"""Used for computing network FLOPs."""
|
|
batch_size, num_channels, _, _ = feature_maps[0].shape
|
|
initial_cubes = feature_maps[0].new_zeros(batch_size, num_channels,
|
|
*self.cube_size)
|
|
_ = self.center_net(initial_cubes)
|
|
|