You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

250 lines
8.6 KiB

import os
import glob
import torch
import random
import logging
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.utils.data
import torch.optim as optim
import scipy.io as scio
from common.opt import opts
from common.utils import *
from common.camera import get_uvd2xyz
from common.load_data_hm36_tds import Fusion
from common.h36m_dataset import Human36mDataset
from model.block.refine import refine
from model.hpformer_3 import Model
from torch.cuda.amp import autocast as autocast
import time
opt = opts().parse()
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
def train(opt, actions, train_loader, model, optimizer, epoch):
return step('train', opt, actions, train_loader, model, optimizer, epoch)
def val(opt, actions, val_loader, model):
with torch.no_grad():
return step('test', opt, actions, val_loader, model)
def step(split, opt, actions, dataLoader, model, optimizer=None, epoch=None):
model_trans = model['trans']
model_refine = model['refine']
if split == 'train':
model_trans.train()
model_refine.train()
else:
model_trans.eval()
model_refine.eval()
loss_all = {'loss': AccumLoss()}
action_error_sum = define_error_list(actions)
action_error_sum_refine = define_error_list(actions)
if split == 'train':
print(f'amp:{opt.amp}')
if opt.amp:
scaler = torch.cuda.amp.GradScaler()
for i, data in enumerate(tqdm(dataLoader, 0)):
#if i ==5:
# break
batch_cam, gt_3D, input_2D, action, subject, scale, bb_box, cam_ind = data
[input_2D, gt_3D, batch_cam, scale, bb_box] = get_varialbe(split, [input_2D, gt_3D, batch_cam, scale, bb_box])
if split =='train':
#start = time.time()
if opt.amp:
with autocast():
output_3D = model_trans(input_2D)
#end = time.time()
#print(input_2D.shape)
#print(output_3D.shape)
#print(gt_3D.shape)
#print(end-start)
#exit()
else:
output_3D = model_trans(input_2D)
else:
input_2D, output_3D = input_augmentation(input_2D, model_trans)
#print(input_2D.shape)
out_target = gt_3D.clone()
out_target[:, :, 0] = 0
output_3D_single = output_3D[:,opt.pad].unsqueeze(1)
if out_target.size(1) > 1:
out_target_single = out_target[:, opt.pad].unsqueeze(1)
gt_3D_single = gt_3D[:, opt.pad].unsqueeze(1)
else:
out_target_single = out_target
gt_3D_single = gt_3D
if opt.refine:
pred_uv = input_2D[:, opt.pad, :, :].unsqueeze(1)
uvd = torch.cat((pred_uv, output_3D_single[:, :, :, 2].unsqueeze(-1)), -1)
xyz = get_uvd2xyz(uvd, gt_3D_single, batch_cam)
xyz[:, :, 0, :] = 0
post_out = model_refine(output_3D_single, xyz)
if split == 'train':
if opt.amp:
with autocast():
if opt.refine:
loss = mpjpe_cal(post_out, out_target_single)
else:
loss = mpjpe_cal(output_3D, out_target)
else:
if opt.refine:
loss = mpjpe_cal(post_out, out_target_single)
else:
loss = mpjpe_cal(output_3D, out_target)
N = input_2D.size(0)
loss_all['loss'].update(loss.detach().cpu().numpy() * N, N)
optimizer.zero_grad()
if opt.amp:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
elif split == 'test':
output_3D[:, :, 0, :] = 0
action_error_sum = test_calculation(output_3D_single, out_target, action, action_error_sum, opt.dataset, subject)
if opt.refine:
output_3D[:, :, 0, :] = 0
action_error_sum_refine = test_calculation(output_3D_single, out_target, action, action_error_sum_refine, opt.dataset, subject)
if split == 'train':
return loss_all['loss'].avg
elif split == 'test':
if opt.refine:
p1, p2 = print_error(opt.dataset, action_error_sum_refine, opt.train)
else:
p1, p2 = print_error(opt.dataset, action_error_sum, opt.train)
return p1, p2
def input_augmentation(input_2D, model_trans):
joints_left = [4, 5, 6, 11, 12, 13]
joints_right = [1, 2, 3, 14, 15, 16]
input_2D_non_flip = input_2D[:, 0]
input_2D_flip = input_2D[:, 1]
output_3D_non_flip = model_trans(input_2D_non_flip)
output_3D_flip = model_trans(input_2D_flip)
output_3D_flip[:, :, :, 0] *= -1
output_3D_flip[:, :, joints_left + joints_right, :] = output_3D_flip[:, :, joints_right + joints_left, :]
output_3D = (output_3D_non_flip + output_3D_flip) / 2
input_2D = input_2D_non_flip
return input_2D, output_3D
if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
opt.manualSeed = 42
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
if opt.train:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y/%m/%d %H:%M:%S', \
filename=os.path.join(opt.checkpoint, 'train.log'), level=logging.INFO)
root_path = opt.root_path
dataset_path = root_path + 'data_3d_' + opt.dataset + '.npz'
dataset = Human36mDataset(dataset_path, opt)
actions = define_actions(opt.actions)
if opt.train:
train_data = Fusion(opt=opt, train=True, dataset=dataset, root_path=root_path, tds=opt.t_downsample)
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=opt.batchSize//opt.stride,
shuffle=True, num_workers=int(opt.workers), pin_memory=True)
test_data = Fusion(opt=opt, train=False,dataset=dataset, root_path =root_path, tds=opt.t_downsample)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=opt.batchSize//opt.stride,
shuffle=False, num_workers=int(opt.workers), pin_memory=True)
opt.out_joints = dataset.skeleton().num_joints()
model = {}
model['trans'] = Model(opt).cuda()
model['refine']= refine(opt).cuda()
model_dict = model['trans'].state_dict()
if opt.reload:
no_refine_path = opt.previous_dir
pre_dict = torch.load(no_refine_path)
for name, key in model_dict.items():
model_dict[name] = pre_dict[name]
model['trans'].load_state_dict(model_dict)
refine_dict = model['refine'].state_dict()
if opt.refine_reload:
refine_path = opt.previous_refine_name
pre_dict_refine = torch.load(refine_path)
for name, key in refine_dict.items():
refine_dict[name] = pre_dict_refine[name]
model['refine'].load_state_dict(refine_dict)
all_param = []
lr = opt.lr
for i_model in model:
all_param += list(model[i_model].parameters())
optimizer_all = optim.Adam(all_param, lr=opt.lr, amsgrad=True)
for epoch in range(1, opt.nepoch):
if opt.train:
loss = train(opt, actions, train_dataloader, model, optimizer_all, epoch)
p1, p2 = val(opt, actions, test_dataloader, model)
# if opt.train and not opt.refine:
# save_model_epoch(opt.checkpoint, epoch, model['trans'])
if opt.train and p1 < opt.previous_best_threshold:
opt.previous_name = save_model(opt.previous_name, opt.checkpoint, epoch, p1, model['trans'], 'no_refine')
if opt.refine:
opt.previous_refine_name = save_model(opt.previous_refine_name, opt.checkpoint, epoch,
p1, model['refine'], 'refine')
opt.previous_best_threshold = p1
if not opt.train:
print('p1: %.2f, p2: %.2f' % (p1, p2))
break
else:
logging.info('epoch: %d, lr: %.7f, loss: %.4f, p1: %.2f, p2: %.2f' % (epoch, lr, loss, p1, p2))
print('e: %d, lr: %.7f, loss: %.4f, p1: %.2f, p2: %.2f' % (epoch, lr, loss, p1, p2))
if epoch % opt.large_decay_epoch == 0:
for param_group in optimizer_all.param_groups:
param_group['lr'] *= opt.lr_decay_large
lr *= opt.lr_decay_large
else:
for param_group in optimizer_all.param_groups:
param_group['lr'] *= opt.lr_decay
lr *= opt.lr_decay