You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

419 lines
16 KiB

import os
import glob
import torch
import random
import logging
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.utils.data
import torch.optim as optim
from common.opt import opts
from common.utils import *
from common.camera import get_uvd2xyz
from common.load_data_3dhp_mae import Fusion
from common.h36m_dataset import Human36mDataset
from model.block.refine import refine
from model.stc_pe_3dhp import Model
from model.stmo_pretrain import Model_MAE
#from thop import clever_format
#from thop.profile import profile
import scipy.io as scio
opt = opts().parse()
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
def train(opt, actions, train_loader, model, optimizer, epoch):
return step('train', opt, actions, train_loader, model, optimizer, epoch)
def val(opt, actions, val_loader, model):
with torch.no_grad():
return step('test', opt, actions, val_loader, model)
def step(split, opt, actions, dataLoader, model, optimizer=None, epoch=None):
model_trans = model['trans']
model_refine = model['refine']
model_MAE = model['MAE']
if split == 'train':
model_trans.train()
model_refine.train()
model_MAE.train()
else:
model_trans.eval()
model_refine.eval()
model_MAE.eval()
loss_all = {'loss': AccumLoss()}
error_sum = AccumLoss()
error_sum_test = AccumLoss()
action_error_sum = define_error_list(actions)
action_error_sum_post_out = define_error_list(actions)
action_error_sum_MAE = define_error_list(actions)
joints_left = [5, 6, 7, 11, 12, 13]
joints_right = [2, 3, 4, 8, 9, 10]
data_inference = {}
for i, data in enumerate(tqdm(dataLoader, 0)):
if opt.MAE:
#batch_cam, input_2D, seq, subject, scale, bb_box, cam_ind = data
if split == "train":
batch_cam, input_2D, seq, subject, scale, bb_box, cam_ind = data
else:
batch_cam, input_2D, seq, scale, bb_box = data
[input_2D, batch_cam, scale, bb_box] = get_varialbe(split,[input_2D, batch_cam, scale, bb_box])
N = input_2D.size(0)
f = opt.frames
mask_num = int(f*opt.temporal_mask_rate)
mask = np.hstack([
np.zeros(f - mask_num),
np.ones(mask_num),
]).flatten()
np.random.seed()
np.random.shuffle(mask)
mask = torch.from_numpy(mask).to(torch.bool).cuda()
spatial_mask = np.zeros((f, 17), dtype=bool)
for k in range(f):
ran = random.sample(range(0, 16), opt.spatial_mask_num)
spatial_mask[k, ran] = True
if opt.test_augmentation and split == 'test':
input_2D, output_2D = input_augmentation_MAE(input_2D, model_MAE, joints_left, joints_right, mask, spatial_mask)
else:
input_2D = input_2D.view(N, -1, opt.n_joints, opt.in_channels, 1).permute(0, 3, 1, 2, 4).type(
torch.cuda.FloatTensor)
output_2D = model_MAE(input_2D, mask, spatial_mask)
input_2D = input_2D.permute(0, 2, 3, 1, 4).view(N, -1, opt.n_joints, 2)
output_2D = output_2D.permute(0, 2, 3, 1, 4).view(N, -1, opt.n_joints, 2)
loss = mpjpe_cal(output_2D, torch.cat((input_2D[:, ~mask], input_2D[:, mask]), dim=1))
else:
#batch_cam, gt_3D, input_2D, action, subject, scale, bb_box, cam_ind = data
if split == "train":
batch_cam, gt_3D, input_2D, seq, subject, scale, bb_box, cam_ind = data
else:
batch_cam, gt_3D, input_2D, seq, scale, bb_box = data
[input_2D, gt_3D, batch_cam, scale, bb_box] = get_varialbe(split,
[input_2D, gt_3D, batch_cam, scale, bb_box])
N = input_2D.size(0)
out_target = gt_3D.clone().view(N, -1, opt.out_joints, opt.out_channels)
out_target[:, :, 14] = 0
gt_3D = gt_3D.view(N, -1, opt.out_joints, opt.out_channels).type(torch.cuda.FloatTensor)
if out_target.size(1) > 1:
out_target_single = out_target[:, opt.pad].unsqueeze(1)
gt_3D_single = gt_3D[:, opt.pad].unsqueeze(1)
else:
out_target_single = out_target
gt_3D_single = gt_3D
if opt.test_augmentation and split =='test':
input_2D, output_3D = input_augmentation(input_2D, model_trans, joints_left, joints_right)
else:
input_2D = input_2D.view(N, -1, opt.n_joints, opt.in_channels, 1).permute(0, 3, 1, 2, 4).type(torch.cuda.FloatTensor)
output_3D = model_trans(input_2D)
# output_3D_VTE = output_3D_VTE.permute(0, 2, 3, 4, 1).contiguous().view(N, -1, opt.out_joints, opt.out_channels)
# output_3D = output_3D.permute(0, 2, 3, 4, 1).contiguous().view(N, -1, opt.out_joints, opt.out_channels)
# output_3D_VTE = output_3D_VTE * scale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, output_3D_VTE.size(1),opt.out_joints, opt.out_channels)
output_3D = output_3D * scale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, output_3D.size(1),opt.out_joints, opt.out_channels)
output_3D_single = output_3D[:,opt.pad].unsqueeze(1)
if split == 'train':
# out_target = out_target[:, opt.pad].unsqueeze(1)
pred_out = output_3D
elif split == 'test':
pred_out = output_3D_single
input_2D = input_2D.permute(0, 2, 3, 1, 4).view(N, -1, opt.n_joints ,2)
if opt.refine:
pred_uv = input_2D
uvd = torch.cat((pred_uv[:, opt.pad, :, :].unsqueeze(1), output_3D_single[:, :, :, 2].unsqueeze(-1)), -1)
xyz = get_uvd2xyz(uvd, gt_3D_single, batch_cam)
xyz[:, :, 0, :] = 0
post_out = model_refine(output_3D_single, xyz)
loss = mpjpe_cal(post_out, out_target_single)
else:
# print(pred_out.shape)
# print(out_target.shape)
loss = mpjpe_cal(pred_out, out_target)
loss_all['loss'].update(loss.detach().cpu().numpy() * N, N)
if split == 'train':
optimizer.zero_grad()
loss.backward()
optimizer.step()
if not opt.MAE:
if opt.refine:
post_out[:,:,14,:] = 0
joint_error = mpjpe_cal(post_out, out_target_single).item()
else:
pred_out[:,:,14,:] = 0
joint_error = mpjpe_cal(pred_out, out_target).item()
error_sum.update(joint_error*N, N)
elif split == 'test':
if opt.MAE:
# action_error_sum_MAE = test_calculation(output_2D, torch.cat((input_2D[:, ~mask], input_2D[:, mask]), dim=1), action, action_error_sum_MAE, opt.dataset,
# subject,MAE=opt.MAE)
joint_error_test = mpjpe_cal(torch.cat((input_2D[:, ~mask], input_2D[:, mask]), dim=1), output_2D).item()
else:
pred_out[:, :, 14, :] = 0
#action_error_sum = test_calculation(pred_out, out_target, action, action_error_sum, opt.dataset, subject)
joint_error_test = mpjpe_cal(pred_out, out_target).item()
out = pred_out
# if opt.refine:
# post_out[:, :, 14, :] = 0
# action_error_sum_post_out = test_calculation(post_out, out_target, action, action_error_sum_post_out, opt.dataset, subject)
if opt.train == 0:
for seq_cnt in range(len(seq)):
seq_name = seq[seq_cnt]
if seq_name in data_inference:
data_inference[seq_name] = np.concatenate(
(data_inference[seq_name], out[seq_cnt].permute(2, 1, 0).cpu().numpy()), axis=2)
else:
data_inference[seq_name] = out[seq_cnt].permute(2, 1, 0).cpu().numpy()
error_sum_test.update(joint_error_test * N, N)
if split == 'train':
if opt.MAE:
return loss_all['loss'].avg*1000
else:
return loss_all['loss'].avg, error_sum.avg
elif split == 'test':
if opt.MAE:
#p1, p2 = print_error(opt.dataset, action_error_sum_MAE, opt.train)
return error_sum_test.avg*1000
if opt.refine:
p1, p2 = print_error(opt.dataset, action_error_sum_post_out, opt.train)
else:
#p1, p2 = print_error(opt.dataset, action_error_sum, opt.train)
if opt.train == 0:
for seq_name in data_inference.keys():
data_inference[seq_name] = data_inference[seq_name][:, :, None, :]
mat_path = os.path.join(opt.checkpoint, 'inference_data_81_3dhp.mat')
scio.savemat(mat_path, data_inference)
return error_sum_test.avg
def input_augmentation_MAE(input_2D, model_trans, joints_left, joints_right, mask, spatial_mask=None):
N, _, T, J, C = input_2D.shape
input_2D_flip = input_2D[:, 1].view(N, T, J, C, 1).permute(0, 3, 1, 2, 4)
input_2D_non_flip = input_2D[:, 0].view(N, T, J, C, 1).permute(0, 3, 1, 2, 4)
output_2D_flip = model_trans(input_2D_flip, mask, spatial_mask)
output_2D_flip[:,:,:, 0] *= -1
output_2D_flip[:, :, :, joints_left + joints_right] = output_2D_flip[:, :, :, joints_right + joints_left]
output_2D_non_flip = model_trans(input_2D_non_flip, mask, spatial_mask)
output_2D = (output_2D_non_flip + output_2D_flip) / 2
input_2D = input_2D_non_flip
return input_2D, output_2D
def input_augmentation(input_2D, model_trans, joints_left, joints_right):
N, _, T, J, C = input_2D.shape
input_2D_flip = input_2D[:, 1].view(N, T, J, C, 1).permute(0, 3, 1, 2, 4)
input_2D_non_flip = input_2D[:, 0].view(N, T, J, C, 1).permute(0, 3, 1, 2, 4)
output_3D_flip = model_trans(input_2D_flip)
# output_3D_flip_VTE[:, 0] *= -1
output_3D_flip[:,:,:, 0] *= -1
# output_3D_flip_VTE[:, :, :, joints_left + joints_right] = output_3D_flip_VTE[:, :, :, joints_right + joints_left]
output_3D_flip[:, :, joints_left + joints_right] = output_3D_flip[:, :, joints_right + joints_left]
output_3D_non_flip = model_trans(input_2D_non_flip)
# output_3D_VTE = (output_3D_non_flip_VTE + output_3D_flip_VTE) / 2
output_3D = (output_3D_non_flip + output_3D_flip) / 2
input_2D = input_2D_non_flip
return input_2D, output_3D
if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
opt.manualSeed = 1
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
np.random.seed(opt.manualSeed)
torch.cuda.manual_seed_all(opt.manualSeed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if opt.train == 1:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y/%m/%d %H:%M:%S', \
filename=os.path.join(opt.checkpoint, 'train.log'), level=logging.INFO)
root_path = opt.root_path
dataset_path = root_path + 'data_3d_' + opt.dataset + '.npz'
#dataset = Human36mDataset(dataset_path, opt)
actions = define_actions(opt.actions)
if opt.train:
#train_data = Fusion(opt=opt, train=True, root_path=root_path)
train_data = Fusion(opt=opt, train=True, root_path=root_path, MAE=opt.MAE)
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=opt.batchSize,
shuffle=True, num_workers=int(opt.workers), pin_memory=True)
if opt.test:
#test_data = Fusion(opt=opt, train=False,root_path =root_path)
test_data = Fusion(opt=opt, train=False, root_path=root_path, MAE=opt.MAE)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=opt.batchSize,
shuffle=False, num_workers=int(opt.workers), pin_memory=True)
opt.out_joints = 17
model = {}
model['trans'] = nn.DataParallel(Model(opt)).cuda()
model['refine'] = nn.DataParallel(refine(opt)).cuda()
model['MAE'] = nn.DataParallel(Model_MAE(opt)).cuda()
model_params = 0
for parameter in model['trans'].parameters():
model_params += parameter.numel()
print('INFO: Trainable parameter count:', model_params)
#if opt.MAE_test_reload==1:
# model_dict = model['MAE'].state_dict()
# MAE_test_path = opt.previous_dir
# pre_dict_MAE = torch.load(MAE_test_path)
# for name, key in model_dict.items():
# model_dict[name] = pre_dict_MAE[name]
# model['MAE'].load_state_dict(model_dict)
if opt.MAE_reload == 1:
model_dict = model['trans'].state_dict()
MAE_path = opt.previous_dir
pre_dict = torch.load(MAE_path)
state_dict = {k: v for k, v in pre_dict.items() if k in model_dict.keys()}
model_dict.update(state_dict)
model['trans'].load_state_dict(model_dict)
model_dict = model['trans'].state_dict()
if opt.reload == 1:
no_refine_path = opt.previous_dir
pre_dict = torch.load(no_refine_path)
for name, key in model_dict.items():
model_dict[name] = pre_dict[name]
model['trans'].load_state_dict(model_dict)
refine_dict = model['refine'].state_dict()
if opt.refine_reload == 1:
refine_path = opt.previous_refine_name
pre_dict_refine = torch.load(refine_path)
for name, key in refine_dict.items():
refine_dict[name] = pre_dict_refine[name]
model['refine'].load_state_dict(refine_dict)
all_param = []
lr = opt.lr
for i_model in model:
all_param += list(model[i_model].parameters())
optimizer_all = optim.Adam(all_param, lr=opt.lr, amsgrad=True)
for epoch in range(1, opt.nepoch):
if opt.train == 1:
if not opt.MAE:
loss, mpjpe = train(opt, actions, train_dataloader, model, optimizer_all, epoch)
else:
loss = train(opt, actions, train_dataloader, model, optimizer_all, epoch)
if opt.test == 1:
if not opt.MAE:
p1 = val(opt, actions, test_dataloader, model)
else:
p1 = val(opt, actions, test_dataloader, model)
data_threshold = p1
if opt.train and data_threshold < opt.previous_best_threshold:
if opt.MAE:
opt.previous_name = save_model(opt.previous_name, opt.checkpoint, epoch, data_threshold,
model['MAE'], 'MAE')
else:
opt.previous_name = save_model(opt.previous_name, opt.checkpoint, epoch, data_threshold, model['trans'], 'no_refine')
if opt.refine:
opt.previous_refine_name = save_model(opt.previous_refine_name, opt.checkpoint, epoch,
data_threshold, model['refine'], 'refine')
opt.previous_best_threshold = data_threshold
if opt.train == 0:
print('p1: %.2f' % (p1))
break
else:
if opt.MAE:
logging.info('epoch: %d, lr: %.7f, loss: %.4f, p1: %.2f' % (
epoch, lr, loss, p1))
print('e: %d, lr: %.7f, loss: %.4f, p1: %.2f' % (epoch, lr, loss, p1))
else:
logging.info('epoch: %d, lr: %.7f, loss: %.4f, MPJPE: %.2f, p1: %.2f' % (epoch, lr, loss, mpjpe, p1))
print('e: %d, lr: %.7f, loss: %.4f, M: %.2f, p1: %.2f' % (epoch, lr, loss, mpjpe, p1))
if epoch % opt.large_decay_epoch == 0:
for param_group in optimizer_all.param_groups:
param_group['lr'] *= opt.lr_decay_large
lr *= opt.lr_decay_large
else:
for param_group in optimizer_all.param_groups:
param_group['lr'] *= opt.lr_decay
lr *= opt.lr_decay