mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-16 21:47:47 +00:00
155 lines
5.7 KiB
Python
155 lines
5.7 KiB
Python
import os
|
|
import torch
|
|
import numpy as np
|
|
|
|
import utils.general as utils
|
|
from utils import rend_util
|
|
|
|
class IFDataset(torch.utils.data.Dataset):
|
|
"""Dataset for a class of objects, where each datapoint is a SceneInstanceDataset."""
|
|
|
|
def __init__(self,
|
|
train_cameras,
|
|
data_dir,
|
|
img_res,
|
|
scan_id=0,
|
|
cam_file=None
|
|
):
|
|
|
|
self.instance_dir = os.path.join('../data', data_dir, 'scan{0}'.format(scan_id))
|
|
|
|
self.total_pixels = img_res[0] * img_res[1]
|
|
self.img_res = img_res
|
|
|
|
assert os.path.exists(self.instance_dir), "Data directory is empty"
|
|
|
|
self.sampling_idx = None
|
|
self.train_cameras = train_cameras
|
|
|
|
image_dir = '{0}/image'.format(self.instance_dir)
|
|
image_paths = sorted(utils.glob_imgs(image_dir))
|
|
mask_dir = '{0}/mask'.format(self.instance_dir)
|
|
mask_paths = sorted(utils.glob_imgs(mask_dir))
|
|
|
|
self.n_images = len(image_paths)
|
|
|
|
self.cam_file = '{0}/cameras.npz'.format(self.instance_dir)
|
|
if cam_file is not None:
|
|
self.cam_file = '{0}/{1}'.format(self.instance_dir, cam_file)
|
|
|
|
camera_dict = np.load(self.cam_file)
|
|
scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
|
|
world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
|
|
|
|
|
|
self.intrinsics_all = []
|
|
self.pose_all = []
|
|
for scale_mat, world_mat in zip(scale_mats, world_mats):
|
|
P = world_mat @ scale_mat
|
|
P = P[:3, :4]
|
|
intrinsics, pose = rend_util.load_K_Rt_from_P(None, P)
|
|
self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
|
|
self.pose_all.append(torch.from_numpy(pose).float())
|
|
|
|
self.rgb_images = []
|
|
for path in image_paths:
|
|
rgb = rend_util.load_rgb(path)
|
|
rgb = rgb.reshape(3, -1).transpose(1, 0)
|
|
self.rgb_images.append(torch.from_numpy(rgb).float())
|
|
|
|
self.object_masks = []
|
|
for path in mask_paths:
|
|
object_mask = rend_util.load_mask_white_bg(path)
|
|
object_mask = object_mask.reshape(-1)
|
|
self.object_masks.append(torch.from_numpy(object_mask).bool())
|
|
|
|
def __len__(self):
|
|
return self.n_images
|
|
|
|
def __getitem__(self, idx):
|
|
uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32)
|
|
uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float()
|
|
uv = uv.reshape(2, -1).transpose(1, 0)
|
|
|
|
sample = {
|
|
"object_mask": self.object_masks[idx],
|
|
"uv": uv,
|
|
"intrinsics": self.intrinsics_all[idx],
|
|
}
|
|
|
|
ground_truth = {
|
|
"rgb": self.rgb_images[idx]
|
|
}
|
|
|
|
if self.sampling_idx is not None:
|
|
ground_truth["rgb"] = self.rgb_images[idx][self.sampling_idx, :]
|
|
sample["object_mask"] = self.object_masks[idx][self.sampling_idx]
|
|
sample["uv"] = uv[self.sampling_idx, :]
|
|
|
|
if not self.train_cameras:
|
|
sample["pose"] = self.pose_all[idx]
|
|
|
|
return idx, sample, ground_truth
|
|
|
|
def collate_fn(self, batch_list):
|
|
# get list of dictionaries and returns input, ground_true as dictionary for all batch instances
|
|
batch_list = zip(*batch_list)
|
|
|
|
all_parsed = []
|
|
for entry in batch_list:
|
|
if type(entry[0]) is dict:
|
|
# make them all into a new dict
|
|
ret = {}
|
|
for k in entry[0].keys():
|
|
ret[k] = torch.stack([obj[k] for obj in entry])
|
|
all_parsed.append(ret)
|
|
else:
|
|
all_parsed.append(torch.LongTensor(entry))
|
|
|
|
return tuple(all_parsed)
|
|
|
|
def change_sampling_idx(self, sampling_size):
|
|
if sampling_size == -1:
|
|
self.sampling_idx = None
|
|
else:
|
|
self.sampling_idx = torch.randperm(self.total_pixels)[:sampling_size]
|
|
|
|
def get_scale_mat(self):
|
|
return np.load(self.cam_file)['scale_mat_0']
|
|
|
|
def get_gt_pose(self, scaled=False):
|
|
# Load gt pose without normalization to unit sphere
|
|
camera_dict = np.load(self.cam_file)
|
|
world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
|
|
scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
|
|
|
|
pose_all = []
|
|
for scale_mat, world_mat in zip(scale_mats, world_mats):
|
|
P = world_mat
|
|
if scaled:
|
|
P = world_mat @ scale_mat
|
|
P = P[:3, :4]
|
|
_, pose = rend_util.load_K_Rt_from_P(None, P)
|
|
pose_all.append(torch.from_numpy(pose).float())
|
|
|
|
return torch.cat([p.float().unsqueeze(0) for p in pose_all], 0)
|
|
|
|
def get_pose_init(self):
|
|
# get noisy initializations obtained with the linear method
|
|
cam_file = '{0}/cameras_linear_init.npz'.format(self.instance_dir)
|
|
camera_dict = np.load(cam_file)
|
|
scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
|
|
world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
|
|
|
|
init_pose = []
|
|
for scale_mat, world_mat in zip(scale_mats, world_mats):
|
|
P = world_mat @ scale_mat
|
|
P = P[:3, :4]
|
|
_, pose = rend_util.load_K_Rt_from_P(None, P)
|
|
init_pose.append(pose)
|
|
init_pose = torch.cat([torch.Tensor(pose).float().unsqueeze(0) for pose in init_pose], 0).cuda()
|
|
init_quat = rend_util.rot_to_quat(init_pose[:, :3, :3])
|
|
init_quat = torch.cat([init_quat, init_pose[:, :3, 3]], 1)
|
|
|
|
return init_quat
|