import numpy as np class ChunkedGenerator: def __init__(self, batch_size, cameras, poses_3d, poses_2d, valid_frame, chunk_length=1, pad=0, causal_shift=0, shuffle=False, random_seed=1234, augment=False, reverse_aug= False,kps_left=None, kps_right=None, joints_left=None, joints_right=None, endless=False, out_all = False, MAE=False, train=True): assert poses_3d is None or len(poses_3d) == len(poses_2d), (len(poses_3d), len(poses_2d)) assert cameras is None or len(cameras) == len(poses_2d) pairs = [] self.saved_index = {} start_index = 0 if train == True: for key in poses_2d.keys(): assert poses_3d is None or poses_2d[key].shape[0] == poses_3d[key].shape[0] n_chunks = (poses_2d[key].shape[0] + chunk_length - 1) // chunk_length offset = (n_chunks * chunk_length - poses_2d[key].shape[0]) // 2 bounds = np.arange(n_chunks + 1) * chunk_length - offset augment_vector = np.full(len(bounds - 1), False, dtype=bool) reverse_augment_vector = np.full(len(bounds - 1), False, dtype=bool) keys = np.tile(np.array(key).reshape([1,3]),(len(bounds - 1),1)) pairs += list(zip(keys, bounds[:-1], bounds[1:], augment_vector,reverse_augment_vector)) if reverse_aug: pairs += list(zip(keys, bounds[:-1], bounds[1:], augment_vector, ~reverse_augment_vector)) if augment: if reverse_aug: pairs += list(zip(keys, bounds[:-1], bounds[1:], ~augment_vector,~reverse_augment_vector)) else: pairs += list(zip(keys, bounds[:-1], bounds[1:], ~augment_vector, reverse_augment_vector)) end_index = start_index + poses_3d[key].shape[0] self.saved_index[key] = [start_index,end_index] start_index = start_index + poses_3d[key].shape[0] else: for key in poses_2d.keys(): assert poses_3d is None or poses_2d[key].shape[0] == poses_3d[key].shape[0] n_chunks = (poses_2d[key].shape[0] + chunk_length - 1) // chunk_length offset = (n_chunks * chunk_length - poses_2d[key].shape[0]) // 2 bounds = np.arange(n_chunks) * chunk_length - offset bounds_low = bounds[valid_frame[key].astype(bool)] bounds_high = bounds[valid_frame[key].astype(bool)] + np.ones(bounds_low.shape[0],dtype=int) augment_vector = np.full(len(bounds_low), False, dtype=bool) reverse_augment_vector = np.full(len(bounds_low), False, dtype=bool) keys = np.tile(np.array(key).reshape([1, 1]), (len(bounds_low), 1)) pairs += list(zip(keys, bounds_low, bounds_high, augment_vector, reverse_augment_vector)) if reverse_aug: pairs += list(zip(keys, bounds_low, bounds_high, augment_vector, ~reverse_augment_vector)) if augment: if reverse_aug: pairs += list(zip(keys, bounds_low, bounds_high, ~augment_vector, ~reverse_augment_vector)) else: pairs += list(zip(keys, bounds_low, bounds_high, ~augment_vector, reverse_augment_vector)) end_index = start_index + poses_3d[key].shape[0] self.saved_index[key] = [start_index, end_index] start_index = start_index + poses_3d[key].shape[0] if cameras is not None: self.batch_cam = np.empty((batch_size, cameras[key].shape[-1])) if poses_3d is not None: self.batch_3d = np.empty((batch_size, chunk_length, poses_3d[key].shape[-2], poses_3d[key].shape[-1])) self.batch_2d = np.empty((batch_size, chunk_length + 2 * pad, poses_2d[key].shape[-2], poses_2d[key].shape[-1])) self.num_batches = (len(pairs) + batch_size - 1) // batch_size self.batch_size = batch_size self.random = np.random.RandomState(random_seed) self.pairs = pairs self.shuffle = shuffle self.pad = pad self.causal_shift = causal_shift self.endless = endless self.state = None self.cameras = cameras if cameras is not None: self.cameras = cameras self.poses_3d = poses_3d self.poses_2d = poses_2d self.augment = augment self.kps_left = kps_left self.kps_right = kps_right self.joints_left = joints_left self.joints_right = joints_right self.out_all = out_all self.MAE=MAE self.valid_frame = valid_frame self.train=train def num_frames(self): return self.num_batches * self.batch_size def random_state(self): return self.random def set_random_state(self, random): self.random = random def augment_enabled(self): return self.augment def next_pairs(self): if self.state is None: if self.shuffle: pairs = self.random.permutation(self.pairs) else: pairs = self.pairs return 0, pairs else: return self.state def get_batch(self, seq_i, start_3d, end_3d, flip, reverse): if self.train==True: subject,seq,cam_index = seq_i seq_name = (subject,seq,cam_index) else: seq_name = seq_i[0] start_2d = start_3d - self.pad - self.causal_shift end_2d = end_3d + self.pad - self.causal_shift seq_2d = self.poses_2d[seq_name].copy() low_2d = max(start_2d, 0) high_2d = min(end_2d, seq_2d.shape[0]) pad_left_2d = low_2d - start_2d pad_right_2d = end_2d - high_2d if pad_left_2d != 0 or pad_right_2d != 0: self.batch_2d = np.pad(seq_2d[low_2d:high_2d], ((pad_left_2d, pad_right_2d), (0, 0), (0, 0)), 'edge') else: self.batch_2d = seq_2d[low_2d:high_2d] if flip: self.batch_2d[ :, :, 0] *= -1 self.batch_2d[ :, self.kps_left + self.kps_right] = self.batch_2d[ :, self.kps_right + self.kps_left] if reverse: self.batch_2d = self.batch_2d[::-1].copy() if not self.MAE: if self.poses_3d is not None: seq_3d = self.poses_3d[seq_name].copy() if self.out_all: low_3d = low_2d high_3d = high_2d pad_left_3d = pad_left_2d pad_right_3d = pad_right_2d else: low_3d = max(start_3d, 0) high_3d = min(end_3d, seq_3d.shape[0]) pad_left_3d = low_3d - start_3d pad_right_3d = end_3d - high_3d if pad_left_3d != 0 or pad_right_3d != 0: self.batch_3d = np.pad(seq_3d[low_3d:high_3d], ((pad_left_3d, pad_right_3d), (0, 0), (0, 0)), 'edge') else: self.batch_3d = seq_3d[low_3d:high_3d] if flip: self.batch_3d[ :, :, 0] *= -1 self.batch_3d[ :, self.joints_left + self.joints_right] = \ self.batch_3d[ :, self.joints_right + self.joints_left] if reverse: self.batch_3d = self.batch_3d[::-1].copy() if self.cameras is not None: self.batch_cam = self.cameras[seq_name].copy() if flip: self.batch_cam[ 2] *= -1 self.batch_cam[ 7] *= -1 if self.train == True: if self.MAE: return np.zeros(9), self.batch_2d.copy(), seq, subject, int(cam_index) if self.poses_3d is None and self.cameras is None: return None, None, self.batch_2d.copy(), seq, subject, int(cam_index) elif self.poses_3d is not None and self.cameras is None: return np.zeros(9), self.batch_3d.copy(), self.batch_2d.copy(),seq, subject, int(cam_index) elif self.poses_3d is None: return self.batch_cam, None, self.batch_2d.copy(),seq, subject, int(cam_index) else: return self.batch_cam, self.batch_3d.copy(), self.batch_2d.copy(),seq, subject, int(cam_index) else: if self.MAE: return np.zeros(9), self.batch_2d.copy(), seq_name, None, None else: return np.zeros(9), self.batch_3d.copy(), self.batch_2d.copy(), seq_name, None, None