update PBIDR code

This commit is contained in:
XingyuRenSJTU
2022-03-19 14:24:51 +08:00
parent 478aafb4fc
commit b6cd2acc68
21 changed files with 2903 additions and 1 deletions

View File

@@ -0,0 +1,20 @@
import torch.nn as nn
import torch
class SampleNetwork(nn.Module):
'''
Represent the intersection (sample) point as differentiable function of the implicit geometry and camera parameters.
See equation 3 in the paper for more details.
'''
def forward(self, surface_output, surface_sdf_values, surface_points_grad, surface_dists, surface_cam_loc, surface_ray_dirs):
# t -> t(theta)
surface_ray_dirs_0 = surface_ray_dirs.detach()
surface_points_dot = torch.bmm(surface_points_grad.view(-1, 1, 3),
surface_ray_dirs_0.view(-1, 3, 1)).squeeze(-1)
surface_dists_theta = surface_dists - (surface_output - surface_sdf_values) / surface_points_dot
# t(theta) -> x(theta,c,v)
surface_points_theta_c_v = surface_cam_loc + surface_dists_theta * surface_ray_dirs
return surface_points_theta_c_v