Files
insightface/recognition/partial_fc/pytorch/sgd.py
2020-11-06 13:59:21 +08:00

72 lines
2.6 KiB
Python

import torch
from torch.optim.optimizer import Optimizer, required
class SGD(Optimizer):
def __init__(self,
params,
lr=required,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
rescale=1):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov)
self.rescale = rescale
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError(
"Nesterov momentum requires a momentum and zero dampening")
super(SGD, self).__init__(params, defaults)
def __setstate__(self, state):
super(SGD, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
for p in group['params']:
if p.grad is None:
continue
p.grad.data.div_(self.rescale)
d_p = p.grad.data
if weight_decay != 0:
d_p.add_(alpha=weight_decay, other=p.data)
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.clone(
d_p).detach()
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(other=d_p, alpha=1 - dampening)
if nesterov:
d_p = d_p.add(alpha=momentum, other=buf)
else:
d_p = buf
p.data.add_(other=d_p, alpha=-group['lr'])
return loss