You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

419 lines
14 KiB

import torch
import torch.nn as nn
# from model.module.trans import Transformer as Transformer_s
# from model.module.trans_hypothesis import Transformer
import numpy as np
from einops import rearrange
from collections import OrderedDict
from torch.nn import functional as F
from torch.nn import init
import scipy.sparse as sp
from timm.models.layers import DropPath
class Model(nn.Module):
def __init__(self, args):
super().__init__()
layers, channel, d_hid, length = args.layers, args.channel, args.d_hid, args.frames
self.num_joints_in, self.num_joints_out = args.n_joints, args.out_joints
args.d_hid = 256
isTrainning = args.train
# dimension tranfer
self.pose_emb = nn.Linear(2, args.d_hid, bias=False)
self.gelu = nn.GELU()
# self.flow_emb = nn.Linear(2, args.d_hid, bias=False)
# self.gelu = nn.GELU()
self.mlpmixer = MlpMixer(6, args.frames, 17, args.d_hid, isTrainning)
self.pose_lift = nn.Linear(args.d_hid, 3, bias=False)
# self.sequence_pos_encoder = PositionalEncoding(args.d_hid, 0.1)
# self.tem_pool = nn.AdaptiveAvgPool1d(1)
# self.lpm = LearnedPosMap2D(args.frames,18)
def forward(self, x):
#x = x[:, :, :, :, 0].permute(0, 2, 3, 1).contiguous() # B,T,J,2,1
x = x[:, :, :, :, 0].permute(0, 2, 3, 1).contiguous() # B,T,J,2,1
#x = x.view(x.shape[0], x.shape[1], x.shape[2], -1) # b,t,j,2
b, t, j, c = x.shape
#g = torch.zeros([b,t,1,c]).cuda()
#x = torch.cat((x,g),-2)
x = self.pose_emb(x)
x = self.gelu(x)
# x = x.reshape(b,t,j,c)
x = self.mlpmixer(x)
x = self.pose_lift(x)
return x
def normalize(mx):
"""Row-normalize sparse matrix"""
rowsum = np.array(mx.sum(1))
r_inv = np.power(rowsum, -1).flatten()
r_inv[np.isinf(r_inv)] = 0.
r_mat_inv = sp.diags(r_inv)
mx = r_mat_inv.dot(mx)
return mx
def sparse_mx_to_torch_sparse_tensor(sparse_mx):
"""Convert a scipy sparse matrix to a torch sparse tensor."""
sparse_mx = sparse_mx.tocoo().astype(np.float32)
indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
values = torch.from_numpy(sparse_mx.data)
shape = torch.Size(sparse_mx.shape)
return torch.sparse.FloatTensor(indices, values, shape)
def adj_mx_from_edges(num_pts, edges, sparse=False):
edges = np.array(edges, dtype=np.int32)
data, i, j = np.ones(edges.shape[0]), edges[:, 0], edges[:, 1]
adj_mx = sp.coo_matrix((data, (i, j)), shape=(num_pts, num_pts), dtype=np.float32)
# print(11,adj_mx)
# build symmetric adjacency matrix
adj_mx = adj_mx + adj_mx.T.multiply(adj_mx.T > adj_mx) - adj_mx.multiply(adj_mx.T > adj_mx)
# adj_mx = normalize(adj_mx + sp.eye(adj_mx.shape[0]))
if sparse:
adj_mx = sparse_mx_to_torch_sparse_tensor(adj_mx)
else:
adj_mx = torch.tensor(adj_mx.todense(), dtype=torch.float)
return adj_mx.sum(-1)
class ChebConv(nn.Module):
"""
The ChebNet convolution operation.
:param in_c: int, number of input channels.
:param out_c: int, number of output channels.
:param K: int, the order of Chebyshev Polynomial.
"""
def __init__(self, in_c, out_c, K, bias=True, normalize=True):
super(ChebConv, self).__init__()
self.normalize = normalize
self.weight = nn.Parameter(torch.Tensor(K + 1, 1, in_c, out_c)) # [K+1, 1, in_c, out_c]
init.xavier_normal_(self.weight)
if bias:
self.bias = nn.Parameter(torch.Tensor(1, 1, out_c))
init.zeros_(self.bias)
else:
self.register_parameter("bias", None)
self.K = K + 1
def forward(self, inputs, graph):
"""
:param inputs: the input data, [B, N, C]
:param graph: the graph structure, [N, N]
:return: convolution result, [B, N, D]
"""
L = ChebConv.get_laplacian(graph, self.normalize) # [N, N]
mul_L = self.cheb_polynomial(L).unsqueeze(1) # [K, 1, N, N]
result = torch.matmul(mul_L, inputs) # [K, B, N, C]
result = torch.matmul(result, self.weight) # [K, B, N, D]
result = torch.sum(result, dim=0) + self.bias # [B, N, D]
return result
def cheb_polynomial(self, laplacian):
"""
Compute the Chebyshev Polynomial, according to the graph laplacian.
:param laplacian: the graph laplacian, [N, N].
:return: the multi order Chebyshev laplacian, [K, N, N].
"""
N = laplacian.size(0) # [N, N]
multi_order_laplacian = torch.zeros([self.K, N, N], device=laplacian.device, dtype=torch.float) # [K, N, N]
multi_order_laplacian[0] = torch.eye(N, device=laplacian.device, dtype=torch.float)
if self.K == 1:
return multi_order_laplacian
else:
multi_order_laplacian[1] = laplacian
if self.K == 2:
return multi_order_laplacian
else:
for k in range(2, self.K):
multi_order_laplacian[k] = 2 * torch.mm(laplacian, multi_order_laplacian[k - 1]) - \
multi_order_laplacian[k - 2]
return multi_order_laplacian
@staticmethod
def get_laplacian(graph, normalize):
"""
return the laplacian of the graph.
:param graph: the graph structure without self loop, [N, N].
:param normalize: whether to used the normalized laplacian.
:return: graph laplacian.
"""
if normalize:
D = torch.diag(torch.sum(graph, dim=-1) ** (-1 / 2))
L = torch.eye(graph.size(0), device=graph.device, dtype=graph.dtype) - torch.mm(torch.mm(D, graph), D)
else:
D = torch.diag(torch.sum(graph, dim=-1))
L = D - graph
return L
class cross_att(nn.Module):
def __init__(self, d_time, d_joint, d_coor, isTrainning=False, head=4):
super().__init__()
self.qkv = nn.Linear(d_coor, d_coor * 3)
self.head = head
self.layer_norm = nn.LayerNorm(d_coor)
# self.lpm_st_1 = LearnedPosMap2D(d_time, d_joint, gamma=4)
self.scale = d_coor ** -0.5
self.proj = nn.Linear(d_coor, d_coor)
self.d_time = d_time
self.d_joint = d_joint
self.head = head
# self.gate_s = nn.Conv2d(d_coor//2, d_coor//2, kernel_size=3, stride=1, padding=1,groups=d_coor//2)
# self.gate_t = nn.Conv2d(d_coor//2, d_coor//2, kernel_size=3, stride=1, padding=1,groups=d_coor//2)
# self.gate_s = MSLSP(d_time, d_joint, d_coor // 2)
self.gate_t = nn.Conv2d(d_coor//2, d_coor//2, kernel_size=3, stride=1, padding=1,groups=d_coor//2)
self.gate_s = nn.Conv2d(d_coor//2, d_coor//2, kernel_size=3, stride=1, padding=1,groups=d_coor//2)
# self.gate_gs = ChebConv(d_coor//2, d_coor//2, K=2)
#self.scf = nn.Parameter(0.0001*torch.Tensor(1,1,d_coor//8))
#self.weight = nn.Parameter(torch.Tensor(K + 1, 1, in_c, out_c)) # [K+1, 1, in_c, out_c]
#init.xavier_normal_(self.scf)
self.body_edges = torch.tensor([[0, 1], [1, 2], [2, 3],
[0, 4], [4, 5], [5, 6],
[0, 7], [7, 8], [8, 9], [9, 10],
[8, 11], [11, 12], [12, 13],
[8, 14], [14, 15], [15, 16]], dtype=torch.long)
# [0,17],[1,17],[2,17],[3,17],[4,17],[5,17],[6,17],[7,17],[8,17],[9,17],
#[10,17],[11,17],[12,17],[13,17],[14,17],[15,17],[16,17]
# self.conv_2 = nn.Conv2d(d_coor, d_coor, kernel_size=5, stride=1, padding=2,groups=d_coor)
self.graph = adj_mx_from_edges(d_joint, self.body_edges).long().cuda()
self.emb = nn.Embedding(20, d_coor//8, padding_idx=0)
self.part = torch.tensor([0,0,1,1,1,2,2,2,3,3,3,4,4,4,0,0,0]).long().cuda()
# self.gate_t = MSLSP(d_time, d_joint, d_coor//2)
# self.lpm_s = LearnedPosMap2D(d_time,d_joint)
# self.lpm_t = LearnedPosMap2D(d_time,d_joint)
self.drop = DropPath(0.5)
def forward(self, input):
b, t, s, c = input.shape
# print(self.scf)
# exit()
# input = input + self.lpm_st_1(input)
h = input
# print(input.shape)
# exit()
x = self.layer_norm(input)
qkv = self.qkv(x)
qkv = qkv.reshape(b, t, s, c, 3).permute(4, 0, 1, 2, 3) # b,t,s,c
# print(qkv.shape)
qkv_s, qkv_t = qkv.chunk(2, 4)
# print(qkv_s.shape,qkv_t.shape)
q_s, k_s, v_s = qkv_s[0], qkv_s[1], qkv_s[2] # b,t,s,c
q_t, k_t, v_t = qkv_t[0], qkv_t[1], qkv_t[2] # b,t,s,c
# print(q_s.shape,q_t.shape)
q_s = rearrange(q_s, 'b t s (h c) -> (b h t) s c', h=self.head)
k_s = rearrange(k_s, 'b t s (h c) -> (b h t) c s ', h=self.head)
q_t = rearrange(q_t, 'b t s (h c) -> (b h s) t c', h=self.head)
k_t = rearrange(k_t, 'b t s (h c) -> (b h s) c t ', h=self.head)
att_s = (q_s @ k_s) * self.scale # b*h,s,s
att_t = (q_t @ k_t) * self.scale # b*h,s,s
att_s = att_s.softmax(-1)
att_t = att_t.softmax(-1)
v_s = rearrange(v_s, 'b t s c -> b c t s ')
v_t = rearrange(v_t, 'b t s c -> b c t s ')
#
# print(v_s.shape,self.graph.shape)
lep_s = self.gate_s(v_s)
lep_t = self.gate_t(v_t)
v_s = rearrange(v_s, 'b c t s -> (b t ) s c')
# sep_s = self.gate_gs(v_s,self.graph)
sep_s = self.emb(self.part).unsqueeze(0)
# print(sep_s.shape)
# sep_s = rearrange(sep_s, '(b t) s (h c) -> (b h t) s c ', t=t,h=self.head)
lep_s = rearrange(lep_s, 'b (h c) t s -> (b h t) s c ', h=self.head)
lep_t = rearrange(lep_t, 'b (h c) t s -> (b h s) t c ', h=self.head)
v_s = rearrange(v_s, '(b t) s (h c) -> (b h t) s c ', t=t,h=self.head)
# v_s = rearrange(v_s, 'b (h c) t s -> (b h t) s c ', h=self.head)
v_t = rearrange(v_t, 'b (h c) t s -> (b h s) t c ', h=self.head)
#print(lep_s[55,:,:])
#print(sep_s[55,:,:])
#print(self.scf)
#print(self.scf*sep_s[55,:,:])
#exit()
# v = torch.cat((v1, v2), -1)
x_s = att_s @ v_s + lep_s + 0.0001*self.drop(sep_s) # b*h,s,c//h
x_t = att_t @ v_t + lep_t # b*h,t,c//h
# print(x_s.shape,x_t.shape)
x_s = rearrange(x_s, '(b h t) s c -> b h t s c ', h=self.head, t=t)
x_t = rearrange(x_t, '(b h s) t c -> b h t s c ', h=self.head, s=s)
# print(x_s.shape,x_t.shape)
x = torch.cat((x_s, x_t), -1)
x = rearrange(x, 'b h t s c -> b t s (h c) ')
x = self.proj(x)
# print(x.shape,h.shape)
x = x + h
return x
class MLP_3D(nn.Module):
def __init__(self, d_time, d_joint, d_coor, isTrainning=False, ):
super().__init__()
self.d_time = d_time
self.d_joint = d_joint
self.d_coor = d_coor
self.layer_norm1 = nn.LayerNorm(self.d_coor)
self.layer_norm2 = nn.LayerNorm(self.d_coor)
self.mlp1 = Mlp(self.d_coor, self.d_coor * 4, self.d_coor)
self.cross_att = cross_att(d_time, d_joint, d_coor, isTrainning)
self.drop = DropPath(0.0)
def forward(self, input):
b, t, s, c = input.shape
x = self.cross_att(input)
x = x + self.drop(self.mlp1(self.layer_norm1(x)))
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.1):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Mlp_C(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
self.act = act_layer()
self.drop = nn.Dropout(drop)
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
self.sig = nn.Sigmoid()
self.avg = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
b, t, s, c = x.shape
# gate = self.avg(x.permute(0,3,1,2)).permute(0,2,3,1)
gate = self.fc1(x)
gate = self.act(gate)
gate = self.drop(gate)
gate = self.fc2(gate)
gate = self.sig(gate)
# gate = gate.expand(b,t,s,c)
x = x * gate
return x
class MlpMixer(nn.Module):
def __init__(self, num_block, d_time, d_joint, d_coor, isTrainning=False, ):
super(MlpMixer, self).__init__()
self.num_block = num_block
self.d_time = d_time
self.d_joint = d_joint
self.d_coor = d_coor
self.mixerblocks = []
for l in range(self.num_block):
self.mixerblocks.append(MLP_3D(self.d_time, self.d_joint, self.d_coor, isTrainning))
self.mixerblocks = nn.ModuleList(self.mixerblocks)
def forward(self, input):
# blocks layers
for i in range(self.num_block):
input = self.mixerblocks[i](input)
# exit()
return input
if __name__ == "__main__":
inputs = torch.rand(64, 351, 34) # [btz, channel, T, H, W]
# inputs = torch.rand(1, 64, 4, 112, 112) #[btz, channel, T, H, W]
net = Model()
output = net(inputs)
print(output.size())
from thop import profile
flops, params = profile(net, inputs=(inputs,))
print(flops)
print(params)
"""
for name, param in net.named_parameters():
if param.requires_grad:
print(name,':',param.size())
"""