You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
329 lines
12 KiB
329 lines
12 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import mmcv
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
|
|
normal_init)
|
|
from mmcv.utils import digit_version
|
|
from torch.nn.modules.batchnorm import _BatchNorm
|
|
|
|
from mmpose.models.utils.ops import resize
|
|
from ..backbones.resnet import BasicBlock, Bottleneck
|
|
from ..builder import NECKS
|
|
|
|
try:
|
|
from mmcv.ops import DeformConv2d
|
|
has_mmcv_full = True
|
|
except (ImportError, ModuleNotFoundError):
|
|
has_mmcv_full = False
|
|
|
|
|
|
@NECKS.register_module()
|
|
class PoseWarperNeck(nn.Module):
|
|
"""PoseWarper neck.
|
|
|
|
`"Learning temporal pose estimation from sparsely-labeled videos"
|
|
<https://arxiv.org/abs/1906.04016>`_.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels from backbone
|
|
out_channels (int): Number of output channels
|
|
inner_channels (int): Number of intermediate channels of the res block
|
|
deform_groups (int): Number of groups in the deformable conv
|
|
dilations (list|tuple): different dilations of the offset conv layers
|
|
trans_conv_kernel (int): the kernel of the trans conv layer, which is
|
|
used to get heatmap from the output of backbone. Default: 1
|
|
res_blocks_cfg (dict|None): config of residual blocks. If None,
|
|
use the default values. If not None, it should contain the
|
|
following keys:
|
|
|
|
- block (str): the type of residual block, Default: 'BASIC'.
|
|
- num_blocks (int): the number of blocks, Default: 20.
|
|
|
|
offsets_kernel (int): the kernel of offset conv layer.
|
|
deform_conv_kernel (int): the kernel of defomrable conv layer.
|
|
in_index (int|Sequence[int]): Input feature index. Default: 0
|
|
input_transform (str|None): Transformation type of input features.
|
|
Options: 'resize_concat', 'multiple_select', None.
|
|
Default: None.
|
|
|
|
- 'resize_concat': Multiple feature maps will be resize to \
|
|
the same size as first one and than concat together. \
|
|
Usually used in FCN head of HRNet.
|
|
- 'multiple_select': Multiple feature maps will be bundle into \
|
|
a list and passed into decode head.
|
|
- None: Only one select feature map is allowed.
|
|
|
|
freeze_trans_layer (bool): Whether to freeze the transition layer
|
|
(stop grad and set eval mode). Default: True.
|
|
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
|
freeze running stats (mean and var). Note: Effect on Batch Norm
|
|
and its variants only. Default: False.
|
|
im2col_step (int): the argument `im2col_step` in deformable conv,
|
|
Default: 80.
|
|
"""
|
|
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
|
|
minimum_mmcv_version = '1.3.17'
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
inner_channels,
|
|
deform_groups=17,
|
|
dilations=(3, 6, 12, 18, 24),
|
|
trans_conv_kernel=1,
|
|
res_blocks_cfg=None,
|
|
offsets_kernel=3,
|
|
deform_conv_kernel=3,
|
|
in_index=0,
|
|
input_transform=None,
|
|
freeze_trans_layer=True,
|
|
norm_eval=False,
|
|
im2col_step=80):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.inner_channels = inner_channels
|
|
self.deform_groups = deform_groups
|
|
self.dilations = dilations
|
|
self.trans_conv_kernel = trans_conv_kernel
|
|
self.res_blocks_cfg = res_blocks_cfg
|
|
self.offsets_kernel = offsets_kernel
|
|
self.deform_conv_kernel = deform_conv_kernel
|
|
self.in_index = in_index
|
|
self.input_transform = input_transform
|
|
self.freeze_trans_layer = freeze_trans_layer
|
|
self.norm_eval = norm_eval
|
|
self.im2col_step = im2col_step
|
|
|
|
identity_trans_layer = False
|
|
|
|
assert trans_conv_kernel in [0, 1, 3]
|
|
kernel_size = trans_conv_kernel
|
|
if kernel_size == 3:
|
|
padding = 1
|
|
elif kernel_size == 1:
|
|
padding = 0
|
|
else:
|
|
# 0 for Identity mapping.
|
|
identity_trans_layer = True
|
|
|
|
if identity_trans_layer:
|
|
self.trans_layer = nn.Identity()
|
|
else:
|
|
self.trans_layer = build_conv_layer(
|
|
cfg=dict(type='Conv2d'),
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=1,
|
|
padding=padding)
|
|
|
|
# build chain of residual blocks
|
|
if res_blocks_cfg is not None and not isinstance(res_blocks_cfg, dict):
|
|
raise TypeError('res_blocks_cfg should be dict or None.')
|
|
|
|
if res_blocks_cfg is None:
|
|
block_type = 'BASIC'
|
|
num_blocks = 20
|
|
else:
|
|
block_type = res_blocks_cfg.get('block', 'BASIC')
|
|
num_blocks = res_blocks_cfg.get('num_blocks', 20)
|
|
|
|
block = self.blocks_dict[block_type]
|
|
|
|
res_layers = []
|
|
downsample = nn.Sequential(
|
|
build_conv_layer(
|
|
cfg=dict(type='Conv2d'),
|
|
in_channels=out_channels,
|
|
out_channels=inner_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
bias=False),
|
|
build_norm_layer(dict(type='BN'), inner_channels)[1])
|
|
res_layers.append(
|
|
block(
|
|
in_channels=out_channels,
|
|
out_channels=inner_channels,
|
|
downsample=downsample))
|
|
|
|
for _ in range(1, num_blocks):
|
|
res_layers.append(block(inner_channels, inner_channels))
|
|
self.offset_feats = nn.Sequential(*res_layers)
|
|
|
|
# build offset layers
|
|
self.num_offset_layers = len(dilations)
|
|
assert self.num_offset_layers > 0, 'Number of offset layers ' \
|
|
'should be larger than 0.'
|
|
|
|
target_offset_channels = 2 * offsets_kernel**2 * deform_groups
|
|
|
|
offset_layers = [
|
|
build_conv_layer(
|
|
cfg=dict(type='Conv2d'),
|
|
in_channels=inner_channels,
|
|
out_channels=target_offset_channels,
|
|
kernel_size=offsets_kernel,
|
|
stride=1,
|
|
dilation=dilations[i],
|
|
padding=dilations[i],
|
|
bias=False,
|
|
) for i in range(self.num_offset_layers)
|
|
]
|
|
self.offset_layers = nn.ModuleList(offset_layers)
|
|
|
|
# build deformable conv layers
|
|
assert digit_version(mmcv.__version__) >= \
|
|
digit_version(self.minimum_mmcv_version), \
|
|
f'Current MMCV version: {mmcv.__version__}, ' \
|
|
f'but MMCV >= {self.minimum_mmcv_version} is required, see ' \
|
|
f'https://github.com/open-mmlab/mmcv/issues/1440, ' \
|
|
f'Please install the latest MMCV.'
|
|
|
|
if has_mmcv_full:
|
|
deform_conv_layers = [
|
|
DeformConv2d(
|
|
in_channels=out_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=deform_conv_kernel,
|
|
stride=1,
|
|
padding=int(deform_conv_kernel / 2) * dilations[i],
|
|
dilation=dilations[i],
|
|
deform_groups=deform_groups,
|
|
im2col_step=self.im2col_step,
|
|
) for i in range(self.num_offset_layers)
|
|
]
|
|
else:
|
|
raise ImportError('Please install the full version of mmcv '
|
|
'to use `DeformConv2d`.')
|
|
|
|
self.deform_conv_layers = nn.ModuleList(deform_conv_layers)
|
|
|
|
self.freeze_layers()
|
|
|
|
def freeze_layers(self):
|
|
if self.freeze_trans_layer:
|
|
self.trans_layer.eval()
|
|
|
|
for param in self.trans_layer.parameters():
|
|
param.requires_grad = False
|
|
|
|
def init_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
normal_init(m, std=0.001)
|
|
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
|
constant_init(m, 1)
|
|
elif isinstance(m, DeformConv2d):
|
|
filler = torch.zeros([
|
|
m.weight.size(0),
|
|
m.weight.size(1),
|
|
m.weight.size(2),
|
|
m.weight.size(3)
|
|
],
|
|
dtype=torch.float32,
|
|
device=m.weight.device)
|
|
for k in range(m.weight.size(0)):
|
|
filler[k, k,
|
|
int(m.weight.size(2) / 2),
|
|
int(m.weight.size(3) / 2)] = 1.0
|
|
m.weight = torch.nn.Parameter(filler)
|
|
m.weight.requires_grad = True
|
|
|
|
# posewarper offset layer weight initialization
|
|
for m in self.offset_layers.modules():
|
|
constant_init(m, 0)
|
|
|
|
def _transform_inputs(self, inputs):
|
|
"""Transform inputs for decoder.
|
|
|
|
Args:
|
|
inputs (list[Tensor] | Tensor): multi-level img features.
|
|
|
|
Returns:
|
|
Tensor: The transformed inputs
|
|
"""
|
|
if not isinstance(inputs, list):
|
|
return inputs
|
|
|
|
if self.input_transform == 'resize_concat':
|
|
inputs = [inputs[i] for i in self.in_index]
|
|
upsampled_inputs = [
|
|
resize(
|
|
input=x,
|
|
size=inputs[0].shape[2:],
|
|
mode='bilinear',
|
|
align_corners=self.align_corners) for x in inputs
|
|
]
|
|
inputs = torch.cat(upsampled_inputs, dim=1)
|
|
elif self.input_transform == 'multiple_select':
|
|
inputs = [inputs[i] for i in self.in_index]
|
|
else:
|
|
inputs = inputs[self.in_index]
|
|
|
|
return inputs
|
|
|
|
def forward(self, inputs, frame_weight):
|
|
assert isinstance(inputs, (list, tuple)), 'PoseWarperNeck inputs ' \
|
|
'should be list or tuple, even though the length is 1, ' \
|
|
'for unified processing.'
|
|
|
|
output_heatmap = 0
|
|
if len(inputs) > 1:
|
|
inputs = [self._transform_inputs(input) for input in inputs]
|
|
inputs = [self.trans_layer(input) for input in inputs]
|
|
|
|
# calculate difference features
|
|
diff_features = [
|
|
self.offset_feats(inputs[0] - input) for input in inputs
|
|
]
|
|
|
|
for i in range(len(inputs)):
|
|
if frame_weight[i] == 0:
|
|
continue
|
|
warped_heatmap = 0
|
|
for j in range(self.num_offset_layers):
|
|
offset = (self.offset_layers[j](diff_features[i]))
|
|
warped_heatmap_tmp = self.deform_conv_layers[j](inputs[i],
|
|
offset)
|
|
warped_heatmap += warped_heatmap_tmp / \
|
|
self.num_offset_layers
|
|
|
|
output_heatmap += warped_heatmap * frame_weight[i]
|
|
|
|
else:
|
|
inputs = inputs[0]
|
|
inputs = self._transform_inputs(inputs)
|
|
inputs = self.trans_layer(inputs)
|
|
|
|
num_frames = len(frame_weight)
|
|
batch_size = inputs.size(0) // num_frames
|
|
ref_x = inputs[:batch_size]
|
|
ref_x_tiled = ref_x.repeat(num_frames, 1, 1, 1)
|
|
|
|
offset_features = self.offset_feats(ref_x_tiled - inputs)
|
|
|
|
warped_heatmap = 0
|
|
for j in range(self.num_offset_layers):
|
|
offset = self.offset_layers[j](offset_features)
|
|
|
|
warped_heatmap_tmp = self.deform_conv_layers[j](inputs, offset)
|
|
warped_heatmap += warped_heatmap_tmp / self.num_offset_layers
|
|
|
|
for i in range(num_frames):
|
|
if frame_weight[i] == 0:
|
|
continue
|
|
output_heatmap += warped_heatmap[i * batch_size:(i + 1) *
|
|
batch_size] * frame_weight[i]
|
|
|
|
return output_heatmap
|
|
|
|
def train(self, mode=True):
|
|
"""Convert the model into training mode."""
|
|
super().train(mode)
|
|
self.freeze_layers()
|
|
if mode and self.norm_eval:
|
|
for m in self.modules():
|
|
if isinstance(m, _BatchNorm):
|
|
m.eval()
|
|
|