You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
223 lines
7.5 KiB
223 lines
7.5 KiB
import torch
|
|
import torch.nn as nn
|
|
# from model.module.trans import Transformer as Transformer_s
|
|
# from model.module.trans_hypothesis import Transformer
|
|
import numpy as np
|
|
from einops import rearrange
|
|
from collections import OrderedDict
|
|
from torch.nn import functional as F
|
|
from torch.nn import init
|
|
import scipy.sparse as sp
|
|
|
|
from timm.models.layers import DropPath
|
|
|
|
|
|
|
|
class Mlp(nn.Module):
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.1):
|
|
super().__init__()
|
|
out_features = out_features or in_features
|
|
hidden_features = hidden_features or in_features
|
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
|
|
self.act = act_layer()
|
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.act(x)
|
|
x = self.drop(x)
|
|
x = self.fc2(x)
|
|
x = self.drop(x)
|
|
return x
|
|
|
|
|
|
class STC_ATTENTION(nn.Module):
|
|
def __init__(self, d_time, d_joint, d_coor, head=8):
|
|
super().__init__()
|
|
"""
|
|
d_time: 帧数
|
|
d_joint: 关节点数
|
|
d_coor: 嵌入维度
|
|
"""
|
|
# print(d_time, d_joint, d_coor, head)
|
|
self.qkv = nn.Linear(d_coor, d_coor * 3)
|
|
self.head = head
|
|
self.layer_norm = nn.LayerNorm(d_coor)
|
|
|
|
self.scale = (d_coor // 2) ** -0.5
|
|
self.proj = nn.Linear(d_coor, d_coor)
|
|
self.d_time = d_time
|
|
self.d_joint = d_joint
|
|
self.head = head
|
|
|
|
# sep1
|
|
# print(d_coor)
|
|
self.emb = nn.Embedding(5, d_coor//head//2)
|
|
self.part = torch.tensor([0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 0, 3, 3, 3, 4, 4, 4]).long().cuda()
|
|
|
|
# sep2
|
|
self.sep2_t = nn.Conv2d(d_coor // 2, d_coor // 2, kernel_size=3, stride=1, padding=1, groups=d_coor // 2)
|
|
# self.sep2_s = nn.Conv2d(d_coor // 2, d_coor // 2, kernel_size=3, stride=1, padding=1, groups=d_coor // 2)
|
|
|
|
self.drop = DropPath(0.5)
|
|
|
|
def forward(self, input):
|
|
b, t, s, c = input.shape
|
|
|
|
h = input
|
|
x = self.layer_norm(input)
|
|
|
|
qkv = self.qkv(x) # b, t, s, c-> b, t, s, 3*c
|
|
qkv = qkv.reshape(b, t, s, c, 3).permute(4, 0, 1, 2, 3) # 3,b,t,s,c
|
|
|
|
# space group and time group
|
|
qkv_s, qkv_t = qkv.chunk(2, 4) # [3,b,t,s,c//2], [3,b,t,s,c//2]
|
|
|
|
q_s, k_s, v_s = qkv_s[0], qkv_s[1], qkv_s[2] # b,t,s,c//2
|
|
q_t, k_t, v_t = qkv_t[0], qkv_t[1], qkv_t[2] # b,t,s,c//2
|
|
|
|
# reshape for mat
|
|
q_s = rearrange(q_s, 'b t s (h c) -> (b h t) s c', h=self.head) # b,t,s,c//2-> b*h*t,s,c//2//h
|
|
k_s = rearrange(k_s, 'b t s (h c) -> (b h t) c s ', h=self.head) # b,t,s,c//2-> b*h*t,c//2//h,s
|
|
|
|
q_t = rearrange(q_t, 'b t s (h c) -> (b h s) t c', h=self.head) # b,t,s,c//2 -> b*h*s,t,c//2//h
|
|
k_t = rearrange(k_t, 'b t s (h c) -> (b h s) c t ', h=self.head) # b,t,s,c//2-> b*h*s,c//2//h,t
|
|
|
|
att_s = (q_s @ k_s) * self.scale # b*h*t,s,s
|
|
att_t = (q_t @ k_t) * self.scale # b*h*s,t,t
|
|
|
|
att_s = att_s.softmax(-1) # b*h*t,s,s
|
|
att_t = att_t.softmax(-1) # b*h*s,t,t
|
|
|
|
v_s = rearrange(v_s, 'b t s c -> b c t s ')
|
|
v_t = rearrange(v_t, 'b t s c -> b c t s ')
|
|
|
|
# sep2
|
|
# sep2_s = self.sep2_s(v_s) # b,c//2,t,s
|
|
sep2_t = self.sep2_t(v_t) # b,c//2,t,s
|
|
# sep2_s = rearrange(sep2_s, 'b (h c) t s -> (b h t) s c ', h=self.head) # b*h*t,s,c//2//h
|
|
sep2_t = rearrange(sep2_t, 'b (h c) t s -> (b h s) t c ', h=self.head) # b*h*s,t,c//2//h
|
|
|
|
# sep1
|
|
# v_s = rearrange(v_s, 'b c t s -> (b t ) s c')
|
|
# v_t = rearrange(v_t, 'b c t s -> (b s ) t c')
|
|
# print(lep_s.shape)
|
|
# sep_s = self.emb(self.part).unsqueeze(0) # 1,s,c//2//h
|
|
sep_t = self.emb(self.part).unsqueeze(0).unsqueeze(0).unsqueeze(0) # 1,1,1,s,c//2//h
|
|
|
|
# MSA
|
|
v_s = rearrange(v_s, 'b (h c) t s -> (b h t) s c ', h=self.head) # b*h*t,s,c//2//h
|
|
v_t = rearrange(v_t, 'b (h c) t s -> (b h s) t c ', h=self.head) # b*h*s,t,c//2//h
|
|
|
|
# x_s = att_s @ v_s + sep2_s + 0.0001 * self.drop(sep_s) # b*h*t,s,c//2//h
|
|
x_s = att_s @ v_s
|
|
x_t = att_t @ v_t + sep2_t # b*h,t,c//h # b*h*s,t,c//2//h
|
|
|
|
x_s = rearrange(x_s, '(b h t) s c -> b h t s c ', h=self.head, t=t) # b*h*t,s,c//h//2 -> b,h,t,s,c//h//2
|
|
x_t = rearrange(x_t, '(b h s) t c -> b h t s c ', h=self.head, s=s) # b*h*s,t,c//h//2 -> b,h,t,s,c//h//2
|
|
|
|
x_t = x_t + 1e-9 * self.drop(sep_t)
|
|
|
|
x = torch.cat((x_s, x_t), -1) # b,h,t,s,c//h
|
|
x = rearrange(x, 'b h t s c -> b t s (h c) ') # b,t,s,c
|
|
|
|
# projection and skip-connection
|
|
x = self.proj(x)
|
|
x = x + h
|
|
return x
|
|
|
|
|
|
class STC_BLOCK(nn.Module):
|
|
def __init__(self, d_time, d_joint, d_coor):
|
|
super().__init__()
|
|
|
|
self.layer_norm = nn.LayerNorm(d_coor)
|
|
|
|
self.mlp = Mlp(d_coor, d_coor * 4, d_coor)
|
|
|
|
self.stc_att = STC_ATTENTION(d_time, d_joint, d_coor)
|
|
self.drop = DropPath(0.0)
|
|
|
|
def forward(self, input):
|
|
b, t, s, c = input.shape
|
|
x = self.stc_att(input)
|
|
x = x + self.drop(self.mlp(self.layer_norm(x)))
|
|
|
|
return x
|
|
|
|
|
|
class STCFormer(nn.Module):
|
|
def __init__(self, num_block, d_time, d_joint, d_coor):
|
|
super(STCFormer, self).__init__()
|
|
|
|
self.num_block = num_block
|
|
self.d_time = d_time
|
|
self.d_joint = d_joint
|
|
self.d_coor = d_coor
|
|
|
|
self.stc_block = []
|
|
for l in range(self.num_block):
|
|
self.stc_block.append(STC_BLOCK(self.d_time, self.d_joint, self.d_coor))
|
|
self.stc_block = nn.ModuleList(self.stc_block)
|
|
|
|
def forward(self, input):
|
|
# blocks layers
|
|
for i in range(self.num_block):
|
|
input = self.stc_block[i](input)
|
|
# exit()
|
|
return input
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
|
|
layers, d_hid, frames = args.layers, args.d_hid, args.frames
|
|
num_joints_in, num_joints_out = args.n_joints, args.out_joints
|
|
|
|
# layers, length, d_hid = layers, frames, d_hid
|
|
# num_joints_in, num_joints_out = 17,17
|
|
|
|
self.pose_emb = nn.Linear(2, d_hid, bias=False)
|
|
self.gelu = nn.GELU()
|
|
self.stcformer = STCFormer(layers, frames, num_joints_in, d_hid)
|
|
self.regress_head = nn.Linear(d_hid, 3, bias=False)
|
|
|
|
def forward(self, x):
|
|
# b, t, s, c = x.shape #batch,frame,joint,coordinate
|
|
# dimension tranfer
|
|
x = self.pose_emb(x)
|
|
x = self.gelu(x)
|
|
# spatio-temporal correlation
|
|
x = self.stcformer(x)
|
|
# regression head
|
|
x = self.regress_head(x)
|
|
|
|
return x
|
|
|
|
class Args:
|
|
def __init__(self, layers, d_hid, frames, n_joints, out_joints):
|
|
self.layers = layers
|
|
self.d_hid = d_hid
|
|
self.frames = frames
|
|
self.n_joints = n_joints
|
|
self.out_joints = out_joints
|
|
|
|
if __name__ == "__main__":
|
|
# inputs = torch.rand(64, 351, 34) # [btz, channel, T, H, W]
|
|
# inputs = torch.rand(1, 64, 4, 112, 112) #[btz, channel, T, H, W]
|
|
args = Args(layers=6, d_hid=256, frames=27, n_joints=17, out_joints=17)
|
|
net = Model(args)
|
|
inputs = torch.rand([1, 27, 17, 2])
|
|
if torch.cuda.is_available():
|
|
net = net.cuda()
|
|
inputs = inputs.cuda()
|
|
output = net(inputs)
|
|
print(output.size())
|
|
|
|
from thop import profile
|
|
# flops = 2*macs, 计算模型的计算量和参数量
|
|
macs, params = profile(net, inputs=(inputs,))
|
|
print(2*macs)
|
|
print(params)
|
|
|