You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
322 lines
11 KiB
322 lines
11 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import warnings
|
|
import logging
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
from mmcv.image import imwrite
|
|
from mmcv.utils.misc import deprecated_api_warning
|
|
from mmcv.visualization.image import imshow
|
|
from mmcv_custom.checkpoint import load_checkpoint
|
|
|
|
from mmpose.core import imshow_bboxes, imshow_keypoints
|
|
from .. import builder
|
|
from ..builder import POSENETS
|
|
from .base import BasePose
|
|
|
|
try:
|
|
from mmcv.runner import auto_fp16
|
|
except ImportError:
|
|
warnings.warn('auto_fp16 from mmpose will be deprecated from v0.15.0'
|
|
'Please install mmcv>=1.1.4')
|
|
from mmpose.core import auto_fp16
|
|
|
|
|
|
@POSENETS.register_module()
|
|
class TopDownSelf(BasePose):
|
|
"""Top-down pose detectors.
|
|
|
|
Args:
|
|
backbone (dict): Backbone modules to extract feature.
|
|
keypoint_head (dict): Keypoint head to process feature.
|
|
train_cfg (dict): Config for training. Default: None.
|
|
test_cfg (dict): Config for testing. Default: None.
|
|
pretrained (str): Path to the pretrained models.
|
|
loss_pose (None): Deprecated arguments. Please use
|
|
`loss_keypoint` for heads instead.
|
|
"""
|
|
|
|
def __init__(self,
|
|
backbone,
|
|
neck=None,
|
|
keypoint_head=None,
|
|
train_cfg=None,
|
|
test_cfg=None,
|
|
pretrained=None,
|
|
loss_pose=None):
|
|
super().__init__()
|
|
self.fp16_enabled = False
|
|
|
|
self.backbone = builder.build_backbone(backbone)
|
|
|
|
self.train_cfg = train_cfg
|
|
self.test_cfg = test_cfg
|
|
|
|
if neck is not None:
|
|
self.neck = builder.build_neck(neck)
|
|
|
|
if keypoint_head is not None:
|
|
keypoint_head['train_cfg'] = train_cfg
|
|
keypoint_head['test_cfg'] = test_cfg
|
|
|
|
if 'loss_keypoint' not in keypoint_head and loss_pose is not None:
|
|
warnings.warn(
|
|
'`loss_pose` for TopDown is deprecated, '
|
|
'use `loss_keypoint` for heads instead. See '
|
|
'https://github.com/open-mmlab/mmpose/pull/382'
|
|
' for more information.', DeprecationWarning)
|
|
keypoint_head['loss_keypoint'] = loss_pose
|
|
|
|
self.keypoint_head = builder.build_head(keypoint_head)
|
|
|
|
self.init_weights(pretrained=pretrained)
|
|
|
|
@property
|
|
def with_neck(self):
|
|
"""Check if has neck."""
|
|
return hasattr(self, 'neck')
|
|
|
|
@property
|
|
def with_keypoint(self):
|
|
"""Check if has keypoint_head."""
|
|
return hasattr(self, 'keypoint_head')
|
|
|
|
def init_weights(self, pretrained=None):
|
|
"""Weight initialization for model."""
|
|
self.backbone.init_weights(pretrained)
|
|
if self.with_neck:
|
|
self.neck.init_weights()
|
|
if self.with_keypoint:
|
|
self.keypoint_head.init_weights()
|
|
|
|
@auto_fp16(apply_to=('img', 'sam_img', ))
|
|
def forward(self,
|
|
img,
|
|
sam_img, # 针对sam_encoder的输入
|
|
target=None,
|
|
target_weight=None,
|
|
img_metas=None,
|
|
return_loss=True,
|
|
return_heatmap=False,
|
|
**kwargs):
|
|
"""Calls either forward_train or forward_test depending on whether
|
|
return_loss=True. Note this setting will change the expected inputs.
|
|
When `return_loss=True`, img and img_meta are single-nested (i.e.
|
|
Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta
|
|
should be double nested (i.e. List[Tensor], List[List[dict]]), with
|
|
the outer list indicating test time augmentations.
|
|
|
|
Note:
|
|
- batch_size: N
|
|
- num_keypoints: K
|
|
- num_img_channel: C (Default: 3)
|
|
- img height: imgH
|
|
- img width: imgW
|
|
- heatmaps height: H
|
|
- heatmaps weight: W
|
|
|
|
Args:
|
|
img (torch.Tensor[NxCximgHximgW]): Input images.
|
|
target (torch.Tensor[NxKxHxW]): Target heatmaps.
|
|
target_weight (torch.Tensor[NxKx1]): Weights across
|
|
different joint types.
|
|
img_metas (list(dict)): Information about data augmentation
|
|
By default this includes:
|
|
|
|
- "image_file: path to the image file
|
|
- "center": center of the bbox
|
|
- "scale": scale of the bbox
|
|
- "rotation": rotation of the bbox
|
|
- "bbox_score": score of bbox
|
|
return_loss (bool): Option to `return loss`. `return loss=True`
|
|
for training, `return loss=False` for validation & test.
|
|
return_heatmap (bool) : Option to return heatmap.
|
|
|
|
Returns:
|
|
dict|tuple: if `return loss` is true, then return losses. \
|
|
Otherwise, return predicted poses, boxes, image paths \
|
|
and heatmaps.
|
|
"""
|
|
if return_loss:
|
|
# 可视化 img, sam_img cv可视化/PIL Image
|
|
# print(sam_img[0].shape)
|
|
# imshow(sam_img[0].cpu().numpy().transpose(1, 2, 0), wait_time=5000)
|
|
# 修改
|
|
return self.forward_train(img, sam_img, target, target_weight, img_metas,
|
|
**kwargs)
|
|
# 修改
|
|
return self.forward_test(
|
|
img, sam_img, img_metas, return_heatmap=return_heatmap, **kwargs)
|
|
|
|
# 修改
|
|
def forward_train(self, img, sam_img, target, target_weight, img_metas, **kwargs):
|
|
"""Defines the computation performed at every call when training."""
|
|
# 修改
|
|
output = self.backbone(img, sam_img) # B, C, Hp, Wp
|
|
if self.with_neck:
|
|
output = self.neck(output)
|
|
if self.with_keypoint:
|
|
output = self.keypoint_head(output)
|
|
|
|
# if return loss
|
|
losses = dict()
|
|
if self.with_keypoint:
|
|
keypoint_losses = self.keypoint_head.get_loss(
|
|
output, target, target_weight)
|
|
losses.update(keypoint_losses)
|
|
keypoint_accuracy = self.keypoint_head.get_accuracy(
|
|
output, target, target_weight)
|
|
losses.update(keypoint_accuracy)
|
|
|
|
return losses
|
|
|
|
# 修改
|
|
def forward_test(self, img, sam_img, img_metas, return_heatmap=False, **kwargs):
|
|
"""Defines the computation performed at every call when testing."""
|
|
assert img.size(0) == len(img_metas)
|
|
batch_size, _, img_height, img_width = img.shape
|
|
if batch_size > 1:
|
|
assert 'bbox_id' in img_metas[0]
|
|
|
|
result = {}
|
|
|
|
# 修改
|
|
features = self.backbone(img, sam_img)
|
|
if self.with_neck:
|
|
features = self.neck(features)
|
|
if self.with_keypoint:
|
|
output_heatmap = self.keypoint_head.inference_model(
|
|
features, flip_pairs=None)
|
|
|
|
if self.test_cfg.get('flip_test', True):
|
|
img_flipped = img.flip(3)
|
|
# 修改
|
|
sam_img_flipped = sam_img.flip(3)
|
|
features_flipped = self.backbone(img_flipped, sam_img_flipped)
|
|
if self.with_neck:
|
|
features_flipped = self.neck(features_flipped)
|
|
if self.with_keypoint:
|
|
output_flipped_heatmap = self.keypoint_head.inference_model(
|
|
features_flipped, img_metas[0]['flip_pairs'])
|
|
output_heatmap = (output_heatmap +
|
|
output_flipped_heatmap) * 0.5
|
|
|
|
if self.with_keypoint:
|
|
keypoint_result = self.keypoint_head.decode(
|
|
img_metas, output_heatmap, img_size=[img_width, img_height])
|
|
result.update(keypoint_result)
|
|
|
|
if not return_heatmap:
|
|
output_heatmap = None
|
|
|
|
result['output_heatmap'] = output_heatmap
|
|
|
|
return result
|
|
|
|
# 修改
|
|
def forward_dummy(self, img, sam_img):
|
|
"""Used for computing network FLOPs.
|
|
|
|
See ``tools/get_flops.py``.
|
|
|
|
Args:
|
|
img (torch.Tensor): Input image.
|
|
|
|
Returns:
|
|
Tensor: Output heatmaps.
|
|
"""
|
|
output = self.backbone(img, sam_img)
|
|
if self.with_neck:
|
|
output = self.neck(output)
|
|
if self.with_keypoint:
|
|
output = self.keypoint_head(output)
|
|
return output
|
|
|
|
@deprecated_api_warning({'pose_limb_color': 'pose_link_color'},
|
|
cls_name='TopDown')
|
|
def show_result(self,
|
|
img,
|
|
result,
|
|
skeleton=None,
|
|
kpt_score_thr=0.3,
|
|
bbox_color='green',
|
|
pose_kpt_color=None,
|
|
pose_link_color=None,
|
|
text_color='white',
|
|
radius=4,
|
|
thickness=1,
|
|
font_scale=0.5,
|
|
bbox_thickness=1,
|
|
win_name='',
|
|
show=False,
|
|
show_keypoint_weight=False,
|
|
wait_time=0,
|
|
out_file=None):
|
|
"""Draw `result` over `img`.
|
|
|
|
Args:
|
|
img (str or Tensor): The image to be displayed.
|
|
result (list[dict]): The results to draw over `img`
|
|
(bbox_result, pose_result).
|
|
skeleton (list[list]): The connection of keypoints.
|
|
skeleton is 0-based indexing.
|
|
kpt_score_thr (float, optional): Minimum score of keypoints
|
|
to be shown. Default: 0.3.
|
|
bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
|
|
pose_kpt_color (np.array[Nx3]`): Color of N keypoints.
|
|
If None, do not draw keypoints.
|
|
pose_link_color (np.array[Mx3]): Color of M links.
|
|
If None, do not draw links.
|
|
text_color (str or tuple or :obj:`Color`): Color of texts.
|
|
radius (int): Radius of circles.
|
|
thickness (int): Thickness of lines.
|
|
font_scale (float): Font scales of texts.
|
|
win_name (str): The window name.
|
|
show (bool): Whether to show the image. Default: False.
|
|
show_keypoint_weight (bool): Whether to change the transparency
|
|
using the predicted confidence scores of keypoints.
|
|
wait_time (int): Value of waitKey param.
|
|
Default: 0.
|
|
out_file (str or None): The filename to write the image.
|
|
Default: None.
|
|
|
|
Returns:
|
|
Tensor: Visualized img, only if not `show` or `out_file`.
|
|
"""
|
|
img = mmcv.imread(img)
|
|
img = img.copy()
|
|
|
|
bbox_result = []
|
|
bbox_labels = []
|
|
pose_result = []
|
|
for res in result:
|
|
if 'bbox' in res:
|
|
bbox_result.append(res['bbox'])
|
|
bbox_labels.append(res.get('label', None))
|
|
pose_result.append(res['keypoints'])
|
|
|
|
if bbox_result:
|
|
bboxes = np.vstack(bbox_result)
|
|
# draw bounding boxes
|
|
imshow_bboxes(
|
|
img,
|
|
bboxes,
|
|
labels=bbox_labels,
|
|
colors=bbox_color,
|
|
text_color=text_color,
|
|
thickness=bbox_thickness,
|
|
font_scale=font_scale,
|
|
show=False)
|
|
|
|
if pose_result:
|
|
imshow_keypoints(img, pose_result, skeleton, kpt_score_thr,
|
|
pose_kpt_color, pose_link_color, radius,
|
|
thickness)
|
|
|
|
if show:
|
|
imshow(img, win_name, wait_time)
|
|
|
|
if out_file is not None:
|
|
imwrite(img, out_file)
|
|
|
|
return img
|
|
|