refine repo structure

This commit is contained in:
nttstar
2020-11-06 13:59:21 +08:00
parent 9fc3cc9c0b
commit b774d6a1b7
309 changed files with 24974 additions and 34253 deletions

View File

@@ -1,12 +1,19 @@
import mxnet as mx
import mxnet.optimizer as optimizer
from mxnet.ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs)
from mxnet.ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as
NDabs)
#from mxnet.ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
# mp_sgd_update, mp_sgd_mom_update, square, ftrl_update)
class ONadam(optimizer.Optimizer):
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
schedule_decay=0.004, **kwargs):
def __init__(self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
schedule_decay=0.004,
**kwargs):
super(ONadam, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
@@ -15,12 +22,14 @@ class ONadam(optimizer.Optimizer):
self.m_schedule = 1.
def create_state(self, index, weight):
return (zeros(weight.shape, weight.context, dtype=weight.dtype), # mean
zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance
return (
zeros(weight.shape, weight.context, dtype=weight.dtype), # mean
zeros(weight.shape, weight.context,
dtype=weight.dtype)) # variance
def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
assert (isinstance(weight, NDArray))
assert (isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
@@ -34,8 +43,11 @@ class ONadam(optimizer.Optimizer):
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
# warming momentum schedule
momentum_t = self.beta1 * (1. - 0.5 * (pow(0.96, t * self.schedule_decay)))
momentum_t_1 = self.beta1 * (1. - 0.5 * (pow(0.96, (t + 1) * self.schedule_decay)))
momentum_t = self.beta1 * (1. - 0.5 *
(pow(0.96, t * self.schedule_decay)))
momentum_t_1 = self.beta1 * (1. - 0.5 *
(pow(0.96,
(t + 1) * self.schedule_decay)))
self.m_schedule = self.m_schedule * momentum_t
m_schedule_next = self.m_schedule * momentum_t_1
@@ -51,4 +63,3 @@ class ONadam(optimizer.Optimizer):
# update weight
weight[:] -= lr * m_t_bar / (sqrt(v_t_prime) + self.epsilon)