You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
616 lines
21 KiB
616 lines
21 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy as cp
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import (ConvModule, MaxPool2d, constant_init, kaiming_init,
|
|
normal_init)
|
|
|
|
from ..builder import BACKBONES
|
|
from .base_backbone import BaseBackbone
|
|
|
|
|
|
class RSB(nn.Module):
|
|
"""Residual Steps block for RSN. Paper ref: Cai et al. "Learning Delicate
|
|
Local Representations for Multi-Person Pose Estimation" (ECCV 2020).
|
|
|
|
Args:
|
|
in_channels (int): Input channels of this block.
|
|
out_channels (int): Output channels of this block.
|
|
num_steps (int): Numbers of steps in RSB
|
|
stride (int): stride of the block. Default: 1
|
|
downsample (nn.Module): downsample operation on identity branch.
|
|
Default: None.
|
|
norm_cfg (dict): dictionary to construct and config norm layer.
|
|
Default: dict(type='BN')
|
|
expand_times (int): Times by which the in_channels are expanded.
|
|
Default:26.
|
|
res_top_channels (int): Number of channels of feature output by
|
|
ResNet_top. Default:64.
|
|
"""
|
|
|
|
expansion = 1
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
num_steps=4,
|
|
stride=1,
|
|
downsample=None,
|
|
with_cp=False,
|
|
norm_cfg=dict(type='BN'),
|
|
expand_times=26,
|
|
res_top_channels=64):
|
|
# Protect mutable default arguments
|
|
norm_cfg = cp.deepcopy(norm_cfg)
|
|
super().__init__()
|
|
assert num_steps > 1
|
|
self.in_channels = in_channels
|
|
self.branch_channels = self.in_channels * expand_times
|
|
self.branch_channels //= res_top_channels
|
|
self.out_channels = out_channels
|
|
self.stride = stride
|
|
self.downsample = downsample
|
|
self.with_cp = with_cp
|
|
self.norm_cfg = norm_cfg
|
|
self.num_steps = num_steps
|
|
self.conv_bn_relu1 = ConvModule(
|
|
self.in_channels,
|
|
self.num_steps * self.branch_channels,
|
|
kernel_size=1,
|
|
stride=self.stride,
|
|
padding=0,
|
|
norm_cfg=self.norm_cfg,
|
|
inplace=False)
|
|
for i in range(self.num_steps):
|
|
for j in range(i + 1):
|
|
module_name = f'conv_bn_relu2_{i + 1}_{j + 1}'
|
|
self.add_module(
|
|
module_name,
|
|
ConvModule(
|
|
self.branch_channels,
|
|
self.branch_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
norm_cfg=self.norm_cfg,
|
|
inplace=False))
|
|
self.conv_bn3 = ConvModule(
|
|
self.num_steps * self.branch_channels,
|
|
self.out_channels * self.expansion,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
act_cfg=None,
|
|
norm_cfg=self.norm_cfg,
|
|
inplace=False)
|
|
self.relu = nn.ReLU(inplace=False)
|
|
|
|
def forward(self, x):
|
|
"""Forward function."""
|
|
|
|
identity = x
|
|
x = self.conv_bn_relu1(x)
|
|
spx = torch.split(x, self.branch_channels, 1)
|
|
outputs = list()
|
|
outs = list()
|
|
for i in range(self.num_steps):
|
|
outputs_i = list()
|
|
outputs.append(outputs_i)
|
|
for j in range(i + 1):
|
|
if j == 0:
|
|
inputs = spx[i]
|
|
else:
|
|
inputs = outputs[i][j - 1]
|
|
if i > j:
|
|
inputs = inputs + outputs[i - 1][j]
|
|
module_name = f'conv_bn_relu2_{i + 1}_{j + 1}'
|
|
module_i_j = getattr(self, module_name)
|
|
outputs[i].append(module_i_j(inputs))
|
|
|
|
outs.append(outputs[i][i])
|
|
out = torch.cat(tuple(outs), 1)
|
|
out = self.conv_bn3(out)
|
|
|
|
if self.downsample is not None:
|
|
identity = self.downsample(identity)
|
|
out = out + identity
|
|
|
|
out = self.relu(out)
|
|
|
|
return out
|
|
|
|
|
|
class Downsample_module(nn.Module):
|
|
"""Downsample module for RSN.
|
|
|
|
Args:
|
|
block (nn.Module): Downsample block.
|
|
num_blocks (list): Number of blocks in each downsample unit.
|
|
num_units (int): Numbers of downsample units. Default: 4
|
|
has_skip (bool): Have skip connections from prior upsample
|
|
module or not. Default:False
|
|
num_steps (int): Number of steps in a block. Default:4
|
|
norm_cfg (dict): dictionary to construct and config norm layer.
|
|
Default: dict(type='BN')
|
|
in_channels (int): Number of channels of the input feature to
|
|
downsample module. Default: 64
|
|
expand_times (int): Times by which the in_channels are expanded.
|
|
Default:26.
|
|
"""
|
|
|
|
def __init__(self,
|
|
block,
|
|
num_blocks,
|
|
num_steps=4,
|
|
num_units=4,
|
|
has_skip=False,
|
|
norm_cfg=dict(type='BN'),
|
|
in_channels=64,
|
|
expand_times=26):
|
|
# Protect mutable default arguments
|
|
norm_cfg = cp.deepcopy(norm_cfg)
|
|
super().__init__()
|
|
self.has_skip = has_skip
|
|
self.in_channels = in_channels
|
|
assert len(num_blocks) == num_units
|
|
self.num_blocks = num_blocks
|
|
self.num_units = num_units
|
|
self.num_steps = num_steps
|
|
self.norm_cfg = norm_cfg
|
|
self.layer1 = self._make_layer(
|
|
block,
|
|
in_channels,
|
|
num_blocks[0],
|
|
expand_times=expand_times,
|
|
res_top_channels=in_channels)
|
|
for i in range(1, num_units):
|
|
module_name = f'layer{i + 1}'
|
|
self.add_module(
|
|
module_name,
|
|
self._make_layer(
|
|
block,
|
|
in_channels * pow(2, i),
|
|
num_blocks[i],
|
|
stride=2,
|
|
expand_times=expand_times,
|
|
res_top_channels=in_channels))
|
|
|
|
def _make_layer(self,
|
|
block,
|
|
out_channels,
|
|
blocks,
|
|
stride=1,
|
|
expand_times=26,
|
|
res_top_channels=64):
|
|
downsample = None
|
|
if stride != 1 or self.in_channels != out_channels * block.expansion:
|
|
downsample = ConvModule(
|
|
self.in_channels,
|
|
out_channels * block.expansion,
|
|
kernel_size=1,
|
|
stride=stride,
|
|
padding=0,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=None,
|
|
inplace=True)
|
|
|
|
units = list()
|
|
units.append(
|
|
block(
|
|
self.in_channels,
|
|
out_channels,
|
|
num_steps=self.num_steps,
|
|
stride=stride,
|
|
downsample=downsample,
|
|
norm_cfg=self.norm_cfg,
|
|
expand_times=expand_times,
|
|
res_top_channels=res_top_channels))
|
|
self.in_channels = out_channels * block.expansion
|
|
for _ in range(1, blocks):
|
|
units.append(
|
|
block(
|
|
self.in_channels,
|
|
out_channels,
|
|
num_steps=self.num_steps,
|
|
expand_times=expand_times,
|
|
res_top_channels=res_top_channels))
|
|
|
|
return nn.Sequential(*units)
|
|
|
|
def forward(self, x, skip1, skip2):
|
|
out = list()
|
|
for i in range(self.num_units):
|
|
module_name = f'layer{i + 1}'
|
|
module_i = getattr(self, module_name)
|
|
x = module_i(x)
|
|
if self.has_skip:
|
|
x = x + skip1[i] + skip2[i]
|
|
out.append(x)
|
|
out.reverse()
|
|
|
|
return tuple(out)
|
|
|
|
|
|
class Upsample_unit(nn.Module):
|
|
"""Upsample unit for upsample module.
|
|
|
|
Args:
|
|
ind (int): Indicates whether to interpolate (>0) and whether to
|
|
generate feature map for the next hourglass-like module.
|
|
num_units (int): Number of units that form a upsample module. Along
|
|
with ind and gen_cross_conv, nm_units is used to decide whether
|
|
to generate feature map for the next hourglass-like module.
|
|
in_channels (int): Channel number of the skip-in feature maps from
|
|
the corresponding downsample unit.
|
|
unit_channels (int): Channel number in this unit. Default:256.
|
|
gen_skip: (bool): Whether or not to generate skips for the posterior
|
|
downsample module. Default:False
|
|
gen_cross_conv (bool): Whether to generate feature map for the next
|
|
hourglass-like module. Default:False
|
|
norm_cfg (dict): dictionary to construct and config norm layer.
|
|
Default: dict(type='BN')
|
|
out_channels (in): Number of channels of feature output by upsample
|
|
module. Must equal to in_channels of downsample module. Default:64
|
|
"""
|
|
|
|
def __init__(self,
|
|
ind,
|
|
num_units,
|
|
in_channels,
|
|
unit_channels=256,
|
|
gen_skip=False,
|
|
gen_cross_conv=False,
|
|
norm_cfg=dict(type='BN'),
|
|
out_channels=64):
|
|
# Protect mutable default arguments
|
|
norm_cfg = cp.deepcopy(norm_cfg)
|
|
super().__init__()
|
|
self.num_units = num_units
|
|
self.norm_cfg = norm_cfg
|
|
self.in_skip = ConvModule(
|
|
in_channels,
|
|
unit_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=None,
|
|
inplace=True)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
self.ind = ind
|
|
if self.ind > 0:
|
|
self.up_conv = ConvModule(
|
|
unit_channels,
|
|
unit_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=None,
|
|
inplace=True)
|
|
|
|
self.gen_skip = gen_skip
|
|
if self.gen_skip:
|
|
self.out_skip1 = ConvModule(
|
|
in_channels,
|
|
in_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
norm_cfg=self.norm_cfg,
|
|
inplace=True)
|
|
|
|
self.out_skip2 = ConvModule(
|
|
unit_channels,
|
|
in_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
norm_cfg=self.norm_cfg,
|
|
inplace=True)
|
|
|
|
self.gen_cross_conv = gen_cross_conv
|
|
if self.ind == num_units - 1 and self.gen_cross_conv:
|
|
self.cross_conv = ConvModule(
|
|
unit_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
norm_cfg=self.norm_cfg,
|
|
inplace=True)
|
|
|
|
def forward(self, x, up_x):
|
|
out = self.in_skip(x)
|
|
|
|
if self.ind > 0:
|
|
up_x = F.interpolate(
|
|
up_x,
|
|
size=(x.size(2), x.size(3)),
|
|
mode='bilinear',
|
|
align_corners=True)
|
|
up_x = self.up_conv(up_x)
|
|
out = out + up_x
|
|
out = self.relu(out)
|
|
|
|
skip1 = None
|
|
skip2 = None
|
|
if self.gen_skip:
|
|
skip1 = self.out_skip1(x)
|
|
skip2 = self.out_skip2(out)
|
|
|
|
cross_conv = None
|
|
if self.ind == self.num_units - 1 and self.gen_cross_conv:
|
|
cross_conv = self.cross_conv(out)
|
|
|
|
return out, skip1, skip2, cross_conv
|
|
|
|
|
|
class Upsample_module(nn.Module):
|
|
"""Upsample module for RSN.
|
|
|
|
Args:
|
|
unit_channels (int): Channel number in the upsample units.
|
|
Default:256.
|
|
num_units (int): Numbers of upsample units. Default: 4
|
|
gen_skip (bool): Whether to generate skip for posterior downsample
|
|
module or not. Default:False
|
|
gen_cross_conv (bool): Whether to generate feature map for the next
|
|
hourglass-like module. Default:False
|
|
norm_cfg (dict): dictionary to construct and config norm layer.
|
|
Default: dict(type='BN')
|
|
out_channels (int): Number of channels of feature output by upsample
|
|
module. Must equal to in_channels of downsample module. Default:64
|
|
"""
|
|
|
|
def __init__(self,
|
|
unit_channels=256,
|
|
num_units=4,
|
|
gen_skip=False,
|
|
gen_cross_conv=False,
|
|
norm_cfg=dict(type='BN'),
|
|
out_channels=64):
|
|
# Protect mutable default arguments
|
|
norm_cfg = cp.deepcopy(norm_cfg)
|
|
super().__init__()
|
|
self.in_channels = list()
|
|
for i in range(num_units):
|
|
self.in_channels.append(RSB.expansion * out_channels * pow(2, i))
|
|
self.in_channels.reverse()
|
|
self.num_units = num_units
|
|
self.gen_skip = gen_skip
|
|
self.gen_cross_conv = gen_cross_conv
|
|
self.norm_cfg = norm_cfg
|
|
for i in range(num_units):
|
|
module_name = f'up{i + 1}'
|
|
self.add_module(
|
|
module_name,
|
|
Upsample_unit(
|
|
i,
|
|
self.num_units,
|
|
self.in_channels[i],
|
|
unit_channels,
|
|
self.gen_skip,
|
|
self.gen_cross_conv,
|
|
norm_cfg=self.norm_cfg,
|
|
out_channels=64))
|
|
|
|
def forward(self, x):
|
|
out = list()
|
|
skip1 = list()
|
|
skip2 = list()
|
|
cross_conv = None
|
|
for i in range(self.num_units):
|
|
module_i = getattr(self, f'up{i + 1}')
|
|
if i == 0:
|
|
outi, skip1_i, skip2_i, _ = module_i(x[i], None)
|
|
elif i == self.num_units - 1:
|
|
outi, skip1_i, skip2_i, cross_conv = module_i(x[i], out[i - 1])
|
|
else:
|
|
outi, skip1_i, skip2_i, _ = module_i(x[i], out[i - 1])
|
|
out.append(outi)
|
|
skip1.append(skip1_i)
|
|
skip2.append(skip2_i)
|
|
skip1.reverse()
|
|
skip2.reverse()
|
|
|
|
return out, skip1, skip2, cross_conv
|
|
|
|
|
|
class Single_stage_RSN(nn.Module):
|
|
"""Single_stage Residual Steps Network.
|
|
|
|
Args:
|
|
unit_channels (int): Channel number in the upsample units. Default:256.
|
|
num_units (int): Numbers of downsample/upsample units. Default: 4
|
|
gen_skip (bool): Whether to generate skip for posterior downsample
|
|
module or not. Default:False
|
|
gen_cross_conv (bool): Whether to generate feature map for the next
|
|
hourglass-like module. Default:False
|
|
has_skip (bool): Have skip connections from prior upsample
|
|
module or not. Default:False
|
|
num_steps (int): Number of steps in RSB. Default: 4
|
|
num_blocks (list): Number of blocks in each downsample unit.
|
|
Default: [2, 2, 2, 2] Note: Make sure num_units==len(num_blocks)
|
|
norm_cfg (dict): dictionary to construct and config norm layer.
|
|
Default: dict(type='BN')
|
|
in_channels (int): Number of channels of the feature from ResNet_Top.
|
|
Default: 64.
|
|
expand_times (int): Times by which the in_channels are expanded in RSB.
|
|
Default:26.
|
|
"""
|
|
|
|
def __init__(self,
|
|
has_skip=False,
|
|
gen_skip=False,
|
|
gen_cross_conv=False,
|
|
unit_channels=256,
|
|
num_units=4,
|
|
num_steps=4,
|
|
num_blocks=[2, 2, 2, 2],
|
|
norm_cfg=dict(type='BN'),
|
|
in_channels=64,
|
|
expand_times=26):
|
|
# Protect mutable default arguments
|
|
norm_cfg = cp.deepcopy(norm_cfg)
|
|
num_blocks = cp.deepcopy(num_blocks)
|
|
super().__init__()
|
|
assert len(num_blocks) == num_units
|
|
self.has_skip = has_skip
|
|
self.gen_skip = gen_skip
|
|
self.gen_cross_conv = gen_cross_conv
|
|
self.num_units = num_units
|
|
self.num_steps = num_steps
|
|
self.unit_channels = unit_channels
|
|
self.num_blocks = num_blocks
|
|
self.norm_cfg = norm_cfg
|
|
|
|
self.downsample = Downsample_module(RSB, num_blocks, num_steps,
|
|
num_units, has_skip, norm_cfg,
|
|
in_channels, expand_times)
|
|
self.upsample = Upsample_module(unit_channels, num_units, gen_skip,
|
|
gen_cross_conv, norm_cfg, in_channels)
|
|
|
|
def forward(self, x, skip1, skip2):
|
|
mid = self.downsample(x, skip1, skip2)
|
|
out, skip1, skip2, cross_conv = self.upsample(mid)
|
|
|
|
return out, skip1, skip2, cross_conv
|
|
|
|
|
|
class ResNet_top(nn.Module):
|
|
"""ResNet top for RSN.
|
|
|
|
Args:
|
|
norm_cfg (dict): dictionary to construct and config norm layer.
|
|
Default: dict(type='BN')
|
|
channels (int): Number of channels of the feature output by ResNet_top.
|
|
"""
|
|
|
|
def __init__(self, norm_cfg=dict(type='BN'), channels=64):
|
|
# Protect mutable default arguments
|
|
norm_cfg = cp.deepcopy(norm_cfg)
|
|
super().__init__()
|
|
self.top = nn.Sequential(
|
|
ConvModule(
|
|
3,
|
|
channels,
|
|
kernel_size=7,
|
|
stride=2,
|
|
padding=3,
|
|
norm_cfg=norm_cfg,
|
|
inplace=True), MaxPool2d(kernel_size=3, stride=2, padding=1))
|
|
|
|
def forward(self, img):
|
|
return self.top(img)
|
|
|
|
|
|
@BACKBONES.register_module()
|
|
class RSN(BaseBackbone):
|
|
"""Residual Steps Network backbone. Paper ref: Cai et al. "Learning
|
|
Delicate Local Representations for Multi-Person Pose Estimation" (ECCV
|
|
2020).
|
|
|
|
Args:
|
|
unit_channels (int): Number of Channels in an upsample unit.
|
|
Default: 256
|
|
num_stages (int): Number of stages in a multi-stage RSN. Default: 4
|
|
num_units (int): NUmber of downsample/upsample units in a single-stage
|
|
RSN. Default: 4 Note: Make sure num_units == len(self.num_blocks)
|
|
num_blocks (list): Number of RSBs (Residual Steps Block) in each
|
|
downsample unit. Default: [2, 2, 2, 2]
|
|
num_steps (int): Number of steps in a RSB. Default:4
|
|
norm_cfg (dict): dictionary to construct and config norm layer.
|
|
Default: dict(type='BN')
|
|
res_top_channels (int): Number of channels of feature from ResNet_top.
|
|
Default: 64.
|
|
expand_times (int): Times by which the in_channels are expanded in RSB.
|
|
Default:26.
|
|
Example:
|
|
>>> from mmpose.models import RSN
|
|
>>> import torch
|
|
>>> self = RSN(num_stages=2,num_units=2,num_blocks=[2,2])
|
|
>>> self.eval()
|
|
>>> inputs = torch.rand(1, 3, 511, 511)
|
|
>>> level_outputs = self.forward(inputs)
|
|
>>> for level_output in level_outputs:
|
|
... for feature in level_output:
|
|
... print(tuple(feature.shape))
|
|
...
|
|
(1, 256, 64, 64)
|
|
(1, 256, 128, 128)
|
|
(1, 256, 64, 64)
|
|
(1, 256, 128, 128)
|
|
"""
|
|
|
|
def __init__(self,
|
|
unit_channels=256,
|
|
num_stages=4,
|
|
num_units=4,
|
|
num_blocks=[2, 2, 2, 2],
|
|
num_steps=4,
|
|
norm_cfg=dict(type='BN'),
|
|
res_top_channels=64,
|
|
expand_times=26):
|
|
# Protect mutable default arguments
|
|
norm_cfg = cp.deepcopy(norm_cfg)
|
|
num_blocks = cp.deepcopy(num_blocks)
|
|
super().__init__()
|
|
self.unit_channels = unit_channels
|
|
self.num_stages = num_stages
|
|
self.num_units = num_units
|
|
self.num_blocks = num_blocks
|
|
self.num_steps = num_steps
|
|
self.norm_cfg = norm_cfg
|
|
|
|
assert self.num_stages > 0
|
|
assert self.num_steps > 1
|
|
assert self.num_units > 1
|
|
assert self.num_units == len(self.num_blocks)
|
|
self.top = ResNet_top(norm_cfg=norm_cfg)
|
|
self.multi_stage_rsn = nn.ModuleList([])
|
|
for i in range(self.num_stages):
|
|
if i == 0:
|
|
has_skip = False
|
|
else:
|
|
has_skip = True
|
|
if i != self.num_stages - 1:
|
|
gen_skip = True
|
|
gen_cross_conv = True
|
|
else:
|
|
gen_skip = False
|
|
gen_cross_conv = False
|
|
self.multi_stage_rsn.append(
|
|
Single_stage_RSN(has_skip, gen_skip, gen_cross_conv,
|
|
unit_channels, num_units, num_steps,
|
|
num_blocks, norm_cfg, res_top_channels,
|
|
expand_times))
|
|
|
|
def forward(self, x):
|
|
"""Model forward function."""
|
|
out_feats = []
|
|
skip1 = None
|
|
skip2 = None
|
|
x = self.top(x)
|
|
for i in range(self.num_stages):
|
|
out, skip1, skip2, x = self.multi_stage_rsn[i](x, skip1, skip2)
|
|
out_feats.append(out)
|
|
|
|
return out_feats
|
|
|
|
def init_weights(self, pretrained=None):
|
|
"""Initialize model weights."""
|
|
for m in self.multi_stage_rsn.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
kaiming_init(m)
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
constant_init(m, 1)
|
|
elif isinstance(m, nn.Linear):
|
|
normal_init(m, std=0.01)
|
|
|
|
for m in self.top.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
kaiming_init(m)
|
|
|