diff --git a/recognition/arcface_torch/partial_fc.py b/recognition/arcface_torch/partial_fc.py index 172ea7c..9325b19 100644 --- a/recognition/arcface_torch/partial_fc.py +++ b/recognition/arcface_torch/partial_fc.py @@ -153,10 +153,9 @@ class PartialFC(Module): logits.backward(grad) if total_features.grad is not None: total_features.grad.detach_() - x_grad: torch.Tensor = torch.zeros_like(features) - x_grad.mul_(self.world_size) - + x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True) # feature gradient all-reduce dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0))) + x_grad = x_grad * self.world_size # backward backbone return x_grad, loss_v