Files
insightface/reconstruction/PBIDR/code/model/sample_network.py

21 lines
900 B
Python
Raw Normal View History

2022-03-19 14:24:51 +08:00
import torch.nn as nn
import torch
class SampleNetwork(nn.Module):
'''
Represent the intersection (sample) point as differentiable function of the implicit geometry and camera parameters.
See equation 3 in the paper for more details.
'''
def forward(self, surface_output, surface_sdf_values, surface_points_grad, surface_dists, surface_cam_loc, surface_ray_dirs):
# t -> t(theta)
surface_ray_dirs_0 = surface_ray_dirs.detach()
surface_points_dot = torch.bmm(surface_points_grad.view(-1, 1, 3),
surface_ray_dirs_0.view(-1, 3, 1)).squeeze(-1)
surface_dists_theta = surface_dists - (surface_output - surface_sdf_values) / surface_points_dot
# t(theta) -> x(theta,c,v)
surface_points_theta_c_v = surface_cam_loc + surface_dists_theta * surface_ray_dirs
return surface_points_theta_c_v