Files
insightface/reconstruction/PBIDR/code/utils/plots.py
2022-03-19 14:24:51 +08:00

424 lines
17 KiB
Python

import plotly.graph_objs as go
import plotly.offline as offline
import numpy as np
import torch
from skimage import measure
import torchvision
import trimesh
from PIL import Image
from utils import rend_util
import pickle
def plot_latent(model, latent, indices, model_outputs ,pose, rgb_gt, path, epoch, img_res, plot_nimgs, max_depth, resolution):
# arrange data to plot
batch_size, num_samples, _ = rgb_gt.shape
network_object_mask = model_outputs['network_object_mask']
points = model_outputs['points'].reshape(batch_size, num_samples, 3)
rgb_eval = model_outputs['rgb_values']
rgb_eval = rgb_eval.reshape(batch_size, num_samples, 3)
depth = torch.ones(batch_size * num_samples).cuda().float() * max_depth
depth[network_object_mask] = rend_util.get_depth(points, pose).reshape(-1)[network_object_mask]
depth = depth.reshape(batch_size, num_samples, 1)
network_object_mask = network_object_mask.reshape(batch_size,-1)
cam_loc, cam_dir = rend_util.get_camera_for_plot(pose)
# plot rendered images
plot_images(rgb_eval, rgb_gt, path, epoch, plot_nimgs, img_res)
# plot depth maps
plot_depth_maps(depth, path, epoch, plot_nimgs, img_res)
data = []
# plot surface
surface_traces = get_surface_trace(path=path,
epoch=epoch,
sdf=lambda x: model.implicit_network(torch.cat([latent.expand(len(x), -1), x], 1))[:, 0],
resolution=resolution
)
data.append(surface_traces[0])
# plot cameras locations
for i, loc, dir in zip(indices, cam_loc, cam_dir):
data.append(get_3D_quiver_trace(loc.unsqueeze(0), dir.unsqueeze(0), name='camera_{0}'.format(i)))
for i, p, m in zip(indices, points, network_object_mask):
p = p[m]
sampling_idx = torch.randperm(p.shape[0])[:2048]
p = p[sampling_idx, :]
val = model.implicit_network(torch.cat([latent.expand(len(p), -1), p], 1))
caption = ["sdf: {0} ".format(v[0].item()) for v in val]
data.append(get_3D_scatter_trace(p, name='intersection_points_{0}'.format(i), caption=caption))
fig = go.Figure(data=data)
scene_dict = dict(xaxis=dict(range=[-3, 3], autorange=False),
yaxis=dict(range=[-3, 3], autorange=False),
zaxis=dict(range=[-3, 3], autorange=False),
aspectratio=dict(x=1, y=1, z=1))
fig.update_layout(scene=scene_dict, width=1400, height=1400, showlegend=True)
filename = '{0}/surface_{1}.html'.format(path, epoch)
offline.plot(fig, filename=filename, auto_open=False)
def plot(model, indices, model_outputs ,pose, rgb_gt, path, epoch, img_res, plot_nimgs, max_depth, resolution):
# arrange data to plot
batch_size, num_samples, _ = rgb_gt.shape
network_object_mask = model_outputs['network_object_mask']
points = model_outputs['points'].reshape(batch_size, num_samples, 3)
rgb_eval = model_outputs['rgb_values']
rgb_eval = rgb_eval.reshape(batch_size, num_samples, 3)
depth = torch.ones(batch_size * num_samples).cuda().float() * max_depth
depth[network_object_mask] = rend_util.get_depth(points, pose).reshape(-1)[network_object_mask]
depth = depth.reshape(batch_size, num_samples, 1)
network_object_mask = network_object_mask.reshape(batch_size,-1)
cam_loc, cam_dir = rend_util.get_camera_for_plot(pose)
# plot rendered images
plot_images(rgb_eval, rgb_gt, path, epoch, plot_nimgs, img_res)
# plot depth maps
plot_depth_maps(depth, path, epoch, plot_nimgs, img_res)
data = []
# plot surface
surface_traces = get_surface_trace(path=path,
epoch=epoch,
sdf=lambda x: model.implicit_network(x)[:, 0],
resolution=resolution
)
data.append(surface_traces[0])
# plot cameras locations
for i, loc, dir in zip(indices, cam_loc, cam_dir):
data.append(get_3D_quiver_trace(loc.unsqueeze(0), dir.unsqueeze(0), name='camera_{0}'.format(i)))
for i, p, m in zip(indices, points, network_object_mask):
p = p[m]
sampling_idx = torch.randperm(p.shape[0])[:2048]
p = p[sampling_idx, :]
val = model.implicit_network(p)
caption = ["sdf: {0} ".format(v[0].item()) for v in val]
data.append(get_3D_scatter_trace(p, name='intersection_points_{0}'.format(i), caption=caption))
fig = go.Figure(data=data)
scene_dict = dict(xaxis=dict(range=[-3, 3], autorange=False),
yaxis=dict(range=[-3, 3], autorange=False),
zaxis=dict(range=[-3, 3], autorange=False),
aspectratio=dict(x=1, y=1, z=1))
fig.update_layout(scene=scene_dict, width=1400, height=1400, showlegend=True)
filename = '{0}/surface_{1}.html'.format(path, epoch)
offline.plot(fig, filename=filename, auto_open=False)
def get_3D_scatter_trace(points, name='', size=3, caption=None):
assert points.shape[1] == 3, "3d scatter plot input points are not correctely shaped "
assert len(points.shape) == 2, "3d scatter plot input points are not correctely shaped "
trace = go.Scatter3d(
x=points[:, 0].cpu(),
y=points[:, 1].cpu(),
z=points[:, 2].cpu(),
mode='markers',
name=name,
marker=dict(
size=size,
line=dict(
width=2,
),
opacity=1.0,
), text=caption)
return trace
def get_3D_quiver_trace(points, directions, color='#bd1540', name=''):
assert points.shape[1] == 3, "3d cone plot input points are not correctely shaped "
assert len(points.shape) == 2, "3d cone plot input points are not correctely shaped "
assert directions.shape[1] == 3, "3d cone plot input directions are not correctely shaped "
assert len(directions.shape) == 2, "3d cone plot input directions are not correctely shaped "
trace = go.Cone(
name=name,
x=points[:, 0].cpu(),
y=points[:, 1].cpu(),
z=points[:, 2].cpu(),
u=directions[:, 0].cpu(),
v=directions[:, 1].cpu(),
w=directions[:, 2].cpu(),
sizemode='absolute',
sizeref=0.125,
showscale=False,
colorscale=[[0, color], [1, color]],
anchor="tail"
)
return trace
def get_surface_trace(path, epoch, sdf, resolution=100, return_mesh=False):
grid = get_grid_uniform(resolution)
points = grid['grid_points']
z = []
for i, pnts in enumerate(torch.split(points, 100000, dim=0)):
z.append(sdf(pnts).detach().cpu().numpy())
z = np.concatenate(z, axis=0)
if (not (np.min(z) > 0 or np.max(z) < 0)):
z = z.astype(np.float32)
verts, faces, normals, values = measure.marching_cubes_lewiner(
volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0],
grid['xyz'][2].shape[0]).transpose([1, 0, 2]),
level=0,
spacing=(grid['xyz'][0][2] - grid['xyz'][0][1],
grid['xyz'][0][2] - grid['xyz'][0][1],
grid['xyz'][0][2] - grid['xyz'][0][1]))
verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]])
I, J, K = faces.transpose()
traces = [go.Mesh3d(x=verts[:, 0], y=verts[:, 1], z=verts[:, 2],
i=I, j=J, k=K, name='implicit_surface',
opacity=1.0)]
meshexport = trimesh.Trimesh(verts, faces, vertex_normals=-normals)
meshexport.export('{0}/surface_{1}.ply'.format(path, epoch), 'ply')
if return_mesh:
return meshexport
return traces
return None
def get_surface_high_res_mesh(sdf, resolution=100):
# get low res mesh to sample point cloud
grid = get_grid_uniform(100)
z = []
points = grid['grid_points']
for i, pnts in enumerate(torch.split(points, 100000, dim=0)):
z.append(sdf(pnts).detach().cpu().numpy())
z = np.concatenate(z, axis=0)
z = z.astype(np.float32)
verts, faces, normals, values = measure.marching_cubes_lewiner(
volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0],
grid['xyz'][2].shape[0]).transpose([1, 0, 2]),
level=0,
spacing=(grid['xyz'][0][2] - grid['xyz'][0][1],
grid['xyz'][0][2] - grid['xyz'][0][1],
grid['xyz'][0][2] - grid['xyz'][0][1]))
verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]])
mesh_low_res = trimesh.Trimesh(verts, faces, vertex_normals=-normals)
# return mesh_low_res
components = mesh_low_res.split(only_watertight=False)
areas = np.array([c.area for c in components], dtype=np.float)
mesh_low_res = components[areas.argmax()]
recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0]
recon_pc = torch.from_numpy(recon_pc).float().cuda()
# Center and align the recon pc
s_mean = recon_pc.mean(dim=0)
s_cov = recon_pc - s_mean
s_cov = torch.mm(s_cov.transpose(0, 1), s_cov)
vecs = torch.eig(s_cov, True)[1].transpose(0, 1)
if torch.det(vecs) < 0:
vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), vecs)
helper = torch.bmm(vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1),
(recon_pc - s_mean).unsqueeze(-1)).squeeze()
grid_aligned = get_grid(helper.cpu(), resolution)
grid_points = grid_aligned['grid_points']
g = []
for i, pnts in enumerate(torch.split(grid_points, 100000, dim=0)):
g.append(torch.bmm(vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2),
pnts.unsqueeze(-1)).squeeze() + s_mean)
grid_points = torch.cat(g, dim=0)
# MC to new grid
points = grid_points
z = []
for i, pnts in enumerate(torch.split(points, 100000, dim=0)):
z.append(sdf(pnts).detach().cpu().numpy())
z = np.concatenate(z, axis=0)
meshexport = None
if (not (np.min(z) > 0 or np.max(z) < 0)):
z = z.astype(np.float32)
verts, faces, normals, values = measure.marching_cubes_lewiner(
volume=z.reshape(grid_aligned['xyz'][1].shape[0], grid_aligned['xyz'][0].shape[0],
grid_aligned['xyz'][2].shape[0]).transpose([1, 0, 2]),
level=0,
spacing=(grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1],
grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1],
grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1]))
verts = torch.from_numpy(verts).cuda().float()
verts = torch.bmm(vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2),
verts.unsqueeze(-1)).squeeze()
verts = (verts + grid_points[0]).cpu().numpy()
meshexport = trimesh.Trimesh(verts, faces, vertex_normals=-normals)
return meshexport
def get_displacement_mesh(model):
def get_detailed_mesh(input):
origin_mesh = input.copy()
mesh_points = torch.tensor(np.array(origin_mesh.vertices).astype(np.float32)).cuda().requires_grad_(True)
sdfs = model.implicit_network(mesh_points)[:, 0:1]
sdf_np = sdfs.detach().cpu().numpy()
vetices_normal = origin_mesh.vertex_normals
new_vertices = -1.0 * sdf_np * vetices_normal + np.array(origin_mesh.vertices)
new_mesh = trimesh.Trimesh(vertices=new_vertices, faces=origin_mesh.faces, process=False, visual=origin_mesh.visual)
print('Detailed mesh created!')
return new_mesh
orimesh = model.mesh
subdivision_mesh = orimesh.subdivide()
new_mesh = get_detailed_mesh(orimesh)
submesh_full = get_detailed_mesh(subdivision_mesh)
return new_mesh, submesh_full
def get_displacement_animation(model):
def get_detailed_mesh(aaa_mesh):
origin_mesh = aaa_mesh.copy()
mesh_points = torch.tensor(np.array(origin_mesh.vertices).astype(np.float32)).cuda().requires_grad_(True)
sdfs = model.implicit_network(mesh_points)[:, 0:1]
sdf_np = sdfs.detach().cpu().numpy()
print('Detailed mesh created!')
return sdf_np
orimesh = model.mesh
subdi_mesh_full = orimesh.subdivide()
sdf_np0 = get_detailed_mesh(orimesh)
sdf_np1 = get_detailed_mesh(subdi_mesh_full)
return sdf_np0, sdf_np1
def get_NormalMaps(model):
def get_detailed_mesh(input):
origin_mesh = input.copy()
mesh_points = torch.tensor(np.array(origin_mesh.vertices).astype(np.float32)).cuda().requires_grad_(True)
detailed_normal = model.implicit_network.gradient(mesh_points)
detailed_normal = detailed_normal.squeeze().detach().cpu().numpy()
vetices_normal = np.asarray(origin_mesh.vertex_normals)
return vetices_normal, detailed_normal
orimesh = model.mesh
smooth1, detail1 = get_detailed_mesh(orimesh)
return smooth1, detail1
def get_grid_uniform(resolution):
x = np.linspace(-1.0, 1.0, resolution)
y = x
z = x
xx, yy, zz = np.meshgrid(x, y, z)
grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float)
return {"grid_points": grid_points.cuda(),
"shortest_axis_length": 2.0,
"xyz": [x, y, z],
"shortest_axis_index": 0}
def get_grid(points, resolution):
eps = 0.2
input_min = torch.min(points, dim=0)[0].squeeze().numpy()
input_max = torch.max(points, dim=0)[0].squeeze().numpy()
bounding_box = input_max - input_min
shortest_axis = np.argmin(bounding_box)
if (shortest_axis == 0):
x = np.linspace(input_min[shortest_axis] - eps,
input_max[shortest_axis] + eps, resolution)
length = np.max(x) - np.min(x)
y = np.arange(input_min[1] - eps, input_max[1] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1))
z = np.arange(input_min[2] - eps, input_max[2] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1))
elif (shortest_axis == 1):
y = np.linspace(input_min[shortest_axis] - eps,
input_max[shortest_axis] + eps, resolution)
length = np.max(y) - np.min(y)
x = np.arange(input_min[0] - eps, input_max[0] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1))
z = np.arange(input_min[2] - eps, input_max[2] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1))
elif (shortest_axis == 2):
z = np.linspace(input_min[shortest_axis] - eps,
input_max[shortest_axis] + eps, resolution)
length = np.max(z) - np.min(z)
x = np.arange(input_min[0] - eps, input_max[0] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1))
y = np.arange(input_min[1] - eps, input_max[1] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1))
xx, yy, zz = np.meshgrid(x, y, z)
grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda()
return {"grid_points": grid_points,
"shortest_axis_length": length,
"xyz": [x, y, z],
"shortest_axis_index": shortest_axis}
def plot_depth_maps(depth_maps, path, epoch, plot_nrow, img_res):
depth_maps_plot = lin2img(depth_maps, img_res)
tensor = torchvision.utils.make_grid(depth_maps_plot.repeat(1, 3, 1, 1),
scale_each=True,
normalize=True,
nrow=plot_nrow).cpu().detach().numpy()
tensor = tensor.transpose(1, 2, 0)
scale_factor = 255
tensor = (tensor * scale_factor).astype(np.uint8)
img = Image.fromarray(tensor)
img.save('{0}/depth_{1}.png'.format(path, epoch))
def plot_images(rgb_points, ground_true, path, epoch, plot_nrow, img_res):
ground_true = (ground_true.cuda() + 1.) / 2.
rgb_points = (rgb_points + 1. ) / 2.
output_vs_gt = torch.cat((rgb_points, ground_true), dim=0)
output_vs_gt_plot = lin2img(output_vs_gt, img_res)
tensor = torchvision.utils.make_grid(output_vs_gt_plot,
scale_each=False,
normalize=False,
nrow=plot_nrow).cpu().detach().numpy()
tensor = tensor.transpose(1, 2, 0)
scale_factor = 255
tensor = (tensor * scale_factor).astype(np.uint8)
img = Image.fromarray(tensor)
img.save('{0}/rendering_{1}.png'.format(path, epoch))
def lin2img(tensor, img_res):
batch_size, num_samples, channels = tensor.shape
return tensor.permute(0, 2, 1).view(batch_size, channels, img_res[0], img_res[1])