mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-11 18:22:41 +00:00
74 lines
1.7 KiB
Python
74 lines
1.7 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import time
|
|
|
|
|
|
class StopWatch(object):
|
|
def __init__(self):
|
|
pass
|
|
|
|
def start(self):
|
|
self.start_time = time.time()
|
|
self.last_split = self.start_time
|
|
|
|
def split(self):
|
|
now = time.time()
|
|
duration = now - self.last_split
|
|
self.last_split = now
|
|
return duration
|
|
|
|
def stop(self):
|
|
self.stop_time = time.time()
|
|
|
|
def duration(self):
|
|
return self.stop_time - self.start_time
|
|
|
|
|
|
class TrainMetric(object):
|
|
def __init__(
|
|
self, desc="train", calculate_batches=1, batch_size=256,
|
|
):
|
|
|
|
self.desc = desc
|
|
self.calculate_batches = calculate_batches
|
|
self.num_samples = calculate_batches * batch_size
|
|
self.fmt = "{}: iter {}, loss {}, throughput: {:.3f}"
|
|
|
|
self.timer = StopWatch()
|
|
self.timer.start()
|
|
|
|
def metric_cb(self, step):
|
|
def callback(loss):
|
|
|
|
if (step + 1) % self.calculate_batches == 0:
|
|
throughput = self.num_samples / self.timer.split()
|
|
|
|
print(
|
|
self.fmt.format(
|
|
self.desc, step, loss.mean(), throughput
|
|
)
|
|
)
|
|
|
|
return callback
|
|
|
|
|
|
class ValidationMetric(object):
|
|
def __init__(self, desc="validation"):
|
|
|
|
self.desc = desc
|
|
self.fmt = "{}: time: {:.3f}"
|
|
|
|
self.timer = StopWatch()
|
|
self.timer.start()
|
|
|
|
def metric_cb(self):
|
|
def callback(metrics):
|
|
|
|
time = self.timer.split()
|
|
|
|
print(self.fmt.format(self.desc, time))
|
|
|
|
return callback
|