mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-20 00:10:28 +00:00
refine repo structure
This commit is contained in:
@@ -1,12 +1,19 @@
|
||||
import mxnet as mx
|
||||
import mxnet.optimizer as optimizer
|
||||
from mxnet.ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs)
|
||||
from mxnet.ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as
|
||||
NDabs)
|
||||
#from mxnet.ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
|
||||
# mp_sgd_update, mp_sgd_mom_update, square, ftrl_update)
|
||||
|
||||
|
||||
class ONadam(optimizer.Optimizer):
|
||||
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
|
||||
schedule_decay=0.004, **kwargs):
|
||||
def __init__(self,
|
||||
learning_rate=0.001,
|
||||
beta1=0.9,
|
||||
beta2=0.999,
|
||||
epsilon=1e-8,
|
||||
schedule_decay=0.004,
|
||||
**kwargs):
|
||||
super(ONadam, self).__init__(learning_rate=learning_rate, **kwargs)
|
||||
self.beta1 = beta1
|
||||
self.beta2 = beta2
|
||||
@@ -15,12 +22,14 @@ class ONadam(optimizer.Optimizer):
|
||||
self.m_schedule = 1.
|
||||
|
||||
def create_state(self, index, weight):
|
||||
return (zeros(weight.shape, weight.context, dtype=weight.dtype), # mean
|
||||
zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance
|
||||
return (
|
||||
zeros(weight.shape, weight.context, dtype=weight.dtype), # mean
|
||||
zeros(weight.shape, weight.context,
|
||||
dtype=weight.dtype)) # variance
|
||||
|
||||
def update(self, index, weight, grad, state):
|
||||
assert(isinstance(weight, NDArray))
|
||||
assert(isinstance(grad, NDArray))
|
||||
assert (isinstance(weight, NDArray))
|
||||
assert (isinstance(grad, NDArray))
|
||||
self._update_count(index)
|
||||
lr = self._get_lr(index)
|
||||
wd = self._get_wd(index)
|
||||
@@ -34,8 +43,11 @@ class ONadam(optimizer.Optimizer):
|
||||
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
|
||||
|
||||
# warming momentum schedule
|
||||
momentum_t = self.beta1 * (1. - 0.5 * (pow(0.96, t * self.schedule_decay)))
|
||||
momentum_t_1 = self.beta1 * (1. - 0.5 * (pow(0.96, (t + 1) * self.schedule_decay)))
|
||||
momentum_t = self.beta1 * (1. - 0.5 *
|
||||
(pow(0.96, t * self.schedule_decay)))
|
||||
momentum_t_1 = self.beta1 * (1. - 0.5 *
|
||||
(pow(0.96,
|
||||
(t + 1) * self.schedule_decay)))
|
||||
self.m_schedule = self.m_schedule * momentum_t
|
||||
m_schedule_next = self.m_schedule * momentum_t_1
|
||||
|
||||
@@ -51,4 +63,3 @@ class ONadam(optimizer.Optimizer):
|
||||
|
||||
# update weight
|
||||
weight[:] -= lr * m_t_bar / (sqrt(v_t_prime) + self.epsilon)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user