You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

155 lines
5.4 KiB

# Copyright (c) OpenMMLab. All rights reserved.
import copy as cp
import os
from abc import ABCMeta
import numpy as np
from torch.utils.data import Dataset
from mmpose.datasets.pipelines import Compose
class MeshBaseDataset(Dataset, metaclass=ABCMeta):
"""Base dataset for 3D human mesh estimation task. In 3D humamesh
estimation task, all datasets share this BaseDataset for training and have
their own evaluate function.
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
This dataset can only be used for training.
For evaluation, subclass should write an extra evaluate function.
Args:
ann_file (str): Path to the annotation file.
img_prefix (str): Path to a directory where images are held.
Default: None.
data_cfg (dict): config
pipeline (list[dict | callable]): A sequence of data transforms.
"""
def __init__(self,
ann_file,
img_prefix,
data_cfg,
pipeline,
test_mode=False):
self.image_info = {}
self.ann_info = {}
self.ann_file = ann_file
self.img_prefix = img_prefix
self.pipeline = pipeline
self.test_mode = test_mode
self.ann_info['image_size'] = np.array(data_cfg['image_size'])
self.ann_info['iuv_size'] = np.array(data_cfg['iuv_size'])
self.ann_info['num_joints'] = data_cfg['num_joints']
self.ann_info['flip_pairs'] = None
self.db = []
self.pipeline = Compose(self.pipeline)
# flip_pairs
# For all mesh dataset, we use 24 joints as CMR and SPIN.
self.ann_info['flip_pairs'] = [[0, 5], [1, 4], [2, 3], [6, 11],
[7, 10], [8, 9], [20, 21], [22, 23]]
self.ann_info['use_different_joint_weights'] = False
assert self.ann_info['num_joints'] == 24
self.ann_info['joint_weights'] = np.ones([24, 1], dtype=np.float32)
self.ann_info['uv_type'] = data_cfg['uv_type']
self.ann_info['use_IUV'] = data_cfg['use_IUV']
uv_type = self.ann_info['uv_type']
self.iuv_prefix = os.path.join(self.img_prefix, f'{uv_type}_IUV_gt')
self.db = self._get_db(ann_file)
def _get_db(self, ann_file):
"""Load dataset."""
data = np.load(ann_file)
tmpl = dict(
image_file=None,
center=None,
scale=None,
rotation=0,
joints_2d=None,
joints_2d_visible=None,
joints_3d=None,
joints_3d_visible=None,
gender=None,
pose=None,
beta=None,
has_smpl=0,
iuv_file=None,
has_iuv=0)
gt_db = []
_imgnames = data['imgname']
_scales = data['scale'].astype(np.float32)
_centers = data['center'].astype(np.float32)
dataset_len = len(_imgnames)
# Get 2D keypoints
if 'part' in data.keys():
_keypoints = data['part'].astype(np.float32)
else:
_keypoints = np.zeros((dataset_len, 24, 3), dtype=np.float32)
# Get gt 3D joints, if available
if 'S' in data.keys():
_joints_3d = data['S'].astype(np.float32)
else:
_joints_3d = np.zeros((dataset_len, 24, 4), dtype=np.float32)
# Get gt SMPL parameters, if available
if 'pose' in data.keys() and 'shape' in data.keys():
_poses = data['pose'].astype(np.float32)
_betas = data['shape'].astype(np.float32)
has_smpl = 1
else:
_poses = np.zeros((dataset_len, 72), dtype=np.float32)
_betas = np.zeros((dataset_len, 10), dtype=np.float32)
has_smpl = 0
# Get gender data, if available
if 'gender' in data.keys():
_genders = data['gender']
_genders = np.array([str(g) != 'm' for g in _genders]).astype(int)
else:
_genders = -1 * np.ones(dataset_len).astype(int)
# Get IUV image, if available
if 'iuv_names' in data.keys():
_iuv_names = data['iuv_names']
has_iuv = has_smpl
else:
_iuv_names = [''] * dataset_len
has_iuv = 0
for i in range(len(_imgnames)):
newitem = cp.deepcopy(tmpl)
newitem['image_file'] = os.path.join(self.img_prefix, _imgnames[i])
newitem['scale'] = np.array([_scales[i], _scales[i]])
newitem['center'] = _centers[i]
newitem['joints_2d'] = _keypoints[i, :, :2]
newitem['joints_2d_visible'] = _keypoints[i, :, -1][:, None]
newitem['joints_3d'] = _joints_3d[i, :, :3]
newitem['joints_3d_visible'] = _joints_3d[i, :, -1][:, None]
newitem['pose'] = _poses[i]
newitem['beta'] = _betas[i]
newitem['has_smpl'] = has_smpl
newitem['gender'] = _genders[i]
newitem['iuv_file'] = os.path.join(self.iuv_prefix, _iuv_names[i])
newitem['has_iuv'] = has_iuv
gt_db.append(newitem)
return gt_db
def __len__(self, ):
"""Get the size of the dataset."""
return len(self.db)
def __getitem__(self, idx):
"""Get the sample given index."""
results = cp.deepcopy(self.db[idx])
results['ann_info'] = self.ann_info
return self.pipeline(results)