import sys sys.path.append('../code') import argparse import GPUtil import torch import random import numpy as np from training.train import IFTrainRunner def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--batch_size', type=int, default=1, help='input batch size') parser.add_argument('--nepoch', type=int, default=400, help='number of epochs to train for') parser.add_argument('--nepoch_freeze', type=int, default=1000, help='number of epochs to train for') parser.add_argument('--conf', type=str, default='./confs/test.conf') parser.add_argument('--expname', type=str, default='') parser.add_argument('--gpu', type=str, default='auto', help='GPU to use [default: GPU auto]') parser.add_argument('--is_continue', default=False, action="store_true", help='If set, indicates continuing from a previous run.') parser.add_argument('--timestamp', default='latest', type=str, help='The timestamp of the run to be used in case of continuing from a previous run.') parser.add_argument('--checkpoint', default='latest',type=str,help='The checkpoint epoch number of the run to be used in case of continuing from a previous run.') parser.add_argument('--train_cameras', default=False, action="store_true", help='If set, optimizing also camera location.') parser.add_argument('--scan_id', type=int, default=-1, help='If set, taken to be the scan id.') opt = parser.parse_args() if opt.gpu == "auto": deviceIDs = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False, excludeID=[], excludeUUID=[]) gpu = deviceIDs[0] else: gpu = opt.gpu setup_seed(0) trainrunner = IFTrainRunner(conf=opt.conf, batch_size=opt.batch_size, nepochs=opt.nepoch, nepoch_freeze=opt.nepoch_freeze, expname=opt.expname, gpu_index=gpu, exps_folder_name='exps', is_continue=opt.is_continue, timestamp=opt.timestamp, checkpoint=opt.checkpoint, scan_id=opt.scan_id, train_cameras=opt.train_cameras ) trainrunner.run()