import numpy as np import imageio import skimage import cv2 import torch from torch.nn import functional as F def load_rgb(path): img = imageio.imread(path) img = skimage.img_as_float32(img) # pixel values between [-1,1] img -= 0.5 img *= 2. img = img.transpose(2, 0, 1) return img def load_mask(path): alpha = imageio.imread(path, as_gray=True) alpha = skimage.img_as_float32(alpha) object_mask = alpha > 127.5 return object_mask def load_mask_white_bg(path): alpha = imageio.imread(path, as_gray=True) alpha = skimage.img_as_float32(alpha) object_mask = alpha < 250.5 return object_mask def load_K_Rt_from_P(filename, P=None): if P is None: lines = open(filename).read().splitlines() if len(lines) == 4: lines = lines[1:] lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] P = np.asarray(lines).astype(np.float32).squeeze() out = cv2.decomposeProjectionMatrix(P) K = out[0] R = out[1] t = out[2] K = K/K[2,2] intrinsics = np.eye(4) intrinsics[:3, :3] = K pose = np.eye(4, dtype=np.float32) to_gl = np.eye(3, dtype=np.float32) to_gl[0, 0] = -1. to_gl[1, 1] = -1. pose[:3, :3] = np.dot(R.transpose(), to_gl) pose[:3,3] = (t[:3] / t[3])[:,0] return intrinsics, pose def get_camera_params(uv, pose, intrinsics): if pose.shape[1] == 7: #In case of quaternion vector representation cam_loc = pose[:, 4:] R = quat_to_rot(pose[:,:4]) p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float() p[:, :3, :3] = R p[:, :3, 3] = cam_loc else: # In case of pose matrix representation cam_loc = pose[:, :3, 3] p = pose batch_size, num_samples, _ = uv.shape depth = torch.ones((batch_size, num_samples)).cuda() x_cam = uv[:, :, 0].view(batch_size, -1) y_cam = uv[:, :, 1].view(batch_size, -1) z_cam = depth.view(batch_size, -1) pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics) # permute for batch matrix product pixel_points_cam = pixel_points_cam.permute(0, 2, 1) world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3] ray_dirs = world_coords - cam_loc[:, None, :] ray_dirs = F.normalize(ray_dirs, dim=2) return ray_dirs, cam_loc def get_camera_for_plot(pose): if pose.shape[1] == 7: #In case of quaternion vector representation cam_loc = pose[:, 4:].detach() R = quat_to_rot(pose[:,:4].detach()) else: # In case of pose matrix representation cam_loc = pose[:, :3, 3] R = pose[:, :3, :3] cam_dir = R[:, :3, 2] return cam_loc, cam_dir def lift(x, y, z, intrinsics): # parse intrinsics intrinsics = intrinsics.cuda() fx = intrinsics[:, 0, 0] fy = intrinsics[:, 1, 1] cx = intrinsics[:, 0, 2] cy = intrinsics[:, 1, 2] sk = intrinsics[:, 0, 1] x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z # homogeneous return torch.stack((x_lift, y_lift, z, torch.ones_like(z).cuda()), dim=-1) def quat_to_rot(q): batch_size, _ = q.shape q = F.normalize(q, dim=1) R = torch.ones((batch_size, 3,3)).cuda() qr=q[:,0] qi = q[:, 1] qj = q[:, 2] qk = q[:, 3] R[:, 0, 0]=1-2 * (qj**2 + qk**2) R[:, 0, 1] = 2 * (qj *qi -qk*qr) R[:, 0, 2] = 2 * (qi * qk + qr * qj) R[:, 1, 0] = 2 * (qj * qi + qk * qr) R[:, 1, 1] = 1-2 * (qi**2 + qk**2) R[:, 1, 2] = 2*(qj*qk - qi*qr) R[:, 2, 0] = 2 * (qk * qi-qj * qr) R[:, 2, 1] = 2 * (qj*qk + qi*qr) R[:, 2, 2] = 1-2 * (qi**2 + qj**2) return R def rot_to_quat(R): batch_size, _,_ = R.shape q = torch.ones((batch_size, 4)).cuda() R00 = R[:, 0,0] R01 = R[:, 0, 1] R02 = R[:, 0, 2] R10 = R[:, 1, 0] R11 = R[:, 1, 1] R12 = R[:, 1, 2] R20 = R[:, 2, 0] R21 = R[:, 2, 1] R22 = R[:, 2, 2] q[:,0]=torch.sqrt(1.0+R00+R11+R22)/2 q[:, 1]=(R21-R12)/(4*q[:,0]) q[:, 2] = (R02 - R20) / (4 * q[:, 0]) q[:, 3] = (R10 - R01) / (4 * q[:, 0]) return q def get_sphere_intersection(cam_loc, ray_directions, r = 1.0): # Input: n_images x 4 x 4 ; n_images x n_rays x 3 # Output: n_images * n_rays x 2 (close and far) ; n_images * n_rays n_imgs, n_pix, _ = ray_directions.shape cam_loc = cam_loc.unsqueeze(-1) ray_cam_dot = torch.bmm(ray_directions, cam_loc).squeeze() under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2,1) ** 2 - r ** 2) under_sqrt = under_sqrt.reshape(-1) mask_intersect = under_sqrt > 0 sphere_intersections = torch.zeros(n_imgs * n_pix, 2).cuda().float() sphere_intersections[mask_intersect] = torch.sqrt(under_sqrt[mask_intersect]).unsqueeze(-1) * torch.Tensor([-1, 1]).cuda().float() sphere_intersections[mask_intersect] -= ray_cam_dot.reshape(-1)[mask_intersect].unsqueeze(-1) sphere_intersections = sphere_intersections.reshape(n_imgs, n_pix, 2) sphere_intersections = sphere_intersections.clamp_min(0.0) mask_intersect = mask_intersect.reshape(n_imgs, n_pix) return sphere_intersections, mask_intersect def get_depth(points, pose): ''' Retruns depth from 3D points according to camera pose ''' batch_size, num_samples, _ = points.shape if pose.shape[1] == 7: # In case of quaternion vector representation cam_loc = pose[:, 4:] R = quat_to_rot(pose[:, :4]) pose = torch.eye(4).unsqueeze(0).repeat(batch_size, 1, 1).cuda().float() pose[:, :3, 3] = cam_loc pose[:, :3, :3] = R points_hom = torch.cat((points, torch.ones((batch_size, num_samples, 1)).cuda()), dim=2) # permute for batch matrix product points_hom = points_hom.permute(0, 2, 1) points_cam = torch.inverse(pose).bmm(points_hom) depth = points_cam[:, 2, :][:, :, None] return depth