Files
insightface/reconstruction/PBIDR/code/training/train.py
2022-03-19 14:24:51 +08:00

284 lines
14 KiB
Python

import os
from datetime import datetime
from pyhocon import ConfigFactory
import sys
import torch
import torch.nn as nn
import utils.general as utils
import utils.plots as plt
class IFTrainRunner():
def __init__(self,**kwargs):
torch.set_default_dtype(torch.float32)
torch.set_num_threads(1)
self.conf = ConfigFactory.parse_file(kwargs['conf'])
self.batch_size = kwargs['batch_size']
self.nepochs = kwargs['nepochs']
self.nepoch_freeze = kwargs['nepoch_freeze']
self.exps_folder_name = kwargs['exps_folder_name']
self.GPU_INDEX = kwargs['gpu_index']
self.train_cameras = kwargs['train_cameras']
self.expname = self.conf.get_string('train.expname') + kwargs['expname']
scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else self.conf.get_int('dataset.scan_id', default=-1)
if scan_id != -1:
self.expname = self.expname + '_{0}'.format(scan_id)
if kwargs['is_continue'] and kwargs['timestamp'] == 'latest':
if os.path.exists(os.path.join('../',kwargs['exps_folder_name'],self.expname)):
timestamps = os.listdir(os.path.join('../',kwargs['exps_folder_name'],self.expname))
if (len(timestamps)) == 0:
is_continue = False
timestamp = None
else:
timestamp = sorted(timestamps)[-1]
is_continue = True
else:
is_continue = False
timestamp = None
else:
timestamp = kwargs['timestamp']
is_continue = kwargs['is_continue']
utils.mkdir_ifnotexists(os.path.join('../',self.exps_folder_name))
self.expdir = os.path.join('../', self.exps_folder_name, self.expname)
utils.mkdir_ifnotexists(self.expdir)
self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())
utils.mkdir_ifnotexists(os.path.join(self.expdir, self.timestamp))
self.plots_dir = os.path.join(self.expdir, self.timestamp, 'plots')
utils.mkdir_ifnotexists(self.plots_dir)
# create checkpoints dirs
self.checkpoints_path = os.path.join(self.expdir, self.timestamp, 'checkpoints')
utils.mkdir_ifnotexists(self.checkpoints_path)
self.model_params_subdir = "ModelParameters"
self.optimizer_params_subdir = "OptimizerParameters"
self.scheduler_params_subdir = "SchedulerParameters"
utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir))
utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir))
utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.scheduler_params_subdir))
if self.train_cameras:
self.optimizer_cam_params_subdir = "OptimizerCamParameters"
self.cam_params_subdir = "CamParameters"
utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_cam_params_subdir))
utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.cam_params_subdir))
os.system("""cp -r {0} "{1}" """.format(kwargs['conf'], os.path.join(self.expdir, self.timestamp, 'runconf.conf')))
if (not self.GPU_INDEX == 'ignore'):
os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX)
print('shell command : {0}'.format(' '.join(sys.argv)))
print('Loading data ...')
dataset_conf = self.conf.get_config('dataset')
if kwargs['scan_id'] != -1:
dataset_conf['scan_id'] = kwargs['scan_id']
self.train_dataset = utils.get_class(self.conf.get_string('train.dataset_class'))(self.train_cameras,
**dataset_conf)
print('Finish loading data ...')
self.train_dataloader = torch.utils.data.DataLoader(self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
collate_fn=self.train_dataset.collate_fn
)
self.plot_dataloader = torch.utils.data.DataLoader(self.train_dataset,
batch_size=self.conf.get_int('plot.plot_nimgs'),
shuffle=True,
collate_fn=self.train_dataset.collate_fn
)
self.model = utils.get_class(self.conf.get_string('train.model_class'))(conf=self.conf.get_config('model'), \
id=scan_id, datadir=dataset_conf['data_dir'])
if torch.cuda.is_available():
self.model.cuda()
self.loss = utils.get_class(self.conf.get_string('train.loss_class'))(**self.conf.get_config('loss'))
self.lr = self.conf.get_float('train.learning_rate')
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.sched_milestones = self.conf.get_list('train.sched_milestones', default=[])
self.sched_factor = self.conf.get_float('train.sched_factor', default=0.0)
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, self.sched_milestones, gamma=self.sched_factor)
# settings for camera optimization
if self.train_cameras:
num_images = len(self.train_dataset)
self.pose_vecs = torch.nn.Embedding(num_images, 7, sparse=True).cuda()
self.pose_vecs.weight.data.copy_(self.train_dataset.get_pose_init())
self.optimizer_cam = torch.optim.SparseAdam(list(self.pose_vecs.parameters()), self.conf.get_float('train.learning_rate_cam'))
self.start_epoch = 0
if is_continue:
old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints')
saved_model_state = torch.load(
os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth"))
self.model.load_state_dict(saved_model_state["model_state_dict"])
self.start_epoch = saved_model_state['epoch']
data = torch.load(
os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth"))
self.optimizer.load_state_dict(data["optimizer_state_dict"])
data = torch.load(
os.path.join(old_checkpnts_dir, self.scheduler_params_subdir, str(kwargs['checkpoint']) + ".pth"))
self.scheduler.load_state_dict(data["scheduler_state_dict"])
if self.train_cameras:
data = torch.load(
os.path.join(old_checkpnts_dir, self.optimizer_cam_params_subdir, str(kwargs['checkpoint']) + ".pth"))
self.optimizer_cam.load_state_dict(data["optimizer_cam_state_dict"])
data = torch.load(
os.path.join(old_checkpnts_dir, self.cam_params_subdir, str(kwargs['checkpoint']) + ".pth"))
self.pose_vecs.load_state_dict(data["pose_vecs_state_dict"])
self.num_pixels = self.conf.get_int('train.num_pixels')
self.total_pixels = self.train_dataset.total_pixels
self.img_res = self.train_dataset.img_res
self.n_batches = len(self.train_dataloader)
self.plot_freq = self.conf.get_int('train.plot_freq')
self.plot_conf = self.conf.get_config('plot')
self.alpha_milestones = self.conf.get_list('train.alpha_milestones', default=[])
self.alpha_factor = self.conf.get_float('train.alpha_factor', default=0.0)
for acc in self.alpha_milestones:
if self.start_epoch > acc:
self.loss.alpha = self.loss.alpha * self.alpha_factor
def save_checkpoints(self, epoch):
torch.save(
{"epoch": epoch, "model_state_dict": self.model.state_dict()},
os.path.join(self.checkpoints_path, self.model_params_subdir, str(epoch) + ".pth"))
torch.save(
{"epoch": epoch, "model_state_dict": self.model.state_dict()},
os.path.join(self.checkpoints_path, self.model_params_subdir, "latest.pth"))
torch.save(
{"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
os.path.join(self.checkpoints_path, self.optimizer_params_subdir, str(epoch) + ".pth"))
torch.save(
{"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth"))
torch.save(
{"epoch": epoch, "scheduler_state_dict": self.scheduler.state_dict()},
os.path.join(self.checkpoints_path, self.scheduler_params_subdir, str(epoch) + ".pth"))
torch.save(
{"epoch": epoch, "scheduler_state_dict": self.scheduler.state_dict()},
os.path.join(self.checkpoints_path, self.scheduler_params_subdir, "latest.pth"))
if self.train_cameras:
torch.save(
{"epoch": epoch, "optimizer_cam_state_dict": self.optimizer_cam.state_dict()},
os.path.join(self.checkpoints_path, self.optimizer_cam_params_subdir, str(epoch) + ".pth"))
torch.save(
{"epoch": epoch, "optimizer_cam_state_dict": self.optimizer_cam.state_dict()},
os.path.join(self.checkpoints_path, self.optimizer_cam_params_subdir, "latest.pth"))
torch.save(
{"epoch": epoch, "pose_vecs_state_dict": self.pose_vecs.state_dict()},
os.path.join(self.checkpoints_path, self.cam_params_subdir, str(epoch) + ".pth"))
torch.save(
{"epoch": epoch, "pose_vecs_state_dict": self.pose_vecs.state_dict()},
os.path.join(self.checkpoints_path, self.cam_params_subdir, "latest.pth"))
def run(self):
print("training...")
for epoch in range(self.start_epoch, self.nepochs + 1):
if epoch in self.alpha_milestones:
self.loss.alpha = self.loss.alpha * self.alpha_factor
if epoch % 100 == 0 and epoch != 0:
self.save_checkpoints(epoch)
if epoch % self.plot_freq == 0 and epoch != 0:
self.model.eval()
if self.train_cameras:
self.pose_vecs.eval()
self.train_dataset.change_sampling_idx(-1)
indices, model_input, ground_truth = next(iter(self.plot_dataloader))
model_input["intrinsics"] = model_input["intrinsics"].cuda()
model_input["uv"] = model_input["uv"].cuda()
model_input["object_mask"] = model_input["object_mask"].cuda()
# model_input[""] = ground_truth["rgb"].cuda()
if self.train_cameras:
pose_input = self.pose_vecs(indices.cuda())
model_input['pose'] = pose_input
else:
model_input['pose'] = model_input['pose'].cuda()
detail_3dmm, detail_3dmm_subdivision_full = plt.get_displacement_mesh(self.model)
detail_3dmm.export('{0}/Detailed_3dmm_{1}.obj'.format(self.plots_dir, epoch), 'obj')
detail_3dmm_subdivision_full.export('{0}/Subdivide_full_{1}.obj'.format(self.plots_dir, epoch), 'obj')
self.model.train()
if self.train_cameras:
self.pose_vecs.train()
self.train_dataset.change_sampling_idx(self.num_pixels)
if epoch > self.nepoch_freeze:
print("Freeze Diffuse Part...")
self.model.diffuse_network.eval()
self.model.albedo_network.eval()
for data_index, (indices, model_input, ground_truth) in enumerate(self.train_dataloader):
model_input["intrinsics"] = model_input["intrinsics"].cuda()
model_input["uv"] = model_input["uv"].cuda()
model_input["object_mask"] = model_input["object_mask"].cuda()
if self.train_cameras:
pose_input = self.pose_vecs(indices.cuda())
model_input['pose'] = pose_input
else:
model_input['pose'] = model_input['pose'].cuda()
model_outputs = self.model(model_input)
loss_output = self.loss(model_outputs, ground_truth)
loss = loss_output['loss']
self.optimizer.zero_grad()
if self.train_cameras:
self.optimizer_cam.zero_grad()
loss.backward()
self.optimizer.step()
if self.train_cameras:
self.optimizer_cam.step()
print(
'{0} [{1}] ({2}/{3}): loss = {4}, rgb_loss = {5}, normal_loss = {6}, reg_loss = {7}, eikonal_loss = {8}, mask_loss = {9}, alpha = {10}, lr = {11}'
.format(self.expname, epoch, data_index, self.n_batches, loss.item(),
loss_output['rgb_loss'].item(),
loss_output['normal_loss'].item(),
loss_output['reg_loss'].item(),
loss_output['eikonal_loss'].item(),
loss_output['mask_loss'].item(),
self.loss.alpha,
self.scheduler.get_lr()[0]))
self.scheduler.step()