1. model parallel x_grad rescale
This commit is contained in:
AnXiang
2021-03-15 18:55:11 +08:00
parent 48a9e0bb52
commit cc4d662302

View File

@@ -153,10 +153,9 @@ class PartialFC(Module):
logits.backward(grad)
if total_features.grad is not None:
total_features.grad.detach_()
x_grad: torch.Tensor = torch.zeros_like(features)
x_grad.mul_(self.world_size)
x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True)
# feature gradient all-reduce
dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0)))
x_grad = x_grad * self.world_size
# backward backbone
return x_grad, loss_v