You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
250 lines
8.6 KiB
250 lines
8.6 KiB
import os
|
|
import glob
|
|
import torch
|
|
import random
|
|
import logging
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
import torch.nn as nn
|
|
import torch.utils.data
|
|
import torch.optim as optim
|
|
import scipy.io as scio
|
|
from common.opt import opts
|
|
from common.utils import *
|
|
from common.camera import get_uvd2xyz
|
|
from common.load_data_hm36_tds import Fusion
|
|
from common.h36m_dataset import Human36mDataset
|
|
from model.block.refine import refine
|
|
from model.hpformer_3 import Model
|
|
from torch.cuda.amp import autocast as autocast
|
|
|
|
import time
|
|
|
|
opt = opts().parse()
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
|
|
|
|
def train(opt, actions, train_loader, model, optimizer, epoch):
|
|
return step('train', opt, actions, train_loader, model, optimizer, epoch)
|
|
|
|
def val(opt, actions, val_loader, model):
|
|
with torch.no_grad():
|
|
return step('test', opt, actions, val_loader, model)
|
|
|
|
def step(split, opt, actions, dataLoader, model, optimizer=None, epoch=None):
|
|
model_trans = model['trans']
|
|
model_refine = model['refine']
|
|
|
|
if split == 'train':
|
|
model_trans.train()
|
|
model_refine.train()
|
|
else:
|
|
model_trans.eval()
|
|
model_refine.eval()
|
|
|
|
loss_all = {'loss': AccumLoss()}
|
|
action_error_sum = define_error_list(actions)
|
|
action_error_sum_refine = define_error_list(actions)
|
|
|
|
if split == 'train':
|
|
print(f'amp:{opt.amp}')
|
|
if opt.amp:
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
|
|
for i, data in enumerate(tqdm(dataLoader, 0)):
|
|
#if i ==5:
|
|
# break
|
|
batch_cam, gt_3D, input_2D, action, subject, scale, bb_box, cam_ind = data
|
|
[input_2D, gt_3D, batch_cam, scale, bb_box] = get_varialbe(split, [input_2D, gt_3D, batch_cam, scale, bb_box])
|
|
|
|
if split =='train':
|
|
#start = time.time()
|
|
if opt.amp:
|
|
with autocast():
|
|
output_3D = model_trans(input_2D)
|
|
#end = time.time()
|
|
#print(input_2D.shape)
|
|
#print(output_3D.shape)
|
|
#print(gt_3D.shape)
|
|
#print(end-start)
|
|
#exit()
|
|
else:
|
|
output_3D = model_trans(input_2D)
|
|
else:
|
|
input_2D, output_3D = input_augmentation(input_2D, model_trans)
|
|
#print(input_2D.shape)
|
|
|
|
out_target = gt_3D.clone()
|
|
out_target[:, :, 0] = 0
|
|
output_3D_single = output_3D[:,opt.pad].unsqueeze(1)
|
|
|
|
if out_target.size(1) > 1:
|
|
out_target_single = out_target[:, opt.pad].unsqueeze(1)
|
|
gt_3D_single = gt_3D[:, opt.pad].unsqueeze(1)
|
|
else:
|
|
out_target_single = out_target
|
|
gt_3D_single = gt_3D
|
|
|
|
if opt.refine:
|
|
pred_uv = input_2D[:, opt.pad, :, :].unsqueeze(1)
|
|
uvd = torch.cat((pred_uv, output_3D_single[:, :, :, 2].unsqueeze(-1)), -1)
|
|
xyz = get_uvd2xyz(uvd, gt_3D_single, batch_cam)
|
|
xyz[:, :, 0, :] = 0
|
|
post_out = model_refine(output_3D_single, xyz)
|
|
|
|
if split == 'train':
|
|
if opt.amp:
|
|
with autocast():
|
|
if opt.refine:
|
|
loss = mpjpe_cal(post_out, out_target_single)
|
|
else:
|
|
loss = mpjpe_cal(output_3D, out_target)
|
|
else:
|
|
if opt.refine:
|
|
loss = mpjpe_cal(post_out, out_target_single)
|
|
else:
|
|
loss = mpjpe_cal(output_3D, out_target)
|
|
|
|
N = input_2D.size(0)
|
|
loss_all['loss'].update(loss.detach().cpu().numpy() * N, N)
|
|
|
|
optimizer.zero_grad()
|
|
if opt.amp:
|
|
scaler.scale(loss).backward()
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
else:
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
elif split == 'test':
|
|
output_3D[:, :, 0, :] = 0
|
|
action_error_sum = test_calculation(output_3D_single, out_target, action, action_error_sum, opt.dataset, subject)
|
|
|
|
if opt.refine:
|
|
output_3D[:, :, 0, :] = 0
|
|
action_error_sum_refine = test_calculation(output_3D_single, out_target, action, action_error_sum_refine, opt.dataset, subject)
|
|
|
|
if split == 'train':
|
|
return loss_all['loss'].avg
|
|
elif split == 'test':
|
|
if opt.refine:
|
|
p1, p2 = print_error(opt.dataset, action_error_sum_refine, opt.train)
|
|
else:
|
|
p1, p2 = print_error(opt.dataset, action_error_sum, opt.train)
|
|
return p1, p2
|
|
|
|
def input_augmentation(input_2D, model_trans):
|
|
joints_left = [4, 5, 6, 11, 12, 13]
|
|
joints_right = [1, 2, 3, 14, 15, 16]
|
|
|
|
input_2D_non_flip = input_2D[:, 0]
|
|
input_2D_flip = input_2D[:, 1]
|
|
|
|
output_3D_non_flip = model_trans(input_2D_non_flip)
|
|
output_3D_flip = model_trans(input_2D_flip)
|
|
|
|
output_3D_flip[:, :, :, 0] *= -1
|
|
|
|
|
|
output_3D_flip[:, :, joints_left + joints_right, :] = output_3D_flip[:, :, joints_right + joints_left, :]
|
|
|
|
|
|
|
|
output_3D = (output_3D_non_flip + output_3D_flip) / 2
|
|
input_2D = input_2D_non_flip
|
|
|
|
return input_2D, output_3D
|
|
|
|
if __name__ == '__main__':
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
|
|
opt.manualSeed = 42
|
|
|
|
random.seed(opt.manualSeed)
|
|
torch.manual_seed(opt.manualSeed)
|
|
|
|
if opt.train:
|
|
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y/%m/%d %H:%M:%S', \
|
|
filename=os.path.join(opt.checkpoint, 'train.log'), level=logging.INFO)
|
|
|
|
root_path = opt.root_path
|
|
dataset_path = root_path + 'data_3d_' + opt.dataset + '.npz'
|
|
|
|
dataset = Human36mDataset(dataset_path, opt)
|
|
actions = define_actions(opt.actions)
|
|
|
|
if opt.train:
|
|
train_data = Fusion(opt=opt, train=True, dataset=dataset, root_path=root_path, tds=opt.t_downsample)
|
|
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=opt.batchSize//opt.stride,
|
|
shuffle=True, num_workers=int(opt.workers), pin_memory=True)
|
|
|
|
test_data = Fusion(opt=opt, train=False,dataset=dataset, root_path =root_path, tds=opt.t_downsample)
|
|
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=opt.batchSize//opt.stride,
|
|
shuffle=False, num_workers=int(opt.workers), pin_memory=True)
|
|
|
|
opt.out_joints = dataset.skeleton().num_joints()
|
|
|
|
model = {}
|
|
model['trans'] = Model(opt).cuda()
|
|
model['refine']= refine(opt).cuda()
|
|
|
|
model_dict = model['trans'].state_dict()
|
|
if opt.reload:
|
|
|
|
no_refine_path = opt.previous_dir
|
|
pre_dict = torch.load(no_refine_path)
|
|
for name, key in model_dict.items():
|
|
model_dict[name] = pre_dict[name]
|
|
model['trans'].load_state_dict(model_dict)
|
|
|
|
refine_dict = model['refine'].state_dict()
|
|
if opt.refine_reload:
|
|
refine_path = opt.previous_refine_name
|
|
|
|
pre_dict_refine = torch.load(refine_path)
|
|
for name, key in refine_dict.items():
|
|
refine_dict[name] = pre_dict_refine[name]
|
|
model['refine'].load_state_dict(refine_dict)
|
|
|
|
all_param = []
|
|
lr = opt.lr
|
|
for i_model in model:
|
|
all_param += list(model[i_model].parameters())
|
|
optimizer_all = optim.Adam(all_param, lr=opt.lr, amsgrad=True)
|
|
|
|
for epoch in range(1, opt.nepoch):
|
|
if opt.train:
|
|
loss = train(opt, actions, train_dataloader, model, optimizer_all, epoch)
|
|
|
|
p1, p2 = val(opt, actions, test_dataloader, model)
|
|
|
|
# if opt.train and not opt.refine:
|
|
# save_model_epoch(opt.checkpoint, epoch, model['trans'])
|
|
|
|
if opt.train and p1 < opt.previous_best_threshold:
|
|
opt.previous_name = save_model(opt.previous_name, opt.checkpoint, epoch, p1, model['trans'], 'no_refine')
|
|
|
|
if opt.refine:
|
|
opt.previous_refine_name = save_model(opt.previous_refine_name, opt.checkpoint, epoch,
|
|
p1, model['refine'], 'refine')
|
|
opt.previous_best_threshold = p1
|
|
|
|
if not opt.train:
|
|
print('p1: %.2f, p2: %.2f' % (p1, p2))
|
|
break
|
|
else:
|
|
logging.info('epoch: %d, lr: %.7f, loss: %.4f, p1: %.2f, p2: %.2f' % (epoch, lr, loss, p1, p2))
|
|
print('e: %d, lr: %.7f, loss: %.4f, p1: %.2f, p2: %.2f' % (epoch, lr, loss, p1, p2))
|
|
|
|
if epoch % opt.large_decay_epoch == 0:
|
|
for param_group in optimizer_all.param_groups:
|
|
param_group['lr'] *= opt.lr_decay_large
|
|
lr *= opt.lr_decay_large
|
|
else:
|
|
for param_group in optimizer_all.param_groups:
|
|
param_group['lr'] *= opt.lr_decay
|
|
lr *= opt.lr_decay
|
|
|
|
|
|
|
|
|
|
|
|
|