Files
insightface/recognition/arcface_torch/lr_scheduler.py
AnXiang bb221e6e6d updated for WebFace42M
updated readability of the code
2022-01-14 17:48:51 +08:00

30 lines
1.1 KiB
Python

from torch.optim.lr_scheduler import _LRScheduler
class PolyScheduler(_LRScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1):
self.base_lr = base_lr
self.warmup_lr_init = 0.0001
self.max_steps: int = max_steps
self.warmup_steps: int = warmup_steps
self.power = 2
super(PolyScheduler, self).__init__(optimizer, last_epoch, False)
def get_warmup_lr(self):
alpha = float(self.last_epoch) / float(self.warmup_steps)
return [self.base_lr * alpha for _ in self.optimizer.param_groups]
def get_lr(self):
if self.last_epoch == -1:
return [self.warmup_lr_init for _ in self.optimizer.param_groups]
if self.last_epoch < self.warmup_steps:
return self.get_warmup_lr()
else:
alpha = pow(
1
- float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps),
self.power,
)
return [self.base_lr * alpha for _ in self.optimizer.param_groups]