mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-16 13:46:15 +00:00
21 lines
900 B
Python
21 lines
900 B
Python
import torch.nn as nn
|
|
import torch
|
|
|
|
class SampleNetwork(nn.Module):
|
|
'''
|
|
Represent the intersection (sample) point as differentiable function of the implicit geometry and camera parameters.
|
|
See equation 3 in the paper for more details.
|
|
'''
|
|
|
|
def forward(self, surface_output, surface_sdf_values, surface_points_grad, surface_dists, surface_cam_loc, surface_ray_dirs):
|
|
# t -> t(theta)
|
|
surface_ray_dirs_0 = surface_ray_dirs.detach()
|
|
surface_points_dot = torch.bmm(surface_points_grad.view(-1, 1, 3),
|
|
surface_ray_dirs_0.view(-1, 3, 1)).squeeze(-1)
|
|
surface_dists_theta = surface_dists - (surface_output - surface_sdf_values) / surface_points_dot
|
|
|
|
# t(theta) -> x(theta,c,v)
|
|
surface_points_theta_c_v = surface_cam_loc + surface_dists_theta * surface_ray_dirs
|
|
|
|
return surface_points_theta_c_v
|