mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-23 02:07:47 +00:00
refine repo structure
This commit is contained in:
@@ -21,8 +21,10 @@ class DistributedOptimizer(mx.optimizer.Optimizer):
|
||||
|
||||
if isinstance(index, (tuple, list)):
|
||||
for i in range(len(index)):
|
||||
hvd.allreduce_(grad[i], average=False,
|
||||
name=self._prefix + str(index[i]), priority=-i)
|
||||
hvd.allreduce_(grad[i],
|
||||
average=False,
|
||||
name=self._prefix + str(index[i]),
|
||||
priority=-i)
|
||||
else:
|
||||
hvd.allreduce_(grad, average=False, name=self._prefix + str(index))
|
||||
|
||||
@@ -58,6 +60,12 @@ class MemoryBankSGDOptimizer(object):
|
||||
if self.momentum > 0:
|
||||
kwargs['momentum'] = self.momentum
|
||||
if state is not None:
|
||||
nd.sgd_mom_update(weight, grad, state, out=weight, lr=lr, wd=self.wd, **kwargs)
|
||||
nd.sgd_mom_update(weight,
|
||||
grad,
|
||||
state,
|
||||
out=weight,
|
||||
lr=lr,
|
||||
wd=self.wd,
|
||||
**kwargs)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
Reference in New Issue
Block a user