You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
184 lines
6.2 KiB
184 lines
6.2 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from ..builder import MESH_MODELS
|
|
|
|
try:
|
|
from smplx import SMPL as SMPL_
|
|
has_smpl = True
|
|
except (ImportError, ModuleNotFoundError):
|
|
has_smpl = False
|
|
|
|
|
|
@MESH_MODELS.register_module()
|
|
class SMPL(nn.Module):
|
|
"""SMPL 3d human mesh model of paper ref: Matthew Loper. ``SMPL: A skinned
|
|
multi-person linear model''. This module is based on the smplx project
|
|
(https://github.com/vchoutas/smplx).
|
|
|
|
Args:
|
|
smpl_path (str): The path to the folder where the model weights are
|
|
stored.
|
|
joints_regressor (str): The path to the file where the joints
|
|
regressor weight are stored.
|
|
"""
|
|
|
|
def __init__(self, smpl_path, joints_regressor):
|
|
super().__init__()
|
|
|
|
assert has_smpl, 'Please install smplx to use SMPL.'
|
|
|
|
self.smpl_neutral = SMPL_(
|
|
model_path=smpl_path,
|
|
create_global_orient=False,
|
|
create_body_pose=False,
|
|
create_transl=False,
|
|
gender='neutral')
|
|
|
|
self.smpl_male = SMPL_(
|
|
model_path=smpl_path,
|
|
create_betas=False,
|
|
create_global_orient=False,
|
|
create_body_pose=False,
|
|
create_transl=False,
|
|
gender='male')
|
|
|
|
self.smpl_female = SMPL_(
|
|
model_path=smpl_path,
|
|
create_betas=False,
|
|
create_global_orient=False,
|
|
create_body_pose=False,
|
|
create_transl=False,
|
|
gender='female')
|
|
|
|
joints_regressor = torch.tensor(
|
|
np.load(joints_regressor), dtype=torch.float)[None, ...]
|
|
self.register_buffer('joints_regressor', joints_regressor)
|
|
|
|
self.num_verts = self.smpl_neutral.get_num_verts()
|
|
self.num_joints = self.joints_regressor.shape[1]
|
|
|
|
def smpl_forward(self, model, **kwargs):
|
|
"""Apply a specific SMPL model with given model parameters.
|
|
|
|
Note:
|
|
B: batch size
|
|
V: number of vertices
|
|
K: number of joints
|
|
|
|
Returns:
|
|
outputs (dict): Dict with mesh vertices and joints.
|
|
- vertices: Tensor([B, V, 3]), mesh vertices
|
|
- joints: Tensor([B, K, 3]), 3d joints regressed
|
|
from mesh vertices.
|
|
"""
|
|
|
|
betas = kwargs['betas']
|
|
batch_size = betas.shape[0]
|
|
device = betas.device
|
|
output = {}
|
|
if batch_size == 0:
|
|
output['vertices'] = betas.new_zeros([0, self.num_verts, 3])
|
|
output['joints'] = betas.new_zeros([0, self.num_joints, 3])
|
|
else:
|
|
smpl_out = model(**kwargs)
|
|
output['vertices'] = smpl_out.vertices
|
|
output['joints'] = torch.matmul(
|
|
self.joints_regressor.to(device), output['vertices'])
|
|
return output
|
|
|
|
def get_faces(self):
|
|
"""Return mesh faces.
|
|
|
|
Note:
|
|
F: number of faces
|
|
|
|
Returns:
|
|
faces: np.ndarray([F, 3]), mesh faces
|
|
"""
|
|
return self.smpl_neutral.faces
|
|
|
|
def forward(self,
|
|
betas,
|
|
body_pose,
|
|
global_orient,
|
|
transl=None,
|
|
gender=None):
|
|
"""Forward function.
|
|
|
|
Note:
|
|
B: batch size
|
|
J: number of controllable joints of model, for smpl model J=23
|
|
K: number of joints
|
|
|
|
Args:
|
|
betas: Tensor([B, 10]), human body shape parameters of SMPL model.
|
|
body_pose: Tensor([B, J*3] or [B, J, 3, 3]), human body pose
|
|
parameters of SMPL model. It should be axis-angle vector
|
|
([B, J*3]) or rotation matrix ([B, J, 3, 3)].
|
|
global_orient: Tensor([B, 3] or [B, 1, 3, 3]), global orientation
|
|
of human body. It should be axis-angle vector ([B, 3]) or
|
|
rotation matrix ([B, 1, 3, 3)].
|
|
transl: Tensor([B, 3]), global translation of human body.
|
|
gender: Tensor([B]), gender parameters of human body. -1 for
|
|
neutral, 0 for male , 1 for female.
|
|
|
|
Returns:
|
|
outputs (dict): Dict with mesh vertices and joints.
|
|
- vertices: Tensor([B, V, 3]), mesh vertices
|
|
- joints: Tensor([B, K, 3]), 3d joints regressed from
|
|
mesh vertices.
|
|
"""
|
|
|
|
batch_size = betas.shape[0]
|
|
pose2rot = True if body_pose.dim() == 2 else False
|
|
if batch_size > 0 and gender is not None:
|
|
output = {
|
|
'vertices': betas.new_zeros([batch_size, self.num_verts, 3]),
|
|
'joints': betas.new_zeros([batch_size, self.num_joints, 3])
|
|
}
|
|
|
|
mask = gender < 0
|
|
_out = self.smpl_forward(
|
|
self.smpl_neutral,
|
|
betas=betas[mask],
|
|
body_pose=body_pose[mask],
|
|
global_orient=global_orient[mask],
|
|
transl=transl[mask] if transl is not None else None,
|
|
pose2rot=pose2rot)
|
|
output['vertices'][mask] = _out['vertices']
|
|
output['joints'][mask] = _out['joints']
|
|
|
|
mask = gender == 0
|
|
_out = self.smpl_forward(
|
|
self.smpl_male,
|
|
betas=betas[mask],
|
|
body_pose=body_pose[mask],
|
|
global_orient=global_orient[mask],
|
|
transl=transl[mask] if transl is not None else None,
|
|
pose2rot=pose2rot)
|
|
output['vertices'][mask] = _out['vertices']
|
|
output['joints'][mask] = _out['joints']
|
|
|
|
mask = gender == 1
|
|
_out = self.smpl_forward(
|
|
self.smpl_male,
|
|
betas=betas[mask],
|
|
body_pose=body_pose[mask],
|
|
global_orient=global_orient[mask],
|
|
transl=transl[mask] if transl is not None else None,
|
|
pose2rot=pose2rot)
|
|
output['vertices'][mask] = _out['vertices']
|
|
output['joints'][mask] = _out['joints']
|
|
else:
|
|
return self.smpl_forward(
|
|
self.smpl_neutral,
|
|
betas=betas,
|
|
body_pose=body_pose,
|
|
global_orient=global_orient,
|
|
transl=transl,
|
|
pose2rot=pose2rot)
|
|
|
|
return output
|
|
|