You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

355 lines
11 KiB

import torch
import torch.nn as nn
# from model.module.trans import Transformer as Transformer_s
# from model.module.trans_hypothesis import Transformer
import numpy as np
from einops import rearrange
from collections import OrderedDict
from torch.nn import functional as F
from torch.nn import init
import scipy.sparse as sp
from timm.models.layers import DropPath
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.1):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Joint_ATTENTION(nn.Module):
def __init__(self, d_time, d_joint, d_coor, head=8):
super().__init__()
"""
d_time: 帧数
d_joint: 关节点数
d_coor: 嵌入维度
"""
self.qkv = nn.Linear(d_coor, d_coor * 3)
self.head = head
self.scale = (d_coor) ** -0.5
self.d_time = d_time
self.d_joint = d_joint
self.pos_emb = nn.Embedding(d_time, d_coor)
self.frame_idx = torch.tensor(list(range(d_time))).long().cuda()
self.drop = DropPath(0.5)
def forward(self, input):
b, t, s, c = input.shape
emb = self.pos_emb(self.frame_idx)
input = input + emb[None, :, None, :]
qkv = self.qkv(input) # b, t, s, c-> b, t, s, 3*c
qkv_t = qkv.reshape(b, t, s, c, 3).permute(4, 0, 1, 2, 3) # 3,b,t,s,c
q_t, k_t, v_t = qkv_t[0], qkv_t[1], qkv_t[2] # b,t,s,c
q_t = rearrange(q_t, 'b t s (h c) -> (b h s) t c', h=self.head) # b,t,s,c -> b*h*s,t,c//h
k_t = rearrange(k_t, 'b t s (h c) -> (b h s) c t ', h=self.head) # b,t,s,c-> b*h*s,c//h,t
att_t = (q_t @ k_t) * self.scale # b*h*s,t,t
att_t = att_t.softmax(-1) # b*h*s,t,t
v_t = rearrange(v_t, 'b t s c -> b c t s ')
# MSA
v_t = rearrange(v_t, 'b (h c) t s -> (b h s) t c', h = self.head) # b*h*s,t,c//h
x_t = att_t @ v_t # b*h*s,t,c//h
x_t = rearrange(x_t, '(b h s) t c -> b t s (h c)', s=s, h=self.head) # b,t,s,c
return x_t
class Part_ATTENTION(nn.Module):
def __init__(self, d_time, d_joint, d_coor, part_list, head=8):
super().__init__()
"""
d_time: 帧数
d_joint: 关节点数
d_coor: 嵌入维度
"""
self.head = head
self.num_of_part = len(part_list)
self.num_joint_of_part = len(part_list[0])
self.scale = (d_coor * self.num_joint_of_part) ** -0.5
self.d_time = d_time
self.d_joint = d_joint
self.layer_norm = nn.LayerNorm(d_coor * self.num_joint_of_part)
self.pos_embed = nn.Embedding(d_time, d_coor * self.num_joint_of_part)
self.frame_idx = torch.tensor(list(range(d_time))).long().cuda()
self.qkv = nn.Linear(d_coor * self.num_joint_of_part, d_coor * self.num_joint_of_part * 3)
self.drop = DropPath(0.5)
# check part_list
for part in part_list:
assert len(part) == 3 # each part should have 3 joints
for idx in part:
assert 0 <= idx < d_joint # joint index should be less than d_joint
self.idx_joint2part = torch.tensor([idx for part in part_list for idx in part], dtype=torch.long)
self.idx_joint2part = self.idx_joint2part.flatten().cuda()
idx_part2joint = list(range(d_joint))
for i, idx in enumerate(self.idx_joint2part):
idx_part2joint[idx] = i
self.idx_part2joint = torch.tensor(idx_part2joint, dtype=torch.long).cuda()
self.overlap = self.get_overlap()
# 查找有重叠的内容
def get_overlap(self):
overlap_list = [-1] * self.d_joint
for i, idx in enumerate(self.idx_joint2part):
if overlap_list[idx] == -1:
overlap_list[idx] = i
else:
if not isinstance(overlap_list[idx], list):
overlap_i = overlap_list[idx]
overlap_list[idx] = list()
overlap_list[idx].append(overlap_i)
overlap_list[idx].append(i)
overlap = []
for i in overlap_list:
if isinstance(i, list):
overlap.append(i)
if len(overlap) == 0:
return None
else:
return overlap
def forward(self, input):
input = torch.index_select(input, 2, self.idx_joint2part)
input = rearrange(input, 'b t (p j) c -> b t p (j c)', j=self.num_joint_of_part)
b, t, p, c = input.shape
emb = self.pos_embed(self.frame_idx)
input = input + emb[None, :, None, :]
qkv = self.qkv(input) # b, t, p, c-> b, t, p, 3*c
qkv_t = qkv.reshape(b, t, p, c, 3).permute(4, 0, 1, 2, 3) # 3,b,t,p,c
q_t, k_t, v_t = qkv_t[0], qkv_t[1], qkv_t[2] # b,t,p,c
q_t = rearrange(q_t, 'b t s (h c) -> (b h s) t c', h=self.head) # b,t,p,c -> b*h*p,t,c//h
k_t = rearrange(k_t, 'b t s (h c) -> (b h s) c t', h=self.head) # b,t,p,c-> b*h*p,c//h,t
att_t = (q_t @ k_t) * self.scale # b*h*p,t,t
att_t = att_t.softmax(-1) # b*h*p,t,t
v_t = rearrange(v_t, 'b t p c -> b c t p')
# MSA
v_t = rearrange(v_t, 'b (h c) t p -> (b h p) t c ', h=self.head) # b*h*p,t,c//h
x_t = att_t @ v_t # b*h*p,t,c//h
x = rearrange(x_t, '(b h p) t c -> b h t p c', h=self.head, p=p) # b*h*p,t,c//h -> b,h,t,p,c//h
# 还原
x = rearrange(x, 'b h t p c -> b t p (h c)')
x = rearrange(x, 'b t p (j c) -> b t (p j) c', j=self.num_joint_of_part)
# 查找重叠 求均值
if self.overlap:
for overlap in self.overlap:
idx = overlap[-1]
for i in overlap[:-1]:
x[:, :, idx, :] += x[:, :, i, :]
x[:, :, idx, :] /= len(overlap)
x = torch.index_select(x, 2, self.idx_part2joint)
return x
class Pose_ATTENTION(nn.Module):
def __init__(self, d_time, d_joint, d_coor, head=8):
super().__init__()
"""
d_time: 帧数
d_joint: 关节点数
d_coor: 嵌入维度
"""
self.head = head
self.scale = (d_coor * d_joint) ** -0.5
self.d_time = d_time
self.d_joint = d_joint
self.pos_emb = nn.Embedding(d_time, d_coor * d_joint)
self.frame_idx = torch.tensor(list(range(d_time))).long().cuda()
self.qkv = nn.Linear(d_coor * d_joint, d_coor * d_joint * 3)
self.drop = DropPath(0.5)
def forward(self, input):
b, t, s, c = input.shape
input = rearrange(input, 'b t s c -> b t (s c)')
emb = self.pos_emb(self.frame_idx)
input = input + emb[None, :]
qkv = self.qkv(input) # b, t, s*c -> b, t, 3*s*c
qkv_t = qkv.reshape(b, t, s*c, 3).permute(3, 0, 1, 2) # 3,b,t,s*c
q_t, k_t, v_t = qkv_t[0], qkv_t[1], qkv_t[2] # b,t,s*c
# reshape for mat
q_t = rearrange(q_t, 'b t (h c) -> (b h) t c', h=self.head) # b,t,s*c -> b*h,t,s*c//h
k_t = rearrange(k_t, 'b t (h c) -> (b h) c t ', h=self.head) # b,t,s*c-> b*h,s*c//h,t
att_t = (q_t @ k_t) * self.scale # b*h,t,t
att_t = att_t.softmax(-1) # b*h,t,t
v_t = rearrange(v_t, 'b t (h c) -> (b h) t c', h=self.head) # b*h,t,s*c//h
x_t = att_t @ v_t # b*h,t,s*c//h
x_t = rearrange(x_t, '(b h) t (s c) -> b t s (h c) ', h=self.head, s=s) # b*h,t,s*c//h -> b,t,s,c
return x_t
class HP_BLOCK(nn.Module):
def __init__(self, d_time, d_joint, d_coor, part_list):
super().__init__()
self.layer_norm = nn.LayerNorm(d_coor)
self.mlp = Mlp(d_coor, d_coor*4, d_coor)
self.joint_att = Joint_ATTENTION(d_time, d_joint, d_coor//3)
self.part_att = Part_ATTENTION(d_time, d_joint, d_coor//3, part_list)
self.pose_att = Pose_ATTENTION(d_time, d_joint, d_coor//3)
self.drop = DropPath(0.0)
def forward(self, input):
b, t, s, c = input.shape
h = input
x = self.layer_norm(input)
x_joint, x_part, x_pose = x.chunk(3, 3)
x = torch.cat((
self.joint_att(x_joint),
self.part_att(x_part),
self.pose_att(x_pose)
), -1)
x = x + h
x = x + self.drop(self.mlp(self.layer_norm(x)))
return x
class HPFormer(nn.Module):
def __init__(self, num_block, d_time, d_joint, d_coor, part_list):
super(HPFormer, self).__init__()
self.num_block = num_block
self.d_time = d_time
self.d_joint = d_joint
self.d_coor = d_coor
self.part_list = part_list
self.hp_block = []
for l in range(self.num_block):
self.hp_block.append(HP_BLOCK(self.d_time, self.d_joint, self.d_coor, self.part_list))
self.hp_block = nn.ModuleList(self.hp_block)
def forward(self, input):
for i in range(self.num_block):
input = self.hp_block[i](input)
return input
class Model(nn.Module):
def __init__(self, args):
super().__init__()
layers, d_hid, frames = args.layers, args.d_hid, args.frames
num_joints_in, num_joints_out = args.n_joints, args.out_joints
part_list = args.part_list
# layers, length, d_hid = layers, frames, d_hid
# num_joints_in, num_joints_out = 17,17
self.pose_emb = nn.Linear(2, d_hid, bias=False)
self.gelu = nn.GELU()
self.hpformer = HPFormer(layers, frames, num_joints_in, d_hid, part_list)
self.regress_head = nn.Linear(d_hid, 3, bias=False)
def forward(self, x):
# b, t, s, c = x.shape #batch,frame,joint,coordinate
# dimension tranfer
x = self.pose_emb(x)
x = self.gelu(x)
# spatio-temporal correlation
x = self.hpformer(x)
# regression head
x = self.regress_head(x)
return x
class Args:
def __init__(self, layers, d_hid, frames, n_joints, out_joints):
self.layers = layers
self.d_hid = d_hid
self.frames = frames
self.n_joints = n_joints
self.out_joints = out_joints
if __name__ == "__main__":
# inputs = torch.rand(64, 351, 34) # [btz, channel, T, H, W]
# inputs = torch.rand(1, 64, 4, 112, 112) #[btz, channel, T, H, W]
args = Args(layers=6, d_hid=192, frames=27, n_joints=17, out_joints=17)
args.part_list = [
[8, 9, 10], # 头
[0, 7, 8], # 身体
[11, 12, 13], # 左手
[14, 15, 16], # 右手
[4, 5, 6], # 左腿
[1, 2, 3] #右腿
]
net = Model(args)
inputs = torch.rand([1, 27, 17, 2])
if torch.cuda.is_available():
net = net.cuda()
inputs = inputs.cuda()
output = net(inputs)
print(output.size())
from thop import profile
# flops = 2*macs, 计算模型的计算量和参数量
macs, params = profile(net, inputs=(inputs,))
print(2*macs)
print(params)