You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

92 lines
2.5 KiB

import sys
import numpy as np
import torch
def normalize_screen_coordinates(X, w, h):
assert X.shape[-1] == 2
return X / w * 2 - [1, h / w]
def image_coordinates(X, w, h):
assert X.shape[-1] == 2
# Reverse camera frame normalization
return (X + [1, h / w]) * w / 2
def world_to_camera(X, R, t):
Rt = wrap(qinverse, R)
return wrap(qrot, np.tile(Rt, (*X.shape[:-1], 1)), X - t)
def camera_to_world(X, R, t):
return wrap(qrot, np.tile(R, (*X.shape[:-1], 1)), X) + t
def wrap(func, *args, unsqueeze=False):
args = list(args)
for i, arg in enumerate(args):
if type(arg) == np.ndarray:
args[i] = torch.from_numpy(arg)
if unsqueeze:
args[i] = args[i].unsqueeze(0)
result = func(*args)
if isinstance(result, tuple):
result = list(result)
for i, res in enumerate(result):
if type(res) == torch.Tensor:
if unsqueeze:
res = res.squeeze(0)
result[i] = res.numpy()
return tuple(result)
elif type(result) == torch.Tensor:
if unsqueeze:
result = result.squeeze(0)
return result.numpy()
else:
return result
def qrot(q, v):
assert q.shape[-1] == 4
assert v.shape[-1] == 3
assert q.shape[:-1] == v.shape[:-1]
qvec = q[..., 1:]
uv = torch.cross(qvec, v, dim=len(q.shape) - 1)
uuv = torch.cross(qvec, uv, dim=len(q.shape) - 1)
return (v + 2 * (q[..., :1] * uv + uuv))
def qinverse(q, inplace=False):
if inplace:
q[..., 1:] *= -1
return q
else:
w = q[..., :1]
xyz = q[..., 1:]
return torch.cat((w, -xyz), dim=len(q.shape) - 1)
def get_uvd2xyz(uvd, gt_3D, cam):
N, T, V,_ = uvd.size()
dec_out_all = uvd.view(-1, T, V, 3).clone()
root = gt_3D[:, :, 0, :].unsqueeze(-2).repeat(1, 1, V, 1).clone()
enc_in_all = uvd[:, :, :, :2].view(-1, T, V, 2).clone()
cam_f_all = cam[..., :2].view(-1,1,1,2).repeat(1,T,V,1)
cam_c_all = cam[..., 2:4].view(-1,1,1,2).repeat(1,T,V,1)
z_global = dec_out_all[:, :, :, 2]
z_global[:, :, 0] = root[:, :, 0, 2]
z_global[:, :, 1:] = dec_out_all[:, :, 1:, 2] + root[:, :, 1:, 2]
z_global = z_global.unsqueeze(-1)
uv = enc_in_all - cam_c_all
xy = uv * z_global.repeat(1, 1, 1, 2) / cam_f_all
xyz_global = torch.cat((xy, z_global), -1)
xyz_offset = (xyz_global - xyz_global[:, :, 0, :].unsqueeze(-2).repeat(1, 1, V, 1))
return xyz_offset