You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
67 lines
2.3 KiB
67 lines
2.3 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from mmcv import Config
|
|
|
|
from mmpose.datasets.builder import build_dataset
|
|
|
|
|
|
def test_concat_dataset():
|
|
# build COCO-like dataset config
|
|
dataset_info = Config.fromfile(
|
|
'configs/_base_/datasets/coco.py').dataset_info
|
|
|
|
channel_cfg = dict(
|
|
num_output_channels=17,
|
|
dataset_joints=17,
|
|
dataset_channel=[
|
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
|
],
|
|
inference_channel=[
|
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
|
|
])
|
|
|
|
data_cfg = dict(
|
|
image_size=[192, 256],
|
|
heatmap_size=[48, 64],
|
|
num_output_channels=channel_cfg['num_output_channels'],
|
|
num_joints=channel_cfg['dataset_joints'],
|
|
dataset_channel=channel_cfg['dataset_channel'],
|
|
inference_channel=channel_cfg['inference_channel'],
|
|
soft_nms=False,
|
|
nms_thr=1.0,
|
|
oks_thr=0.9,
|
|
vis_thr=0.2,
|
|
use_gt_bbox=True,
|
|
det_bbox_thr=0.0,
|
|
bbox_file='tests/data/coco/test_coco_det_AP_H_56.json',
|
|
)
|
|
|
|
dataset_cfg = dict(
|
|
type='TopDownCocoDataset',
|
|
ann_file='tests/data/coco/test_coco.json',
|
|
img_prefix='tests/data/coco/',
|
|
data_cfg=data_cfg,
|
|
pipeline=[],
|
|
dataset_info=dataset_info)
|
|
|
|
dataset = build_dataset(dataset_cfg)
|
|
|
|
# Case 1: build ConcatDataset explicitly
|
|
concat_dataset_cfg = dict(
|
|
type='ConcatDataset', datasets=[dataset_cfg, dataset_cfg])
|
|
concat_dataset = build_dataset(concat_dataset_cfg)
|
|
assert len(concat_dataset) == 2 * len(dataset)
|
|
|
|
# Case 2: build ConcatDataset from cfg sequence
|
|
concat_dataset = build_dataset([dataset_cfg, dataset_cfg])
|
|
assert len(concat_dataset) == 2 * len(dataset)
|
|
|
|
# Case 3: build ConcatDataset from ann_file sequence
|
|
concat_dataset_cfg = dataset_cfg.copy()
|
|
for key in ['ann_file', 'type', 'img_prefix', 'dataset_info']:
|
|
val = concat_dataset_cfg[key]
|
|
concat_dataset_cfg[key] = [val] * 2
|
|
for key in ['num_joints', 'dataset_channel']:
|
|
val = concat_dataset_cfg['data_cfg'][key]
|
|
concat_dataset_cfg['data_cfg'][key] = [val] * 2
|
|
concat_dataset = build_dataset(concat_dataset_cfg)
|
|
assert len(concat_dataset) == 2 * len(dataset)
|
|
|