From cc4d6623023841a845ea71aa801c19af7266d1ad Mon Sep 17 00:00:00 2001 From: AnXiang Date: Mon, 15 Mar 2021 18:55:11 +0800 Subject: [PATCH] FIX BUG 1. model parallel x_grad rescale --- recognition/arcface_torch/partial_fc.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/recognition/arcface_torch/partial_fc.py b/recognition/arcface_torch/partial_fc.py index 172ea7c..9325b19 100644 --- a/recognition/arcface_torch/partial_fc.py +++ b/recognition/arcface_torch/partial_fc.py @@ -153,10 +153,9 @@ class PartialFC(Module): logits.backward(grad) if total_features.grad is not None: total_features.grad.detach_() - x_grad: torch.Tensor = torch.zeros_like(features) - x_grad.mul_(self.world_size) - + x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True) # feature gradient all-reduce dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0))) + x_grad = x_grad * self.world_size # backward backbone return x_grad, loss_v