mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
FIX BUG
1. model parallel x_grad rescale
This commit is contained in:
@@ -153,10 +153,9 @@ class PartialFC(Module):
|
||||
logits.backward(grad)
|
||||
if total_features.grad is not None:
|
||||
total_features.grad.detach_()
|
||||
x_grad: torch.Tensor = torch.zeros_like(features)
|
||||
x_grad.mul_(self.world_size)
|
||||
|
||||
x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True)
|
||||
# feature gradient all-reduce
|
||||
dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0)))
|
||||
x_grad = x_grad * self.world_size
|
||||
# backward backbone
|
||||
return x_grad, loss_v
|
||||
|
||||
Reference in New Issue
Block a user