You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
137 lines
5.7 KiB
137 lines
5.7 KiB
import argparse
|
|
import os
|
|
import math
|
|
import time
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
class opts():
|
|
def __init__(self):
|
|
self.parser = argparse.ArgumentParser()
|
|
|
|
def init(self):
|
|
self.parser.add_argument('--layers', default=6, type=int)
|
|
self.parser.add_argument('--channel', default=256, type=int)
|
|
self.parser.add_argument('--d_hid', default=192, type=int) # 嵌入维度
|
|
self.parser.add_argument('--dataset', type=str, default='h36m')
|
|
self.parser.add_argument('-k', '--keypoints', default='cpn_ft_h36m_dbb', type=str)
|
|
self.parser.add_argument('--data_augmentation', type=bool, default=True)
|
|
self.parser.add_argument('--reverse_augmentation', type=bool, default=False)
|
|
self.parser.add_argument('--test_augmentation', type=bool, default=True)
|
|
self.parser.add_argument('--crop_uv', type=int, default=0)
|
|
self.parser.add_argument('--root_path', type=str, default='./dataset/')
|
|
self.parser.add_argument('-a', '--actions', default='*', type=str)
|
|
self.parser.add_argument('--downsample', default=1, type=int)
|
|
self.parser.add_argument('--subset', default=1, type=float)
|
|
self.parser.add_argument('-s', '--stride', default=1, type=int)
|
|
self.parser.add_argument('--gpu', default='1', type=str, help='')
|
|
self.parser.add_argument('--train', type=int, default=0)
|
|
self.parser.add_argument('--test', type=int, default=1)
|
|
self.parser.add_argument('--nepoch', type=int, default=80)
|
|
self.parser.add_argument('-b','--batchSize', type=int, default=1024)
|
|
self.parser.add_argument('--lr', type=float, default=1e-3)
|
|
self.parser.add_argument('--lr_refine', type=float, default=1e-5)
|
|
self.parser.add_argument('--lr_decay_large', type=float, default=0.5)
|
|
self.parser.add_argument('--large_decay_epoch', type=int, default=80)
|
|
self.parser.add_argument('--workers', type=int, default=8)
|
|
self.parser.add_argument('-lrd', '--lr_decay', default=0.96, type=float)
|
|
self.parser.add_argument('-f','--frames', type=int, default=243)
|
|
self.parser.add_argument('--pad', type=int, default=121)
|
|
self.parser.add_argument('--refine', action='store_true')
|
|
self.parser.add_argument('--reload', type=int, default=0) # 是否加载预训练模型
|
|
self.parser.add_argument('--refine_reload', type=int, default=0)
|
|
self.parser.add_argument('-c','--checkpoint', type=str, default='model')
|
|
self.parser.add_argument('--previous_dir', type=str, default='')
|
|
self.parser.add_argument('--n_joints', type=int, default=17)
|
|
self.parser.add_argument('--out_joints', type=int, default=17)
|
|
self.parser.add_argument('--out_all', type=int, default=1)
|
|
self.parser.add_argument('--in_channels', type=int, default=2)
|
|
self.parser.add_argument('--out_channels', type=int, default=3)
|
|
self.parser.add_argument('-previous_best_threshold', type=float, default= math.inf)
|
|
self.parser.add_argument('-previous_name', type=str, default='')
|
|
self.parser.add_argument('--previous_refine_name', type=str, default='')
|
|
self.parser.add_argument('--manualSeed', type=int, default=1)
|
|
|
|
self.parser.add_argument('--MAE', action='store_true')
|
|
self.parser.add_argument('-tmr','--temporal_mask_rate', type=float, default=0)
|
|
self.parser.add_argument('-smn', '--spatial_mask_num', type=int, default=0)
|
|
self.parser.add_argument('-tds', '--t_downsample', type=int, default=3)
|
|
|
|
self.parser.add_argument('--MAE_reload', type=int, default=0)
|
|
self.parser.add_argument('-r', '--resume', action='store_true')
|
|
|
|
self.parser.add_argument('-mt', '--model_type', type=str, default='')
|
|
self.parser.add_argument('--amp', type=int, default=0)
|
|
|
|
|
|
|
|
|
|
|
|
def parse(self):
|
|
self.init()
|
|
self.opt = self.parser.parse_args()
|
|
|
|
self.opt.pad = (self.opt.frames-1) // 2
|
|
|
|
self.opt.part_list = [
|
|
[8, 9, 10], # 头
|
|
[0, 7, 8], # 身体
|
|
[11, 12, 13], # 左手
|
|
[14, 15, 16], # 右手
|
|
[4, 5, 6], # 左腿
|
|
[1, 2, 3] #右腿
|
|
]
|
|
|
|
stride_num = {
|
|
'9': [1, 3, 3],
|
|
'27': [3, 3, 3],
|
|
'351': [3, 9, 13],
|
|
'81': [3, 3, 3, 3],
|
|
'243': [3, 3, 3, 3, 3],
|
|
}
|
|
|
|
if str(self.opt.frames) in stride_num:
|
|
self.opt.stride_num = stride_num[str(self.opt.frames)]
|
|
else:
|
|
self.opt.stride_num = None
|
|
print('no stride_num')
|
|
exit()
|
|
|
|
self.opt.subjects_train = 'S1,S5,S6,S7,S8'
|
|
self.opt.subjects_test = 'S9,S11'
|
|
#self.opt.subjects_test = 'S11'
|
|
|
|
#if self.opt.train:
|
|
logtime = time.strftime('%m%d_%H%M_%S_')
|
|
|
|
ckp_suffix = ''
|
|
if self.opt.refine:
|
|
ckp_suffix='_refine'
|
|
elif self.opt.MAE:
|
|
ckp_suffix = '_pretrain'
|
|
else:
|
|
ckp_suffix = '_STCFormer'
|
|
self.opt.checkpoint = 'checkpoint/'+self.opt.checkpoint + '_%d'%(self.opt.pad*2+1) + \
|
|
'_%s'%self.opt.model_type
|
|
|
|
if not os.path.exists(self.opt.checkpoint):
|
|
os.makedirs(self.opt.checkpoint)
|
|
|
|
if self.opt.train:
|
|
args = dict((name, getattr(self.opt, name)) for name in dir(self.opt)
|
|
if not name.startswith('_'))
|
|
|
|
file_name = os.path.join(self.opt.checkpoint, 'opt.txt')
|
|
with open(file_name, 'wt') as opt_file:
|
|
opt_file.write('==> Args:\n')
|
|
for k, v in sorted(args.items()):
|
|
opt_file.write(' %s: %s\n' % (str(k), str(v)))
|
|
opt_file.write('==> Args:\n')
|
|
|
|
return self.opt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|