You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
41 lines
1.4 KiB
41 lines
1.4 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import numpy as np
|
|
from torch.utils.data import Dataset
|
|
|
|
from mmpose.datasets.builder import DATASETS, build_dataset
|
|
|
|
|
|
@DATASETS.register_module()
|
|
class Body3DSemiSupervisionDataset(Dataset):
|
|
"""Mix Dataset for semi-supervised training in 3D human pose estimation
|
|
task.
|
|
|
|
The dataset combines data from two datasets (a labeled one and an unlabeled
|
|
one) and return a dict containing data from two datasets.
|
|
|
|
Args:
|
|
labeled_dataset (Dataset): Dataset with 3D keypoint annotations.
|
|
unlabeled_dataset (Dataset): Dataset without 3D keypoint annotations.
|
|
"""
|
|
|
|
def __init__(self, labeled_dataset, unlabeled_dataset):
|
|
super().__init__()
|
|
self.labeled_dataset = build_dataset(labeled_dataset)
|
|
self.unlabeled_dataset = build_dataset(unlabeled_dataset)
|
|
self.length = len(self.unlabeled_dataset)
|
|
|
|
def __len__(self):
|
|
"""Get the size of the dataset."""
|
|
return self.length
|
|
|
|
def __getitem__(self, i):
|
|
"""Given index, get the data from unlabeled dataset and randomly sample
|
|
an item from labeled dataset.
|
|
|
|
Return a dict containing data from labeled and unlabeled dataset.
|
|
"""
|
|
data = self.unlabeled_dataset[i]
|
|
rand_ind = np.random.randint(0, len(self.labeled_dataset))
|
|
labeled_data = self.labeled_dataset[rand_ind]
|
|
data.update(labeled_data)
|
|
return data
|
|
|