mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-13 12:00:46 +00:00
72 lines
2.6 KiB
Python
72 lines
2.6 KiB
Python
import torch
|
|
from torch.optim.optimizer import Optimizer, required
|
|
|
|
|
|
class SGD(Optimizer):
|
|
def __init__(self,
|
|
params,
|
|
lr=required,
|
|
momentum=0,
|
|
dampening=0,
|
|
weight_decay=0,
|
|
nesterov=False,
|
|
rescale=1):
|
|
if lr is not required and lr < 0.0:
|
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
if momentum < 0.0:
|
|
raise ValueError("Invalid momentum value: {}".format(momentum))
|
|
if weight_decay < 0.0:
|
|
raise ValueError(
|
|
"Invalid weight_decay value: {}".format(weight_decay))
|
|
|
|
defaults = dict(lr=lr,
|
|
momentum=momentum,
|
|
dampening=dampening,
|
|
weight_decay=weight_decay,
|
|
nesterov=nesterov)
|
|
self.rescale = rescale
|
|
if nesterov and (momentum <= 0 or dampening != 0):
|
|
raise ValueError(
|
|
"Nesterov momentum requires a momentum and zero dampening")
|
|
super(SGD, self).__init__(params, defaults)
|
|
|
|
def __setstate__(self, state):
|
|
super(SGD, self).__setstate__(state)
|
|
for group in self.param_groups:
|
|
group.setdefault('nesterov', False)
|
|
|
|
def step(self, closure=None):
|
|
loss = None
|
|
if closure is not None:
|
|
loss = closure()
|
|
|
|
for group in self.param_groups:
|
|
weight_decay = group['weight_decay']
|
|
momentum = group['momentum']
|
|
dampening = group['dampening']
|
|
nesterov = group['nesterov']
|
|
|
|
for p in group['params']:
|
|
if p.grad is None:
|
|
continue
|
|
p.grad.data.div_(self.rescale)
|
|
d_p = p.grad.data
|
|
if weight_decay != 0:
|
|
d_p.add_(alpha=weight_decay, other=p.data)
|
|
if momentum != 0:
|
|
param_state = self.state[p]
|
|
if 'momentum_buffer' not in param_state:
|
|
buf = param_state['momentum_buffer'] = torch.clone(
|
|
d_p).detach()
|
|
else:
|
|
buf = param_state['momentum_buffer']
|
|
buf.mul_(momentum).add_(other=d_p, alpha=1 - dampening)
|
|
if nesterov:
|
|
d_p = d_p.add(alpha=momentum, other=buf)
|
|
else:
|
|
d_p = buf
|
|
|
|
p.data.add_(other=d_p, alpha=-group['lr'])
|
|
|
|
return loss
|