mirror of
https://gitcode.com/gh_mirrors/eas/EasyFace.git
synced 2026-05-11 01:42:39 +00:00
97 lines
2.3 KiB
Python
97 lines
2.3 KiB
Python
# Copyright (c) Megvii Inc. All rights reserved.
|
|
# Copyright © Alibaba, Inc. and its affiliates.
|
|
|
|
import functools
|
|
import os
|
|
from collections import defaultdict, deque
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
__all__ = [
|
|
'AverageMeter',
|
|
'MeterBuffer',
|
|
'gpu_mem_usage',
|
|
]
|
|
|
|
|
|
def gpu_mem_usage():
|
|
"""
|
|
Compute the GPU memory usage for the current device (MB).
|
|
"""
|
|
mem_usage_bytes = torch.cuda.max_memory_allocated()
|
|
return mem_usage_bytes / (1024 * 1024)
|
|
|
|
|
|
class AverageMeter:
|
|
"""Track a series of values and provide access to smoothed values over a
|
|
window or the global series average.
|
|
"""
|
|
def __init__(self, window_size=50):
|
|
self._deque = deque(maxlen=window_size)
|
|
self._total = 0.0
|
|
self._count = 0
|
|
|
|
def update(self, value):
|
|
self._deque.append(value)
|
|
self._count += 1
|
|
self._total += value
|
|
|
|
@property
|
|
def median(self):
|
|
d = np.array(list(self._deque))
|
|
return np.median(d)
|
|
|
|
@property
|
|
def avg(self):
|
|
# if deque is empty, nan will be returned.
|
|
d = np.array(list(self._deque))
|
|
return d.mean()
|
|
|
|
@property
|
|
def global_avg(self):
|
|
return self._total / max(self._count, 1e-5)
|
|
|
|
@property
|
|
def latest(self):
|
|
return self._deque[-1] if len(self._deque) > 0 else None
|
|
|
|
@property
|
|
def total(self):
|
|
return self._total
|
|
|
|
def reset(self):
|
|
self._deque.clear()
|
|
self._total = 0.0
|
|
self._count = 0
|
|
|
|
def clear(self):
|
|
self._deque.clear()
|
|
|
|
|
|
class MeterBuffer(defaultdict):
|
|
"""Computes and stores the average and current value"""
|
|
def __init__(self, window_size=20):
|
|
factory = functools.partial(AverageMeter, window_size=window_size)
|
|
super().__init__(factory)
|
|
|
|
def reset(self):
|
|
for v in self.values():
|
|
v.reset()
|
|
|
|
def get_filtered_meter(self, filter_key='time'):
|
|
return {k: v for k, v in self.items() if filter_key in k}
|
|
|
|
def update(self, values=None, **kwargs):
|
|
if values is None:
|
|
values = {}
|
|
values.update(kwargs)
|
|
for k, v in values.items():
|
|
if isinstance(v, torch.Tensor):
|
|
v = v.detach()
|
|
self[k].update(v)
|
|
|
|
def clear_meters(self):
|
|
for v in self.values():
|
|
v.clear()
|