You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
355 lines
11 KiB
355 lines
11 KiB
import torch
|
|
import torch.nn as nn
|
|
# from model.module.trans import Transformer as Transformer_s
|
|
# from model.module.trans_hypothesis import Transformer
|
|
import numpy as np
|
|
from einops import rearrange
|
|
from collections import OrderedDict
|
|
from torch.nn import functional as F
|
|
from torch.nn import init
|
|
import scipy.sparse as sp
|
|
|
|
from timm.models.layers import DropPath
|
|
|
|
|
|
|
|
class Mlp(nn.Module):
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.1):
|
|
super().__init__()
|
|
out_features = out_features or in_features
|
|
hidden_features = hidden_features or in_features
|
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
|
|
self.act = act_layer()
|
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.act(x)
|
|
x = self.drop(x)
|
|
x = self.fc2(x)
|
|
x = self.drop(x)
|
|
return x
|
|
|
|
|
|
class Joint_ATTENTION(nn.Module):
|
|
def __init__(self, d_time, d_joint, d_coor, head=8):
|
|
super().__init__()
|
|
"""
|
|
d_time: 帧数
|
|
d_joint: 关节点数
|
|
d_coor: 嵌入维度
|
|
"""
|
|
self.qkv = nn.Linear(d_coor, d_coor * 3)
|
|
self.head = head
|
|
|
|
self.scale = (d_coor) ** -0.5
|
|
self.d_time = d_time
|
|
self.d_joint = d_joint
|
|
|
|
self.pos_emb = nn.Embedding(d_time, d_coor)
|
|
self.frame_idx = torch.tensor(list(range(d_time))).long().cuda()
|
|
|
|
self.drop = DropPath(0.5)
|
|
|
|
def forward(self, input):
|
|
b, t, s, c = input.shape
|
|
|
|
emb = self.pos_emb(self.frame_idx)
|
|
input = input + emb[None, :, None, :]
|
|
|
|
qkv = self.qkv(input) # b, t, s, c-> b, t, s, 3*c
|
|
qkv_t = qkv.reshape(b, t, s, c, 3).permute(4, 0, 1, 2, 3) # 3,b,t,s,c
|
|
|
|
q_t, k_t, v_t = qkv_t[0], qkv_t[1], qkv_t[2] # b,t,s,c
|
|
|
|
q_t = rearrange(q_t, 'b t s (h c) -> (b h s) t c', h=self.head) # b,t,s,c -> b*h*s,t,c//h
|
|
k_t = rearrange(k_t, 'b t s (h c) -> (b h s) c t ', h=self.head) # b,t,s,c-> b*h*s,c//h,t
|
|
|
|
att_t = (q_t @ k_t) * self.scale # b*h*s,t,t
|
|
att_t = att_t.softmax(-1) # b*h*s,t,t
|
|
|
|
v_t = rearrange(v_t, 'b t s c -> b c t s ')
|
|
|
|
# MSA
|
|
v_t = rearrange(v_t, 'b (h c) t s -> (b h s) t c', h = self.head) # b*h*s,t,c//h
|
|
|
|
x_t = att_t @ v_t # b*h*s,t,c//h
|
|
|
|
x_t = rearrange(x_t, '(b h s) t c -> b t s (h c)', s=s, h=self.head) # b,t,s,c
|
|
|
|
return x_t
|
|
|
|
|
|
class Part_ATTENTION(nn.Module):
|
|
def __init__(self, d_time, d_joint, d_coor, part_list, head=8):
|
|
super().__init__()
|
|
|
|
"""
|
|
d_time: 帧数
|
|
d_joint: 关节点数
|
|
d_coor: 嵌入维度
|
|
"""
|
|
|
|
self.head = head
|
|
|
|
self.num_of_part = len(part_list)
|
|
self.num_joint_of_part = len(part_list[0])
|
|
|
|
self.scale = (d_coor * self.num_joint_of_part) ** -0.5
|
|
self.d_time = d_time
|
|
self.d_joint = d_joint
|
|
self.layer_norm = nn.LayerNorm(d_coor * self.num_joint_of_part)
|
|
|
|
self.pos_embed = nn.Embedding(d_time, d_coor * self.num_joint_of_part)
|
|
self.frame_idx = torch.tensor(list(range(d_time))).long().cuda()
|
|
|
|
self.qkv = nn.Linear(d_coor * self.num_joint_of_part, d_coor * self.num_joint_of_part * 3)
|
|
self.drop = DropPath(0.5)
|
|
# check part_list
|
|
for part in part_list:
|
|
assert len(part) == 3 # each part should have 3 joints
|
|
for idx in part:
|
|
assert 0 <= idx < d_joint # joint index should be less than d_joint
|
|
|
|
self.idx_joint2part = torch.tensor([idx for part in part_list for idx in part], dtype=torch.long)
|
|
self.idx_joint2part = self.idx_joint2part.flatten().cuda()
|
|
idx_part2joint = list(range(d_joint))
|
|
for i, idx in enumerate(self.idx_joint2part):
|
|
idx_part2joint[idx] = i
|
|
self.idx_part2joint = torch.tensor(idx_part2joint, dtype=torch.long).cuda()
|
|
|
|
self.overlap = self.get_overlap()
|
|
|
|
# 查找有重叠的内容
|
|
def get_overlap(self):
|
|
overlap_list = [-1] * self.d_joint
|
|
for i, idx in enumerate(self.idx_joint2part):
|
|
if overlap_list[idx] == -1:
|
|
overlap_list[idx] = i
|
|
else:
|
|
if not isinstance(overlap_list[idx], list):
|
|
overlap_i = overlap_list[idx]
|
|
overlap_list[idx] = list()
|
|
overlap_list[idx].append(overlap_i)
|
|
overlap_list[idx].append(i)
|
|
|
|
overlap = []
|
|
for i in overlap_list:
|
|
if isinstance(i, list):
|
|
overlap.append(i)
|
|
|
|
if len(overlap) == 0:
|
|
return None
|
|
else:
|
|
return overlap
|
|
|
|
def forward(self, input):
|
|
input = torch.index_select(input, 2, self.idx_joint2part)
|
|
input = rearrange(input, 'b t (p j) c -> b t p (j c)', j=self.num_joint_of_part)
|
|
|
|
b, t, p, c = input.shape
|
|
|
|
emb = self.pos_embed(self.frame_idx)
|
|
input = input + emb[None, :, None, :]
|
|
|
|
qkv = self.qkv(input) # b, t, p, c-> b, t, p, 3*c
|
|
qkv_t = qkv.reshape(b, t, p, c, 3).permute(4, 0, 1, 2, 3) # 3,b,t,p,c
|
|
|
|
q_t, k_t, v_t = qkv_t[0], qkv_t[1], qkv_t[2] # b,t,p,c
|
|
|
|
q_t = rearrange(q_t, 'b t s (h c) -> (b h s) t c', h=self.head) # b,t,p,c -> b*h*p,t,c//h
|
|
k_t = rearrange(k_t, 'b t s (h c) -> (b h s) c t', h=self.head) # b,t,p,c-> b*h*p,c//h,t
|
|
|
|
att_t = (q_t @ k_t) * self.scale # b*h*p,t,t
|
|
att_t = att_t.softmax(-1) # b*h*p,t,t
|
|
|
|
v_t = rearrange(v_t, 'b t p c -> b c t p')
|
|
|
|
# MSA
|
|
v_t = rearrange(v_t, 'b (h c) t p -> (b h p) t c ', h=self.head) # b*h*p,t,c//h
|
|
|
|
x_t = att_t @ v_t # b*h*p,t,c//h
|
|
|
|
x = rearrange(x_t, '(b h p) t c -> b h t p c', h=self.head, p=p) # b*h*p,t,c//h -> b,h,t,p,c//h
|
|
|
|
# 还原
|
|
x = rearrange(x, 'b h t p c -> b t p (h c)')
|
|
x = rearrange(x, 'b t p (j c) -> b t (p j) c', j=self.num_joint_of_part)
|
|
|
|
# 查找重叠 求均值
|
|
if self.overlap:
|
|
for overlap in self.overlap:
|
|
idx = overlap[-1]
|
|
for i in overlap[:-1]:
|
|
x[:, :, idx, :] += x[:, :, i, :]
|
|
x[:, :, idx, :] /= len(overlap)
|
|
|
|
x = torch.index_select(x, 2, self.idx_part2joint)
|
|
return x
|
|
|
|
|
|
class Pose_ATTENTION(nn.Module):
|
|
def __init__(self, d_time, d_joint, d_coor, head=8):
|
|
super().__init__()
|
|
"""
|
|
d_time: 帧数
|
|
d_joint: 关节点数
|
|
d_coor: 嵌入维度
|
|
"""
|
|
self.head = head
|
|
|
|
self.scale = (d_coor * d_joint) ** -0.5
|
|
self.d_time = d_time
|
|
self.d_joint = d_joint
|
|
|
|
self.pos_emb = nn.Embedding(d_time, d_coor * d_joint)
|
|
self.frame_idx = torch.tensor(list(range(d_time))).long().cuda()
|
|
|
|
self.qkv = nn.Linear(d_coor * d_joint, d_coor * d_joint * 3)
|
|
self.drop = DropPath(0.5)
|
|
|
|
def forward(self, input):
|
|
b, t, s, c = input.shape
|
|
input = rearrange(input, 'b t s c -> b t (s c)')
|
|
|
|
emb = self.pos_emb(self.frame_idx)
|
|
input = input + emb[None, :]
|
|
|
|
qkv = self.qkv(input) # b, t, s*c -> b, t, 3*s*c
|
|
qkv_t = qkv.reshape(b, t, s*c, 3).permute(3, 0, 1, 2) # 3,b,t,s*c
|
|
|
|
q_t, k_t, v_t = qkv_t[0], qkv_t[1], qkv_t[2] # b,t,s*c
|
|
|
|
# reshape for mat
|
|
q_t = rearrange(q_t, 'b t (h c) -> (b h) t c', h=self.head) # b,t,s*c -> b*h,t,s*c//h
|
|
k_t = rearrange(k_t, 'b t (h c) -> (b h) c t ', h=self.head) # b,t,s*c-> b*h,s*c//h,t
|
|
|
|
att_t = (q_t @ k_t) * self.scale # b*h,t,t
|
|
att_t = att_t.softmax(-1) # b*h,t,t
|
|
|
|
v_t = rearrange(v_t, 'b t (h c) -> (b h) t c', h=self.head) # b*h,t,s*c//h
|
|
|
|
x_t = att_t @ v_t # b*h,t,s*c//h
|
|
|
|
x_t = rearrange(x_t, '(b h) t (s c) -> b t s (h c) ', h=self.head, s=s) # b*h,t,s*c//h -> b,t,s,c
|
|
|
|
return x_t
|
|
|
|
|
|
class HP_BLOCK(nn.Module):
|
|
def __init__(self, d_time, d_joint, d_coor, part_list):
|
|
super().__init__()
|
|
|
|
self.layer_norm = nn.LayerNorm(d_coor)
|
|
|
|
self.mlp = Mlp(d_coor, d_coor*4, d_coor)
|
|
|
|
self.joint_att = Joint_ATTENTION(d_time, d_joint, d_coor//3)
|
|
self.part_att = Part_ATTENTION(d_time, d_joint, d_coor//3, part_list)
|
|
self.pose_att = Pose_ATTENTION(d_time, d_joint, d_coor//3)
|
|
|
|
self.drop = DropPath(0.0)
|
|
|
|
def forward(self, input):
|
|
b, t, s, c = input.shape
|
|
h = input
|
|
x = self.layer_norm(input)
|
|
|
|
x_joint, x_part, x_pose = x.chunk(3, 3)
|
|
|
|
x = torch.cat((
|
|
self.joint_att(x_joint),
|
|
self.part_att(x_part),
|
|
self.pose_att(x_pose)
|
|
), -1)
|
|
|
|
x = x + h
|
|
x = x + self.drop(self.mlp(self.layer_norm(x)))
|
|
|
|
return x
|
|
|
|
|
|
class HPFormer(nn.Module):
|
|
def __init__(self, num_block, d_time, d_joint, d_coor, part_list):
|
|
super(HPFormer, self).__init__()
|
|
|
|
self.num_block = num_block
|
|
self.d_time = d_time
|
|
self.d_joint = d_joint
|
|
self.d_coor = d_coor
|
|
self.part_list = part_list
|
|
|
|
self.hp_block = []
|
|
for l in range(self.num_block):
|
|
self.hp_block.append(HP_BLOCK(self.d_time, self.d_joint, self.d_coor, self.part_list))
|
|
self.hp_block = nn.ModuleList(self.hp_block)
|
|
|
|
def forward(self, input):
|
|
for i in range(self.num_block):
|
|
input = self.hp_block[i](input)
|
|
|
|
return input
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
|
|
layers, d_hid, frames = args.layers, args.d_hid, args.frames
|
|
num_joints_in, num_joints_out = args.n_joints, args.out_joints
|
|
part_list = args.part_list
|
|
|
|
# layers, length, d_hid = layers, frames, d_hid
|
|
# num_joints_in, num_joints_out = 17,17
|
|
|
|
self.pose_emb = nn.Linear(2, d_hid, bias=False)
|
|
self.gelu = nn.GELU()
|
|
self.hpformer = HPFormer(layers, frames, num_joints_in, d_hid, part_list)
|
|
self.regress_head = nn.Linear(d_hid, 3, bias=False)
|
|
|
|
def forward(self, x):
|
|
# b, t, s, c = x.shape #batch,frame,joint,coordinate
|
|
# dimension tranfer
|
|
x = self.pose_emb(x)
|
|
x = self.gelu(x)
|
|
# spatio-temporal correlation
|
|
x = self.hpformer(x)
|
|
# regression head
|
|
x = self.regress_head(x)
|
|
|
|
return x
|
|
|
|
class Args:
|
|
def __init__(self, layers, d_hid, frames, n_joints, out_joints):
|
|
self.layers = layers
|
|
self.d_hid = d_hid
|
|
self.frames = frames
|
|
self.n_joints = n_joints
|
|
self.out_joints = out_joints
|
|
|
|
if __name__ == "__main__":
|
|
# inputs = torch.rand(64, 351, 34) # [btz, channel, T, H, W]
|
|
# inputs = torch.rand(1, 64, 4, 112, 112) #[btz, channel, T, H, W]
|
|
args = Args(layers=6, d_hid=192, frames=27, n_joints=17, out_joints=17)
|
|
args.part_list = [
|
|
[8, 9, 10], # 头
|
|
[0, 7, 8], # 身体
|
|
[11, 12, 13], # 左手
|
|
[14, 15, 16], # 右手
|
|
[4, 5, 6], # 左腿
|
|
[1, 2, 3] #右腿
|
|
]
|
|
net = Model(args)
|
|
inputs = torch.rand([1, 27, 17, 2])
|
|
if torch.cuda.is_available():
|
|
net = net.cuda()
|
|
inputs = inputs.cuda()
|
|
output = net(inputs)
|
|
print(output.size())
|
|
|
|
from thop import profile
|
|
# flops = 2*macs, 计算模型的计算量和参数量
|
|
macs, params = profile(net, inputs=(inputs,))
|
|
print(2*macs)
|
|
print(params)
|
|
|