mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-14 04:12:35 +00:00
refine repo structure
This commit is contained in:
@@ -26,13 +26,16 @@ class DistSampleClassifier(Module):
|
||||
def __init__(self, rank, local_rank, world_size):
|
||||
super(DistSampleClassifier, self).__init__()
|
||||
self.sample_rate = cfg.sample_rate
|
||||
self.num_local = cfg.num_classes // world_size + int(rank < cfg.num_classes % world_size)
|
||||
self.class_start = cfg.num_classes // world_size * rank + min(rank, cfg.num_classes % world_size)
|
||||
self.num_local = cfg.num_classes // world_size + int(
|
||||
rank < cfg.num_classes % world_size)
|
||||
self.class_start = cfg.num_classes // world_size * rank + min(
|
||||
rank, cfg.num_classes % world_size)
|
||||
self.num_sample = int(self.sample_rate * self.num_local)
|
||||
self.local_rank = local_rank
|
||||
self.world_size = world_size
|
||||
|
||||
self.weight = torch.empty(size=(self.num_local, cfg.embedding_size), device=local_rank)
|
||||
self.weight = torch.empty(size=(self.num_local, cfg.embedding_size),
|
||||
device=local_rank)
|
||||
self.weight_mom = torch.zeros_like(self.weight)
|
||||
self.stream = torch.cuda.Stream(local_rank)
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
@@ -48,7 +51,8 @@ class DistSampleClassifier(Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, total_label):
|
||||
P = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local)
|
||||
P = (self.class_start <=
|
||||
total_label) & (total_label < self.class_start + self.num_local)
|
||||
total_label[~P] = -1
|
||||
total_label[P] -= self.class_start
|
||||
if int(self.sample_rate) != 1:
|
||||
@@ -81,11 +85,15 @@ class DistSampleClassifier(Module):
|
||||
|
||||
def prepare(self, label, optimizer):
|
||||
with torch.cuda.stream(self.stream):
|
||||
total_label = torch.zeros(label.size()[0] * self.world_size, device=self.local_rank, dtype=torch.long)
|
||||
dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label)
|
||||
total_label = torch.zeros(label.size()[0] * self.world_size,
|
||||
device=self.local_rank,
|
||||
dtype=torch.long)
|
||||
dist.all_gather(list(total_label.chunk(self.world_size, dim=0)),
|
||||
label)
|
||||
self.sample(total_label)
|
||||
optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None)
|
||||
optimizer.param_groups[-1]['params'][0] = self.sub_weight
|
||||
optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom
|
||||
optimizer.state[
|
||||
self.sub_weight]['momentum_buffer'] = self.sub_weight_mom
|
||||
norm_weight = F.normalize(self.sub_weight)
|
||||
return total_label, norm_weight
|
||||
|
||||
Reference in New Issue
Block a user