mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-16 13:46:15 +00:00
284 lines
14 KiB
Python
284 lines
14 KiB
Python
import os
|
|
from datetime import datetime
|
|
from pyhocon import ConfigFactory
|
|
import sys
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
import utils.general as utils
|
|
import utils.plots as plt
|
|
|
|
class IFTrainRunner():
|
|
def __init__(self,**kwargs):
|
|
torch.set_default_dtype(torch.float32)
|
|
torch.set_num_threads(1)
|
|
|
|
self.conf = ConfigFactory.parse_file(kwargs['conf'])
|
|
self.batch_size = kwargs['batch_size']
|
|
|
|
self.nepochs = kwargs['nepochs']
|
|
self.nepoch_freeze = kwargs['nepoch_freeze']
|
|
|
|
self.exps_folder_name = kwargs['exps_folder_name']
|
|
self.GPU_INDEX = kwargs['gpu_index']
|
|
self.train_cameras = kwargs['train_cameras']
|
|
|
|
self.expname = self.conf.get_string('train.expname') + kwargs['expname']
|
|
scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else self.conf.get_int('dataset.scan_id', default=-1)
|
|
if scan_id != -1:
|
|
self.expname = self.expname + '_{0}'.format(scan_id)
|
|
|
|
if kwargs['is_continue'] and kwargs['timestamp'] == 'latest':
|
|
if os.path.exists(os.path.join('../',kwargs['exps_folder_name'],self.expname)):
|
|
timestamps = os.listdir(os.path.join('../',kwargs['exps_folder_name'],self.expname))
|
|
if (len(timestamps)) == 0:
|
|
is_continue = False
|
|
timestamp = None
|
|
else:
|
|
timestamp = sorted(timestamps)[-1]
|
|
is_continue = True
|
|
else:
|
|
is_continue = False
|
|
timestamp = None
|
|
else:
|
|
timestamp = kwargs['timestamp']
|
|
is_continue = kwargs['is_continue']
|
|
|
|
utils.mkdir_ifnotexists(os.path.join('../',self.exps_folder_name))
|
|
self.expdir = os.path.join('../', self.exps_folder_name, self.expname)
|
|
utils.mkdir_ifnotexists(self.expdir)
|
|
self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())
|
|
utils.mkdir_ifnotexists(os.path.join(self.expdir, self.timestamp))
|
|
|
|
self.plots_dir = os.path.join(self.expdir, self.timestamp, 'plots')
|
|
utils.mkdir_ifnotexists(self.plots_dir)
|
|
|
|
# create checkpoints dirs
|
|
self.checkpoints_path = os.path.join(self.expdir, self.timestamp, 'checkpoints')
|
|
utils.mkdir_ifnotexists(self.checkpoints_path)
|
|
self.model_params_subdir = "ModelParameters"
|
|
self.optimizer_params_subdir = "OptimizerParameters"
|
|
self.scheduler_params_subdir = "SchedulerParameters"
|
|
|
|
utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir))
|
|
utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir))
|
|
utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.scheduler_params_subdir))
|
|
|
|
if self.train_cameras:
|
|
self.optimizer_cam_params_subdir = "OptimizerCamParameters"
|
|
self.cam_params_subdir = "CamParameters"
|
|
|
|
utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_cam_params_subdir))
|
|
utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.cam_params_subdir))
|
|
|
|
os.system("""cp -r {0} "{1}" """.format(kwargs['conf'], os.path.join(self.expdir, self.timestamp, 'runconf.conf')))
|
|
|
|
if (not self.GPU_INDEX == 'ignore'):
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX)
|
|
|
|
print('shell command : {0}'.format(' '.join(sys.argv)))
|
|
|
|
print('Loading data ...')
|
|
|
|
dataset_conf = self.conf.get_config('dataset')
|
|
if kwargs['scan_id'] != -1:
|
|
dataset_conf['scan_id'] = kwargs['scan_id']
|
|
|
|
self.train_dataset = utils.get_class(self.conf.get_string('train.dataset_class'))(self.train_cameras,
|
|
**dataset_conf)
|
|
|
|
print('Finish loading data ...')
|
|
|
|
self.train_dataloader = torch.utils.data.DataLoader(self.train_dataset,
|
|
batch_size=self.batch_size,
|
|
shuffle=True,
|
|
collate_fn=self.train_dataset.collate_fn
|
|
)
|
|
self.plot_dataloader = torch.utils.data.DataLoader(self.train_dataset,
|
|
batch_size=self.conf.get_int('plot.plot_nimgs'),
|
|
shuffle=True,
|
|
collate_fn=self.train_dataset.collate_fn
|
|
)
|
|
|
|
self.model = utils.get_class(self.conf.get_string('train.model_class'))(conf=self.conf.get_config('model'), \
|
|
id=scan_id, datadir=dataset_conf['data_dir'])
|
|
if torch.cuda.is_available():
|
|
self.model.cuda()
|
|
|
|
self.loss = utils.get_class(self.conf.get_string('train.loss_class'))(**self.conf.get_config('loss'))
|
|
|
|
self.lr = self.conf.get_float('train.learning_rate')
|
|
|
|
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
|
|
|
|
self.sched_milestones = self.conf.get_list('train.sched_milestones', default=[])
|
|
self.sched_factor = self.conf.get_float('train.sched_factor', default=0.0)
|
|
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, self.sched_milestones, gamma=self.sched_factor)
|
|
|
|
# settings for camera optimization
|
|
if self.train_cameras:
|
|
num_images = len(self.train_dataset)
|
|
|
|
self.pose_vecs = torch.nn.Embedding(num_images, 7, sparse=True).cuda()
|
|
self.pose_vecs.weight.data.copy_(self.train_dataset.get_pose_init())
|
|
self.optimizer_cam = torch.optim.SparseAdam(list(self.pose_vecs.parameters()), self.conf.get_float('train.learning_rate_cam'))
|
|
|
|
self.start_epoch = 0
|
|
if is_continue:
|
|
old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints')
|
|
|
|
saved_model_state = torch.load(
|
|
os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth"))
|
|
self.model.load_state_dict(saved_model_state["model_state_dict"])
|
|
self.start_epoch = saved_model_state['epoch']
|
|
|
|
data = torch.load(
|
|
os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth"))
|
|
self.optimizer.load_state_dict(data["optimizer_state_dict"])
|
|
|
|
data = torch.load(
|
|
os.path.join(old_checkpnts_dir, self.scheduler_params_subdir, str(kwargs['checkpoint']) + ".pth"))
|
|
self.scheduler.load_state_dict(data["scheduler_state_dict"])
|
|
|
|
if self.train_cameras:
|
|
data = torch.load(
|
|
os.path.join(old_checkpnts_dir, self.optimizer_cam_params_subdir, str(kwargs['checkpoint']) + ".pth"))
|
|
self.optimizer_cam.load_state_dict(data["optimizer_cam_state_dict"])
|
|
|
|
data = torch.load(
|
|
os.path.join(old_checkpnts_dir, self.cam_params_subdir, str(kwargs['checkpoint']) + ".pth"))
|
|
self.pose_vecs.load_state_dict(data["pose_vecs_state_dict"])
|
|
|
|
self.num_pixels = self.conf.get_int('train.num_pixels')
|
|
self.total_pixels = self.train_dataset.total_pixels
|
|
self.img_res = self.train_dataset.img_res
|
|
self.n_batches = len(self.train_dataloader)
|
|
self.plot_freq = self.conf.get_int('train.plot_freq')
|
|
self.plot_conf = self.conf.get_config('plot')
|
|
|
|
self.alpha_milestones = self.conf.get_list('train.alpha_milestones', default=[])
|
|
self.alpha_factor = self.conf.get_float('train.alpha_factor', default=0.0)
|
|
for acc in self.alpha_milestones:
|
|
if self.start_epoch > acc:
|
|
self.loss.alpha = self.loss.alpha * self.alpha_factor
|
|
|
|
def save_checkpoints(self, epoch):
|
|
torch.save(
|
|
{"epoch": epoch, "model_state_dict": self.model.state_dict()},
|
|
os.path.join(self.checkpoints_path, self.model_params_subdir, str(epoch) + ".pth"))
|
|
torch.save(
|
|
{"epoch": epoch, "model_state_dict": self.model.state_dict()},
|
|
os.path.join(self.checkpoints_path, self.model_params_subdir, "latest.pth"))
|
|
|
|
torch.save(
|
|
{"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
|
|
os.path.join(self.checkpoints_path, self.optimizer_params_subdir, str(epoch) + ".pth"))
|
|
torch.save(
|
|
{"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
|
|
os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth"))
|
|
|
|
torch.save(
|
|
{"epoch": epoch, "scheduler_state_dict": self.scheduler.state_dict()},
|
|
os.path.join(self.checkpoints_path, self.scheduler_params_subdir, str(epoch) + ".pth"))
|
|
torch.save(
|
|
{"epoch": epoch, "scheduler_state_dict": self.scheduler.state_dict()},
|
|
os.path.join(self.checkpoints_path, self.scheduler_params_subdir, "latest.pth"))
|
|
|
|
if self.train_cameras:
|
|
torch.save(
|
|
{"epoch": epoch, "optimizer_cam_state_dict": self.optimizer_cam.state_dict()},
|
|
os.path.join(self.checkpoints_path, self.optimizer_cam_params_subdir, str(epoch) + ".pth"))
|
|
torch.save(
|
|
{"epoch": epoch, "optimizer_cam_state_dict": self.optimizer_cam.state_dict()},
|
|
os.path.join(self.checkpoints_path, self.optimizer_cam_params_subdir, "latest.pth"))
|
|
|
|
torch.save(
|
|
{"epoch": epoch, "pose_vecs_state_dict": self.pose_vecs.state_dict()},
|
|
os.path.join(self.checkpoints_path, self.cam_params_subdir, str(epoch) + ".pth"))
|
|
torch.save(
|
|
{"epoch": epoch, "pose_vecs_state_dict": self.pose_vecs.state_dict()},
|
|
os.path.join(self.checkpoints_path, self.cam_params_subdir, "latest.pth"))
|
|
|
|
def run(self):
|
|
print("training...")
|
|
|
|
for epoch in range(self.start_epoch, self.nepochs + 1):
|
|
|
|
if epoch in self.alpha_milestones:
|
|
self.loss.alpha = self.loss.alpha * self.alpha_factor
|
|
|
|
if epoch % 100 == 0 and epoch != 0:
|
|
self.save_checkpoints(epoch)
|
|
|
|
if epoch % self.plot_freq == 0 and epoch != 0:
|
|
self.model.eval()
|
|
if self.train_cameras:
|
|
self.pose_vecs.eval()
|
|
self.train_dataset.change_sampling_idx(-1)
|
|
indices, model_input, ground_truth = next(iter(self.plot_dataloader))
|
|
|
|
model_input["intrinsics"] = model_input["intrinsics"].cuda()
|
|
model_input["uv"] = model_input["uv"].cuda()
|
|
model_input["object_mask"] = model_input["object_mask"].cuda()
|
|
# model_input[""] = ground_truth["rgb"].cuda()
|
|
|
|
if self.train_cameras:
|
|
pose_input = self.pose_vecs(indices.cuda())
|
|
model_input['pose'] = pose_input
|
|
else:
|
|
model_input['pose'] = model_input['pose'].cuda()
|
|
|
|
detail_3dmm, detail_3dmm_subdivision_full = plt.get_displacement_mesh(self.model)
|
|
detail_3dmm.export('{0}/Detailed_3dmm_{1}.obj'.format(self.plots_dir, epoch), 'obj')
|
|
detail_3dmm_subdivision_full.export('{0}/Subdivide_full_{1}.obj'.format(self.plots_dir, epoch), 'obj')
|
|
|
|
self.model.train()
|
|
if self.train_cameras:
|
|
self.pose_vecs.train()
|
|
|
|
self.train_dataset.change_sampling_idx(self.num_pixels)
|
|
|
|
if epoch > self.nepoch_freeze:
|
|
print("Freeze Diffuse Part...")
|
|
self.model.diffuse_network.eval()
|
|
self.model.albedo_network.eval()
|
|
|
|
for data_index, (indices, model_input, ground_truth) in enumerate(self.train_dataloader):
|
|
|
|
model_input["intrinsics"] = model_input["intrinsics"].cuda()
|
|
model_input["uv"] = model_input["uv"].cuda()
|
|
model_input["object_mask"] = model_input["object_mask"].cuda()
|
|
|
|
if self.train_cameras:
|
|
pose_input = self.pose_vecs(indices.cuda())
|
|
model_input['pose'] = pose_input
|
|
else:
|
|
model_input['pose'] = model_input['pose'].cuda()
|
|
|
|
model_outputs = self.model(model_input)
|
|
loss_output = self.loss(model_outputs, ground_truth)
|
|
|
|
loss = loss_output['loss']
|
|
|
|
self.optimizer.zero_grad()
|
|
if self.train_cameras:
|
|
self.optimizer_cam.zero_grad()
|
|
|
|
loss.backward()
|
|
|
|
self.optimizer.step()
|
|
if self.train_cameras:
|
|
self.optimizer_cam.step()
|
|
|
|
print(
|
|
'{0} [{1}] ({2}/{3}): loss = {4}, rgb_loss = {5}, normal_loss = {6}, reg_loss = {7}, eikonal_loss = {8}, mask_loss = {9}, alpha = {10}, lr = {11}'
|
|
.format(self.expname, epoch, data_index, self.n_batches, loss.item(),
|
|
loss_output['rgb_loss'].item(),
|
|
loss_output['normal_loss'].item(),
|
|
loss_output['reg_loss'].item(),
|
|
loss_output['eikonal_loss'].item(),
|
|
loss_output['mask_loss'].item(),
|
|
self.loss.alpha,
|
|
self.scheduler.get_lr()[0]))
|
|
|
|
self.scheduler.step() |