mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-16 13:46:15 +00:00
66 lines
2.3 KiB
Python
66 lines
2.3 KiB
Python
import os
|
|
from glob import glob
|
|
import torch
|
|
|
|
def mkdir_ifnotexists(directory):
|
|
if not os.path.exists(directory):
|
|
os.mkdir(directory)
|
|
|
|
def get_class(kls):
|
|
parts = kls.split('.')
|
|
module = ".".join(parts[:-1])
|
|
m = __import__(module)
|
|
for comp in parts[1:]:
|
|
m = getattr(m, comp)
|
|
return m
|
|
|
|
def glob_imgs(path):
|
|
imgs = []
|
|
for ext in ['*.png', '*.jpg', '*.JPEG', '*.JPG']:
|
|
imgs.extend(glob(os.path.join(path, ext)))
|
|
return imgs
|
|
|
|
def split_input(model_input, total_pixels):
|
|
'''
|
|
Split the input to fit Cuda memory for large resolution.
|
|
Can decrease the value of n_pixels in case of cuda out of memory error.
|
|
'''
|
|
n_pixels = 10000
|
|
split = []
|
|
for i, indx in enumerate(torch.split(torch.arange(total_pixels).cuda(), n_pixels, dim=0)):
|
|
data = model_input.copy()
|
|
data['uv'] = torch.index_select(model_input['uv'], 1, indx)
|
|
data['object_mask'] = torch.index_select(model_input['object_mask'], 1, indx)
|
|
split.append(data)
|
|
return split
|
|
|
|
def split_input_albedo(model_input, total_pixels):
|
|
'''
|
|
Split the input to fit Cuda memory for large resolution.
|
|
Can decrease the value of n_pixels in case of cuda out of memory error.
|
|
'''
|
|
n_pixels = 10000
|
|
split = []
|
|
for i, indx in enumerate(torch.split(torch.arange(total_pixels).cuda(), n_pixels, dim=0)):
|
|
data = model_input.copy()
|
|
data['uv'] = torch.index_select(model_input['uv'], 1, indx)
|
|
data['object_mask'] = torch.index_select(model_input['object_mask'], 1, indx)
|
|
data['rgb'] = torch.index_select(model_input['rgb'], 1, indx)
|
|
split.append(data)
|
|
return split
|
|
|
|
def merge_output(res, total_pixels, batch_size):
|
|
''' Merge the split output. '''
|
|
|
|
model_outputs = {}
|
|
for entry in res[0]:
|
|
if res[0][entry] is None:
|
|
continue
|
|
if len(res[0][entry].shape) == 1:
|
|
model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, 1) for r in res],
|
|
1).reshape(batch_size * total_pixels)
|
|
else:
|
|
model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, r[entry].shape[-1]) for r in res],
|
|
1).reshape(batch_size * total_pixels, -1)
|
|
|
|
return model_outputs |