You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

223 lines
7.5 KiB

import torch
import torch.nn as nn
# from model.module.trans import Transformer as Transformer_s
# from model.module.trans_hypothesis import Transformer
import numpy as np
from einops import rearrange
from collections import OrderedDict
from torch.nn import functional as F
from torch.nn import init
import scipy.sparse as sp
from timm.models.layers import DropPath
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.1):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class STC_ATTENTION(nn.Module):
def __init__(self, d_time, d_joint, d_coor, head=8):
super().__init__()
"""
d_time: 帧数
d_joint: 关节点数
d_coor: 嵌入维度
"""
# print(d_time, d_joint, d_coor, head)
self.qkv = nn.Linear(d_coor, d_coor * 3)
self.head = head
self.layer_norm = nn.LayerNorm(d_coor)
self.scale = (d_coor // 2) ** -0.5
self.proj = nn.Linear(d_coor, d_coor)
self.d_time = d_time
self.d_joint = d_joint
self.head = head
# sep1
# print(d_coor)
self.emb = nn.Embedding(5, d_coor//head//2)
self.part = torch.tensor([0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 0, 3, 3, 3, 4, 4, 4]).long().cuda()
# sep2
self.sep2_t = nn.Conv2d(d_coor // 2, d_coor // 2, kernel_size=3, stride=1, padding=1, groups=d_coor // 2)
# self.sep2_s = nn.Conv2d(d_coor // 2, d_coor // 2, kernel_size=3, stride=1, padding=1, groups=d_coor // 2)
self.drop = DropPath(0.5)
def forward(self, input):
b, t, s, c = input.shape
h = input
x = self.layer_norm(input)
qkv = self.qkv(x) # b, t, s, c-> b, t, s, 3*c
qkv = qkv.reshape(b, t, s, c, 3).permute(4, 0, 1, 2, 3) # 3,b,t,s,c
# space group and time group
qkv_s, qkv_t = qkv.chunk(2, 4) # [3,b,t,s,c//2], [3,b,t,s,c//2]
q_s, k_s, v_s = qkv_s[0], qkv_s[1], qkv_s[2] # b,t,s,c//2
q_t, k_t, v_t = qkv_t[0], qkv_t[1], qkv_t[2] # b,t,s,c//2
# reshape for mat
q_s = rearrange(q_s, 'b t s (h c) -> (b h t) s c', h=self.head) # b,t,s,c//2-> b*h*t,s,c//2//h
k_s = rearrange(k_s, 'b t s (h c) -> (b h t) c s ', h=self.head) # b,t,s,c//2-> b*h*t,c//2//h,s
q_t = rearrange(q_t, 'b t s (h c) -> (b h s) t c', h=self.head) # b,t,s,c//2 -> b*h*s,t,c//2//h
k_t = rearrange(k_t, 'b t s (h c) -> (b h s) c t ', h=self.head) # b,t,s,c//2-> b*h*s,c//2//h,t
att_s = (q_s @ k_s) * self.scale # b*h*t,s,s
att_t = (q_t @ k_t) * self.scale # b*h*s,t,t
att_s = att_s.softmax(-1) # b*h*t,s,s
att_t = att_t.softmax(-1) # b*h*s,t,t
v_s = rearrange(v_s, 'b t s c -> b c t s ')
v_t = rearrange(v_t, 'b t s c -> b c t s ')
# sep2
# sep2_s = self.sep2_s(v_s) # b,c//2,t,s
sep2_t = self.sep2_t(v_t) # b,c//2,t,s
# sep2_s = rearrange(sep2_s, 'b (h c) t s -> (b h t) s c ', h=self.head) # b*h*t,s,c//2//h
sep2_t = rearrange(sep2_t, 'b (h c) t s -> (b h s) t c ', h=self.head) # b*h*s,t,c//2//h
# sep1
# v_s = rearrange(v_s, 'b c t s -> (b t ) s c')
# v_t = rearrange(v_t, 'b c t s -> (b s ) t c')
# print(lep_s.shape)
# sep_s = self.emb(self.part).unsqueeze(0) # 1,s,c//2//h
sep_t = self.emb(self.part).unsqueeze(0).unsqueeze(0).unsqueeze(0) # 1,1,1,s,c//2//h
# MSA
v_s = rearrange(v_s, 'b (h c) t s -> (b h t) s c ', h=self.head) # b*h*t,s,c//2//h
v_t = rearrange(v_t, 'b (h c) t s -> (b h s) t c ', h=self.head) # b*h*s,t,c//2//h
# x_s = att_s @ v_s + sep2_s + 0.0001 * self.drop(sep_s) # b*h*t,s,c//2//h
x_s = att_s @ v_s
x_t = att_t @ v_t + sep2_t # b*h,t,c//h # b*h*s,t,c//2//h
x_s = rearrange(x_s, '(b h t) s c -> b h t s c ', h=self.head, t=t) # b*h*t,s,c//h//2 -> b,h,t,s,c//h//2
x_t = rearrange(x_t, '(b h s) t c -> b h t s c ', h=self.head, s=s) # b*h*s,t,c//h//2 -> b,h,t,s,c//h//2
x_t = x_t + 1e-9 * self.drop(sep_t)
x = torch.cat((x_s, x_t), -1) # b,h,t,s,c//h
x = rearrange(x, 'b h t s c -> b t s (h c) ') # b,t,s,c
# projection and skip-connection
x = self.proj(x)
x = x + h
return x
class STC_BLOCK(nn.Module):
def __init__(self, d_time, d_joint, d_coor):
super().__init__()
self.layer_norm = nn.LayerNorm(d_coor)
self.mlp = Mlp(d_coor, d_coor * 4, d_coor)
self.stc_att = STC_ATTENTION(d_time, d_joint, d_coor)
self.drop = DropPath(0.0)
def forward(self, input):
b, t, s, c = input.shape
x = self.stc_att(input)
x = x + self.drop(self.mlp(self.layer_norm(x)))
return x
class STCFormer(nn.Module):
def __init__(self, num_block, d_time, d_joint, d_coor):
super(STCFormer, self).__init__()
self.num_block = num_block
self.d_time = d_time
self.d_joint = d_joint
self.d_coor = d_coor
self.stc_block = []
for l in range(self.num_block):
self.stc_block.append(STC_BLOCK(self.d_time, self.d_joint, self.d_coor))
self.stc_block = nn.ModuleList(self.stc_block)
def forward(self, input):
# blocks layers
for i in range(self.num_block):
input = self.stc_block[i](input)
# exit()
return input
class Model(nn.Module):
def __init__(self, args):
super().__init__()
layers, d_hid, frames = args.layers, args.d_hid, args.frames
num_joints_in, num_joints_out = args.n_joints, args.out_joints
# layers, length, d_hid = layers, frames, d_hid
# num_joints_in, num_joints_out = 17,17
self.pose_emb = nn.Linear(2, d_hid, bias=False)
self.gelu = nn.GELU()
self.stcformer = STCFormer(layers, frames, num_joints_in, d_hid)
self.regress_head = nn.Linear(d_hid, 3, bias=False)
def forward(self, x):
# b, t, s, c = x.shape #batch,frame,joint,coordinate
# dimension tranfer
x = self.pose_emb(x)
x = self.gelu(x)
# spatio-temporal correlation
x = self.stcformer(x)
# regression head
x = self.regress_head(x)
return x
class Args:
def __init__(self, layers, d_hid, frames, n_joints, out_joints):
self.layers = layers
self.d_hid = d_hid
self.frames = frames
self.n_joints = n_joints
self.out_joints = out_joints
if __name__ == "__main__":
# inputs = torch.rand(64, 351, 34) # [btz, channel, T, H, W]
# inputs = torch.rand(1, 64, 4, 112, 112) #[btz, channel, T, H, W]
args = Args(layers=6, d_hid=256, frames=27, n_joints=17, out_joints=17)
net = Model(args)
inputs = torch.rand([1, 27, 17, 2])
if torch.cuda.is_available():
net = net.cuda()
inputs = inputs.cuda()
output = net(inputs)
print(output.size())
from thop import profile
# flops = 2*macs, 计算模型的计算量和参数量
macs, params = profile(net, inputs=(inputs,))
print(2*macs)
print(params)