diff --git a/reconstruction/PBIDR/README.md b/reconstruction/PBIDR/README.md index 9183d78..0388594 100644 --- a/reconstruction/PBIDR/README.md +++ b/reconstruction/PBIDR/README.md @@ -1 +1,73 @@ -OSFIR +# Facial Geometric Detail Recovery via Implicit Representation + +:herb: **Facial Geometric Detail Recovery via Implicit Representation]** + +Xingyu Ren, Alexandros Lattas, Baris Gecer, Jiankang Deng, Chao Ma, Xiaokang Yang, and Stefanos Zafeiriou. + +*arXiv Preprint 2022* + +## Introduction + +![overview](https://github.com/deepinsight/insightface/tree/master/reconstruction/PBIDR/figures/overview.png) + +This paper introduces a single facial image geometric detail recovery algorithm. The method generates complete high-fidelity texture maps from occluded facial images, and employs implicit renderer and shape functions, to derive fine geometric details by decoupled specular normals. As a bonus, it disentangles the facial texture into approximate diffuse albedo, diffuse and specular shading in a self-supervision manner. + +## Installation + +Please refer to the installation and usage of [IDR](https://github.com/lioryariv/idr). + +The code is compatible with python 3.7 and pytorch 1.7.1. In addition, the following packages are required: +numpy, mento, menpo3d, scikit-image, trimesh (with pyembree), opencv, torchvision, pytorch3d 0.4.0. + +You can create an anaconda environment by our requirements file: + +``` +conda create -n pbidr python=3.7 +pip install -r requirements.txt +``` + +## Tutorial + +### Data Preprocessing + + We have provided several textured meshes from [Google Drive](https://drive.google.com/file/d/1R7MdWawdMSjQUOnciJ5mb1pcwoY61Tzc/view?usp=sharing) and [Baidu Drive](https://pan.baidu.com/s/16mAqB_7mlbW2--0__patWA) (password: wp47). Otherwise, please refer to [OSTeC](https://github.com/barisgecer/OSTeC) to make a textured mesh firstly. + +Please download raw textured meshes and run: + + ```shell +cd ./code +bash script/data_process.sh + ``` + + You can synthesize the auxiliary image sets for the next implicit details recovery. + +### Train & Eval + +You can start the training phase with the following script. + + ```shell +cd ./code +bash script/fast_train.sh + ``` + + We also provide a script for eval: + + ```shell +cd ./code +bash script/fast_eval.sh + ``` + +## Citation + + If any parts of our paper and codes are helpful to your work, please generously citing: + + ``` + + ``` + +## Reference + + We refer to the following repositories when implementing our whole pipeline. Thanks for their great work. + + - [barisgecer/OSTeC](https://github.com/barisgecer/OSTeC) + - [lioryariv/idr](https://github.com/lioryariv/idr) diff --git a/reconstruction/PBIDR/code/confs/test.conf b/reconstruction/PBIDR/code/confs/test.conf new file mode 100644 index 0000000..8851fb4 --- /dev/null +++ b/reconstruction/PBIDR/code/confs/test.conf @@ -0,0 +1,72 @@ +train{ + expname = test + dataset_class = datasets.dataset.IFDataset + model_class = model.renderer.IFNetwork + loss_class = model.loss.IFLoss + learning_rate = 1.0e-4 + num_pixels = 2048 + plot_freq = 100 + alpha_milestones = [250, 500, 750, 1000, 1250] + alpha_factor = 2 + sched_milestones = [1000,1500] + sched_factor = 0.5 +} +plot{ + plot_nimgs = 1 + max_depth = 3.0 + resolution = 100 +} +loss{ + eikonal_weight = 0.1 + mask_weight = 100.0 + reg_weight = 5.0 + normal_weight = 1.0 + alpha = 50.0 +} +dataset{ + data_dir = Test + img_res = [1024, 1024] + scan_id = 0 +} +model{ + feature_vector_size = 256 + implicit_network + { + d_in = 3 + d_out = 1 + dims = [512, 512, 512, 512, 512, 512, 512, 512] + geometric_init = True + bias = 0.6 + skip_in = [4] + weight_norm = True + multires = 6 + } + diffuse_network + { + dims = [128] + weight_norm = True + multires_view = 6 + } + specular_network + { + dims = [128] + weight_norm = True + multires_view = 4 + } + albedo_network + { + dims = [256, 256, 256, 256] + weight_norm = True + multires_view = 4 + } + ray_tracer + { + object_bounding_sphere = 1.0 + sdf_threshold = 5.0e-5 + line_search_step = 0.5 + line_step_iters = 3 + sphere_tracing_iters = 10 + n_steps = 100 + n_secant_steps = 8 + } +} \ No newline at end of file diff --git a/reconstruction/PBIDR/code/datasets/dataset.py b/reconstruction/PBIDR/code/datasets/dataset.py new file mode 100644 index 0000000..ae12367 --- /dev/null +++ b/reconstruction/PBIDR/code/datasets/dataset.py @@ -0,0 +1,154 @@ +import os +import torch +import numpy as np + +import utils.general as utils +from utils import rend_util + +class IFDataset(torch.utils.data.Dataset): + """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset.""" + + def __init__(self, + train_cameras, + data_dir, + img_res, + scan_id=0, + cam_file=None + ): + + self.instance_dir = os.path.join('../data', data_dir, 'scan{0}'.format(scan_id)) + + self.total_pixels = img_res[0] * img_res[1] + self.img_res = img_res + + assert os.path.exists(self.instance_dir), "Data directory is empty" + + self.sampling_idx = None + self.train_cameras = train_cameras + + image_dir = '{0}/image'.format(self.instance_dir) + image_paths = sorted(utils.glob_imgs(image_dir)) + mask_dir = '{0}/mask'.format(self.instance_dir) + mask_paths = sorted(utils.glob_imgs(mask_dir)) + + self.n_images = len(image_paths) + + self.cam_file = '{0}/cameras.npz'.format(self.instance_dir) + if cam_file is not None: + self.cam_file = '{0}/{1}'.format(self.instance_dir, cam_file) + + camera_dict = np.load(self.cam_file) + scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] + world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] + + + self.intrinsics_all = [] + self.pose_all = [] + for scale_mat, world_mat in zip(scale_mats, world_mats): + P = world_mat @ scale_mat + P = P[:3, :4] + intrinsics, pose = rend_util.load_K_Rt_from_P(None, P) + self.intrinsics_all.append(torch.from_numpy(intrinsics).float()) + self.pose_all.append(torch.from_numpy(pose).float()) + + self.rgb_images = [] + for path in image_paths: + rgb = rend_util.load_rgb(path) + rgb = rgb.reshape(3, -1).transpose(1, 0) + self.rgb_images.append(torch.from_numpy(rgb).float()) + + self.object_masks = [] + for path in mask_paths: + object_mask = rend_util.load_mask_white_bg(path) + object_mask = object_mask.reshape(-1) + self.object_masks.append(torch.from_numpy(object_mask).bool()) + + def __len__(self): + return self.n_images + + def __getitem__(self, idx): + uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32) + uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float() + uv = uv.reshape(2, -1).transpose(1, 0) + + sample = { + "object_mask": self.object_masks[idx], + "uv": uv, + "intrinsics": self.intrinsics_all[idx], + } + + ground_truth = { + "rgb": self.rgb_images[idx] + } + + if self.sampling_idx is not None: + ground_truth["rgb"] = self.rgb_images[idx][self.sampling_idx, :] + sample["object_mask"] = self.object_masks[idx][self.sampling_idx] + sample["uv"] = uv[self.sampling_idx, :] + + if not self.train_cameras: + sample["pose"] = self.pose_all[idx] + + return idx, sample, ground_truth + + def collate_fn(self, batch_list): + # get list of dictionaries and returns input, ground_true as dictionary for all batch instances + batch_list = zip(*batch_list) + + all_parsed = [] + for entry in batch_list: + if type(entry[0]) is dict: + # make them all into a new dict + ret = {} + for k in entry[0].keys(): + ret[k] = torch.stack([obj[k] for obj in entry]) + all_parsed.append(ret) + else: + all_parsed.append(torch.LongTensor(entry)) + + return tuple(all_parsed) + + def change_sampling_idx(self, sampling_size): + if sampling_size == -1: + self.sampling_idx = None + else: + self.sampling_idx = torch.randperm(self.total_pixels)[:sampling_size] + + def get_scale_mat(self): + return np.load(self.cam_file)['scale_mat_0'] + + def get_gt_pose(self, scaled=False): + # Load gt pose without normalization to unit sphere + camera_dict = np.load(self.cam_file) + world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] + scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] + + pose_all = [] + for scale_mat, world_mat in zip(scale_mats, world_mats): + P = world_mat + if scaled: + P = world_mat @ scale_mat + P = P[:3, :4] + _, pose = rend_util.load_K_Rt_from_P(None, P) + pose_all.append(torch.from_numpy(pose).float()) + + return torch.cat([p.float().unsqueeze(0) for p in pose_all], 0) + + def get_pose_init(self): + # get noisy initializations obtained with the linear method + cam_file = '{0}/cameras_linear_init.npz'.format(self.instance_dir) + camera_dict = np.load(cam_file) + scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] + world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] + + init_pose = [] + for scale_mat, world_mat in zip(scale_mats, world_mats): + P = world_mat @ scale_mat + P = P[:3, :4] + _, pose = rend_util.load_K_Rt_from_P(None, P) + init_pose.append(pose) + init_pose = torch.cat([torch.Tensor(pose).float().unsqueeze(0) for pose in init_pose], 0).cuda() + init_quat = rend_util.rot_to_quat(init_pose[:, :3, :3]) + init_quat = torch.cat([init_quat, init_pose[:, :3, 3]], 1) + + return init_quat diff --git a/reconstruction/PBIDR/code/evaluation/eval.py b/reconstruction/PBIDR/code/evaluation/eval.py new file mode 100644 index 0000000..efb1b6a --- /dev/null +++ b/reconstruction/PBIDR/code/evaluation/eval.py @@ -0,0 +1,211 @@ +import sys +sys.path.append('../code') +import argparse +import GPUtil +import os +from pyhocon import ConfigFactory +import torch +import numpy as np +import cvxpy as cp +from PIL import Image +import math + +import utils.general as utils +import utils.plots as plt +from utils import rend_util + +def evaluate(**kwargs): + torch.set_default_dtype(torch.float32) + + conf = ConfigFactory.parse_file(kwargs['conf']) + exps_folder_name = kwargs['exps_folder_name'] + evals_folder_name = kwargs['evals_folder_name'] + eval_rendering = kwargs['eval_rendering'] + eval_animation = kwargs['eval_animation'] + + expname = conf.get_string('train.expname') + kwargs['expname'] + scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else conf.get_int('dataset.scan_id', default=-1) + if scan_id != -1: + expname = expname + '_{0}'.format(scan_id) + + if kwargs['timestamp'] == 'latest': + if os.path.exists(os.path.join('../', kwargs['exps_folder_name'], expname)): + timestamps = os.listdir(os.path.join('../', kwargs['exps_folder_name'], expname)) + if (len(timestamps)) == 0: + print('WRONG EXP FOLDER') + exit() + else: + timestamp = sorted(timestamps)[-1] + else: + print('WRONG EXP FOLDER') + exit() + else: + timestamp = kwargs['timestamp'] + + utils.mkdir_ifnotexists(os.path.join('../', evals_folder_name)) + expdir = os.path.join('../', exps_folder_name, expname) + evaldir = os.path.join('../', evals_folder_name, expname) + utils.mkdir_ifnotexists(evaldir) + + dataset_conf = conf.get_config('dataset') + model = utils.get_class(conf.get_string('train.model_class'))(conf=conf.get_config('model'),\ + id=scan_id, datadir=dataset_conf['data_dir']) + if torch.cuda.is_available(): + model.cuda() + + + if kwargs['scan_id'] != -1: + dataset_conf['scan_id'] = kwargs['scan_id'] + eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))(False, **dataset_conf) + + if eval_rendering: + eval_dataloader = torch.utils.data.DataLoader(eval_dataset, + batch_size=1, + shuffle=False, + collate_fn=eval_dataset.collate_fn + ) + total_pixels = eval_dataset.total_pixels + img_res = eval_dataset.img_res + + old_checkpnts_dir = os.path.join(expdir, timestamp, 'checkpoints') + + saved_model_state = torch.load(os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth")) + model.load_state_dict(saved_model_state["model_state_dict"]) + epoch = saved_model_state['epoch'] + + #################################################################################################################### + print("evaluating...") + + model.eval() + + detail_3dmm, detail_3dmm_subdivision_full = plt.get_displacement_mesh(model) + detail_3dmm.export('{0}/Detailed_3dmm_{1}.obj'.format(evaldir, epoch), 'obj') + detail_3dmm_subdivision_full.export('{0}/Subdivide_full_{1}.obj'.format(evaldir, epoch), 'obj') + + if eval_animation: + sdf_np0, sdf_np1 = plt.get_displacement_animation(model) + np.save('{0}/Cropped_Detailed_sdf_{1}.npy'.format(evaldir, epoch), sdf_np0) + np.save('{0}/Cropped_Subdivide_full_{1}.npy'.format(evaldir, epoch), sdf_np1) + + if eval_rendering: + images_dir = '{0}/rendering'.format(evaldir) + utils.mkdir_ifnotexists(images_dir) + + psnrs = [] + for data_index, (indices, model_input, ground_truth) in enumerate(eval_dataloader): + model_input["intrinsics"] = model_input["intrinsics"].cuda() + model_input["uv"] = model_input["uv"].cuda() + model_input["object_mask"] = model_input["object_mask"].cuda() + model_input['pose'] = model_input['pose'].cuda() + + split = utils.split_input(model_input, total_pixels) + res = [] + for s in split: + out = model(s) + res.append({ + 'rgb_values': out['rgb_values'].detach(), + 'diffuse_values': out['diffuse_values'].detach(), + 'specular_values': out['specular_values'].detach(), + 'albedo_values': out['albedo_values'].detach(), + }) + + batch_size = ground_truth['rgb'].shape[0] + model_outputs = utils.merge_output(res, total_pixels, batch_size) + rgb_eval = model_outputs['rgb_values'] + rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3) + rgb_eval = (rgb_eval + 1.) / 2. + rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0] + rgb_eval = rgb_eval.transpose(1, 2, 0) + img = Image.fromarray((rgb_eval * 255).astype(np.uint8)) + img.save('{0}/eval_{1}.png'.format(images_dir,'%03d' % indices[0])) + + diffuse_eval = model_outputs['diffuse_values'] + diffuse_eval = diffuse_eval.reshape(batch_size, total_pixels, 3) + diffuse_eval = (diffuse_eval + 1.) / 2. + diffuse_eval = plt.lin2img(diffuse_eval, img_res).detach().cpu().numpy()[0] + diffuse_eval = diffuse_eval.transpose(1, 2, 0) + img = Image.fromarray((diffuse_eval * 255).astype(np.uint8)) + img.save('{0}/eval_{1}_diffuse.png'.format(images_dir, '%03d' % indices[0])) + + specular_eval = model_outputs['specular_values'] + specular_eval = specular_eval.reshape(batch_size, total_pixels, 3) + specular_eval = (specular_eval + 1.) / 2. + specular_eval = plt.lin2img(specular_eval, img_res).detach().cpu().numpy()[0] + specular_eval = specular_eval.transpose(1, 2, 0) + img = Image.fromarray((specular_eval * 255).astype(np.uint8)) + img.save('{0}/eval_{1}_specular.png'.format(images_dir, '%03d' % indices[0])) + + albedo_eval = model_outputs['albedo_values'] + albedo_eval = albedo_eval.reshape(batch_size, total_pixels, 3) + albedo_eval = (albedo_eval + 1.) / 2. + albedo_eval = plt.lin2img(albedo_eval, img_res).detach().cpu().numpy()[0] + albedo_eval = albedo_eval.transpose(1, 2, 0) + img = Image.fromarray((albedo_eval * 255).astype(np.uint8)) + img.save('{0}/eval_{1}_albedo.png'.format(images_dir, '%03d' % indices[0])) + + rgb_gt = ground_truth['rgb'] + rgb_gt = (rgb_gt + 1.) / 2. + rgb_gt = plt.lin2img(rgb_gt, img_res).numpy()[0] + rgb_gt = rgb_gt.transpose(1, 2, 0) + + mask = model_input['object_mask'] + mask = plt.lin2img(mask.unsqueeze(-1), img_res).cpu().numpy()[0] + mask = mask.transpose(1, 2, 0) + + rgb_eval_masked = rgb_eval * mask + rgb_gt_masked = rgb_gt * mask + + psnr = calculate_psnr(rgb_eval_masked, rgb_gt_masked, mask) + psnrs.append(psnr) + + psnrs = np.array(psnrs).astype(np.float64) + print("RENDERING EVALUATION {2}: psnr mean = {0} ; psnr std = {1}".format("%.2f" % psnrs.mean(), "%.2f" % psnrs.std(), scan_id)) + + + +def calculate_psnr(img1, img2, mask): + # img1 and img2 have range [0, 1] + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) * (img2.shape[0] * img2.shape[1]) / mask.sum() + if mse == 0: + return float('inf') + return 20 * math.log10(1.0 / math.sqrt(mse)) + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--conf', type=str, default='./confs/test.conf') + parser.add_argument('--expname', type=str, default='', help='The experiment name to be evaluated.') + parser.add_argument('--exps_folder', type=str, default='exps', help='The experiments folder name.') + parser.add_argument('--gpu', type=str, default='auto', help='GPU to use [default: GPU auto]') + parser.add_argument('--timestamp', default='latest', type=str, help='The experiemnt timestamp to test.') + parser.add_argument('--checkpoint', default='latest',type=str,help='The trained model checkpoint to test') + parser.add_argument('--scan_id', type=int, default=0, help='If set, taken to be the scan id.') + parser.add_argument('--resolution', default=512, type=int, help='Grid resolution for marching cube') + parser.add_argument('--is_uniform_grid', default=False, action="store_true", help='If set, evaluate marching cube with uniform grid.') + parser.add_argument('--eval_rendering', default=False, action="store_true",help='If set, evaluate rendering quality.') + parser.add_argument('--eval_animation', default=False, action="store_true",help='If set, evaluate rendering quality.') + + opt = parser.parse_args() + + if opt.gpu == "auto": + deviceIDs = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False, excludeID=[], excludeUUID=[]) + gpu = deviceIDs[0] + else: + gpu = opt.gpu + + if (not gpu == 'ignore'): + os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(gpu) + + evaluate(conf=opt.conf, + expname=opt.expname, + exps_folder_name=opt.exps_folder, + evals_folder_name='evals', + timestamp=opt.timestamp, + checkpoint=opt.checkpoint, + scan_id=opt.scan_id, + resolution=opt.resolution, + eval_rendering=opt.eval_rendering, + eval_animation=opt.eval_animation + ) diff --git a/reconstruction/PBIDR/code/model/embedder.py b/reconstruction/PBIDR/code/model/embedder.py new file mode 100644 index 0000000..02f2c08 --- /dev/null +++ b/reconstruction/PBIDR/code/model/embedder.py @@ -0,0 +1,50 @@ +import torch + +""" Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """ + +class Embedder: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_embedding_fn() + + def create_embedding_fn(self): + embed_fns = [] + d = self.kwargs['input_dims'] + out_dim = 0 + if self.kwargs['include_input']: + embed_fns.append(lambda x: x) + out_dim += d + + max_freq = self.kwargs['max_freq_log2'] + N_freqs = self.kwargs['num_freqs'] + + if self.kwargs['log_sampling']: + freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) + else: + freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) + + for freq in freq_bands: + for p_fn in self.kwargs['periodic_fns']: + embed_fns.append(lambda x, p_fn=p_fn, + freq=freq: p_fn(x * freq)) + out_dim += d + + self.embed_fns = embed_fns + self.out_dim = out_dim + + def embed(self, inputs): + return torch.cat([fn(inputs) for fn in self.embed_fns], -1) + +def get_embedder(multires): + embed_kwargs = { + 'include_input': True, + 'input_dims': 3, + 'max_freq_log2': multires-1, + 'num_freqs': multires, + 'log_sampling': True, + 'periodic_fns': [torch.sin, torch.cos], + } + + embedder_obj = Embedder(**embed_kwargs) + def embed(x, eo=embedder_obj): return eo.embed(x) + return embed, embedder_obj.out_dim diff --git a/reconstruction/PBIDR/code/model/loss.py b/reconstruction/PBIDR/code/model/loss.py new file mode 100644 index 0000000..9843ad3 --- /dev/null +++ b/reconstruction/PBIDR/code/model/loss.py @@ -0,0 +1,69 @@ +import torch +from torch import nn +from torch.nn import functional as F + +class IFLoss(nn.Module): + def __init__(self, eikonal_weight, mask_weight, reg_weight, normal_weight, alpha): + super().__init__() + self.eikonal_weight = eikonal_weight + self.mask_weight = mask_weight + self.reg_weight = reg_weight + self.normal_weight = normal_weight + self.alpha = alpha + self.l1_loss = nn.L1Loss(reduction='sum') + self.l2_loss = nn.MSELoss(reduction='sum') + self.cosine = nn.CosineSimilarity() + + def get_rgb_loss(self,rgb_values, rgb_gt, network_object_mask, object_mask): + if (network_object_mask & object_mask).sum() == 0: + return torch.tensor(0.0).cuda().float() + + rgb_values = rgb_values[network_object_mask & object_mask] + rgb_gt = rgb_gt.reshape(-1, 3)[network_object_mask & object_mask] + rgb_loss = self.l1_loss(rgb_values, rgb_gt) / float(object_mask.shape[0]) + return rgb_loss + + def get_eikonal_loss(self, grad_theta): + if grad_theta.shape[0] == 0: + return torch.tensor(0.0).cuda().float() + + eikonal_loss = ((grad_theta.norm(2, dim=1) - 1) ** 2).mean() + return eikonal_loss + + def get_mask_loss(self, sdf_output, network_object_mask, object_mask): + mask = ~(network_object_mask & object_mask) + if mask.sum() == 0: + return torch.tensor(0.0).cuda().float() + sdf_pred = -self.alpha * sdf_output[mask] + gt = object_mask[mask].float() + mask_loss = (1 / self.alpha) * F.binary_cross_entropy_with_logits(sdf_pred.squeeze(), gt, reduction='sum') / float(object_mask.shape[0]) + return mask_loss + + def get_reg_loss(self, point_gt, point_pre): + loss = self.l2_loss(point_gt, point_pre) / len(point_pre) + return loss + + def forward(self, model_outputs, ground_truth): + rgb_gt = ground_truth['rgb'].cuda() + network_object_mask = model_outputs['network_object_mask'] + object_mask = model_outputs['object_mask'] + + rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'], rgb_gt, network_object_mask, object_mask) + mask_loss = self.get_mask_loss(model_outputs['sdf_output'], network_object_mask, object_mask) + eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta']) + reg_loss = self.get_reg_loss(model_outputs['points_mesh_ray_gt'], model_outputs['points_pre']) + normal_loss = 1 - torch.mean(self.cosine(model_outputs['points_mesh_ray_normals'], model_outputs['surface_normals'])) + loss = rgb_loss + \ + self.eikonal_weight * eikonal_loss + \ + self.mask_weight * mask_loss + \ + self.reg_weight * reg_loss + \ + self.normal_weight * normal_loss + + return { + 'loss': loss, + 'rgb_loss': rgb_loss, + 'eikonal_loss': eikonal_loss, + 'mask_loss': mask_loss, + 'reg_loss': reg_loss, + 'normal_loss': normal_loss, + } diff --git a/reconstruction/PBIDR/code/model/ray_tracing.py b/reconstruction/PBIDR/code/model/ray_tracing.py new file mode 100644 index 0000000..2e9ea83 --- /dev/null +++ b/reconstruction/PBIDR/code/model/ray_tracing.py @@ -0,0 +1,301 @@ +import torch +import torch.nn as nn +from utils import rend_util + +class RayTracing(nn.Module): + def __init__( + self, + object_bounding_sphere=1.0, + sdf_threshold=5.0e-5, + line_search_step=0.5, + line_step_iters=1, + sphere_tracing_iters=10, + n_steps=100, + n_secant_steps=8, + ): + super().__init__() + + self.object_bounding_sphere = object_bounding_sphere + self.sdf_threshold = sdf_threshold + self.sphere_tracing_iters = sphere_tracing_iters + self.line_step_iters = line_step_iters + self.line_search_step = line_search_step + self.n_steps = n_steps + self.n_secant_steps = n_secant_steps + + def forward(self, + sdf, + cam_loc, + object_mask, + ray_directions + ): + + batch_size, num_pixels, _ = ray_directions.shape + + sphere_intersections, mask_intersect = rend_util.get_sphere_intersection(cam_loc, ray_directions, r=self.object_bounding_sphere) + + curr_start_points, unfinished_mask_start, acc_start_dis, acc_end_dis, min_dis, max_dis = \ + self.sphere_tracing(batch_size, num_pixels, sdf, cam_loc, ray_directions, mask_intersect, sphere_intersections) + + network_object_mask = (acc_start_dis < acc_end_dis) + + # The non convergent rays should be handled by the sampler + sampler_mask = unfinished_mask_start + sampler_net_obj_mask = torch.zeros_like(sampler_mask).bool().cuda() + if sampler_mask.sum() > 0: + sampler_min_max = torch.zeros((batch_size, num_pixels, 2)).cuda() + sampler_min_max.reshape(-1, 2)[sampler_mask, 0] = acc_start_dis[sampler_mask] + sampler_min_max.reshape(-1, 2)[sampler_mask, 1] = acc_end_dis[sampler_mask] + + sampler_pts, sampler_net_obj_mask, sampler_dists = self.ray_sampler(sdf, + cam_loc, + object_mask, + ray_directions, + sampler_min_max, + sampler_mask + ) + + curr_start_points[sampler_mask] = sampler_pts[sampler_mask] + acc_start_dis[sampler_mask] = sampler_dists[sampler_mask] + network_object_mask[sampler_mask] = sampler_net_obj_mask[sampler_mask] + + print('----------------------------------------------------------------') + print('RayTracing: object = {0}/{1}, secant on {2}/{3}.' + .format(network_object_mask.sum(), len(network_object_mask), sampler_net_obj_mask.sum(), sampler_mask.sum())) + print('----------------------------------------------------------------') + + if not self.training: + return curr_start_points, \ + network_object_mask, \ + acc_start_dis + + ray_directions = ray_directions.reshape(-1, 3) + mask_intersect = mask_intersect.reshape(-1) + + in_mask = ~network_object_mask & object_mask & ~sampler_mask + out_mask = ~object_mask & ~sampler_mask + + mask_left_out = (in_mask | out_mask) & ~mask_intersect + if mask_left_out.sum() > 0: # project the origin to the not intersect points on the sphere + cam_left_out = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)[mask_left_out] + rays_left_out = ray_directions[mask_left_out] + acc_start_dis[mask_left_out] = -torch.bmm(rays_left_out.view(-1, 1, 3), cam_left_out.view(-1, 3, 1)).squeeze() + curr_start_points[mask_left_out] = cam_left_out + acc_start_dis[mask_left_out].unsqueeze(1) * rays_left_out + + mask = (in_mask | out_mask) & mask_intersect + + if mask.sum() > 0: + min_dis[network_object_mask & out_mask] = acc_start_dis[network_object_mask & out_mask] + + min_mask_points, min_mask_dist = self.minimal_sdf_points(num_pixels, sdf, cam_loc, ray_directions, mask, min_dis, max_dis) + + curr_start_points[mask] = min_mask_points + acc_start_dis[mask] = min_mask_dist + + return curr_start_points, \ + network_object_mask, \ + acc_start_dis + + + def sphere_tracing(self, batch_size, num_pixels, sdf, cam_loc, ray_directions, mask_intersect, sphere_intersections): + ''' Run sphere tracing algorithm for max iterations from both sides of unit sphere intersection ''' + + sphere_intersections_points = cam_loc.reshape(batch_size, 1, 1, 3) + sphere_intersections.unsqueeze(-1) * ray_directions.unsqueeze(2) + unfinished_mask_start = mask_intersect.reshape(-1).clone() + unfinished_mask_end = mask_intersect.reshape(-1).clone() + + # Initialize start current points + curr_start_points = torch.zeros(batch_size * num_pixels, 3).cuda().float() + curr_start_points[unfinished_mask_start] = sphere_intersections_points[:,:,0,:].reshape(-1,3)[unfinished_mask_start] + acc_start_dis = torch.zeros(batch_size * num_pixels).cuda().float() + acc_start_dis[unfinished_mask_start] = sphere_intersections.reshape(-1,2)[unfinished_mask_start,0] + + # Initialize end current points + curr_end_points = torch.zeros(batch_size * num_pixels, 3).cuda().float() + curr_end_points[unfinished_mask_end] = sphere_intersections_points[:,:,1,:].reshape(-1,3)[unfinished_mask_end] + acc_end_dis = torch.zeros(batch_size * num_pixels).cuda().float() + acc_end_dis[unfinished_mask_end] = sphere_intersections.reshape(-1,2)[unfinished_mask_end,1] + + # Initizliae min and max depth + min_dis = acc_start_dis.clone() + max_dis = acc_end_dis.clone() + + # Iterate on the rays (from both sides) till finding a surface + iters = 0 + + next_sdf_start = torch.zeros_like(acc_start_dis).cuda() + next_sdf_start[unfinished_mask_start] = sdf(curr_start_points[unfinished_mask_start]) + + next_sdf_end = torch.zeros_like(acc_end_dis).cuda() + next_sdf_end[unfinished_mask_end] = sdf(curr_end_points[unfinished_mask_end]) + + while True: + # Update sdf + curr_sdf_start = torch.zeros_like(acc_start_dis).cuda() + curr_sdf_start[unfinished_mask_start] = next_sdf_start[unfinished_mask_start] + curr_sdf_start[curr_sdf_start <= self.sdf_threshold] = 0 + + curr_sdf_end = torch.zeros_like(acc_end_dis).cuda() + curr_sdf_end[unfinished_mask_end] = next_sdf_end[unfinished_mask_end] + curr_sdf_end[curr_sdf_end <= self.sdf_threshold] = 0 + + # Update masks + unfinished_mask_start = unfinished_mask_start & (curr_sdf_start > self.sdf_threshold) + unfinished_mask_end = unfinished_mask_end & (curr_sdf_end > self.sdf_threshold) + + if (unfinished_mask_start.sum() == 0 and unfinished_mask_end.sum() == 0) or iters == self.sphere_tracing_iters: + break + iters += 1 + + # Make step + # Update distance + acc_start_dis = acc_start_dis + curr_sdf_start + acc_end_dis = acc_end_dis - curr_sdf_end + + # Update points + curr_start_points = (cam_loc.unsqueeze(1) + acc_start_dis.reshape(batch_size, num_pixels, 1) * ray_directions).reshape(-1, 3) + curr_end_points = (cam_loc.unsqueeze(1) + acc_end_dis.reshape(batch_size, num_pixels, 1) * ray_directions).reshape(-1, 3) + + # Fix points which wrongly crossed the surface + next_sdf_start = torch.zeros_like(acc_start_dis).cuda() + next_sdf_start[unfinished_mask_start] = sdf(curr_start_points[unfinished_mask_start]) + + next_sdf_end = torch.zeros_like(acc_end_dis).cuda() + next_sdf_end[unfinished_mask_end] = sdf(curr_end_points[unfinished_mask_end]) + + not_projected_start = next_sdf_start < 0 + not_projected_end = next_sdf_end < 0 + not_proj_iters = 0 + while (not_projected_start.sum() > 0 or not_projected_end.sum() > 0) and not_proj_iters < self.line_step_iters: + # Step backwards + acc_start_dis[not_projected_start] -= ((1 - self.line_search_step) / (2 ** not_proj_iters)) * curr_sdf_start[not_projected_start] + curr_start_points[not_projected_start] = (cam_loc.unsqueeze(1) + acc_start_dis.reshape(batch_size, num_pixels, 1) * ray_directions).reshape(-1, 3)[not_projected_start] + + acc_end_dis[not_projected_end] += ((1 - self.line_search_step) / (2 ** not_proj_iters)) * curr_sdf_end[not_projected_end] + curr_end_points[not_projected_end] = (cam_loc.unsqueeze(1) + acc_end_dis.reshape(batch_size, num_pixels, 1) * ray_directions).reshape(-1, 3)[not_projected_end] + + # Calc sdf + next_sdf_start[not_projected_start] = sdf(curr_start_points[not_projected_start]) + next_sdf_end[not_projected_end] = sdf(curr_end_points[not_projected_end]) + + # Update mask + not_projected_start = next_sdf_start < 0 + not_projected_end = next_sdf_end < 0 + not_proj_iters += 1 + + unfinished_mask_start = unfinished_mask_start & (acc_start_dis < acc_end_dis) + unfinished_mask_end = unfinished_mask_end & (acc_start_dis < acc_end_dis) + + return curr_start_points, unfinished_mask_start, acc_start_dis, acc_end_dis, min_dis, max_dis + + def ray_sampler(self, sdf, cam_loc, object_mask, ray_directions, sampler_min_max, sampler_mask): + ''' Sample the ray in a given range and run secant on rays which have sign transition ''' + + batch_size, num_pixels, _ = ray_directions.shape + n_total_pxl = batch_size * num_pixels + sampler_pts = torch.zeros(n_total_pxl, 3).cuda().float() + sampler_dists = torch.zeros(n_total_pxl).cuda().float() + + intervals_dist = torch.linspace(0, 1, steps=self.n_steps).cuda().view(1, 1, -1) + + pts_intervals = sampler_min_max[:, :, 0].unsqueeze(-1) + intervals_dist * (sampler_min_max[:, :, 1] - sampler_min_max[:, :, 0]).unsqueeze(-1) + points = cam_loc.reshape(batch_size, 1, 1, 3) + pts_intervals.unsqueeze(-1) * ray_directions.unsqueeze(2) + + # Get the non convergent rays + mask_intersect_idx = torch.nonzero(sampler_mask).flatten() + points = points.reshape((-1, self.n_steps, 3))[sampler_mask, :, :] + pts_intervals = pts_intervals.reshape((-1, self.n_steps))[sampler_mask] + + sdf_val_all = [] + for pnts in torch.split(points.reshape(-1, 3), 100000, dim=0): + sdf_val_all.append(sdf(pnts)) + sdf_val = torch.cat(sdf_val_all).reshape(-1, self.n_steps) + + tmp = torch.sign(sdf_val) * torch.arange(self.n_steps, 0, -1).cuda().float().reshape((1, self.n_steps)) # Force argmin to return the first min value + sampler_pts_ind = torch.argmin(tmp, -1) + sampler_pts[mask_intersect_idx] = points[torch.arange(points.shape[0]), sampler_pts_ind, :] + sampler_dists[mask_intersect_idx] = pts_intervals[torch.arange(pts_intervals.shape[0]), sampler_pts_ind] + + true_surface_pts = object_mask[sampler_mask] + net_surface_pts = (sdf_val[torch.arange(sdf_val.shape[0]), sampler_pts_ind] < 0) + + # take points with minimal SDF value for P_out pixels + p_out_mask = ~(true_surface_pts & net_surface_pts) + n_p_out = p_out_mask.sum() + if n_p_out > 0: + out_pts_idx = torch.argmin(sdf_val[p_out_mask, :], -1) + sampler_pts[mask_intersect_idx[p_out_mask]] = points[p_out_mask, :, :][torch.arange(n_p_out), out_pts_idx, :] + sampler_dists[mask_intersect_idx[p_out_mask]] = pts_intervals[p_out_mask, :][torch.arange(n_p_out), out_pts_idx] + + # Get Network object mask + sampler_net_obj_mask = sampler_mask.clone() + sampler_net_obj_mask[mask_intersect_idx[~net_surface_pts]] = False + + # Run Secant method + secant_pts = net_surface_pts & true_surface_pts if self.training else net_surface_pts + n_secant_pts = secant_pts.sum() + if n_secant_pts > 0: + # Get secant z predictions + z_high = pts_intervals[torch.arange(pts_intervals.shape[0]), sampler_pts_ind][secant_pts] + sdf_high = sdf_val[torch.arange(sdf_val.shape[0]), sampler_pts_ind][secant_pts] + z_low = pts_intervals[secant_pts][torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1] + sdf_low = sdf_val[secant_pts][torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1] + cam_loc_secant = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape((-1, 3))[mask_intersect_idx[secant_pts]] + ray_directions_secant = ray_directions.reshape((-1, 3))[mask_intersect_idx[secant_pts]] + z_pred_secant = self.secant(sdf_low, sdf_high, z_low, z_high, cam_loc_secant, ray_directions_secant, sdf) + + # Get points + sampler_pts[mask_intersect_idx[secant_pts]] = cam_loc_secant + z_pred_secant.unsqueeze(-1) * ray_directions_secant + sampler_dists[mask_intersect_idx[secant_pts]] = z_pred_secant + + return sampler_pts, sampler_net_obj_mask, sampler_dists + + def secant(self, sdf_low, sdf_high, z_low, z_high, cam_loc, ray_directions, sdf): + ''' Runs the secant method for interval [z_low, z_high] for n_secant_steps ''' + + z_pred = - sdf_low * (z_high - z_low) / (sdf_high - sdf_low) + z_low + for i in range(self.n_secant_steps): + p_mid = cam_loc + z_pred.unsqueeze(-1) * ray_directions + sdf_mid = sdf(p_mid) + ind_low = sdf_mid > 0 + if ind_low.sum() > 0: + z_low[ind_low] = z_pred[ind_low] + sdf_low[ind_low] = sdf_mid[ind_low] + ind_high = sdf_mid < 0 + if ind_high.sum() > 0: + z_high[ind_high] = z_pred[ind_high] + sdf_high[ind_high] = sdf_mid[ind_high] + + z_pred = - sdf_low * (z_high - z_low) / (sdf_high - sdf_low) + z_low + + return z_pred + + def minimal_sdf_points(self, num_pixels, sdf, cam_loc, ray_directions, mask, min_dis, max_dis): + ''' Find points with minimal SDF value on rays for P_out pixels ''' + + n_mask_points = mask.sum() + + n = self.n_steps + # steps = torch.linspace(0.0, 1.0,n).cuda() + steps = torch.empty(n).uniform_(0.0, 1.0).cuda() + mask_max_dis = max_dis[mask].unsqueeze(-1) + mask_min_dis = min_dis[mask].unsqueeze(-1) + steps = steps.unsqueeze(0).repeat(n_mask_points, 1) * (mask_max_dis - mask_min_dis) + mask_min_dis + + mask_points = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)[mask] + mask_rays = ray_directions[mask, :] + + mask_points_all = mask_points.unsqueeze(1).repeat(1, n, 1) + steps.unsqueeze(-1) * mask_rays.unsqueeze( + 1).repeat(1, n, 1) + points = mask_points_all.reshape(-1, 3) + + mask_sdf_all = [] + for pnts in torch.split(points, 100000, dim=0): + mask_sdf_all.append(sdf(pnts)) + + mask_sdf_all = torch.cat(mask_sdf_all).reshape(-1, n) + min_vals, min_idx = mask_sdf_all.min(-1) + min_mask_points = mask_points_all.reshape(-1, n, 3)[torch.arange(0, n_mask_points), min_idx] + min_mask_dist = steps.reshape(-1, n)[torch.arange(0, n_mask_points), min_idx] + + return min_mask_points, min_mask_dist diff --git a/reconstruction/PBIDR/code/model/renderer.py b/reconstruction/PBIDR/code/model/renderer.py new file mode 100644 index 0000000..7081698 --- /dev/null +++ b/reconstruction/PBIDR/code/model/renderer.py @@ -0,0 +1,461 @@ +import torch +import torch.nn as nn +import numpy as np +import trimesh +import os + +from utils import rend_util +from model.embedder import * +from model.ray_tracing import RayTracing +from model.sample_network import SampleNetwork + + +def barycentric_coordinates(p, select_vertices): + + a = select_vertices[:, 0, :] + b = select_vertices[:, 1, :] + c = select_vertices[:, 2, :] + # p = point + + v0 = b - a + v1 = c - a + v2 = p - a + d00 = (v0 * v0).sum(axis=1) + d01 = (v0 * v1).sum(axis=1) + d11 = (v1 * v1).sum(axis=1) + d20 = (v2 * v0).sum(axis=1) + d21 = (v2 * v1).sum(axis=1) + denom = d00 * d11 - d01 * d01 + v = (d11 * d20 - d01 * d21) / denom + w = (d00 * d21 - d01 * d20) / denom + u = 1 - v - w + + return np.vstack([u, v, w]).T + +class ImplicitNetwork(nn.Module): + def __init__( + self, + feature_vector_size, + d_in, + d_out, + dims, + geometric_init=True, + bias=1.0, + skip_in=(), + weight_norm=True, + multires=0 + ): + super().__init__() + + dims = [d_in] + dims + [d_out + feature_vector_size] + + self.embed_fn = None + if multires > 0: + embed_fn, input_ch = get_embedder(multires) + self.embed_fn = embed_fn + dims[0] = input_ch + + self.num_layers = len(dims) + self.skip_in = skip_in + + for l in range(0, self.num_layers - 1): + if l + 1 in self.skip_in: + out_dim = dims[l + 1] - dims[0] + else: + out_dim = dims[l + 1] + + lin = nn.Linear(dims[l], out_dim) + + if geometric_init: + if l == self.num_layers - 2: + torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) + torch.nn.init.constant_(lin.bias, -bias) + elif multires > 0 and l == 0: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.constant_(lin.weight[:, 3:], 0.0) + torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) + elif multires > 0 and l in self.skip_in: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) + torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) + else: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) + + if weight_norm: + lin = nn.utils.weight_norm(lin) + + setattr(self, "lin" + str(l), lin) + + self.softplus = nn.Softplus(beta=100) + + def forward(self, input, compute_grad=False): + if self.embed_fn is not None: + input = self.embed_fn(input) + + x = input + + for l in range(0, self.num_layers - 1): + lin = getattr(self, "lin" + str(l)) + + if l in self.skip_in: + x = torch.cat([x, input], 1) / np.sqrt(2) + + x = lin(x) + + if l < self.num_layers - 2: + x = self.softplus(x) + + return x + + def gradient(self, x): + x.requires_grad_(True) + y = self.forward(x)[:,:1] + d_output = torch.ones_like(y, requires_grad=False, device=y.device) + gradients = torch.autograd.grad( + outputs=y, + inputs=x, + grad_outputs=d_output, + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + return gradients.unsqueeze(1) + +class AlbedoNetwork(nn.Module): + def __init__( + self, + feature_vector_size, + dims=[512, 512, 512, 512], + weight_norm=True, + multires_view=4, + ): + super().__init__() + + dims = [3 + feature_vector_size] + dims + [3] + embedview_fn, input_ch = get_embedder(multires_view) + self.embedview_fn = embedview_fn + dims[0] += (input_ch - 3) + self.num_layers = len(dims) + + for l in range(0, self.num_layers - 1): + out_dim = dims[l + 1] + lin = nn.Linear(dims[l], out_dim) + + if weight_norm: + lin = nn.utils.weight_norm(lin) + + setattr(self, "lin" + str(l), lin) + + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + + def forward(self, points, feature_vectors): + + Mpoints = self.embedview_fn(points) + x = torch.cat([Mpoints, feature_vectors], dim=-1) + + for l in range(0, self.num_layers - 1): + lin = getattr(self, "lin" + str(l)) + + x = lin(x) + + if l < self.num_layers - 2: + x = self.relu(x) + + x = self.tanh(x) + return x + +class SpecularNetwork(nn.Module): + def __init__( + self, + dims=[256, 256, 256], + weight_norm=True, + multires_view=4 + ): + super().__init__() + dims = [3 + 3] + dims + [1] + + embedview_fn, input_ch = get_embedder(multires_view) + self.embedview_fn = embedview_fn + dims[0] += (input_ch - 3) + dims[0] += (input_ch - 3) + self.num_layers = len(dims) + + for l in range(0, self.num_layers - 1): + out_dim = dims[l + 1] + lin = nn.Linear(dims[l], out_dim) + + if weight_norm: + lin = nn.utils.weight_norm(lin) + + setattr(self, "lin" + str(l), lin) + + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + + def forward(self, normals, view_dirs): + + Mview_dirs = self.embedview_fn(view_dirs) + Mnormals= self.embedview_fn(normals) + x = torch.cat([Mview_dirs, Mnormals], dim=-1) + + for l in range(0, self.num_layers - 1): + lin = getattr(self, "lin" + str(l)) + + x = lin(x) + + if l < self.num_layers - 2: + x = self.relu(x) + + x = self.tanh(x) + return x + def optimaize(self): + return + +class DiffuseNetwork(nn.Module): + def __init__( + self, + dims=[256, 256, 256], + weight_norm=True, + multires_view=6, + ): + super().__init__() + + dims = [3] + dims + [1] + embedview_fn, input_ch = get_embedder(multires_view) + self.embedview_fn = embedview_fn + dims[0] += (input_ch - 3) + self.num_layers = len(dims) + + for l in range(0, self.num_layers - 1): + out_dim = dims[l + 1] + lin = nn.Linear(dims[l], out_dim) + + if weight_norm: + lin = nn.utils.weight_norm(lin) + + setattr(self, "lin" + str(l), lin) + + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + + def forward(self, normals): + + Mnormals = self.embedview_fn(normals) + x = Mnormals + + for l in range(0, self.num_layers - 1): + lin = getattr(self, "lin" + str(l)) + + x = lin(x) + + if l < self.num_layers - 2: + x = self.relu(x) + + x = self.tanh(x) + return x + +class IFNetwork(nn.Module): + def __init__(self, conf, id, datadir): + super().__init__() + self.feature_vector_size = conf.get_int('feature_vector_size') + self.implicit_network = ImplicitNetwork(self.feature_vector_size, **conf.get_config('implicit_network')) + # self.rendering_network = RenderingNetwork(self.feature_vector_size, **conf.get_config('rendering_network')) + + self.diffuse_network = DiffuseNetwork(**conf.get_config('diffuse_network')) + self.specular_network = SpecularNetwork(**conf.get_config('specular_network')) + self.albedo_network = AlbedoNetwork(self.feature_vector_size, **conf.get_config('albedo_network')) + + self.ray_tracer = RayTracing(**conf.get_config('ray_tracer')) + self.sample_network = SampleNetwork() + self.object_bounding_sphere = conf.get_float('ray_tracer.object_bounding_sphere') + self.mesh = trimesh.load_mesh('{0}/mesh.obj'.format(os.path.join('../data', datadir, 'scan{0}'.format(id))), + process=False, use_embree=True) + self.faces = self.mesh.faces + self.vertex_normals = np.array(self.mesh.vertex_normals) + self.vertices = np.array(self.mesh.vertices) + print('Loaded Mesh') + + def forward(self, input): + + # Parse model input + points_predicted = None + + intrinsics = input["intrinsics"] + uv = input["uv"] + pose = input["pose"] + object_mask = input["object_mask"].reshape(-1) + + ray_dirs, cam_loc = rend_util.get_camera_params(uv, pose, intrinsics) + batch_size, num_pixels, _ = ray_dirs.shape + + self.implicit_network.eval() + with torch.no_grad(): + points, network_object_mask, dists = self.ray_tracer(sdf=lambda x: self.implicit_network(x)[:, 0], + cam_loc=cam_loc, + object_mask=object_mask, + ray_directions=ray_dirs) + self.implicit_network.train() + + points = (cam_loc.unsqueeze(1) + dists.reshape(batch_size, num_pixels, 1) * ray_dirs).reshape(-1, 3) + points_normal = self.implicit_network.gradient(points) + sdf_output = self.implicit_network(points)[:, 0:1] + ray_dirs = ray_dirs.reshape(-1, 3) + + ray_dirs_np = ray_dirs.cpu().numpy() + cam_loc_np = np.concatenate([cam_loc.cpu().numpy()] * len(ray_dirs_np), axis=0) + # points_mesh_ray: may have the more points than surface mask points, + # Need an Index for the Points_Mesh_Ray + points_mesh_ray, index_ray, index_tri = self.mesh.ray.intersects_location(ray_origins=cam_loc_np, + ray_directions=ray_dirs_np, + multiple_hits=False) + # Index ray: total 2048 / ~1200 + MeshRay_mask = torch.tensor([True if i in index_ray else False for i in range(len(cam_loc_np))], dtype=torch.bool).to(points.device) + network_object_mask = network_object_mask & MeshRay_mask + + if self.training: + + surface_mask = network_object_mask & object_mask + + listA = surface_mask.cpu().detach().numpy() + A = [int(a) for a in listA] + AA = [i for i, a in enumerate(A) if a == 1] # surface mask 的 index + MeshRay_Index = np.array([i for i, a in enumerate(index_ray) if a in AA], dtype=int) + + face_points_index = self.faces[index_tri][MeshRay_Index] + select_vertex_normals = self.vertex_normals[face_points_index] + select_vertices = self.vertices[face_points_index] + + points_mesh_ray = points_mesh_ray[MeshRay_Index] + bcoords = barycentric_coordinates(points_mesh_ray, select_vertices) + resampled_normals = np.sum(np.expand_dims(bcoords, -1) * select_vertex_normals, 1) + + # Mesh Pull + resampled_normals = torch.tensor(resampled_normals).to(points) + points_mesh_ray = torch.tensor(points_mesh_ray).to(points) + sdf_points_mesh_ray = self.implicit_network(points_mesh_ray)[:, 0:1] + g_points_mesh_ray = self.implicit_network.gradient(points_mesh_ray) + points_predicted = points_mesh_ray - g_points_mesh_ray.squeeze() * sdf_points_mesh_ray + + surface_points = points[surface_mask] + surface_dists = dists[surface_mask].unsqueeze(-1) + surface_ray_dirs = ray_dirs[surface_mask] + surface_cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)[surface_mask] + surface_output = sdf_output[surface_mask] + N = surface_points.shape[0] + + # Sample points for the eikonal loss + eik_bounding_box = self.object_bounding_sphere + n_eik_points = batch_size * num_pixels // 2 + eikonal_points = torch.empty(n_eik_points, 3).uniform_(-eik_bounding_box, eik_bounding_box).cuda() + eikonal_pixel_points = points.clone() + eikonal_pixel_points = eikonal_pixel_points.detach() + eikonal_points = torch.cat([eikonal_points, eikonal_pixel_points], 0) + + points_all = torch.cat([surface_points, eikonal_points], dim=0) + + output = self.implicit_network(surface_points) + surface_sdf_values = output[:N, 0:1].detach() + + g = self.implicit_network.gradient(points_all) + surface_points_grad = g[:N, 0, :].clone().detach() + grad_theta = g[N:, 0, :] + + differentiable_surface_points = self.sample_network(surface_output, + surface_sdf_values, + surface_points_grad, + surface_dists, + surface_cam_loc, + surface_ray_dirs) + + else: + surface_mask = network_object_mask + differentiable_surface_points = points[surface_mask] + grad_theta = None + + listA = surface_mask.cpu().detach().numpy() + A = [int(a) for a in listA] + AA = [i for i, a in enumerate(A) if a == 1] # surface mask 的 index + MeshRay_Index = np.array([i for i, a in enumerate(index_ray) if a in AA], dtype=int) + + face_points_index = self.faces[index_tri][MeshRay_Index] + select_vertex_normals = self.vertex_normals[face_points_index] + select_vertices = self.vertices[face_points_index] + + points_mesh_ray = points_mesh_ray[MeshRay_Index] + bcoords = barycentric_coordinates(points_mesh_ray, select_vertices) + resampled_normals = np.sum(np.expand_dims(bcoords, -1) * select_vertex_normals, 1) + resampled_normals = torch.tensor(resampled_normals).to(points) + + view = -ray_dirs[surface_mask] + + rgb_values = torch.ones_like(points).float().cuda() + diffuse_values = torch.ones_like(points).float().cuda() + specular_values = torch.ones_like(points).float().cuda() + albedo_values = torch.ones_like(points).float().cuda() + + if differentiable_surface_points.shape[0] > 0: + + rgb_values[surface_mask] = self.get_rbg_value(differentiable_surface_points, view, resampled_normals) + diffuse_values[surface_mask] = self.get_diffuse_value(differentiable_surface_points, view, resampled_normals) + + specular_values[surface_mask] = self.get_specular_value(differentiable_surface_points, view) + albedo_values[surface_mask] = self.get_albedo_value(differentiable_surface_points, view) + + output = { + 'points': points, + 'points_pre': points_predicted, + 'points_mesh_ray_gt': points[surface_mask], + 'points_mesh_ray_normals': resampled_normals, + 'surface_normals': points_normal[surface_mask].reshape([-1, 3]), + + 'rgb_values': rgb_values, + 'diffuse_values': diffuse_values, + 'specular_values': specular_values, + 'albedo_values': albedo_values, + + 'sdf_output': sdf_output, + 'network_object_mask': network_object_mask, + 'object_mask': object_mask, + 'grad_theta': grad_theta + } + + return output + + def get_rbg_value(self, points, view_dirs, diffuse_normals): + output = self.implicit_network(points) + g = self.implicit_network.gradient(points) + normals = g[:, 0, :] + feature_vectors = output[:, 1:] + + diffuse_shading = self.diffuse_network(diffuse_normals) + specular_shading = self.specular_network(normals, view_dirs) + albedo = self.albedo_network(points, feature_vectors) + + diffuse_shading = (diffuse_shading + 1.) / 2. + specular_shading = (specular_shading + 1.) / 2. + albedo = (albedo + 1.) / 2. + + rgb_vals = diffuse_shading * albedo + specular_shading + rgb_vals = (rgb_vals * 2.) - 1. + + return rgb_vals + + def get_diffuse_value(self, points, view_dirs, diffuse_normals): + + diffuse_shading = self.diffuse_network(diffuse_normals) + return diffuse_shading.expand([-1, 3]) + + def get_albedo_value(self, points, view_dirs): + output = self.implicit_network(points) + feature_vectors = output[:, 1:] + albedo = self.albedo_network(points, feature_vectors) + + return albedo + + def get_specular_value(self, points, view_dirs): + g = self.implicit_network.gradient(points) + normals = g[:, 0, :] + + specular_shading = self.specular_network(normals, view_dirs) + return specular_shading.expand([-1, 3]) \ No newline at end of file diff --git a/reconstruction/PBIDR/code/model/sample_network.py b/reconstruction/PBIDR/code/model/sample_network.py new file mode 100644 index 0000000..746ac93 --- /dev/null +++ b/reconstruction/PBIDR/code/model/sample_network.py @@ -0,0 +1,20 @@ +import torch.nn as nn +import torch + +class SampleNetwork(nn.Module): + ''' + Represent the intersection (sample) point as differentiable function of the implicit geometry and camera parameters. + See equation 3 in the paper for more details. + ''' + + def forward(self, surface_output, surface_sdf_values, surface_points_grad, surface_dists, surface_cam_loc, surface_ray_dirs): + # t -> t(theta) + surface_ray_dirs_0 = surface_ray_dirs.detach() + surface_points_dot = torch.bmm(surface_points_grad.view(-1, 1, 3), + surface_ray_dirs_0.view(-1, 3, 1)).squeeze(-1) + surface_dists_theta = surface_dists - (surface_output - surface_sdf_values) / surface_points_dot + + # t(theta) -> x(theta,c,v) + surface_points_theta_c_v = surface_cam_loc + surface_dists_theta * surface_ray_dirs + + return surface_points_theta_c_v diff --git a/reconstruction/PBIDR/code/preprocess/get_aux_dataset.py b/reconstruction/PBIDR/code/preprocess/get_aux_dataset.py new file mode 100644 index 0000000..63c84c3 --- /dev/null +++ b/reconstruction/PBIDR/code/preprocess/get_aux_dataset.py @@ -0,0 +1,158 @@ +import os +import sys +sys.path.append(os.path.abspath('')) +import torch +import argparse +import numpy as np +from pytorch3d.io import load_objs_as_meshes, save_obj,load_obj +from pytorch3d.renderer import ( + look_at_view_transform, + PerspectiveCameras, + # FoVPerspectiveCameras, + PointLights, + # DirectionalLights, + # Materials, + RasterizationSettings, + MeshRenderer, + MeshRasterizer, + # SoftPhongShader, + # SoftSilhouetteShader, + SoftPhongShader, + # TexturesVertex, + Materials +) +from PIL import Image + +print("Start to get aux dataset!") + +if __name__ == '__main__': + + parser = argparse.ArgumentParser("PreProcessing") + parser.add_argument('--gpu', '-g', type=str, default='0',help='GPU') + parser.add_argument('--input', '-i', type=str, default='../raw_data', help='Location of Raw Textured Mesh Dataset') + parser.add_argument('--output', '-o', type=int, required=True, help='New aux dataset') + parser.add_argument('--yaw', type=int, default=15, help='num_views_yaw') + parser.add_argument('--yaw_angle', type=int, default=45, help='yaw_angle') + parser.add_argument('--pitch', type=int, default=9, help='num_views_pitch') + parser.add_argument('--pitch_angle', type=int, default=30, help='pitch_angle') + parser.add_argument('--datapath', type=str, default='../data/', help='Location of code data') + + parser.add_argument('--dataset', '-d', type=str, default='Face', help='FaceTest dataset') + + args = parser.parse_args() + os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu + + # Setup + if torch.cuda.is_available(): + device = torch.device("cuda:0") + torch.cuda.set_device(device) + else: + device = torch.device("cpu") + + # Set paths + DATA_DIR = args.input + IMAGE_DIR = os.path.join(DATA_DIR, "mesh") + + if os.path.exists(IMAGE_DIR): + os.system("rm -r " + IMAGE_DIR) + os.mkdir(IMAGE_DIR) + obj_filename = os.path.join(DATA_DIR, "mesh.obj") + + if not os.path.exists(os.path.join(args.datapath, args.dataset)): + os.mkdir(os.path.join(args.datapath, args.dataset)) + + # Load obj file + mesh = load_objs_as_meshes([obj_filename], device=device,load_textures=True) + print(obj_filename) + print("Loaded Mesh") + + # the number of different viewpoints from which we want to render the mesh. + def Ry(q): + return np.array([[-np.cos(q * np.pi / 180), 0, -np.sin(q * np.pi / 180)], [0, 1, 0], + [np.sin(q * np.pi / 180), 0, -np.cos(q * np.pi / 180)]]) + def Rx(q): + return np.array([[-1, 0, 0], [0, np.cos(q * np.pi / 180), np.sin(q * np.pi / 180)], + [0, np.sin(q * np.pi / 180), -np.cos(q * np.pi / 180)]]) + + def get_R_matrix(azim, axis="Ry"): + print("Rotation Martix {}".format(axis)) + aa = [] + if axis == "Ry": + for q in azim: + aa.append(Ry(q)) + RRR = torch.tensor(np.array(aa)).to(device) + else: + for q in azim: + aa.append(Rx(q)) + RRR = torch.tensor(np.array(aa)).to(device) + return RRR + + num_views = args.yaw + args.pitch + + yaw_dim = torch.linspace(-1 * args.yaw_angle, args.yaw_angle, args.yaw) + pitch_dim = torch.linspace(-1 * args.pitch_angle, args.pitch_angle , args.pitch) + + lights = PointLights(device=device, location=[[0, 50, 100]], ambient_color=((1.0, 1.0, 1.0), ), diffuse_color=((0.0, 0.0, 0.0), ), specular_color=((0.0, 0.0, 0.0), )) + RRy, TTy = look_at_view_transform(dist=8, elev=0, azim=yaw_dim, up=((0, 1, 0),), device=device) + + TTx = TTy[:args.pitch] + RRx = get_R_matrix(azim=pitch_dim, axis="Rx") + + Rtotal = torch.cat([RRy, RRx], dim=0) + Ttotal = torch.cat([TTy, TTx], dim=0) + + cameras = PerspectiveCameras(device=device, focal_length=4500, principal_point=((512, 512),), R=Rtotal, T=Ttotal, + image_size=((1024, 1024),)) + + if num_views != 1: + camera = PerspectiveCameras(device=device, focal_length=4500, principal_point=((512, 512),), R=Rtotal[None, 1, ...], + T=Ttotal[None, 1, ...], image_size=((1024, 1024),)) + else: + camera = PerspectiveCameras(device=device, focal_length=4500, principal_point=((512, 512),), + R=Rtotal, + T=Ttotal, image_size=((1024, 1024),)) + + mymaterials = Materials(device=device, shininess=8) + raster_settings = RasterizationSettings( + image_size=1024, + blur_radius=0.0, + faces_per_pixel=1, + ) + renderer = MeshRenderer( + rasterizer=MeshRasterizer( + cameras=camera, + raster_settings=raster_settings + ), + shader=SoftPhongShader( + device=device, + cameras=camera, + lights=lights, + materials=mymaterials, + ) + ) + + meshes = mesh.extend(num_views) + target_images = renderer(meshes, cameras=cameras, lights=lights) + target_rgb = [target_images[i, ..., :3] for i in range(num_views)] + target_cameras = [PerspectiveCameras(device=device, focal_length=4500, principal_point=((512, 512),), R=Rtotal[None, i, ...], + T=Ttotal[None, i, ...], image_size=((1024, 1024),)) for i in range(num_views)] + + # RGB images + if not os.path.exists(os.path.join(IMAGE_DIR, 'image')): + os.mkdir(os.path.join(IMAGE_DIR, 'image')) + if not os.path.exists(os.path.join(IMAGE_DIR, 'mask')): + os.mkdir(os.path.join(IMAGE_DIR, 'mask')) + + for i in range(len(target_images)): + img = Image.fromarray((target_images[i, ..., :3].cpu().numpy() * 255).astype(np.uint8)) + img.save(os.path.join(IMAGE_DIR, 'image/{0}.png'.format('%03d' % int(i+1)))) + img.save(os.path.join(IMAGE_DIR, 'mask/{0}.png'.format('%03d' % int(i+1)))) + np.save(os.path.join(IMAGE_DIR,'R.npy'), Rtotal.cpu().numpy()) + np.save(os.path.join(IMAGE_DIR,'T.npy'), Ttotal.cpu().numpy()) + + SCAN_DIR = args.datapath + args.dataset + '/scan' + str(args.output) + "/" + if os.path.exists(SCAN_DIR): + os.system("rm -r " + SCAN_DIR) + os.system("cp -r " + IMAGE_DIR + " " + SCAN_DIR) + os.system("cp " + DATA_DIR + "/mesh.* " + SCAN_DIR + ".") + print("Finished") \ No newline at end of file diff --git a/reconstruction/PBIDR/code/preprocess/preprocess_cameras.py b/reconstruction/PBIDR/code/preprocess/preprocess_cameras.py new file mode 100644 index 0000000..43b9ced --- /dev/null +++ b/reconstruction/PBIDR/code/preprocess/preprocess_cameras.py @@ -0,0 +1,100 @@ +import numpy as np +import matplotlib.image as mpimg +import matplotlib.pyplot as plt +import cv2 +import argparse +from glob import glob +import os +import sys +import pickle +sys.path.append('../code') +from scipy.spatial.transform import Rotation +import utils.general as utils + + +def get_Ps_from_Faces(R, T): + Ps = [] + cam_locs = [] + + intrinsics = np.concatenate([[4500.0], [0.0], [512.0], [0.0], [4500.0], [512.0], [0.0], [0.0], [1.0]], axis=0) + intrinsics = np.reshape(intrinsics, [3, 3]) + + projection = np.concatenate([[1.0], [0.0], [0.0], [0.0], [0.0], [1.0], [0.0], [0.0], [0.0], [0.0], [1.0], [0.0]], axis=0) + projection = np.reshape(projection, [3, 4]) + + I14 = np.concatenate([[0.0], [0.0], [0.0], [1.0]], axis=0) + I14 = np.reshape(I14, [1, 4]) + + for i in range(0, len(R)): + R0 = R[i] + T0 = T[i].reshape(3, 1) + + p = np.concatenate([np.concatenate([R[i].T, T[i].reshape(3, 1)], axis=1), I14], axis=0) + P = intrinsics @ projection @ p + P = P.astype(np.float64) + + camera_loc = -np.dot(R0, T0) + cam_locs.append(camera_loc) + Ps.append(P) + + return np.array(Ps) + + +def get_all_mask_points_white_bg(masks_dir): + mask_paths = sorted(utils.glob_imgs(masks_dir)) + mask_points_all=[] + mask_ims = [] + for path in mask_paths: + img = mpimg.imread(path) + cur_mask = img.max(axis=2) < 0.9 + mask_points = np.where(img.max(axis=2) < 0.9) + xs = mask_points[1] + ys = mask_points[0] + mask_points_all.append(np.stack((xs,ys,np.ones_like(xs))).astype(float)) + mask_ims.append(cur_mask) + return mask_points_all,np.array(mask_ims) + + +def get_normalization(source_dir): + print('Preprocessing', source_dir) + + masks_dir= '{0}/mask'.format(source_dir) + mask_points_all, masks_all = get_all_mask_points_white_bg(masks_dir) + number_of_cameras = len(masks_all) + R = np.load('{0}/R.npy'.format(source_dir)) + T = np.load('{0}/T.npy'.format(source_dir)) + Ps = get_Ps_from_Faces(R, T) + normalization = np.eye(4).astype(np.float32) + + cameras_new={} + for i in range(number_of_cameras): + cameras_new['scale_mat_%d' % i] = normalization + cameras_new['world_mat_%d' % i] = np.concatenate((Ps[i],np.array([[0,0,0,1.0]])),axis=0).astype(np.float32) + + np.savez('{0}/{1}.npz'.format(source_dir, "cameras"), **cameras_new) + print(normalization) + print('--------------------------------------------------------') + + if False: #for debugging + for i in range(number_of_cameras): + plt.figure() + + plt.imshow(mpimg.imread('%s/%03d.png' % ('{0}/mask'.format(source_dir), i+1))) + xy = (Ps[i,:2, :] @ (np.concatenate((np.array(all_Xs), np.ones((len(all_Xs), 1))), axis=1).T)) / ( + Ps[i,2, :] @ (np.concatenate((np.array(all_Xs), np.ones((len(all_Xs), 1))), axis=1).T)) + + plt.plot(xy[0, :], xy[1, :], '*') + plt.show() + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--scan_id', '-i', type=int, default=0, help='data source folder for preprocess') + parser.add_argument('--dataset', '-d', type=str, default='Face', help='dataset dir') + opt = parser.parse_args() + + SCAN_DIR = '../data/' + opt.dataset + '/scan' + str(opt.scan_id) + get_normalization(SCAN_DIR) + + print('Done!') \ No newline at end of file diff --git a/reconstruction/PBIDR/code/script/data_process.sh b/reconstruction/PBIDR/code/script/data_process.sh new file mode 100644 index 0000000..8cc319f --- /dev/null +++ b/reconstruction/PBIDR/code/script/data_process.sh @@ -0,0 +1,6 @@ +set -ex + +GPU=0 + +python preprocess/get_aux_dataset.py -g $GPU -i '../raw_data/0' -o 0 -d 'Test' --yaw 17 --pitch 0 +python preprocess/preprocess_cameras.py -i 0 -d 'Test' \ No newline at end of file diff --git a/reconstruction/PBIDR/code/script/fast_eval.sh b/reconstruction/PBIDR/code/script/fast_eval.sh new file mode 100644 index 0000000..2608eda --- /dev/null +++ b/reconstruction/PBIDR/code/script/fast_eval.sh @@ -0,0 +1,4 @@ +set -ex + +GPU=0 +python evaluation/eval.py --conf ./confs/test.conf --scan_id 0 --gpu $GPU --checkpoint 400 --eval_rendering \ No newline at end of file diff --git a/reconstruction/PBIDR/code/script/fast_train.sh b/reconstruction/PBIDR/code/script/fast_train.sh new file mode 100644 index 0000000..a2fa2da --- /dev/null +++ b/reconstruction/PBIDR/code/script/fast_train.sh @@ -0,0 +1,4 @@ +set -ex + +GPU=0 +python training/runner.py --conf ./confs/test.conf --scan_id 0 --gpu $GPU --nepoch 400 \ No newline at end of file diff --git a/reconstruction/PBIDR/code/training/runner.py b/reconstruction/PBIDR/code/training/runner.py new file mode 100644 index 0000000..eba2cb9 --- /dev/null +++ b/reconstruction/PBIDR/code/training/runner.py @@ -0,0 +1,58 @@ +import sys +sys.path.append('../code') +import argparse +import GPUtil +import torch +import random +import numpy as np + +from training.train import IFTrainRunner + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', type=int, default=1, help='input batch size') + parser.add_argument('--nepoch', type=int, default=400, help='number of epochs to train for') + parser.add_argument('--nepoch_freeze', type=int, default=1000, help='number of epochs to train for') + parser.add_argument('--conf', type=str, default='./confs/test.conf') + parser.add_argument('--expname', type=str, default='') + parser.add_argument('--gpu', type=str, default='auto', help='GPU to use [default: GPU auto]') + parser.add_argument('--is_continue', default=False, action="store_true", help='If set, indicates continuing from a previous run.') + parser.add_argument('--timestamp', default='latest', type=str, help='The timestamp of the run to be used in case of continuing from a previous run.') + parser.add_argument('--checkpoint', default='latest',type=str,help='The checkpoint epoch number of the run to be used in case of continuing from a previous run.') + parser.add_argument('--train_cameras', default=False, action="store_true", help='If set, optimizing also camera location.') + parser.add_argument('--scan_id', type=int, default=-1, help='If set, taken to be the scan id.') + + opt = parser.parse_args() + + if opt.gpu == "auto": + deviceIDs = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False, excludeID=[], excludeUUID=[]) + gpu = deviceIDs[0] + else: + gpu = opt.gpu + + setup_seed(0) + + trainrunner = IFTrainRunner(conf=opt.conf, + batch_size=opt.batch_size, + nepochs=opt.nepoch, + nepoch_freeze=opt.nepoch_freeze, + expname=opt.expname, + gpu_index=gpu, + exps_folder_name='exps', + is_continue=opt.is_continue, + timestamp=opt.timestamp, + checkpoint=opt.checkpoint, + scan_id=opt.scan_id, + train_cameras=opt.train_cameras + ) + + trainrunner.run() diff --git a/reconstruction/PBIDR/code/training/train.py b/reconstruction/PBIDR/code/training/train.py new file mode 100644 index 0000000..9ae6617 --- /dev/null +++ b/reconstruction/PBIDR/code/training/train.py @@ -0,0 +1,284 @@ +import os +from datetime import datetime +from pyhocon import ConfigFactory +import sys +import torch +import torch.nn as nn + +import utils.general as utils +import utils.plots as plt + +class IFTrainRunner(): + def __init__(self,**kwargs): + torch.set_default_dtype(torch.float32) + torch.set_num_threads(1) + + self.conf = ConfigFactory.parse_file(kwargs['conf']) + self.batch_size = kwargs['batch_size'] + + self.nepochs = kwargs['nepochs'] + self.nepoch_freeze = kwargs['nepoch_freeze'] + + self.exps_folder_name = kwargs['exps_folder_name'] + self.GPU_INDEX = kwargs['gpu_index'] + self.train_cameras = kwargs['train_cameras'] + + self.expname = self.conf.get_string('train.expname') + kwargs['expname'] + scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else self.conf.get_int('dataset.scan_id', default=-1) + if scan_id != -1: + self.expname = self.expname + '_{0}'.format(scan_id) + + if kwargs['is_continue'] and kwargs['timestamp'] == 'latest': + if os.path.exists(os.path.join('../',kwargs['exps_folder_name'],self.expname)): + timestamps = os.listdir(os.path.join('../',kwargs['exps_folder_name'],self.expname)) + if (len(timestamps)) == 0: + is_continue = False + timestamp = None + else: + timestamp = sorted(timestamps)[-1] + is_continue = True + else: + is_continue = False + timestamp = None + else: + timestamp = kwargs['timestamp'] + is_continue = kwargs['is_continue'] + + utils.mkdir_ifnotexists(os.path.join('../',self.exps_folder_name)) + self.expdir = os.path.join('../', self.exps_folder_name, self.expname) + utils.mkdir_ifnotexists(self.expdir) + self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now()) + utils.mkdir_ifnotexists(os.path.join(self.expdir, self.timestamp)) + + self.plots_dir = os.path.join(self.expdir, self.timestamp, 'plots') + utils.mkdir_ifnotexists(self.plots_dir) + + # create checkpoints dirs + self.checkpoints_path = os.path.join(self.expdir, self.timestamp, 'checkpoints') + utils.mkdir_ifnotexists(self.checkpoints_path) + self.model_params_subdir = "ModelParameters" + self.optimizer_params_subdir = "OptimizerParameters" + self.scheduler_params_subdir = "SchedulerParameters" + + utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir)) + utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir)) + utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.scheduler_params_subdir)) + + if self.train_cameras: + self.optimizer_cam_params_subdir = "OptimizerCamParameters" + self.cam_params_subdir = "CamParameters" + + utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_cam_params_subdir)) + utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.cam_params_subdir)) + + os.system("""cp -r {0} "{1}" """.format(kwargs['conf'], os.path.join(self.expdir, self.timestamp, 'runconf.conf'))) + + if (not self.GPU_INDEX == 'ignore'): + os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX) + + print('shell command : {0}'.format(' '.join(sys.argv))) + + print('Loading data ...') + + dataset_conf = self.conf.get_config('dataset') + if kwargs['scan_id'] != -1: + dataset_conf['scan_id'] = kwargs['scan_id'] + + self.train_dataset = utils.get_class(self.conf.get_string('train.dataset_class'))(self.train_cameras, + **dataset_conf) + + print('Finish loading data ...') + + self.train_dataloader = torch.utils.data.DataLoader(self.train_dataset, + batch_size=self.batch_size, + shuffle=True, + collate_fn=self.train_dataset.collate_fn + ) + self.plot_dataloader = torch.utils.data.DataLoader(self.train_dataset, + batch_size=self.conf.get_int('plot.plot_nimgs'), + shuffle=True, + collate_fn=self.train_dataset.collate_fn + ) + + self.model = utils.get_class(self.conf.get_string('train.model_class'))(conf=self.conf.get_config('model'), \ + id=scan_id, datadir=dataset_conf['data_dir']) + if torch.cuda.is_available(): + self.model.cuda() + + self.loss = utils.get_class(self.conf.get_string('train.loss_class'))(**self.conf.get_config('loss')) + + self.lr = self.conf.get_float('train.learning_rate') + + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) + + self.sched_milestones = self.conf.get_list('train.sched_milestones', default=[]) + self.sched_factor = self.conf.get_float('train.sched_factor', default=0.0) + self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, self.sched_milestones, gamma=self.sched_factor) + + # settings for camera optimization + if self.train_cameras: + num_images = len(self.train_dataset) + + self.pose_vecs = torch.nn.Embedding(num_images, 7, sparse=True).cuda() + self.pose_vecs.weight.data.copy_(self.train_dataset.get_pose_init()) + self.optimizer_cam = torch.optim.SparseAdam(list(self.pose_vecs.parameters()), self.conf.get_float('train.learning_rate_cam')) + + self.start_epoch = 0 + if is_continue: + old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints') + + saved_model_state = torch.load( + os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth")) + self.model.load_state_dict(saved_model_state["model_state_dict"]) + self.start_epoch = saved_model_state['epoch'] + + data = torch.load( + os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth")) + self.optimizer.load_state_dict(data["optimizer_state_dict"]) + + data = torch.load( + os.path.join(old_checkpnts_dir, self.scheduler_params_subdir, str(kwargs['checkpoint']) + ".pth")) + self.scheduler.load_state_dict(data["scheduler_state_dict"]) + + if self.train_cameras: + data = torch.load( + os.path.join(old_checkpnts_dir, self.optimizer_cam_params_subdir, str(kwargs['checkpoint']) + ".pth")) + self.optimizer_cam.load_state_dict(data["optimizer_cam_state_dict"]) + + data = torch.load( + os.path.join(old_checkpnts_dir, self.cam_params_subdir, str(kwargs['checkpoint']) + ".pth")) + self.pose_vecs.load_state_dict(data["pose_vecs_state_dict"]) + + self.num_pixels = self.conf.get_int('train.num_pixels') + self.total_pixels = self.train_dataset.total_pixels + self.img_res = self.train_dataset.img_res + self.n_batches = len(self.train_dataloader) + self.plot_freq = self.conf.get_int('train.plot_freq') + self.plot_conf = self.conf.get_config('plot') + + self.alpha_milestones = self.conf.get_list('train.alpha_milestones', default=[]) + self.alpha_factor = self.conf.get_float('train.alpha_factor', default=0.0) + for acc in self.alpha_milestones: + if self.start_epoch > acc: + self.loss.alpha = self.loss.alpha * self.alpha_factor + + def save_checkpoints(self, epoch): + torch.save( + {"epoch": epoch, "model_state_dict": self.model.state_dict()}, + os.path.join(self.checkpoints_path, self.model_params_subdir, str(epoch) + ".pth")) + torch.save( + {"epoch": epoch, "model_state_dict": self.model.state_dict()}, + os.path.join(self.checkpoints_path, self.model_params_subdir, "latest.pth")) + + torch.save( + {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()}, + os.path.join(self.checkpoints_path, self.optimizer_params_subdir, str(epoch) + ".pth")) + torch.save( + {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()}, + os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth")) + + torch.save( + {"epoch": epoch, "scheduler_state_dict": self.scheduler.state_dict()}, + os.path.join(self.checkpoints_path, self.scheduler_params_subdir, str(epoch) + ".pth")) + torch.save( + {"epoch": epoch, "scheduler_state_dict": self.scheduler.state_dict()}, + os.path.join(self.checkpoints_path, self.scheduler_params_subdir, "latest.pth")) + + if self.train_cameras: + torch.save( + {"epoch": epoch, "optimizer_cam_state_dict": self.optimizer_cam.state_dict()}, + os.path.join(self.checkpoints_path, self.optimizer_cam_params_subdir, str(epoch) + ".pth")) + torch.save( + {"epoch": epoch, "optimizer_cam_state_dict": self.optimizer_cam.state_dict()}, + os.path.join(self.checkpoints_path, self.optimizer_cam_params_subdir, "latest.pth")) + + torch.save( + {"epoch": epoch, "pose_vecs_state_dict": self.pose_vecs.state_dict()}, + os.path.join(self.checkpoints_path, self.cam_params_subdir, str(epoch) + ".pth")) + torch.save( + {"epoch": epoch, "pose_vecs_state_dict": self.pose_vecs.state_dict()}, + os.path.join(self.checkpoints_path, self.cam_params_subdir, "latest.pth")) + + def run(self): + print("training...") + + for epoch in range(self.start_epoch, self.nepochs + 1): + + if epoch in self.alpha_milestones: + self.loss.alpha = self.loss.alpha * self.alpha_factor + + if epoch % 100 == 0 and epoch != 0: + self.save_checkpoints(epoch) + + if epoch % self.plot_freq == 0 and epoch != 0: + self.model.eval() + if self.train_cameras: + self.pose_vecs.eval() + self.train_dataset.change_sampling_idx(-1) + indices, model_input, ground_truth = next(iter(self.plot_dataloader)) + + model_input["intrinsics"] = model_input["intrinsics"].cuda() + model_input["uv"] = model_input["uv"].cuda() + model_input["object_mask"] = model_input["object_mask"].cuda() + # model_input[""] = ground_truth["rgb"].cuda() + + if self.train_cameras: + pose_input = self.pose_vecs(indices.cuda()) + model_input['pose'] = pose_input + else: + model_input['pose'] = model_input['pose'].cuda() + + detail_3dmm, detail_3dmm_subdivision_full = plt.get_displacement_mesh(self.model) + detail_3dmm.export('{0}/Detailed_3dmm_{1}.obj'.format(self.plots_dir, epoch), 'obj') + detail_3dmm_subdivision_full.export('{0}/Subdivide_full_{1}.obj'.format(self.plots_dir, epoch), 'obj') + + self.model.train() + if self.train_cameras: + self.pose_vecs.train() + + self.train_dataset.change_sampling_idx(self.num_pixels) + + if epoch > self.nepoch_freeze: + print("Freeze Diffuse Part...") + self.model.diffuse_network.eval() + self.model.albedo_network.eval() + + for data_index, (indices, model_input, ground_truth) in enumerate(self.train_dataloader): + + model_input["intrinsics"] = model_input["intrinsics"].cuda() + model_input["uv"] = model_input["uv"].cuda() + model_input["object_mask"] = model_input["object_mask"].cuda() + + if self.train_cameras: + pose_input = self.pose_vecs(indices.cuda()) + model_input['pose'] = pose_input + else: + model_input['pose'] = model_input['pose'].cuda() + + model_outputs = self.model(model_input) + loss_output = self.loss(model_outputs, ground_truth) + + loss = loss_output['loss'] + + self.optimizer.zero_grad() + if self.train_cameras: + self.optimizer_cam.zero_grad() + + loss.backward() + + self.optimizer.step() + if self.train_cameras: + self.optimizer_cam.step() + + print( + '{0} [{1}] ({2}/{3}): loss = {4}, rgb_loss = {5}, normal_loss = {6}, reg_loss = {7}, eikonal_loss = {8}, mask_loss = {9}, alpha = {10}, lr = {11}' + .format(self.expname, epoch, data_index, self.n_batches, loss.item(), + loss_output['rgb_loss'].item(), + loss_output['normal_loss'].item(), + loss_output['reg_loss'].item(), + loss_output['eikonal_loss'].item(), + loss_output['mask_loss'].item(), + self.loss.alpha, + self.scheduler.get_lr()[0])) + + self.scheduler.step() \ No newline at end of file diff --git a/reconstruction/PBIDR/code/utils/general.py b/reconstruction/PBIDR/code/utils/general.py new file mode 100644 index 0000000..f76c700 --- /dev/null +++ b/reconstruction/PBIDR/code/utils/general.py @@ -0,0 +1,66 @@ +import os +from glob import glob +import torch + +def mkdir_ifnotexists(directory): + if not os.path.exists(directory): + os.mkdir(directory) + +def get_class(kls): + parts = kls.split('.') + module = ".".join(parts[:-1]) + m = __import__(module) + for comp in parts[1:]: + m = getattr(m, comp) + return m + +def glob_imgs(path): + imgs = [] + for ext in ['*.png', '*.jpg', '*.JPEG', '*.JPG']: + imgs.extend(glob(os.path.join(path, ext))) + return imgs + +def split_input(model_input, total_pixels): + ''' + Split the input to fit Cuda memory for large resolution. + Can decrease the value of n_pixels in case of cuda out of memory error. + ''' + n_pixels = 10000 + split = [] + for i, indx in enumerate(torch.split(torch.arange(total_pixels).cuda(), n_pixels, dim=0)): + data = model_input.copy() + data['uv'] = torch.index_select(model_input['uv'], 1, indx) + data['object_mask'] = torch.index_select(model_input['object_mask'], 1, indx) + split.append(data) + return split + +def split_input_albedo(model_input, total_pixels): + ''' + Split the input to fit Cuda memory for large resolution. + Can decrease the value of n_pixels in case of cuda out of memory error. + ''' + n_pixels = 10000 + split = [] + for i, indx in enumerate(torch.split(torch.arange(total_pixels).cuda(), n_pixels, dim=0)): + data = model_input.copy() + data['uv'] = torch.index_select(model_input['uv'], 1, indx) + data['object_mask'] = torch.index_select(model_input['object_mask'], 1, indx) + data['rgb'] = torch.index_select(model_input['rgb'], 1, indx) + split.append(data) + return split + +def merge_output(res, total_pixels, batch_size): + ''' Merge the split output. ''' + + model_outputs = {} + for entry in res[0]: + if res[0][entry] is None: + continue + if len(res[0][entry].shape) == 1: + model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, 1) for r in res], + 1).reshape(batch_size * total_pixels) + else: + model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, r[entry].shape[-1]) for r in res], + 1).reshape(batch_size * total_pixels, -1) + + return model_outputs \ No newline at end of file diff --git a/reconstruction/PBIDR/code/utils/plots.py b/reconstruction/PBIDR/code/utils/plots.py new file mode 100644 index 0000000..fde881e --- /dev/null +++ b/reconstruction/PBIDR/code/utils/plots.py @@ -0,0 +1,424 @@ +import plotly.graph_objs as go +import plotly.offline as offline +import numpy as np +import torch +from skimage import measure +import torchvision +import trimesh +from PIL import Image +from utils import rend_util +import pickle + + +def plot_latent(model, latent, indices, model_outputs ,pose, rgb_gt, path, epoch, img_res, plot_nimgs, max_depth, resolution): + # arrange data to plot + batch_size, num_samples, _ = rgb_gt.shape + + network_object_mask = model_outputs['network_object_mask'] + points = model_outputs['points'].reshape(batch_size, num_samples, 3) + rgb_eval = model_outputs['rgb_values'] + rgb_eval = rgb_eval.reshape(batch_size, num_samples, 3) + + depth = torch.ones(batch_size * num_samples).cuda().float() * max_depth + depth[network_object_mask] = rend_util.get_depth(points, pose).reshape(-1)[network_object_mask] + depth = depth.reshape(batch_size, num_samples, 1) + network_object_mask = network_object_mask.reshape(batch_size,-1) + + cam_loc, cam_dir = rend_util.get_camera_for_plot(pose) + + # plot rendered images + plot_images(rgb_eval, rgb_gt, path, epoch, plot_nimgs, img_res) + + # plot depth maps + plot_depth_maps(depth, path, epoch, plot_nimgs, img_res) + + data = [] + + # plot surface + surface_traces = get_surface_trace(path=path, + epoch=epoch, + sdf=lambda x: model.implicit_network(torch.cat([latent.expand(len(x), -1), x], 1))[:, 0], + resolution=resolution + ) + data.append(surface_traces[0]) + + # plot cameras locations + for i, loc, dir in zip(indices, cam_loc, cam_dir): + data.append(get_3D_quiver_trace(loc.unsqueeze(0), dir.unsqueeze(0), name='camera_{0}'.format(i))) + + for i, p, m in zip(indices, points, network_object_mask): + p = p[m] + sampling_idx = torch.randperm(p.shape[0])[:2048] + p = p[sampling_idx, :] + + val = model.implicit_network(torch.cat([latent.expand(len(p), -1), p], 1)) + caption = ["sdf: {0} ".format(v[0].item()) for v in val] + + data.append(get_3D_scatter_trace(p, name='intersection_points_{0}'.format(i), caption=caption)) + + fig = go.Figure(data=data) + scene_dict = dict(xaxis=dict(range=[-3, 3], autorange=False), + yaxis=dict(range=[-3, 3], autorange=False), + zaxis=dict(range=[-3, 3], autorange=False), + aspectratio=dict(x=1, y=1, z=1)) + fig.update_layout(scene=scene_dict, width=1400, height=1400, showlegend=True) + filename = '{0}/surface_{1}.html'.format(path, epoch) + offline.plot(fig, filename=filename, auto_open=False) + + +def plot(model, indices, model_outputs ,pose, rgb_gt, path, epoch, img_res, plot_nimgs, max_depth, resolution): + # arrange data to plot + batch_size, num_samples, _ = rgb_gt.shape + + network_object_mask = model_outputs['network_object_mask'] + points = model_outputs['points'].reshape(batch_size, num_samples, 3) + rgb_eval = model_outputs['rgb_values'] + rgb_eval = rgb_eval.reshape(batch_size, num_samples, 3) + + depth = torch.ones(batch_size * num_samples).cuda().float() * max_depth + depth[network_object_mask] = rend_util.get_depth(points, pose).reshape(-1)[network_object_mask] + depth = depth.reshape(batch_size, num_samples, 1) + network_object_mask = network_object_mask.reshape(batch_size,-1) + + cam_loc, cam_dir = rend_util.get_camera_for_plot(pose) + + # plot rendered images + plot_images(rgb_eval, rgb_gt, path, epoch, plot_nimgs, img_res) + + # plot depth maps + plot_depth_maps(depth, path, epoch, plot_nimgs, img_res) + + data = [] + + # plot surface + surface_traces = get_surface_trace(path=path, + epoch=epoch, + sdf=lambda x: model.implicit_network(x)[:, 0], + resolution=resolution + ) + data.append(surface_traces[0]) + + # plot cameras locations + for i, loc, dir in zip(indices, cam_loc, cam_dir): + data.append(get_3D_quiver_trace(loc.unsqueeze(0), dir.unsqueeze(0), name='camera_{0}'.format(i))) + + for i, p, m in zip(indices, points, network_object_mask): + p = p[m] + sampling_idx = torch.randperm(p.shape[0])[:2048] + p = p[sampling_idx, :] + + val = model.implicit_network(p) + caption = ["sdf: {0} ".format(v[0].item()) for v in val] + + data.append(get_3D_scatter_trace(p, name='intersection_points_{0}'.format(i), caption=caption)) + + fig = go.Figure(data=data) + scene_dict = dict(xaxis=dict(range=[-3, 3], autorange=False), + yaxis=dict(range=[-3, 3], autorange=False), + zaxis=dict(range=[-3, 3], autorange=False), + aspectratio=dict(x=1, y=1, z=1)) + fig.update_layout(scene=scene_dict, width=1400, height=1400, showlegend=True) + filename = '{0}/surface_{1}.html'.format(path, epoch) + offline.plot(fig, filename=filename, auto_open=False) + + +def get_3D_scatter_trace(points, name='', size=3, caption=None): + assert points.shape[1] == 3, "3d scatter plot input points are not correctely shaped " + assert len(points.shape) == 2, "3d scatter plot input points are not correctely shaped " + + trace = go.Scatter3d( + x=points[:, 0].cpu(), + y=points[:, 1].cpu(), + z=points[:, 2].cpu(), + mode='markers', + name=name, + marker=dict( + size=size, + line=dict( + width=2, + ), + opacity=1.0, + ), text=caption) + + return trace + + +def get_3D_quiver_trace(points, directions, color='#bd1540', name=''): + assert points.shape[1] == 3, "3d cone plot input points are not correctely shaped " + assert len(points.shape) == 2, "3d cone plot input points are not correctely shaped " + assert directions.shape[1] == 3, "3d cone plot input directions are not correctely shaped " + assert len(directions.shape) == 2, "3d cone plot input directions are not correctely shaped " + + trace = go.Cone( + name=name, + x=points[:, 0].cpu(), + y=points[:, 1].cpu(), + z=points[:, 2].cpu(), + u=directions[:, 0].cpu(), + v=directions[:, 1].cpu(), + w=directions[:, 2].cpu(), + sizemode='absolute', + sizeref=0.125, + showscale=False, + colorscale=[[0, color], [1, color]], + anchor="tail" + ) + + return trace + + +def get_surface_trace(path, epoch, sdf, resolution=100, return_mesh=False): + grid = get_grid_uniform(resolution) + points = grid['grid_points'] + + z = [] + for i, pnts in enumerate(torch.split(points, 100000, dim=0)): + z.append(sdf(pnts).detach().cpu().numpy()) + z = np.concatenate(z, axis=0) + + if (not (np.min(z) > 0 or np.max(z) < 0)): + + z = z.astype(np.float32) + + verts, faces, normals, values = measure.marching_cubes_lewiner( + volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0], + grid['xyz'][2].shape[0]).transpose([1, 0, 2]), + level=0, + spacing=(grid['xyz'][0][2] - grid['xyz'][0][1], + grid['xyz'][0][2] - grid['xyz'][0][1], + grid['xyz'][0][2] - grid['xyz'][0][1])) + + verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]]) + + I, J, K = faces.transpose() + + traces = [go.Mesh3d(x=verts[:, 0], y=verts[:, 1], z=verts[:, 2], + i=I, j=J, k=K, name='implicit_surface', + opacity=1.0)] + + meshexport = trimesh.Trimesh(verts, faces, vertex_normals=-normals) + meshexport.export('{0}/surface_{1}.ply'.format(path, epoch), 'ply') + + if return_mesh: + return meshexport + return traces + return None + + +def get_surface_high_res_mesh(sdf, resolution=100): + # get low res mesh to sample point cloud + grid = get_grid_uniform(100) + z = [] + points = grid['grid_points'] + + for i, pnts in enumerate(torch.split(points, 100000, dim=0)): + z.append(sdf(pnts).detach().cpu().numpy()) + z = np.concatenate(z, axis=0) + + z = z.astype(np.float32) + + verts, faces, normals, values = measure.marching_cubes_lewiner( + volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0], + grid['xyz'][2].shape[0]).transpose([1, 0, 2]), + level=0, + spacing=(grid['xyz'][0][2] - grid['xyz'][0][1], + grid['xyz'][0][2] - grid['xyz'][0][1], + grid['xyz'][0][2] - grid['xyz'][0][1])) + + verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]]) + + mesh_low_res = trimesh.Trimesh(verts, faces, vertex_normals=-normals) + # return mesh_low_res + + components = mesh_low_res.split(only_watertight=False) + areas = np.array([c.area for c in components], dtype=np.float) + mesh_low_res = components[areas.argmax()] + + recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0] + recon_pc = torch.from_numpy(recon_pc).float().cuda() + + # Center and align the recon pc + s_mean = recon_pc.mean(dim=0) + s_cov = recon_pc - s_mean + s_cov = torch.mm(s_cov.transpose(0, 1), s_cov) + vecs = torch.eig(s_cov, True)[1].transpose(0, 1) + if torch.det(vecs) < 0: + vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), vecs) + helper = torch.bmm(vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1), + (recon_pc - s_mean).unsqueeze(-1)).squeeze() + + grid_aligned = get_grid(helper.cpu(), resolution) + + grid_points = grid_aligned['grid_points'] + + g = [] + for i, pnts in enumerate(torch.split(grid_points, 100000, dim=0)): + g.append(torch.bmm(vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2), + pnts.unsqueeze(-1)).squeeze() + s_mean) + grid_points = torch.cat(g, dim=0) + + # MC to new grid + points = grid_points + z = [] + for i, pnts in enumerate(torch.split(points, 100000, dim=0)): + z.append(sdf(pnts).detach().cpu().numpy()) + z = np.concatenate(z, axis=0) + + meshexport = None + if (not (np.min(z) > 0 or np.max(z) < 0)): + + z = z.astype(np.float32) + + verts, faces, normals, values = measure.marching_cubes_lewiner( + volume=z.reshape(grid_aligned['xyz'][1].shape[0], grid_aligned['xyz'][0].shape[0], + grid_aligned['xyz'][2].shape[0]).transpose([1, 0, 2]), + level=0, + spacing=(grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1], + grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1], + grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1])) + + verts = torch.from_numpy(verts).cuda().float() + verts = torch.bmm(vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2), + verts.unsqueeze(-1)).squeeze() + verts = (verts + grid_points[0]).cpu().numpy() + + meshexport = trimesh.Trimesh(verts, faces, vertex_normals=-normals) + + return meshexport + + +def get_displacement_mesh(model): + + def get_detailed_mesh(input): + origin_mesh = input.copy() + mesh_points = torch.tensor(np.array(origin_mesh.vertices).astype(np.float32)).cuda().requires_grad_(True) + sdfs = model.implicit_network(mesh_points)[:, 0:1] + sdf_np = sdfs.detach().cpu().numpy() + vetices_normal = origin_mesh.vertex_normals + new_vertices = -1.0 * sdf_np * vetices_normal + np.array(origin_mesh.vertices) + new_mesh = trimesh.Trimesh(vertices=new_vertices, faces=origin_mesh.faces, process=False, visual=origin_mesh.visual) + print('Detailed mesh created!') + return new_mesh + + orimesh = model.mesh + subdivision_mesh = orimesh.subdivide() + new_mesh = get_detailed_mesh(orimesh) + submesh_full = get_detailed_mesh(subdivision_mesh) + + return new_mesh, submesh_full + +def get_displacement_animation(model): + + def get_detailed_mesh(aaa_mesh): + origin_mesh = aaa_mesh.copy() + mesh_points = torch.tensor(np.array(origin_mesh.vertices).astype(np.float32)).cuda().requires_grad_(True) + sdfs = model.implicit_network(mesh_points)[:, 0:1] + sdf_np = sdfs.detach().cpu().numpy() + + print('Detailed mesh created!') + return sdf_np + + orimesh = model.mesh + subdi_mesh_full = orimesh.subdivide() + sdf_np0 = get_detailed_mesh(orimesh) + sdf_np1 = get_detailed_mesh(subdi_mesh_full) + + return sdf_np0, sdf_np1 + + +def get_NormalMaps(model): + + def get_detailed_mesh(input): + origin_mesh = input.copy() + mesh_points = torch.tensor(np.array(origin_mesh.vertices).astype(np.float32)).cuda().requires_grad_(True) + detailed_normal = model.implicit_network.gradient(mesh_points) + detailed_normal = detailed_normal.squeeze().detach().cpu().numpy() + vetices_normal = np.asarray(origin_mesh.vertex_normals) + return vetices_normal, detailed_normal + + orimesh = model.mesh + smooth1, detail1 = get_detailed_mesh(orimesh) + return smooth1, detail1 + +def get_grid_uniform(resolution): + x = np.linspace(-1.0, 1.0, resolution) + y = x + z = x + + xx, yy, zz = np.meshgrid(x, y, z) + grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float) + + return {"grid_points": grid_points.cuda(), + "shortest_axis_length": 2.0, + "xyz": [x, y, z], + "shortest_axis_index": 0} + +def get_grid(points, resolution): + eps = 0.2 + input_min = torch.min(points, dim=0)[0].squeeze().numpy() + input_max = torch.max(points, dim=0)[0].squeeze().numpy() + + bounding_box = input_max - input_min + shortest_axis = np.argmin(bounding_box) + if (shortest_axis == 0): + x = np.linspace(input_min[shortest_axis] - eps, + input_max[shortest_axis] + eps, resolution) + length = np.max(x) - np.min(x) + y = np.arange(input_min[1] - eps, input_max[1] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1)) + z = np.arange(input_min[2] - eps, input_max[2] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1)) + elif (shortest_axis == 1): + y = np.linspace(input_min[shortest_axis] - eps, + input_max[shortest_axis] + eps, resolution) + length = np.max(y) - np.min(y) + x = np.arange(input_min[0] - eps, input_max[0] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1)) + z = np.arange(input_min[2] - eps, input_max[2] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1)) + elif (shortest_axis == 2): + z = np.linspace(input_min[shortest_axis] - eps, + input_max[shortest_axis] + eps, resolution) + length = np.max(z) - np.min(z) + x = np.arange(input_min[0] - eps, input_max[0] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1)) + y = np.arange(input_min[1] - eps, input_max[1] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1)) + + xx, yy, zz = np.meshgrid(x, y, z) + grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda() + return {"grid_points": grid_points, + "shortest_axis_length": length, + "xyz": [x, y, z], + "shortest_axis_index": shortest_axis} + +def plot_depth_maps(depth_maps, path, epoch, plot_nrow, img_res): + depth_maps_plot = lin2img(depth_maps, img_res) + + tensor = torchvision.utils.make_grid(depth_maps_plot.repeat(1, 3, 1, 1), + scale_each=True, + normalize=True, + nrow=plot_nrow).cpu().detach().numpy() + tensor = tensor.transpose(1, 2, 0) + scale_factor = 255 + tensor = (tensor * scale_factor).astype(np.uint8) + + img = Image.fromarray(tensor) + img.save('{0}/depth_{1}.png'.format(path, epoch)) + +def plot_images(rgb_points, ground_true, path, epoch, plot_nrow, img_res): + ground_true = (ground_true.cuda() + 1.) / 2. + rgb_points = (rgb_points + 1. ) / 2. + + output_vs_gt = torch.cat((rgb_points, ground_true), dim=0) + output_vs_gt_plot = lin2img(output_vs_gt, img_res) + + tensor = torchvision.utils.make_grid(output_vs_gt_plot, + scale_each=False, + normalize=False, + nrow=plot_nrow).cpu().detach().numpy() + + tensor = tensor.transpose(1, 2, 0) + scale_factor = 255 + tensor = (tensor * scale_factor).astype(np.uint8) + + img = Image.fromarray(tensor) + img.save('{0}/rendering_{1}.png'.format(path, epoch)) + +def lin2img(tensor, img_res): + batch_size, num_samples, channels = tensor.shape + return tensor.permute(0, 2, 1).view(batch_size, channels, img_res[0], img_res[1]) \ No newline at end of file diff --git a/reconstruction/PBIDR/code/utils/rend_util.py b/reconstruction/PBIDR/code/utils/rend_util.py new file mode 100644 index 0000000..7a255e2 --- /dev/null +++ b/reconstruction/PBIDR/code/utils/rend_util.py @@ -0,0 +1,192 @@ +import numpy as np +import imageio +import skimage +import cv2 +import torch +from torch.nn import functional as F + +def load_rgb(path): + img = imageio.imread(path) + img = skimage.img_as_float32(img) + + # pixel values between [-1,1] + img -= 0.5 + img *= 2. + img = img.transpose(2, 0, 1) + return img + +def load_mask(path): + alpha = imageio.imread(path, as_gray=True) + alpha = skimage.img_as_float32(alpha) + object_mask = alpha > 127.5 + + return object_mask + +def load_mask_white_bg(path): + alpha = imageio.imread(path, as_gray=True) + alpha = skimage.img_as_float32(alpha) + object_mask = alpha < 250.5 + + return object_mask + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K/K[2,2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + to_gl = np.eye(3, dtype=np.float32) + to_gl[0, 0] = -1. + to_gl[1, 1] = -1. + pose[:3, :3] = np.dot(R.transpose(), to_gl) + pose[:3,3] = (t[:3] / t[3])[:,0] + + return intrinsics, pose + +def get_camera_params(uv, pose, intrinsics): + if pose.shape[1] == 7: #In case of quaternion vector representation + cam_loc = pose[:, 4:] + R = quat_to_rot(pose[:,:4]) + p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float() + p[:, :3, :3] = R + p[:, :3, 3] = cam_loc + else: # In case of pose matrix representation + cam_loc = pose[:, :3, 3] + p = pose + + batch_size, num_samples, _ = uv.shape + + depth = torch.ones((batch_size, num_samples)).cuda() + x_cam = uv[:, :, 0].view(batch_size, -1) + y_cam = uv[:, :, 1].view(batch_size, -1) + z_cam = depth.view(batch_size, -1) + + pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics) + + # permute for batch matrix product + pixel_points_cam = pixel_points_cam.permute(0, 2, 1) + + world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3] + ray_dirs = world_coords - cam_loc[:, None, :] + ray_dirs = F.normalize(ray_dirs, dim=2) + + return ray_dirs, cam_loc + +def get_camera_for_plot(pose): + if pose.shape[1] == 7: #In case of quaternion vector representation + cam_loc = pose[:, 4:].detach() + R = quat_to_rot(pose[:,:4].detach()) + else: # In case of pose matrix representation + cam_loc = pose[:, :3, 3] + R = pose[:, :3, :3] + cam_dir = R[:, :3, 2] + return cam_loc, cam_dir + +def lift(x, y, z, intrinsics): + # parse intrinsics + intrinsics = intrinsics.cuda() + fx = intrinsics[:, 0, 0] + fy = intrinsics[:, 1, 1] + cx = intrinsics[:, 0, 2] + cy = intrinsics[:, 1, 2] + sk = intrinsics[:, 0, 1] + + x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z + y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z + + # homogeneous + return torch.stack((x_lift, y_lift, z, torch.ones_like(z).cuda()), dim=-1) + +def quat_to_rot(q): + batch_size, _ = q.shape + q = F.normalize(q, dim=1) + R = torch.ones((batch_size, 3,3)).cuda() + qr=q[:,0] + qi = q[:, 1] + qj = q[:, 2] + qk = q[:, 3] + R[:, 0, 0]=1-2 * (qj**2 + qk**2) + R[:, 0, 1] = 2 * (qj *qi -qk*qr) + R[:, 0, 2] = 2 * (qi * qk + qr * qj) + R[:, 1, 0] = 2 * (qj * qi + qk * qr) + R[:, 1, 1] = 1-2 * (qi**2 + qk**2) + R[:, 1, 2] = 2*(qj*qk - qi*qr) + R[:, 2, 0] = 2 * (qk * qi-qj * qr) + R[:, 2, 1] = 2 * (qj*qk + qi*qr) + R[:, 2, 2] = 1-2 * (qi**2 + qj**2) + return R + +def rot_to_quat(R): + batch_size, _,_ = R.shape + q = torch.ones((batch_size, 4)).cuda() + + R00 = R[:, 0,0] + R01 = R[:, 0, 1] + R02 = R[:, 0, 2] + R10 = R[:, 1, 0] + R11 = R[:, 1, 1] + R12 = R[:, 1, 2] + R20 = R[:, 2, 0] + R21 = R[:, 2, 1] + R22 = R[:, 2, 2] + + q[:,0]=torch.sqrt(1.0+R00+R11+R22)/2 + q[:, 1]=(R21-R12)/(4*q[:,0]) + q[:, 2] = (R02 - R20) / (4 * q[:, 0]) + q[:, 3] = (R10 - R01) / (4 * q[:, 0]) + return q + +def get_sphere_intersection(cam_loc, ray_directions, r = 1.0): + # Input: n_images x 4 x 4 ; n_images x n_rays x 3 + # Output: n_images * n_rays x 2 (close and far) ; n_images * n_rays + + n_imgs, n_pix, _ = ray_directions.shape + + cam_loc = cam_loc.unsqueeze(-1) + ray_cam_dot = torch.bmm(ray_directions, cam_loc).squeeze() + under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2,1) ** 2 - r ** 2) + + under_sqrt = under_sqrt.reshape(-1) + mask_intersect = under_sqrt > 0 + + sphere_intersections = torch.zeros(n_imgs * n_pix, 2).cuda().float() + sphere_intersections[mask_intersect] = torch.sqrt(under_sqrt[mask_intersect]).unsqueeze(-1) * torch.Tensor([-1, 1]).cuda().float() + sphere_intersections[mask_intersect] -= ray_cam_dot.reshape(-1)[mask_intersect].unsqueeze(-1) + + sphere_intersections = sphere_intersections.reshape(n_imgs, n_pix, 2) + sphere_intersections = sphere_intersections.clamp_min(0.0) + mask_intersect = mask_intersect.reshape(n_imgs, n_pix) + + return sphere_intersections, mask_intersect + +def get_depth(points, pose): + ''' Retruns depth from 3D points according to camera pose ''' + batch_size, num_samples, _ = points.shape + if pose.shape[1] == 7: # In case of quaternion vector representation + cam_loc = pose[:, 4:] + R = quat_to_rot(pose[:, :4]) + pose = torch.eye(4).unsqueeze(0).repeat(batch_size, 1, 1).cuda().float() + pose[:, :3, 3] = cam_loc + pose[:, :3, :3] = R + + points_hom = torch.cat((points, torch.ones((batch_size, num_samples, 1)).cuda()), dim=2) + + # permute for batch matrix product + points_hom = points_hom.permute(0, 2, 1) + + points_cam = torch.inverse(pose).bmm(points_hom) + depth = points_cam[:, 2, :][:, :, None] + return depth + diff --git a/reconstruction/PBIDR/figures/overview.png b/reconstruction/PBIDR/figures/overview.png new file mode 100644 index 0000000..1f89f64 Binary files /dev/null and b/reconstruction/PBIDR/figures/overview.png differ diff --git a/reconstruction/PBIDR/requirements.txt b/reconstruction/PBIDR/requirements.txt new file mode 100644 index 0000000..6d62eab --- /dev/null +++ b/reconstruction/PBIDR/requirements.txt @@ -0,0 +1,196 @@ +aadict==0.2.3 +absl-py==0.12.0 +aiohttp==3.7.4.post0 +ansicolors==1.1.8 +ansiwrap==0.8.4 +apptools==5.1.0 +argon2-cffi==20.1.0 +asset==0.6.13 +astor==0.8.1 +async-generator==1.10 +async-timeout==3.0.1 +attrs==20.3.0 +autobahn==21.3.1 +Automat==20.2.0 +autoprop==3.0.0 +backcall==0.2.0 +bleach==3.3.0 +cached-property==1.5.2 +certifi==2020.12.5 +cffi==1.14.5 +chainmap==1.0.3 +chardet==4.0.0 +charset-normalizer==2.0.4 +chumpy==0.70 +cloudpickle==1.6.0 +colorlog==5.0.1 +combomethod==1.0.12 +ConfigArgParse==1.4.1 +configobj==5.0.6 +constantly==15.1.0 +cryptography==35.0.0 +cvxpy==1.1.13 +cycler==0.10.0 +Cython==0.29.24 +dask==2021.7.0 +decorator==4.4.2 +defusedxml==0.7.1 +ecos==2.0.7.post1 +entrypoints==0.3 +envisage==6.0.1 +face-alignment==1.3.3 +freetype-py==2.2.0 +fsspec==2021.7.0 +future==0.18.2 +fvcore==0.1.5.post20210423 +gast==0.4.0 +glcontext==2.3.4 +globre==0.1.5 +glooey==0.3.5 +google-pasta==0.2.0 +GPUtil==1.4.0 +grpcio==1.37.1 +h5py==3.2.1 +hyperlink==21.0.0 +idna==3.2 +imageio==2.9.0 +imageio-ffmpeg==0.4.3 +importlib-metadata==4.0.1 +importlib-resources==5.3.0 +incremental==21.3.0 +intspan==1.6.1 +iopath==0.1.8 +ipykernel==5.5.3 +ipython==7.23.0 +ipython-genutils==0.2.0 +jedi==0.18.0 +Jinja2==2.11.3 +jsonschema==3.2.0 +jupyter-client==6.2.0 +jupyter-core==4.7.1 +jupyterlab-pygments==0.1.2 +Keras-Applications==1.0.8 +Keras-Preprocessing==1.1.2 +kiwisolver==1.3.1 +llvmlite==0.36.0 +locket==0.2.1 +lxml==4.6.3 +mapbox-earcut==0.12.10 +Markdown==3.3.4 +MarkupSafe==1.1.1 +matplotlib==3.4.1 +matplotlib-inline==0.1.2 +mayavi==4.7.3 +mementos==1.3.1 +menpo==0.11.0 +menpo3d==0.8.3 +meshio==4.4.6 +mistune==0.8.4 +moderngl==5.6.4 +more-itertools==8.8.0 +mpmath==1.2.1 +msgpack==1.0.2 +multidict==5.2.0 +nbclient==0.5.3 +nbconvert==6.0.7 +nbformat==5.1.3 +nest-asyncio==1.5.1 +networkx==2.5.1 +notebook==6.3.0 +nulltype==2.3.1 +numba==0.53.1 +numpy==1.20.2 +opencv-python==4.5.1.48 +options==1.4.10 +osqp==0.6.2.post0 +packaging==20.9 +pandas==1.3.5 +pandocfilters==1.4.3 +parso==0.8.2 +partd==1.2.0 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==8.2.0 +plotly==5.1.0 +portalocker==2.3.0 +prometheus-client==0.10.1 +prompt-toolkit==3.0.18 +protobuf==3.17.0 +psutil==5.8.0 +ptyprocess==0.7.0 +pycollada==0.7.1 +pycparser==2.20 +pyembree==0.1.6 +pyface==7.3.0 +pyglet==1.5.18 +Pygments==2.9.0 +pyhocon==0.3.57 +PyOpenGL==3.1.0 +pyparsing==2.4.7 +PyQt5==5.15.6 +PyQt5-Qt5==5.15.2 +PyQt5-sip==12.9.0 +pyrender==0.1.45 +pyrsistent==0.17.3 +python-dateutil==2.8.1 +python-fcl==0.0.12 +pytorch3d==0.4.0 +pytz==2021.3 +PyWavelets==1.1.1 +PyYAML==5.4.1 +pyzmq==22.0.3 +qdldl==0.1.5.post0 +readline==6.2.4.1 +requests==2.26.0 +retrying==1.3.3 +Rtree==0.9.7 +say==1.6.6 +scikit-image==0.18.1 +scipy==1.6.2 +scs==2.1.3 +Send2Trash==1.5.0 +Shapely==1.7.1 +show==1.6.0 +signature-dispatch==0.1.0 +simplere==1.2.13 +six==1.12.0 +svg.path==4.1 +sympy==1.8 +tabulate==0.8.9 +tenacity==7.0.0 +tensorboard==1.14.0 +tensorflow-estimator==1.14.0 +tensorflow-gpu==1.14.0 +termcolor==1.1.0 +terminado==0.9.4 +testpath==0.4.4 +textdata==2.4.1 +textwrap3==0.9.2 +tifffile==2021.4.8 +toolz==0.11.1 +torch==1.7.1 +torchvision==0.7.0 +tornado==6.1 +tqdm==4.60.0 +traitlets==5.0.5 +traits==6.3.1 +traitsui==7.2.1 +triangle==20200424 +trimesh==3.9.25 +Twisted==21.7.0 +txaio==21.2.1 +typing-extensions==3.7.4.3 +urllib3==1.26.6 +vecrec==0.3.0 +vtk==9.0.1 +wcwidth==0.2.5 +webencodings==0.5.1 +Werkzeug==2.0.1 +wrapt==1.12.1 +wslink==1.1.0 +xatlas==0.0.5 +xxhash==2.0.2 +yacs==0.1.8 +yarl==1.7.0 +zipp==3.4.1 +zope.interface==5.4.0