mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
30 lines
1.1 KiB
Python
30 lines
1.1 KiB
Python
from torch.optim.lr_scheduler import _LRScheduler
|
|
|
|
|
|
class PolyScheduler(_LRScheduler):
|
|
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1):
|
|
self.base_lr = base_lr
|
|
self.warmup_lr_init = 0.0001
|
|
self.max_steps: int = max_steps
|
|
self.warmup_steps: int = warmup_steps
|
|
self.power = 2
|
|
super(PolyScheduler, self).__init__(optimizer, last_epoch, False)
|
|
|
|
def get_warmup_lr(self):
|
|
alpha = float(self.last_epoch) / float(self.warmup_steps)
|
|
return [self.base_lr * alpha for _ in self.optimizer.param_groups]
|
|
|
|
def get_lr(self):
|
|
if self.last_epoch == -1:
|
|
return [self.warmup_lr_init for _ in self.optimizer.param_groups]
|
|
if self.last_epoch < self.warmup_steps:
|
|
return self.get_warmup_lr()
|
|
else:
|
|
alpha = pow(
|
|
1
|
|
- float(self.last_epoch - self.warmup_steps)
|
|
/ float(self.max_steps - self.warmup_steps),
|
|
self.power,
|
|
)
|
|
return [self.base_lr * alpha for _ in self.optimizer.param_groups]
|