diff --git a/configs/vitpose_sam/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTSam_base_coco_256x192.py b/configs/vitpose_sam/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTSam_base_coco_256x192.py index 60a1108..50b6ffa 100644 --- a/configs/vitpose_sam/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTSam_base_coco_256x192.py +++ b/configs/vitpose_sam/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTSam_base_coco_256x192.py @@ -59,9 +59,10 @@ model = dict( qkv_bias=True, drop_path_rate=0.3, frozen_stages=12, - freeze_attn = True, - freeze_ffn = True, - samvit_checkpoint='/root/autodl-tmp/code/ViTPose/checkpoints/sam/sam_vit_b_01ec64.pth' + freeze_attn=True, + freeze_ffn=True, + samvit_checkpoint='/root/autodl-tmp/code/ViTPose/checkpoints/sam/sam_vit_b_01ec64.pth', + sam_img_size=512 ), keypoint_head=dict( type='TopdownHeatmapSimpleHead', @@ -81,7 +82,7 @@ model = dict( modulate_kernel=11, use_udp=True)) -data_root = '/root/autodl-tmp/dataset/coco2017/' +data_root = '/root/autodl-tmp/dataset/coco2017' data_cfg = dict( image_size=[192, 256], @@ -97,6 +98,7 @@ data_cfg = dict( use_gt_bbox=False, det_bbox_thr=0.0, bbox_file=f'{data_root}/person_detection_results/COCO_val2017_detections_AP_H_56_person.json', + sam_image_size=[512, 512], ) train_pipeline = [ @@ -155,21 +157,21 @@ data = dict( val_dataloader=dict(samples_per_gpu=12), test_dataloader=dict(samples_per_gpu=12), train=dict( - type='TopDownCocoDataset', + type='TopDownCocoDatasetSelf', ann_file=f'{data_root}/annotations/person_keypoints_train2017.json', img_prefix=f'{data_root}/train2017/', data_cfg=data_cfg, pipeline=train_pipeline, dataset_info={{_base_.dataset_info}}), val=dict( - type='TopDownCocoDataset', + type='TopDownCocoDatasetSelf', ann_file=f'{data_root}/annotations/person_keypoints_val2017.json', img_prefix=f'{data_root}/val2017/', data_cfg=data_cfg, pipeline=val_pipeline, dataset_info={{_base_.dataset_info}}), test=dict( - type='TopDownCocoDataset', + type='TopDownCocoDatasetSelf', ann_file=f'{data_root}/annotations/person_keypoints_val2017.json', img_prefix=f'{data_root}/val2017/', data_cfg=data_cfg, diff --git a/mmpose/datasets/datasets/top_down/__init__.py b/mmpose/datasets/datasets/top_down/__init__.py index cc5b46a..fc42df0 100644 --- a/mmpose/datasets/datasets/top_down/__init__.py +++ b/mmpose/datasets/datasets/top_down/__init__.py @@ -13,6 +13,8 @@ from .topdown_ochuman_dataset import TopDownOCHumanDataset from .topdown_posetrack18_dataset import TopDownPoseTrack18Dataset from .topdown_posetrack18_video_dataset import TopDownPoseTrack18VideoDataset +from .topdown_coco_dataset_self import TopDownCocoDatasetSelf + __all__ = [ 'TopDownAicDataset', 'TopDownCocoDataset', @@ -27,4 +29,5 @@ __all__ = [ 'TopDownH36MDataset', 'TopDownHalpeDataset', 'TopDownPoseTrack18VideoDataset', + 'TopDownCocoDatasetSelf' ] diff --git a/mmpose/datasets/datasets/top_down/topdown_coco_dataset_self.py b/mmpose/datasets/datasets/top_down/topdown_coco_dataset_self.py new file mode 100644 index 0000000..d9776c0 --- /dev/null +++ b/mmpose/datasets/datasets/top_down/topdown_coco_dataset_self.py @@ -0,0 +1,407 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import tempfile +import warnings +from collections import OrderedDict, defaultdict + +import json_tricks as json +import numpy as np +from mmcv import Config, deprecated_api_warning +from xtcocotools.cocoeval import COCOeval + +from ....core.post_processing import oks_nms, soft_oks_nms +from ...builder import DATASETS +from ..base import Kpt2dSviewRgbImgTopDownDataset + + +@DATASETS.register_module() +class TopDownCocoDatasetSelf(Kpt2dSviewRgbImgTopDownDataset): + """CocoDataset dataset for top-down pose estimation. + + "Microsoft COCO: Common Objects in Context", ECCV'2014. + More details can be found in the `paper + `__ . + + The dataset loads raw features and apply specified transforms + to return a dict containing the image tensors and other information. + + COCO keypoint indexes:: + + 0: 'nose', + 1: 'left_eye', + 2: 'right_eye', + 3: 'left_ear', + 4: 'right_ear', + 5: 'left_shoulder', + 6: 'right_shoulder', + 7: 'left_elbow', + 8: 'right_elbow', + 9: 'left_wrist', + 10: 'right_wrist', + 11: 'left_hip', + 12: 'right_hip', + 13: 'left_knee', + 14: 'right_knee', + 15: 'left_ankle', + 16: 'right_ankle' + + Args: + ann_file (str): Path to the annotation file. + img_prefix (str): Path to a directory where images are held. + Default: None. + data_cfg (dict): config + pipeline (list[dict | callable]): A sequence of data transforms. + dataset_info (DatasetInfo): A class containing all dataset info. + test_mode (bool): Store True when building test or + validation dataset. Default: False. + """ + + def __init__(self, + ann_file, + img_prefix, + data_cfg, + pipeline, + dataset_info=None, + test_mode=False): + + if dataset_info is None: + warnings.warn( + 'dataset_info is missing. ' + 'Check https://github.com/open-mmlab/mmpose/pull/663 ' + 'for details.', DeprecationWarning) + cfg = Config.fromfile('configs/_base_/datasets/coco.py') + dataset_info = cfg._cfg_dict['dataset_info'] + + super().__init__( + ann_file, + img_prefix, + data_cfg, + pipeline, + dataset_info=dataset_info, + test_mode=test_mode) + + self.use_gt_bbox = data_cfg['use_gt_bbox'] + self.bbox_file = data_cfg['bbox_file'] + self.det_bbox_thr = data_cfg.get('det_bbox_thr', 0.0) + self.use_nms = data_cfg.get('use_nms', True) + self.soft_nms = data_cfg['soft_nms'] + self.nms_thr = data_cfg['nms_thr'] + self.oks_thr = data_cfg['oks_thr'] + self.vis_thr = data_cfg['vis_thr'] + + self.ann_info['sam_image_size'] = np.array(data_cfg['sam_image_size']) + + self.db = self._get_db() + + print(f'=> num_images: {self.num_images}') + print(f'=> load {len(self.db)} samples') + + def _get_db(self): + """Load dataset.""" + if (not self.test_mode) or self.use_gt_bbox: + # use ground truth bbox + gt_db = self._load_coco_keypoint_annotations() + else: + # use bbox from detection + gt_db = self._load_coco_person_detection_results() + return gt_db + + def _load_coco_keypoint_annotations(self): + """Ground truth bbox and keypoints.""" + gt_db = [] + for img_id in self.img_ids: + gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id)) + return gt_db + + def _load_coco_keypoint_annotation_kernel(self, img_id): + """load annotation from COCOAPI. + + Note: + bbox:[x1, y1, w, h] + + Args: + img_id: coco image id + + Returns: + dict: db entry + """ + img_ann = self.coco.loadImgs(img_id)[0] + width = img_ann['width'] + height = img_ann['height'] + num_joints = self.ann_info['num_joints'] + + ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False) + objs = self.coco.loadAnns(ann_ids) + + # sanitize bboxes + valid_objs = [] + for obj in objs: + if 'bbox' not in obj: + continue + x, y, w, h = obj['bbox'] + x1 = max(0, x) + y1 = max(0, y) + x2 = min(width - 1, x1 + max(0, w - 1)) + y2 = min(height - 1, y1 + max(0, h - 1)) + if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1: + obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1] + valid_objs.append(obj) + objs = valid_objs + + bbox_id = 0 + rec = [] + for obj in objs: + if 'keypoints' not in obj: + continue + if max(obj['keypoints']) == 0: + continue + if 'num_keypoints' in obj and obj['num_keypoints'] == 0: + continue + joints_3d = np.zeros((num_joints, 3), dtype=np.float32) + joints_3d_visible = np.zeros((num_joints, 3), dtype=np.float32) + + keypoints = np.array(obj['keypoints']).reshape(-1, 3) + joints_3d[:, :2] = keypoints[:, :2] + joints_3d_visible[:, :2] = np.minimum(1, keypoints[:, 2:3]) + + center, scale = self._xywh2cs(*obj['clean_bbox'][:4]) + + image_file = osp.join(self.img_prefix, self.id2name[img_id]) + rec.append({ + 'image_file': image_file, + 'center': center, + 'scale': scale, + 'bbox': obj['clean_bbox'][:4], + 'rotation': 0, + 'joints_3d': joints_3d, + 'joints_3d_visible': joints_3d_visible, + 'dataset': self.dataset_name, + 'bbox_score': 1, + 'bbox_id': bbox_id + }) + bbox_id = bbox_id + 1 + + return rec + + def _load_coco_person_detection_results(self): + """Load coco person detection results.""" + num_joints = self.ann_info['num_joints'] + all_boxes = None + with open(self.bbox_file, 'r') as f: + all_boxes = json.load(f) + + if not all_boxes: + raise ValueError('=> Load %s fail!' % self.bbox_file) + + print(f'=> Total boxes: {len(all_boxes)}') + + kpt_db = [] + bbox_id = 0 + for det_res in all_boxes: + if det_res['category_id'] != 1: + continue + + image_file = osp.join(self.img_prefix, + self.id2name[det_res['image_id']]) + box = det_res['bbox'] + score = det_res['score'] + + if score < self.det_bbox_thr: + continue + + center, scale = self._xywh2cs(*box[:4]) + joints_3d = np.zeros((num_joints, 3), dtype=np.float32) + joints_3d_visible = np.ones((num_joints, 3), dtype=np.float32) + kpt_db.append({ + 'image_file': image_file, + 'center': center, + 'scale': scale, + 'rotation': 0, + 'bbox': box[:4], + 'bbox_score': score, + 'dataset': self.dataset_name, + 'joints_3d': joints_3d, + 'joints_3d_visible': joints_3d_visible, + 'bbox_id': bbox_id + }) + bbox_id = bbox_id + 1 + print(f'=> Total boxes after filter ' + f'low score@{self.det_bbox_thr}: {bbox_id}') + return kpt_db + + @deprecated_api_warning(name_dict=dict(outputs='results')) + def evaluate(self, results, res_folder=None, metric='mAP', **kwargs): + """Evaluate coco keypoint results. The pose prediction results will be + saved in ``${res_folder}/result_keypoints.json``. + + Note: + - batch_size: N + - num_keypoints: K + - heatmap height: H + - heatmap width: W + + Args: + results (list[dict]): Testing results containing the following + items: + + - preds (np.ndarray[N,K,3]): The first two dimensions are \ + coordinates, score is the third dimension of the array. + - boxes (np.ndarray[N,6]): [center[0], center[1], scale[0], \ + scale[1],area, score] + - image_paths (list[str]): For example, ['data/coco/val2017\ + /000000393226.jpg'] + - heatmap (np.ndarray[N, K, H, W]): model output heatmap + - bbox_id (list(int)). + res_folder (str, optional): The folder to save the testing + results. If not specified, a temp folder will be created. + Default: None. + metric (str | list[str]): Metric to be performed. Defaults: 'mAP'. + + Returns: + dict: Evaluation results for evaluation metric. + """ + metrics = metric if isinstance(metric, list) else [metric] + allowed_metrics = ['mAP'] + for metric in metrics: + if metric not in allowed_metrics: + raise KeyError(f'metric {metric} is not supported') + + if res_folder is not None: + tmp_folder = None + res_file = osp.join(res_folder, 'result_keypoints.json') + else: + tmp_folder = tempfile.TemporaryDirectory() + res_file = osp.join(tmp_folder.name, 'result_keypoints.json') + + kpts = defaultdict(list) + + for result in results: + preds = result['preds'] + boxes = result['boxes'] + image_paths = result['image_paths'] + bbox_ids = result['bbox_ids'] + + batch_size = len(image_paths) + for i in range(batch_size): + image_id = self.name2id[image_paths[i][len(self.img_prefix):]] + kpts[image_id].append({ + 'keypoints': preds[i], + 'center': boxes[i][0:2], + 'scale': boxes[i][2:4], + 'area': boxes[i][4], + 'score': boxes[i][5], + 'image_id': image_id, + 'bbox_id': bbox_ids[i] + }) + kpts = self._sort_and_unique_bboxes(kpts) + + # rescoring and oks nms + num_joints = self.ann_info['num_joints'] + vis_thr = self.vis_thr + oks_thr = self.oks_thr + valid_kpts = [] + for image_id in kpts.keys(): + img_kpts = kpts[image_id] + for n_p in img_kpts: + box_score = n_p['score'] + kpt_score = 0 + valid_num = 0 + for n_jt in range(0, num_joints): + t_s = n_p['keypoints'][n_jt][2] + if t_s > vis_thr: + kpt_score = kpt_score + t_s + valid_num = valid_num + 1 + if valid_num != 0: + kpt_score = kpt_score / valid_num + # rescoring + n_p['score'] = kpt_score * box_score + + if self.use_nms: + nms = soft_oks_nms if self.soft_nms else oks_nms + keep = nms(img_kpts, oks_thr, sigmas=self.sigmas) + valid_kpts.append([img_kpts[_keep] for _keep in keep]) + else: + valid_kpts.append(img_kpts) + + self._write_coco_keypoint_results(valid_kpts, res_file) + + info_str = self._do_python_keypoint_eval(res_file) + name_value = OrderedDict(info_str) + + if tmp_folder is not None: + tmp_folder.cleanup() + + return name_value + + def _write_coco_keypoint_results(self, keypoints, res_file): + """Write results into a json file.""" + data_pack = [{ + 'cat_id': self._class_to_coco_ind[cls], + 'cls_ind': cls_ind, + 'cls': cls, + 'ann_type': 'keypoints', + 'keypoints': keypoints + } for cls_ind, cls in enumerate(self.classes) + if not cls == '__background__'] + + results = self._coco_keypoint_results_one_category_kernel(data_pack[0]) + + with open(res_file, 'w') as f: + json.dump(results, f, sort_keys=True, indent=4) + + def _coco_keypoint_results_one_category_kernel(self, data_pack): + """Get coco keypoint results.""" + cat_id = data_pack['cat_id'] + keypoints = data_pack['keypoints'] + cat_results = [] + + for img_kpts in keypoints: + if len(img_kpts) == 0: + continue + + _key_points = np.array( + [img_kpt['keypoints'] for img_kpt in img_kpts]) + key_points = _key_points.reshape(-1, + self.ann_info['num_joints'] * 3) + + result = [{ + 'image_id': img_kpt['image_id'], + 'category_id': cat_id, + 'keypoints': key_point.tolist(), + 'score': float(img_kpt['score']), + 'center': img_kpt['center'].tolist(), + 'scale': img_kpt['scale'].tolist() + } for img_kpt, key_point in zip(img_kpts, key_points)] + + cat_results.extend(result) + + return cat_results + + def _do_python_keypoint_eval(self, res_file): + """Keypoint evaluation using COCOAPI.""" + coco_det = self.coco.loadRes(res_file) + coco_eval = COCOeval(self.coco, coco_det, 'keypoints', self.sigmas) + coco_eval.params.useSegm = None + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + stats_names = [ + 'AP', 'AP .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5', + 'AR .75', 'AR (M)', 'AR (L)' + ] + + info_str = list(zip(stats_names, coco_eval.stats)) + + return info_str + + def _sort_and_unique_bboxes(self, kpts, key='bbox_id'): + """sort kpts and remove the repeated ones.""" + for img_id, persons in kpts.items(): + num = len(persons) + kpts[img_id] = sorted(kpts[img_id], key=lambda x: x[key]) + for i in range(num - 1, 0, -1): + if kpts[img_id][i][key] == kpts[img_id][i - 1][key]: + del kpts[img_id][i] + + return kpts diff --git a/mmpose/datasets/pipelines/top_down_transform_self.py b/mmpose/datasets/pipelines/top_down_transform_self.py index c6c8b49..0c8859c 100644 --- a/mmpose/datasets/pipelines/top_down_transform_self.py +++ b/mmpose/datasets/pipelines/top_down_transform_self.py @@ -27,7 +27,7 @@ class TopDownAffineSam: def __call__(self, results): image_size = results['ann_info']['image_size'] # 修改 - sam_image_size = np.array([1024, 1024]) + sam_image_size = results['ann_info']['sam_image_size'] img = results['img'] joints_3d = results['joints_3d'] diff --git a/mmpose/models/backbones/vit_sam.py b/mmpose/models/backbones/vit_sam.py index 58e438e..a89788b 100644 --- a/mmpose/models/backbones/vit_sam.py +++ b/mmpose/models/backbones/vit_sam.py @@ -282,7 +282,8 @@ class ViTSam(BaseBackbone): num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False, frozen_stages=-1, ratio=1, last_norm=True, - patch_padding='pad', freeze_attn=False, freeze_ffn=False, samvit_checkpoint=None + patch_padding='pad', freeze_attn=False, freeze_ffn=False, + samvit_checkpoint=None, sam_img_size=1024 ): # Protect mutable default arguments super(ViTSam, self).__init__() @@ -324,14 +325,14 @@ class ViTSam(BaseBackbone): self._freeze_stages() # ======================== SAM-ViT ======================== - self.sam_vit = build_vit_sam(model_name='vit_b', checkpoint=samvit_checkpoint) + self.sam_vit = build_vit_sam(model_name='vit_b', checkpoint=samvit_checkpoint, img_size=sam_img_size) self.sam_vit.eval() for param in self.sam_vit.parameters(): param.requires_grad = False # 交叉注意力 - # self.cross_attn = Cross_Attention(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, \ - # qk_scale=qk_scale, attn_drop=attn_drop_rate, proj_drop=drop_rate) + self.cross_attn = Cross_Attention(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, \ + qk_scale=qk_scale, attn_drop=attn_drop_rate, proj_drop=drop_rate) # vit_token做自注意力后,再和sam_token做交叉注意力,得到的结果再经过FFN # self.custom_attn_ffn = CustomAttentionFFN(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, \