You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
419 lines
14 KiB
419 lines
14 KiB
import torch
|
|
import torch.nn as nn
|
|
# from model.module.trans import Transformer as Transformer_s
|
|
# from model.module.trans_hypothesis import Transformer
|
|
import numpy as np
|
|
from einops import rearrange
|
|
from collections import OrderedDict
|
|
from torch.nn import functional as F
|
|
from torch.nn import init
|
|
import scipy.sparse as sp
|
|
|
|
from timm.models.layers import DropPath
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
|
|
layers, channel, d_hid, length = args.layers, args.channel, args.d_hid, args.frames
|
|
self.num_joints_in, self.num_joints_out = args.n_joints, args.out_joints
|
|
args.d_hid = 256
|
|
isTrainning = args.train
|
|
|
|
# dimension tranfer
|
|
self.pose_emb = nn.Linear(2, args.d_hid, bias=False)
|
|
self.gelu = nn.GELU()
|
|
|
|
# self.flow_emb = nn.Linear(2, args.d_hid, bias=False)
|
|
# self.gelu = nn.GELU()
|
|
|
|
self.mlpmixer = MlpMixer(6, args.frames, 17, args.d_hid, isTrainning)
|
|
|
|
self.pose_lift = nn.Linear(args.d_hid, 3, bias=False)
|
|
|
|
# self.sequence_pos_encoder = PositionalEncoding(args.d_hid, 0.1)
|
|
|
|
# self.tem_pool = nn.AdaptiveAvgPool1d(1)
|
|
# self.lpm = LearnedPosMap2D(args.frames,18)
|
|
|
|
def forward(self, x):
|
|
#x = x[:, :, :, :, 0].permute(0, 2, 3, 1).contiguous() # B,T,J,2,1
|
|
x = x[:, :, :, :, 0].permute(0, 2, 3, 1).contiguous() # B,T,J,2,1
|
|
#x = x.view(x.shape[0], x.shape[1], x.shape[2], -1) # b,t,j,2
|
|
|
|
b, t, j, c = x.shape
|
|
|
|
#g = torch.zeros([b,t,1,c]).cuda()
|
|
#x = torch.cat((x,g),-2)
|
|
|
|
x = self.pose_emb(x)
|
|
x = self.gelu(x)
|
|
|
|
|
|
# x = x.reshape(b,t,j,c)
|
|
|
|
x = self.mlpmixer(x)
|
|
|
|
|
|
x = self.pose_lift(x)
|
|
|
|
return x
|
|
|
|
|
|
def normalize(mx):
|
|
"""Row-normalize sparse matrix"""
|
|
rowsum = np.array(mx.sum(1))
|
|
r_inv = np.power(rowsum, -1).flatten()
|
|
r_inv[np.isinf(r_inv)] = 0.
|
|
r_mat_inv = sp.diags(r_inv)
|
|
mx = r_mat_inv.dot(mx)
|
|
return mx
|
|
|
|
|
|
def sparse_mx_to_torch_sparse_tensor(sparse_mx):
|
|
"""Convert a scipy sparse matrix to a torch sparse tensor."""
|
|
sparse_mx = sparse_mx.tocoo().astype(np.float32)
|
|
indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
|
|
values = torch.from_numpy(sparse_mx.data)
|
|
shape = torch.Size(sparse_mx.shape)
|
|
return torch.sparse.FloatTensor(indices, values, shape)
|
|
|
|
|
|
def adj_mx_from_edges(num_pts, edges, sparse=False):
|
|
edges = np.array(edges, dtype=np.int32)
|
|
data, i, j = np.ones(edges.shape[0]), edges[:, 0], edges[:, 1]
|
|
adj_mx = sp.coo_matrix((data, (i, j)), shape=(num_pts, num_pts), dtype=np.float32)
|
|
# print(11,adj_mx)
|
|
|
|
# build symmetric adjacency matrix
|
|
adj_mx = adj_mx + adj_mx.T.multiply(adj_mx.T > adj_mx) - adj_mx.multiply(adj_mx.T > adj_mx)
|
|
# adj_mx = normalize(adj_mx + sp.eye(adj_mx.shape[0]))
|
|
if sparse:
|
|
adj_mx = sparse_mx_to_torch_sparse_tensor(adj_mx)
|
|
else:
|
|
adj_mx = torch.tensor(adj_mx.todense(), dtype=torch.float)
|
|
return adj_mx.sum(-1)
|
|
|
|
|
|
class ChebConv(nn.Module):
|
|
"""
|
|
The ChebNet convolution operation.
|
|
:param in_c: int, number of input channels.
|
|
:param out_c: int, number of output channels.
|
|
:param K: int, the order of Chebyshev Polynomial.
|
|
"""
|
|
|
|
def __init__(self, in_c, out_c, K, bias=True, normalize=True):
|
|
super(ChebConv, self).__init__()
|
|
self.normalize = normalize
|
|
|
|
self.weight = nn.Parameter(torch.Tensor(K + 1, 1, in_c, out_c)) # [K+1, 1, in_c, out_c]
|
|
init.xavier_normal_(self.weight)
|
|
|
|
if bias:
|
|
self.bias = nn.Parameter(torch.Tensor(1, 1, out_c))
|
|
init.zeros_(self.bias)
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
|
|
self.K = K + 1
|
|
|
|
def forward(self, inputs, graph):
|
|
"""
|
|
:param inputs: the input data, [B, N, C]
|
|
:param graph: the graph structure, [N, N]
|
|
:return: convolution result, [B, N, D]
|
|
"""
|
|
L = ChebConv.get_laplacian(graph, self.normalize) # [N, N]
|
|
mul_L = self.cheb_polynomial(L).unsqueeze(1) # [K, 1, N, N]
|
|
|
|
result = torch.matmul(mul_L, inputs) # [K, B, N, C]
|
|
|
|
result = torch.matmul(result, self.weight) # [K, B, N, D]
|
|
result = torch.sum(result, dim=0) + self.bias # [B, N, D]
|
|
|
|
return result
|
|
|
|
def cheb_polynomial(self, laplacian):
|
|
"""
|
|
Compute the Chebyshev Polynomial, according to the graph laplacian.
|
|
:param laplacian: the graph laplacian, [N, N].
|
|
:return: the multi order Chebyshev laplacian, [K, N, N].
|
|
"""
|
|
N = laplacian.size(0) # [N, N]
|
|
multi_order_laplacian = torch.zeros([self.K, N, N], device=laplacian.device, dtype=torch.float) # [K, N, N]
|
|
multi_order_laplacian[0] = torch.eye(N, device=laplacian.device, dtype=torch.float)
|
|
|
|
if self.K == 1:
|
|
return multi_order_laplacian
|
|
else:
|
|
multi_order_laplacian[1] = laplacian
|
|
if self.K == 2:
|
|
return multi_order_laplacian
|
|
else:
|
|
for k in range(2, self.K):
|
|
multi_order_laplacian[k] = 2 * torch.mm(laplacian, multi_order_laplacian[k - 1]) - \
|
|
multi_order_laplacian[k - 2]
|
|
|
|
return multi_order_laplacian
|
|
|
|
@staticmethod
|
|
def get_laplacian(graph, normalize):
|
|
"""
|
|
return the laplacian of the graph.
|
|
:param graph: the graph structure without self loop, [N, N].
|
|
:param normalize: whether to used the normalized laplacian.
|
|
:return: graph laplacian.
|
|
"""
|
|
if normalize:
|
|
|
|
D = torch.diag(torch.sum(graph, dim=-1) ** (-1 / 2))
|
|
L = torch.eye(graph.size(0), device=graph.device, dtype=graph.dtype) - torch.mm(torch.mm(D, graph), D)
|
|
else:
|
|
D = torch.diag(torch.sum(graph, dim=-1))
|
|
L = D - graph
|
|
return L
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class cross_att(nn.Module):
|
|
def __init__(self, d_time, d_joint, d_coor, isTrainning=False, head=4):
|
|
super().__init__()
|
|
|
|
self.qkv = nn.Linear(d_coor, d_coor * 3)
|
|
self.head = head
|
|
self.layer_norm = nn.LayerNorm(d_coor)
|
|
# self.lpm_st_1 = LearnedPosMap2D(d_time, d_joint, gamma=4)
|
|
self.scale = d_coor ** -0.5
|
|
self.proj = nn.Linear(d_coor, d_coor)
|
|
self.d_time = d_time
|
|
self.d_joint = d_joint
|
|
self.head = head
|
|
|
|
# self.gate_s = nn.Conv2d(d_coor//2, d_coor//2, kernel_size=3, stride=1, padding=1,groups=d_coor//2)
|
|
# self.gate_t = nn.Conv2d(d_coor//2, d_coor//2, kernel_size=3, stride=1, padding=1,groups=d_coor//2)
|
|
|
|
# self.gate_s = MSLSP(d_time, d_joint, d_coor // 2)
|
|
self.gate_t = nn.Conv2d(d_coor//2, d_coor//2, kernel_size=3, stride=1, padding=1,groups=d_coor//2)
|
|
self.gate_s = nn.Conv2d(d_coor//2, d_coor//2, kernel_size=3, stride=1, padding=1,groups=d_coor//2)
|
|
|
|
# self.gate_gs = ChebConv(d_coor//2, d_coor//2, K=2)
|
|
#self.scf = nn.Parameter(0.0001*torch.Tensor(1,1,d_coor//8))
|
|
|
|
#self.weight = nn.Parameter(torch.Tensor(K + 1, 1, in_c, out_c)) # [K+1, 1, in_c, out_c]
|
|
#init.xavier_normal_(self.scf)
|
|
|
|
|
|
self.body_edges = torch.tensor([[0, 1], [1, 2], [2, 3],
|
|
[0, 4], [4, 5], [5, 6],
|
|
[0, 7], [7, 8], [8, 9], [9, 10],
|
|
[8, 11], [11, 12], [12, 13],
|
|
[8, 14], [14, 15], [15, 16]], dtype=torch.long)
|
|
# [0,17],[1,17],[2,17],[3,17],[4,17],[5,17],[6,17],[7,17],[8,17],[9,17],
|
|
#[10,17],[11,17],[12,17],[13,17],[14,17],[15,17],[16,17]
|
|
# self.conv_2 = nn.Conv2d(d_coor, d_coor, kernel_size=5, stride=1, padding=2,groups=d_coor)
|
|
self.graph = adj_mx_from_edges(d_joint, self.body_edges).long().cuda()
|
|
self.emb = nn.Embedding(20, d_coor//8, padding_idx=0)
|
|
self.part = torch.tensor([0,0,1,1,1,2,2,2,3,3,3,4,4,4,0,0,0]).long().cuda()
|
|
|
|
# self.gate_t = MSLSP(d_time, d_joint, d_coor//2)
|
|
|
|
|
|
# self.lpm_s = LearnedPosMap2D(d_time,d_joint)
|
|
# self.lpm_t = LearnedPosMap2D(d_time,d_joint)
|
|
|
|
self.drop = DropPath(0.5)
|
|
|
|
def forward(self, input):
|
|
b, t, s, c = input.shape
|
|
# print(self.scf)
|
|
# exit()
|
|
# input = input + self.lpm_st_1(input)
|
|
h = input
|
|
# print(input.shape)
|
|
# exit()
|
|
x = self.layer_norm(input)
|
|
qkv = self.qkv(x)
|
|
|
|
qkv = qkv.reshape(b, t, s, c, 3).permute(4, 0, 1, 2, 3) # b,t,s,c
|
|
# print(qkv.shape)
|
|
|
|
qkv_s, qkv_t = qkv.chunk(2, 4)
|
|
# print(qkv_s.shape,qkv_t.shape)
|
|
|
|
q_s, k_s, v_s = qkv_s[0], qkv_s[1], qkv_s[2] # b,t,s,c
|
|
q_t, k_t, v_t = qkv_t[0], qkv_t[1], qkv_t[2] # b,t,s,c
|
|
|
|
# print(q_s.shape,q_t.shape)
|
|
|
|
q_s = rearrange(q_s, 'b t s (h c) -> (b h t) s c', h=self.head)
|
|
k_s = rearrange(k_s, 'b t s (h c) -> (b h t) c s ', h=self.head)
|
|
|
|
q_t = rearrange(q_t, 'b t s (h c) -> (b h s) t c', h=self.head)
|
|
k_t = rearrange(k_t, 'b t s (h c) -> (b h s) c t ', h=self.head)
|
|
|
|
att_s = (q_s @ k_s) * self.scale # b*h,s,s
|
|
att_t = (q_t @ k_t) * self.scale # b*h,s,s
|
|
|
|
att_s = att_s.softmax(-1)
|
|
att_t = att_t.softmax(-1)
|
|
|
|
v_s = rearrange(v_s, 'b t s c -> b c t s ')
|
|
v_t = rearrange(v_t, 'b t s c -> b c t s ')
|
|
|
|
|
|
#
|
|
# print(v_s.shape,self.graph.shape)
|
|
lep_s = self.gate_s(v_s)
|
|
lep_t = self.gate_t(v_t)
|
|
v_s = rearrange(v_s, 'b c t s -> (b t ) s c')
|
|
# sep_s = self.gate_gs(v_s,self.graph)
|
|
sep_s = self.emb(self.part).unsqueeze(0)
|
|
# print(sep_s.shape)
|
|
|
|
# sep_s = rearrange(sep_s, '(b t) s (h c) -> (b h t) s c ', t=t,h=self.head)
|
|
|
|
lep_s = rearrange(lep_s, 'b (h c) t s -> (b h t) s c ', h=self.head)
|
|
lep_t = rearrange(lep_t, 'b (h c) t s -> (b h s) t c ', h=self.head)
|
|
|
|
|
|
v_s = rearrange(v_s, '(b t) s (h c) -> (b h t) s c ', t=t,h=self.head)
|
|
# v_s = rearrange(v_s, 'b (h c) t s -> (b h t) s c ', h=self.head)
|
|
v_t = rearrange(v_t, 'b (h c) t s -> (b h s) t c ', h=self.head)
|
|
#print(lep_s[55,:,:])
|
|
#print(sep_s[55,:,:])
|
|
#print(self.scf)
|
|
#print(self.scf*sep_s[55,:,:])
|
|
#exit()
|
|
|
|
# v = torch.cat((v1, v2), -1)
|
|
|
|
x_s = att_s @ v_s + lep_s + 0.0001*self.drop(sep_s) # b*h,s,c//h
|
|
x_t = att_t @ v_t + lep_t # b*h,t,c//h
|
|
# print(x_s.shape,x_t.shape)
|
|
|
|
x_s = rearrange(x_s, '(b h t) s c -> b h t s c ', h=self.head, t=t)
|
|
x_t = rearrange(x_t, '(b h s) t c -> b h t s c ', h=self.head, s=s)
|
|
# print(x_s.shape,x_t.shape)
|
|
x = torch.cat((x_s, x_t), -1)
|
|
x = rearrange(x, 'b h t s c -> b t s (h c) ')
|
|
|
|
x = self.proj(x)
|
|
# print(x.shape,h.shape)
|
|
x = x + h
|
|
return x
|
|
|
|
|
|
class MLP_3D(nn.Module):
|
|
def __init__(self, d_time, d_joint, d_coor, isTrainning=False, ):
|
|
super().__init__()
|
|
|
|
self.d_time = d_time
|
|
self.d_joint = d_joint
|
|
self.d_coor = d_coor
|
|
|
|
self.layer_norm1 = nn.LayerNorm(self.d_coor)
|
|
self.layer_norm2 = nn.LayerNorm(self.d_coor)
|
|
|
|
self.mlp1 = Mlp(self.d_coor, self.d_coor * 4, self.d_coor)
|
|
|
|
self.cross_att = cross_att(d_time, d_joint, d_coor, isTrainning)
|
|
self.drop = DropPath(0.0)
|
|
|
|
def forward(self, input):
|
|
b, t, s, c = input.shape
|
|
|
|
x = self.cross_att(input)
|
|
|
|
x = x + self.drop(self.mlp1(self.layer_norm1(x)))
|
|
|
|
return x
|
|
|
|
|
|
class Mlp(nn.Module):
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.1):
|
|
super().__init__()
|
|
out_features = out_features or in_features
|
|
hidden_features = hidden_features or in_features
|
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
|
|
self.act = act_layer()
|
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.act(x)
|
|
x = self.drop(x)
|
|
x = self.fc2(x)
|
|
x = self.drop(x)
|
|
return x
|
|
|
|
|
|
class Mlp_C(nn.Module):
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
|
super().__init__()
|
|
out_features = out_features or in_features
|
|
hidden_features = hidden_features or in_features
|
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
|
|
self.act = act_layer()
|
|
self.drop = nn.Dropout(drop)
|
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
|
|
self.sig = nn.Sigmoid()
|
|
self.avg = nn.AdaptiveAvgPool2d((1, 1))
|
|
|
|
def forward(self, x):
|
|
b, t, s, c = x.shape
|
|
# gate = self.avg(x.permute(0,3,1,2)).permute(0,2,3,1)
|
|
gate = self.fc1(x)
|
|
gate = self.act(gate)
|
|
gate = self.drop(gate)
|
|
gate = self.fc2(gate)
|
|
gate = self.sig(gate)
|
|
# gate = gate.expand(b,t,s,c)
|
|
x = x * gate
|
|
return x
|
|
|
|
|
|
class MlpMixer(nn.Module):
|
|
def __init__(self, num_block, d_time, d_joint, d_coor, isTrainning=False, ):
|
|
super(MlpMixer, self).__init__()
|
|
|
|
self.num_block = num_block
|
|
self.d_time = d_time
|
|
self.d_joint = d_joint
|
|
self.d_coor = d_coor
|
|
|
|
self.mixerblocks = []
|
|
for l in range(self.num_block):
|
|
self.mixerblocks.append(MLP_3D(self.d_time, self.d_joint, self.d_coor, isTrainning))
|
|
self.mixerblocks = nn.ModuleList(self.mixerblocks)
|
|
|
|
def forward(self, input):
|
|
# blocks layers
|
|
for i in range(self.num_block):
|
|
input = self.mixerblocks[i](input)
|
|
# exit()
|
|
|
|
return input
|
|
|
|
|
|
if __name__ == "__main__":
|
|
inputs = torch.rand(64, 351, 34) # [btz, channel, T, H, W]
|
|
# inputs = torch.rand(1, 64, 4, 112, 112) #[btz, channel, T, H, W]
|
|
net = Model()
|
|
output = net(inputs)
|
|
print(output.size())
|
|
from thop import profile
|
|
|
|
flops, params = profile(net, inputs=(inputs,))
|
|
print(flops)
|
|
print(params)
|
|
"""
|
|
for name, param in net.named_parameters():
|
|
if param.requires_grad:
|
|
print(name,':',param.size())
|
|
"""
|
|
|