You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
419 lines
16 KiB
419 lines
16 KiB
import os
|
|
import glob
|
|
import torch
|
|
import random
|
|
import logging
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
import torch.nn as nn
|
|
import torch.utils.data
|
|
import torch.optim as optim
|
|
from common.opt import opts
|
|
from common.utils import *
|
|
from common.camera import get_uvd2xyz
|
|
from common.load_data_3dhp_mae import Fusion
|
|
from common.h36m_dataset import Human36mDataset
|
|
from model.block.refine import refine
|
|
from model.stc_pe_3dhp import Model
|
|
from model.stmo_pretrain import Model_MAE
|
|
|
|
#from thop import clever_format
|
|
#from thop.profile import profile
|
|
import scipy.io as scio
|
|
|
|
opt = opts().parse()
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
|
|
|
|
def train(opt, actions, train_loader, model, optimizer, epoch):
|
|
return step('train', opt, actions, train_loader, model, optimizer, epoch)
|
|
|
|
def val(opt, actions, val_loader, model):
|
|
with torch.no_grad():
|
|
return step('test', opt, actions, val_loader, model)
|
|
|
|
def step(split, opt, actions, dataLoader, model, optimizer=None, epoch=None):
|
|
model_trans = model['trans']
|
|
model_refine = model['refine']
|
|
model_MAE = model['MAE']
|
|
|
|
if split == 'train':
|
|
model_trans.train()
|
|
model_refine.train()
|
|
model_MAE.train()
|
|
else:
|
|
model_trans.eval()
|
|
model_refine.eval()
|
|
model_MAE.eval()
|
|
|
|
loss_all = {'loss': AccumLoss()}
|
|
error_sum = AccumLoss()
|
|
error_sum_test = AccumLoss()
|
|
|
|
action_error_sum = define_error_list(actions)
|
|
action_error_sum_post_out = define_error_list(actions)
|
|
action_error_sum_MAE = define_error_list(actions)
|
|
|
|
joints_left = [5, 6, 7, 11, 12, 13]
|
|
joints_right = [2, 3, 4, 8, 9, 10]
|
|
|
|
data_inference = {}
|
|
|
|
for i, data in enumerate(tqdm(dataLoader, 0)):
|
|
|
|
if opt.MAE:
|
|
#batch_cam, input_2D, seq, subject, scale, bb_box, cam_ind = data
|
|
if split == "train":
|
|
batch_cam, input_2D, seq, subject, scale, bb_box, cam_ind = data
|
|
else:
|
|
batch_cam, input_2D, seq, scale, bb_box = data
|
|
[input_2D, batch_cam, scale, bb_box] = get_varialbe(split,[input_2D, batch_cam, scale, bb_box])
|
|
|
|
N = input_2D.size(0)
|
|
f = opt.frames
|
|
|
|
mask_num = int(f*opt.temporal_mask_rate)
|
|
mask = np.hstack([
|
|
np.zeros(f - mask_num),
|
|
np.ones(mask_num),
|
|
]).flatten()
|
|
|
|
np.random.seed()
|
|
np.random.shuffle(mask)
|
|
|
|
mask = torch.from_numpy(mask).to(torch.bool).cuda()
|
|
|
|
spatial_mask = np.zeros((f, 17), dtype=bool)
|
|
for k in range(f):
|
|
ran = random.sample(range(0, 16), opt.spatial_mask_num)
|
|
spatial_mask[k, ran] = True
|
|
|
|
|
|
if opt.test_augmentation and split == 'test':
|
|
input_2D, output_2D = input_augmentation_MAE(input_2D, model_MAE, joints_left, joints_right, mask, spatial_mask)
|
|
|
|
else:
|
|
input_2D = input_2D.view(N, -1, opt.n_joints, opt.in_channels, 1).permute(0, 3, 1, 2, 4).type(
|
|
torch.cuda.FloatTensor)
|
|
output_2D = model_MAE(input_2D, mask, spatial_mask)
|
|
|
|
|
|
input_2D = input_2D.permute(0, 2, 3, 1, 4).view(N, -1, opt.n_joints, 2)
|
|
output_2D = output_2D.permute(0, 2, 3, 1, 4).view(N, -1, opt.n_joints, 2)
|
|
|
|
loss = mpjpe_cal(output_2D, torch.cat((input_2D[:, ~mask], input_2D[:, mask]), dim=1))
|
|
|
|
|
|
else:
|
|
#batch_cam, gt_3D, input_2D, action, subject, scale, bb_box, cam_ind = data
|
|
|
|
if split == "train":
|
|
batch_cam, gt_3D, input_2D, seq, subject, scale, bb_box, cam_ind = data
|
|
else:
|
|
batch_cam, gt_3D, input_2D, seq, scale, bb_box = data
|
|
|
|
[input_2D, gt_3D, batch_cam, scale, bb_box] = get_varialbe(split,
|
|
[input_2D, gt_3D, batch_cam, scale, bb_box])
|
|
|
|
N = input_2D.size(0)
|
|
|
|
out_target = gt_3D.clone().view(N, -1, opt.out_joints, opt.out_channels)
|
|
out_target[:, :, 14] = 0
|
|
gt_3D = gt_3D.view(N, -1, opt.out_joints, opt.out_channels).type(torch.cuda.FloatTensor)
|
|
|
|
if out_target.size(1) > 1:
|
|
out_target_single = out_target[:, opt.pad].unsqueeze(1)
|
|
gt_3D_single = gt_3D[:, opt.pad].unsqueeze(1)
|
|
else:
|
|
out_target_single = out_target
|
|
gt_3D_single = gt_3D
|
|
|
|
if opt.test_augmentation and split =='test':
|
|
input_2D, output_3D = input_augmentation(input_2D, model_trans, joints_left, joints_right)
|
|
else:
|
|
input_2D = input_2D.view(N, -1, opt.n_joints, opt.in_channels, 1).permute(0, 3, 1, 2, 4).type(torch.cuda.FloatTensor)
|
|
output_3D = model_trans(input_2D)
|
|
|
|
# output_3D_VTE = output_3D_VTE.permute(0, 2, 3, 4, 1).contiguous().view(N, -1, opt.out_joints, opt.out_channels)
|
|
# output_3D = output_3D.permute(0, 2, 3, 4, 1).contiguous().view(N, -1, opt.out_joints, opt.out_channels)
|
|
|
|
# output_3D_VTE = output_3D_VTE * scale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, output_3D_VTE.size(1),opt.out_joints, opt.out_channels)
|
|
output_3D = output_3D * scale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, output_3D.size(1),opt.out_joints, opt.out_channels)
|
|
output_3D_single = output_3D[:,opt.pad].unsqueeze(1)
|
|
|
|
if split == 'train':
|
|
# out_target = out_target[:, opt.pad].unsqueeze(1)
|
|
pred_out = output_3D
|
|
|
|
elif split == 'test':
|
|
pred_out = output_3D_single
|
|
|
|
input_2D = input_2D.permute(0, 2, 3, 1, 4).view(N, -1, opt.n_joints ,2)
|
|
|
|
if opt.refine:
|
|
pred_uv = input_2D
|
|
uvd = torch.cat((pred_uv[:, opt.pad, :, :].unsqueeze(1), output_3D_single[:, :, :, 2].unsqueeze(-1)), -1)
|
|
xyz = get_uvd2xyz(uvd, gt_3D_single, batch_cam)
|
|
xyz[:, :, 0, :] = 0
|
|
post_out = model_refine(output_3D_single, xyz)
|
|
loss = mpjpe_cal(post_out, out_target_single)
|
|
else:
|
|
# print(pred_out.shape)
|
|
# print(out_target.shape)
|
|
loss = mpjpe_cal(pred_out, out_target)
|
|
|
|
loss_all['loss'].update(loss.detach().cpu().numpy() * N, N)
|
|
|
|
if split == 'train':
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
if not opt.MAE:
|
|
|
|
if opt.refine:
|
|
post_out[:,:,14,:] = 0
|
|
joint_error = mpjpe_cal(post_out, out_target_single).item()
|
|
else:
|
|
pred_out[:,:,14,:] = 0
|
|
joint_error = mpjpe_cal(pred_out, out_target).item()
|
|
|
|
error_sum.update(joint_error*N, N)
|
|
|
|
elif split == 'test':
|
|
if opt.MAE:
|
|
# action_error_sum_MAE = test_calculation(output_2D, torch.cat((input_2D[:, ~mask], input_2D[:, mask]), dim=1), action, action_error_sum_MAE, opt.dataset,
|
|
# subject,MAE=opt.MAE)
|
|
joint_error_test = mpjpe_cal(torch.cat((input_2D[:, ~mask], input_2D[:, mask]), dim=1), output_2D).item()
|
|
else:
|
|
pred_out[:, :, 14, :] = 0
|
|
#action_error_sum = test_calculation(pred_out, out_target, action, action_error_sum, opt.dataset, subject)
|
|
joint_error_test = mpjpe_cal(pred_out, out_target).item()
|
|
out = pred_out
|
|
# if opt.refine:
|
|
# post_out[:, :, 14, :] = 0
|
|
# action_error_sum_post_out = test_calculation(post_out, out_target, action, action_error_sum_post_out, opt.dataset, subject)
|
|
|
|
if opt.train == 0:
|
|
for seq_cnt in range(len(seq)):
|
|
seq_name = seq[seq_cnt]
|
|
if seq_name in data_inference:
|
|
data_inference[seq_name] = np.concatenate(
|
|
(data_inference[seq_name], out[seq_cnt].permute(2, 1, 0).cpu().numpy()), axis=2)
|
|
else:
|
|
data_inference[seq_name] = out[seq_cnt].permute(2, 1, 0).cpu().numpy()
|
|
|
|
error_sum_test.update(joint_error_test * N, N)
|
|
|
|
if split == 'train':
|
|
if opt.MAE:
|
|
return loss_all['loss'].avg*1000
|
|
else:
|
|
return loss_all['loss'].avg, error_sum.avg
|
|
elif split == 'test':
|
|
if opt.MAE:
|
|
#p1, p2 = print_error(opt.dataset, action_error_sum_MAE, opt.train)
|
|
return error_sum_test.avg*1000
|
|
if opt.refine:
|
|
p1, p2 = print_error(opt.dataset, action_error_sum_post_out, opt.train)
|
|
else:
|
|
#p1, p2 = print_error(opt.dataset, action_error_sum, opt.train)
|
|
if opt.train == 0:
|
|
for seq_name in data_inference.keys():
|
|
data_inference[seq_name] = data_inference[seq_name][:, :, None, :]
|
|
mat_path = os.path.join(opt.checkpoint, 'inference_data_81_3dhp.mat')
|
|
scio.savemat(mat_path, data_inference)
|
|
|
|
return error_sum_test.avg
|
|
|
|
def input_augmentation_MAE(input_2D, model_trans, joints_left, joints_right, mask, spatial_mask=None):
|
|
N, _, T, J, C = input_2D.shape
|
|
|
|
input_2D_flip = input_2D[:, 1].view(N, T, J, C, 1).permute(0, 3, 1, 2, 4)
|
|
input_2D_non_flip = input_2D[:, 0].view(N, T, J, C, 1).permute(0, 3, 1, 2, 4)
|
|
|
|
output_2D_flip = model_trans(input_2D_flip, mask, spatial_mask)
|
|
|
|
output_2D_flip[:,:,:, 0] *= -1
|
|
|
|
output_2D_flip[:, :, :, joints_left + joints_right] = output_2D_flip[:, :, :, joints_right + joints_left]
|
|
|
|
output_2D_non_flip = model_trans(input_2D_non_flip, mask, spatial_mask)
|
|
|
|
output_2D = (output_2D_non_flip + output_2D_flip) / 2
|
|
|
|
input_2D = input_2D_non_flip
|
|
|
|
return input_2D, output_2D
|
|
|
|
def input_augmentation(input_2D, model_trans, joints_left, joints_right):
|
|
N, _, T, J, C = input_2D.shape
|
|
|
|
input_2D_flip = input_2D[:, 1].view(N, T, J, C, 1).permute(0, 3, 1, 2, 4)
|
|
input_2D_non_flip = input_2D[:, 0].view(N, T, J, C, 1).permute(0, 3, 1, 2, 4)
|
|
|
|
output_3D_flip = model_trans(input_2D_flip)
|
|
|
|
# output_3D_flip_VTE[:, 0] *= -1
|
|
output_3D_flip[:,:,:, 0] *= -1
|
|
|
|
# output_3D_flip_VTE[:, :, :, joints_left + joints_right] = output_3D_flip_VTE[:, :, :, joints_right + joints_left]
|
|
output_3D_flip[:, :, joints_left + joints_right] = output_3D_flip[:, :, joints_right + joints_left]
|
|
|
|
output_3D_non_flip = model_trans(input_2D_non_flip)
|
|
|
|
# output_3D_VTE = (output_3D_non_flip_VTE + output_3D_flip_VTE) / 2
|
|
output_3D = (output_3D_non_flip + output_3D_flip) / 2
|
|
|
|
input_2D = input_2D_non_flip
|
|
|
|
return input_2D, output_3D
|
|
|
|
if __name__ == '__main__':
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
opt.manualSeed = 1
|
|
|
|
random.seed(opt.manualSeed)
|
|
torch.manual_seed(opt.manualSeed)
|
|
np.random.seed(opt.manualSeed)
|
|
torch.cuda.manual_seed_all(opt.manualSeed)
|
|
|
|
torch.backends.cudnn.benchmark = False
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
if opt.train == 1:
|
|
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y/%m/%d %H:%M:%S', \
|
|
filename=os.path.join(opt.checkpoint, 'train.log'), level=logging.INFO)
|
|
|
|
root_path = opt.root_path
|
|
dataset_path = root_path + 'data_3d_' + opt.dataset + '.npz'
|
|
|
|
#dataset = Human36mDataset(dataset_path, opt)
|
|
actions = define_actions(opt.actions)
|
|
|
|
if opt.train:
|
|
#train_data = Fusion(opt=opt, train=True, root_path=root_path)
|
|
train_data = Fusion(opt=opt, train=True, root_path=root_path, MAE=opt.MAE)
|
|
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=opt.batchSize,
|
|
shuffle=True, num_workers=int(opt.workers), pin_memory=True)
|
|
if opt.test:
|
|
#test_data = Fusion(opt=opt, train=False,root_path =root_path)
|
|
test_data = Fusion(opt=opt, train=False, root_path=root_path, MAE=opt.MAE)
|
|
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=opt.batchSize,
|
|
shuffle=False, num_workers=int(opt.workers), pin_memory=True)
|
|
|
|
opt.out_joints = 17
|
|
|
|
model = {}
|
|
model['trans'] = nn.DataParallel(Model(opt)).cuda()
|
|
model['refine'] = nn.DataParallel(refine(opt)).cuda()
|
|
model['MAE'] = nn.DataParallel(Model_MAE(opt)).cuda()
|
|
|
|
model_params = 0
|
|
for parameter in model['trans'].parameters():
|
|
model_params += parameter.numel()
|
|
print('INFO: Trainable parameter count:', model_params)
|
|
|
|
|
|
#if opt.MAE_test_reload==1:
|
|
# model_dict = model['MAE'].state_dict()
|
|
|
|
# MAE_test_path = opt.previous_dir
|
|
|
|
# pre_dict_MAE = torch.load(MAE_test_path)
|
|
# for name, key in model_dict.items():
|
|
# model_dict[name] = pre_dict_MAE[name]
|
|
# model['MAE'].load_state_dict(model_dict)
|
|
|
|
if opt.MAE_reload == 1:
|
|
model_dict = model['trans'].state_dict()
|
|
|
|
MAE_path = opt.previous_dir
|
|
|
|
pre_dict = torch.load(MAE_path)
|
|
|
|
state_dict = {k: v for k, v in pre_dict.items() if k in model_dict.keys()}
|
|
|
|
model_dict.update(state_dict)
|
|
model['trans'].load_state_dict(model_dict)
|
|
|
|
|
|
model_dict = model['trans'].state_dict()
|
|
if opt.reload == 1:
|
|
|
|
no_refine_path = opt.previous_dir
|
|
|
|
pre_dict = torch.load(no_refine_path)
|
|
for name, key in model_dict.items():
|
|
model_dict[name] = pre_dict[name]
|
|
model['trans'].load_state_dict(model_dict)
|
|
|
|
refine_dict = model['refine'].state_dict()
|
|
if opt.refine_reload == 1:
|
|
|
|
refine_path = opt.previous_refine_name
|
|
|
|
pre_dict_refine = torch.load(refine_path)
|
|
for name, key in refine_dict.items():
|
|
refine_dict[name] = pre_dict_refine[name]
|
|
model['refine'].load_state_dict(refine_dict)
|
|
|
|
all_param = []
|
|
lr = opt.lr
|
|
for i_model in model:
|
|
all_param += list(model[i_model].parameters())
|
|
optimizer_all = optim.Adam(all_param, lr=opt.lr, amsgrad=True)
|
|
|
|
for epoch in range(1, opt.nepoch):
|
|
if opt.train == 1:
|
|
if not opt.MAE:
|
|
loss, mpjpe = train(opt, actions, train_dataloader, model, optimizer_all, epoch)
|
|
else:
|
|
loss = train(opt, actions, train_dataloader, model, optimizer_all, epoch)
|
|
if opt.test == 1:
|
|
if not opt.MAE:
|
|
p1 = val(opt, actions, test_dataloader, model)
|
|
else:
|
|
p1 = val(opt, actions, test_dataloader, model)
|
|
data_threshold = p1
|
|
|
|
if opt.train and data_threshold < opt.previous_best_threshold:
|
|
if opt.MAE:
|
|
opt.previous_name = save_model(opt.previous_name, opt.checkpoint, epoch, data_threshold,
|
|
model['MAE'], 'MAE')
|
|
|
|
else:
|
|
opt.previous_name = save_model(opt.previous_name, opt.checkpoint, epoch, data_threshold, model['trans'], 'no_refine')
|
|
|
|
if opt.refine:
|
|
opt.previous_refine_name = save_model(opt.previous_refine_name, opt.checkpoint, epoch,
|
|
data_threshold, model['refine'], 'refine')
|
|
opt.previous_best_threshold = data_threshold
|
|
|
|
if opt.train == 0:
|
|
print('p1: %.2f' % (p1))
|
|
break
|
|
else:
|
|
if opt.MAE:
|
|
logging.info('epoch: %d, lr: %.7f, loss: %.4f, p1: %.2f' % (
|
|
epoch, lr, loss, p1))
|
|
print('e: %d, lr: %.7f, loss: %.4f, p1: %.2f' % (epoch, lr, loss, p1))
|
|
else:
|
|
logging.info('epoch: %d, lr: %.7f, loss: %.4f, MPJPE: %.2f, p1: %.2f' % (epoch, lr, loss, mpjpe, p1))
|
|
print('e: %d, lr: %.7f, loss: %.4f, M: %.2f, p1: %.2f' % (epoch, lr, loss, mpjpe, p1))
|
|
|
|
if epoch % opt.large_decay_epoch == 0:
|
|
for param_group in optimizer_all.param_groups:
|
|
param_group['lr'] *= opt.lr_decay_large
|
|
lr *= opt.lr_decay_large
|
|
else:
|
|
for param_group in optimizer_all.param_groups:
|
|
param_group['lr'] *= opt.lr_decay
|
|
lr *= opt.lr_decay
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|