import torch import torch.nn as nn # from model.module.trans import Transformer as Transformer_s # from model.module.trans_hypothesis import Transformer import numpy as np from einops import rearrange from collections import OrderedDict from torch.nn import functional as F from torch.nn import init import scipy.sparse as sp from timm.models.layers import DropPath """ 网络结构: 128 joint 128 part + = 256 -> mlp * 3 -> + = 256 -> mlp * 3 128 pose 128 pose """ class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.1): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features, bias=False) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features, bias=False) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Joint_ATTENTION(nn.Module): def __init__(self, d_time, d_joint, d_coor, head=8): super().__init__() """ d_time: 帧数 d_joint: 关节点数 d_coor: 嵌入维度 """ self.qkv = nn.Linear(d_coor, d_coor * 3) self.head = head self.layer_norm = nn.LayerNorm(d_coor) self.scale = (d_coor) ** -0.5 self.d_time = d_time self.d_joint = d_joint self.pos_emb = nn.Embedding(d_time, d_coor) self.frame_idx = torch.tensor(list(range(d_time))).long().cuda() self.drop = DropPath(0.5) def forward(self, input): b, t, s, c = input.shape emb = self.pos_emb(self.frame_idx) input = input + emb[None, :, None, :] input = self.layer_norm(input) qkv = self.qkv(input) # b, t, s, c-> b, t, s, 3*c qkv_t = qkv.reshape(b, t, s, c, 3).permute(4, 0, 1, 2, 3) # 3,b,t,s,c q_t, k_t, v_t = qkv_t[0], qkv_t[1], qkv_t[2] # b,t,s,c q_t = rearrange(q_t, 'b t s (h c) -> (b h s) t c', h=self.head) # b,t,s,c -> b*h*s,t,c//h k_t = rearrange(k_t, 'b t s (h c) -> (b h s) c t ', h=self.head) # b,t,s,c-> b*h*s,c//h,t att_t = (q_t @ k_t) * self.scale # b*h*s,t,t att_t = att_t.softmax(-1) # b*h*s,t,t v_t = rearrange(v_t, 'b t s c -> b c t s ') # MSA v_t = rearrange(v_t, 'b (h c) t s -> (b h s) t c', h = self.head) # b*h*s,t,c//h x_t = att_t @ v_t # b*h*s,t,c//h x_t = rearrange(x_t, '(b h s) t c -> b t s (h c)', s=s, h=self.head) # b,t,s,c return x_t class Part_ATTENTION(nn.Module): def __init__(self, d_time, d_joint, d_coor, part_list, head=8): super().__init__() """ d_time: 帧数 d_joint: 关节点数 d_coor: 嵌入维度 """ self.head = head self.num_of_part = len(part_list) self.num_joint_of_part = len(part_list[0]) self.scale = (d_coor * self.num_joint_of_part) ** -0.5 self.d_time = d_time self.d_joint = d_joint self.layer_norm = nn.LayerNorm(d_coor * self.num_joint_of_part) self.pos_embed = nn.Embedding(d_time, d_coor * self.num_joint_of_part) self.frame_idx = torch.tensor(list(range(d_time))).long().cuda() self.qkv = nn.Linear(d_coor * self.num_joint_of_part, d_coor * self.num_joint_of_part * 3) self.drop = DropPath(0.5) # check part_list for part in part_list: assert len(part) == 3 # each part should have 3 joints for idx in part: assert 0 <= idx < d_joint # joint index should be less than d_joint self.idx_joint2part = torch.tensor([idx for part in part_list for idx in part], dtype=torch.long) self.idx_joint2part = self.idx_joint2part.flatten().cuda() idx_part2joint = list(range(d_joint)) for i, idx in enumerate(self.idx_joint2part): idx_part2joint[idx] = i self.idx_part2joint = torch.tensor(idx_part2joint, dtype=torch.long).cuda() self.overlap = self.get_overlap() # 查找有重叠的内容 def get_overlap(self): overlap_list = [-1] * self.d_joint for i, idx in enumerate(self.idx_joint2part): if overlap_list[idx] == -1: overlap_list[idx] = i else: if not isinstance(overlap_list[idx], list): overlap_i = overlap_list[idx] overlap_list[idx] = list() overlap_list[idx].append(overlap_i) overlap_list[idx].append(i) overlap = [] for i in overlap_list: if isinstance(i, list): overlap.append(i) if len(overlap) == 0: return None else: return overlap def forward(self, input): input = torch.index_select(input, 2, self.idx_joint2part) input = rearrange(input, 'b t (p j) c -> b t p (j c)', j=self.num_joint_of_part) b, t, p, c = input.shape emb = self.pos_embed(self.frame_idx) input = input + emb[None, :, None, :] input = self.layer_norm(input) qkv = self.qkv(input) # b, t, p, c-> b, t, p, 3*c qkv_t = qkv.reshape(b, t, p, c, 3).permute(4, 0, 1, 2, 3) # 3,b,t,p,c q_t, k_t, v_t = qkv_t[0], qkv_t[1], qkv_t[2] # b,t,p,c q_t = rearrange(q_t, 'b t s (h c) -> (b h s) t c', h=self.head) # b,t,p,c -> b*h*p,t,c//h k_t = rearrange(k_t, 'b t s (h c) -> (b h s) c t', h=self.head) # b,t,p,c-> b*h*p,c//h,t att_t = (q_t @ k_t) * self.scale # b*h*p,t,t att_t = att_t.softmax(-1) # b*h*p,t,t v_t = rearrange(v_t, 'b t p c -> b c t p') # MSA v_t = rearrange(v_t, 'b (h c) t p -> (b h p) t c ', h=self.head) # b*h*p,t,c//h x_t = att_t @ v_t # b*h*p,t,c//h x = rearrange(x_t, '(b h p) t c -> b h t p c', h=self.head, p=p) # b*h*p,t,c//h -> b,h,t,p,c//h # 还原 x = rearrange(x, 'b h t p c -> b t p (h c)') x = rearrange(x, 'b t p (j c) -> b t (p j) c', j=self.num_joint_of_part) # 查找重叠 求均值 if self.overlap: for overlap in self.overlap: idx = overlap[-1] for i in overlap[:-1]: x[:, :, idx, :] += x[:, :, i, :] x[:, :, idx, :] /= len(overlap) x = torch.index_select(x, 2, self.idx_part2joint) return x class Pose_ATTENTION(nn.Module): def __init__(self, d_time, d_joint, d_coor, head=8): super().__init__() """ d_time: 帧数 d_joint: 关节点数 d_coor: 嵌入维度 """ self.head = head self.scale = (d_coor * d_joint) ** -0.5 self.d_time = d_time self.d_joint = d_joint self.layer_norm = nn.LayerNorm(d_coor * self.d_joint) self.pos_emb = nn.Embedding(d_time, d_coor * d_joint) self.frame_idx = torch.tensor(list(range(d_time))).long().cuda() self.qkv = nn.Linear(d_coor * d_joint, d_coor * d_joint * 3) self.drop = DropPath(0.5) def forward(self, input): b, t, s, c = input.shape input = rearrange(input, 'b t s c -> b t (s c)') emb = self.pos_emb(self.frame_idx) input = input + emb[None, :] input = self.layer_norm(input) qkv = self.qkv(input) # b, t, s*c -> b, t, 3*s*c qkv_t = qkv.reshape(b, t, s*c, 3).permute(3, 0, 1, 2) # 3,b,t,s*c q_t, k_t, v_t = qkv_t[0], qkv_t[1], qkv_t[2] # b,t,s*c # reshape for mat q_t = rearrange(q_t, 'b t (h c) -> (b h) t c', h=self.head) # b,t,s*c -> b*h,t,s*c//h k_t = rearrange(k_t, 'b t (h c) -> (b h) c t ', h=self.head) # b,t,s*c-> b*h,s*c//h,t att_t = (q_t @ k_t) * self.scale # b*h,t,t att_t = att_t.softmax(-1) # b*h,t,t v_t = rearrange(v_t, 'b t (h c) -> (b h) t c', h=self.head) # b*h,t,s*c//h x_t = att_t @ v_t # b*h,t,s*c//h x_t = rearrange(x_t, '(b h) t (s c) -> b t s (h c) ', h=self.head, s=s) # b*h,t,s*c//h -> b,t,s,c return x_t class HP_BLOCK(nn.Module): def __init__(self, d_time, d_joint, d_coor, part_list, count): super().__init__() self.layer_norm = nn.LayerNorm(d_coor) self.mlp = Mlp(d_coor, d_coor*4, d_coor) self.count = count if count < 3: self.joint_att = Joint_ATTENTION(d_time, d_joint, d_coor//2) else: self.part_att = Part_ATTENTION(d_time, d_joint, d_coor//2, part_list) self.pose_att = Pose_ATTENTION(d_time, d_joint, d_coor//2) self.drop = DropPath(0.0) def forward(self, input): b, t, s, c = input.shape h = input # x = self.layer_norm(input) x_joint, x_pose = input.chunk(2, 3) if self.count < 3: x = torch.cat(( self.joint_att(x_joint), self.pose_att(x_pose) ), -1) else: x = torch.cat(( self.part_att(x_joint), self.pose_att(x_pose) ), -1) x = x + h x = x + self.drop(self.mlp(self.layer_norm(x))) return x class HPFormer(nn.Module): def __init__(self, num_block, d_time, d_joint, d_coor, part_list): super(HPFormer, self).__init__() self.num_block = num_block self.d_time = d_time self.d_joint = d_joint self.d_coor = d_coor self.part_list = part_list self.hp_block = [] for i in range(self.num_block): self.hp_block.append(HP_BLOCK(self.d_time, self.d_joint, self.d_coor, self.part_list, i)) self.hp_block = nn.ModuleList(self.hp_block) def forward(self, input): for i in range(self.num_block): input = self.hp_block[i](input) return input class Model(nn.Module): def __init__(self, args): super().__init__() layers, d_hid, frames = args.layers, args.d_hid, args.frames num_joints_in, num_joints_out = args.n_joints, args.out_joints part_list = args.part_list # layers, length, d_hid = layers, frames, d_hid # num_joints_in, num_joints_out = 17,17 self.pose_emb = nn.Linear(2, d_hid, bias=False) self.gelu = nn.GELU() self.hpformer = HPFormer(layers, frames, num_joints_in, d_hid, part_list) self.regress_head = nn.Linear(d_hid, 3, bias=False) def forward(self, x): # b, t, s, c = x.shape #batch,frame,joint,coordinate # dimension tranfer x = self.pose_emb(x) x = self.gelu(x) # spatio-temporal correlation x = self.hpformer(x) # regression head x = self.regress_head(x) return x class Args: def __init__(self, layers, d_hid, frames, n_joints, out_joints): self.layers = layers self.d_hid = d_hid self.frames = frames self.n_joints = n_joints self.out_joints = out_joints if __name__ == "__main__": # inputs = torch.rand(64, 351, 34) # [btz, channel, T, H, W] # inputs = torch.rand(1, 64, 4, 112, 112) #[btz, channel, T, H, W] args = Args(layers=6, d_hid=192, frames=27, n_joints=17, out_joints=17) args.part_list = [ [8, 9, 10], # 头 [0, 7, 8], # 身体 [11, 12, 13], # 左手 [14, 15, 16], # 右手 [4, 5, 6], # 左腿 [1, 2, 3] #右腿 ] import os os.environ["CUDA_VISIBLE_DEVICES"] = "1" net = Model(args) inputs = torch.rand([1, 27, 17, 2]) if torch.cuda.is_available(): net = net.cuda() inputs = inputs.cuda() output = net(inputs) print(output.size()) from thop import profile # flops = 2*macs, 计算模型的计算量和参数量 macs, params = profile(net, inputs=(inputs,)) print(2*macs) print(params)