mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-21 17:17:49 +00:00
change model save logic
This commit is contained in:
@@ -11,5 +11,5 @@ MODELDIR="../model-$NETWORK-$JOB"
|
||||
mkdir -p "$MODELDIR"
|
||||
PREFIX="$MODELDIR/model"
|
||||
LOGFILE="$MODELDIR/log"
|
||||
CUDA_VISIBLE_DEVICES='0,1,2,3' python -u train_softmax.py --data-dir $DATA_DIR --network "$NETWORK" --loss-type 0 --lr 0.1 --prefix "$PREFIX" --per-batch-size 128 --image-size '112,112' --version-input 1 --version-output E --version-unit 3 > "$LOGFILE" 2>&1 &
|
||||
CUDA_VISIBLE_DEVICES='0,1,2,3' python -u train_softmax.py --data-dir $DATA_DIR --network "$NETWORK" --loss-type 0 --lr 0.1 --prefix "$PREFIX" --per-batch-size 128 --image-size '112,112' --version-input 1 --version-output E --version-unit 3 --use-se > "$LOGFILE" 2>&1 &
|
||||
|
||||
|
||||
@@ -504,8 +504,9 @@ def train_net(args):
|
||||
print('VACC: %f'%(acc_value))
|
||||
|
||||
|
||||
highest_acc = [0.0]
|
||||
last_save_acc = [0.0]
|
||||
highest_acc = []
|
||||
for i in xrange(len(ver_list)):
|
||||
highest_acc.append(0.0)
|
||||
global_step = [0]
|
||||
save_step = [0]
|
||||
if len(args.lr_steps)==0:
|
||||
@@ -534,15 +535,17 @@ def train_net(args):
|
||||
|
||||
if mbatch>=0 and mbatch%args.verbose==0:
|
||||
acc_list = ver_test(mbatch)
|
||||
acc = acc_list[0]
|
||||
save_step[0]+=1
|
||||
msave = save_step[0]
|
||||
do_save = False
|
||||
if acc>=highest_acc[0]:
|
||||
highest_acc[0] = acc
|
||||
if acc>=0.99:
|
||||
do_save = True
|
||||
if mbatch>lr_steps[-1] and mbatch%10000==0:
|
||||
lfw_score = acc_list[0]
|
||||
for i in xrange(len(acc_list)):
|
||||
acc = acc_list[i]
|
||||
if acc>=highest_acc[i]:
|
||||
highest_acc[i] = acc
|
||||
if lfw_score>=0.99:
|
||||
do_save = True
|
||||
if args.loss_type==1 and mbatch>lr_steps[-1] and mbatch%10000==0:
|
||||
do_save = True
|
||||
if do_save:
|
||||
print('saving', msave, acc)
|
||||
@@ -555,13 +558,12 @@ def train_net(args):
|
||||
# X = np.concatenate(embeddings_list, axis=0)
|
||||
# print('saving lfw npy', X.shape)
|
||||
# np.save(lfw_npy, X)
|
||||
print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[0]))
|
||||
#print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[0]))
|
||||
if mbatch<=args.beta_freeze:
|
||||
_beta = args.beta
|
||||
else:
|
||||
move = max(0, mbatch-args.beta_freeze)
|
||||
_beta = max(args.beta_min, args.beta*math.pow(1+args.gamma*move, -1.0*args.power))
|
||||
#_beta = max(args.beta_min, args.beta*math.pow(0.7, move//500))
|
||||
#print('beta', _beta)
|
||||
os.environ['BETA'] = str(_beta)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user