mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
258 lines
9.4 KiB
Python
258 lines
9.4 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
"""
|
|
@Author : Qingping Zheng
|
|
@Contact : qingpingzheng2014@gmail.com
|
|
@File : encoding.py
|
|
@Time : 10/01/21 00:00 PM
|
|
@Desc :
|
|
@License : Licensed under the Apache License, Version 2.0 (the "License");
|
|
@Copyright : Copyright 2022 The Authors. All Rights Reserved.
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import functools
|
|
import threading
|
|
import torch
|
|
import torch.cuda.comm as comm
|
|
|
|
from torch.autograd import Variable, Function
|
|
from torch.nn.parallel.data_parallel import DataParallel
|
|
from torch.nn.parallel.parallel_apply import get_a_var
|
|
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
|
|
|
torch_ver = torch.__version__[:3]
|
|
|
|
__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
|
|
'patch_replication_callback']
|
|
|
|
def allreduce(*inputs):
|
|
"""Cross GPU all reduce autograd operation for calculate mean and
|
|
variance in SyncBN.
|
|
"""
|
|
return AllReduce.apply(*inputs)
|
|
|
|
class AllReduce(Function):
|
|
@staticmethod
|
|
def forward(ctx, num_inputs, *inputs):
|
|
ctx.num_inputs = num_inputs
|
|
ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
|
|
inputs = [inputs[i:i + num_inputs]
|
|
for i in range(0, len(inputs), num_inputs)]
|
|
# sort before reduce sum
|
|
inputs = sorted(inputs, key=lambda i: i[0].get_device())
|
|
results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
|
|
outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
|
|
return tuple([t for tensors in outputs for t in tensors])
|
|
|
|
@staticmethod
|
|
def backward(ctx, *inputs):
|
|
inputs = [i.data for i in inputs]
|
|
inputs = [inputs[i:i + ctx.num_inputs]
|
|
for i in range(0, len(inputs), ctx.num_inputs)]
|
|
results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
|
|
outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
|
|
return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])
|
|
|
|
|
|
class Reduce(Function):
|
|
@staticmethod
|
|
def forward(ctx, *inputs):
|
|
ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
|
|
inputs = sorted(inputs, key=lambda i: i.get_device())
|
|
return comm.reduce_add(inputs)
|
|
|
|
@staticmethod
|
|
def backward(ctx, gradOutput):
|
|
return Broadcast.apply(ctx.target_gpus, gradOutput)
|
|
|
|
|
|
class DataParallelModel(DataParallel):
|
|
"""Implements data parallelism at the module level.
|
|
|
|
This container parallelizes the application of the given module by
|
|
splitting the input across the specified devices by chunking in the
|
|
batch dimension.
|
|
In the forward pass, the module is replicated on each device,
|
|
and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.
|
|
Note that the outputs are not gathered, please use compatible
|
|
:class:`encoding.parallel.DataParallelCriterion`.
|
|
|
|
The batch size should be larger than the number of GPUs used. It should
|
|
also be an integer multiple of the number of GPUs so that each chunk is
|
|
the same size (so that each GPU processes the same number of samples).
|
|
|
|
Args:
|
|
module: module to be parallelized
|
|
device_ids: CUDA devices (default: all devices)
|
|
|
|
Reference:
|
|
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
|
|
Amit Agrawal. “Context Encoding for Semantic Segmentation.
|
|
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
|
|
|
|
Example::
|
|
|
|
>>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
|
|
>>> y = net(x)
|
|
"""
|
|
def gather(self, outputs, output_device):
|
|
return outputs
|
|
|
|
def replicate(self, module, device_ids):
|
|
modules = super(DataParallelModel, self).replicate(module, device_ids)
|
|
execute_replication_callbacks(modules)
|
|
return modules
|
|
|
|
|
|
class DataParallelCriterion(DataParallel):
|
|
"""
|
|
Calculate loss in multiple-GPUs, which balance the memory usage for
|
|
Semantic Segmentation.
|
|
|
|
The targets are splitted across the specified devices by chunking in
|
|
the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
|
|
|
|
Reference:
|
|
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
|
|
Amit Agrawal. “Context Encoding for Semantic Segmentation.
|
|
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
|
|
|
|
Example::
|
|
|
|
>>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
|
|
>>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
|
|
>>> y = net(x)
|
|
>>> loss = criterion(y, target)
|
|
"""
|
|
def forward(self, inputs, *targets, **kwargs):
|
|
# input should be already scatterd
|
|
# scattering the targets instead
|
|
if not self.device_ids:
|
|
return self.module(inputs, *targets, **kwargs)
|
|
targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
|
|
if len(self.device_ids) == 1:
|
|
return self.module(inputs, *targets[0], **kwargs[0])
|
|
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
|
|
outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
|
|
return Reduce.apply(*outputs) / len(outputs)
|
|
#return self.gather(outputs, self.output_device).mean()
|
|
|
|
|
|
def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
|
|
assert len(modules) == len(inputs)
|
|
assert len(targets) == len(inputs)
|
|
if kwargs_tup:
|
|
assert len(modules) == len(kwargs_tup)
|
|
else:
|
|
kwargs_tup = ({},) * len(modules)
|
|
if devices is not None:
|
|
assert len(modules) == len(devices)
|
|
else:
|
|
devices = [None] * len(modules)
|
|
|
|
lock = threading.Lock()
|
|
results = {}
|
|
if torch_ver != "0.3":
|
|
grad_enabled = torch.is_grad_enabled()
|
|
|
|
def _worker(i, module, input, target, kwargs, device=None):
|
|
if torch_ver != "0.3":
|
|
torch.set_grad_enabled(grad_enabled)
|
|
if device is None:
|
|
device = get_a_var(input).get_device()
|
|
try:
|
|
if not isinstance(input, tuple):
|
|
input = (input,)
|
|
with torch.cuda.device(device):
|
|
output = module(*(input + target), **kwargs)
|
|
with lock:
|
|
results[i] = output
|
|
except Exception as e:
|
|
with lock:
|
|
results[i] = e
|
|
|
|
if len(modules) > 1:
|
|
threads = [threading.Thread(target=_worker,
|
|
args=(i, module, input, target,
|
|
kwargs, device),)
|
|
for i, (module, input, target, kwargs, device) in
|
|
enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]
|
|
|
|
for thread in threads:
|
|
thread.start()
|
|
for thread in threads:
|
|
thread.join()
|
|
else:
|
|
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
|
|
|
|
outputs = []
|
|
for i in range(len(inputs)):
|
|
output = results[i]
|
|
if isinstance(output, Exception):
|
|
raise output
|
|
outputs.append(output)
|
|
return outputs
|
|
|
|
|
|
###########################################################################
|
|
# Adapted from Synchronized-BatchNorm-PyTorch.
|
|
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
|
#
|
|
class CallbackContext(object):
|
|
pass
|
|
|
|
|
|
def execute_replication_callbacks(modules):
|
|
"""
|
|
Execute an replication callback `__data_parallel_replicate__` on each module created
|
|
by original replication.
|
|
|
|
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
|
|
|
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
|
(shared among multiple copies of this module on different devices).
|
|
Through this context, different copies can share some information.
|
|
|
|
We guarantee that the callback on the master copy (the first copy) will be called ahead
|
|
of calling the callback of any slave copies.
|
|
"""
|
|
master_copy = modules[0]
|
|
nr_modules = len(list(master_copy.modules()))
|
|
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
|
|
|
for i, module in enumerate(modules):
|
|
for j, m in enumerate(module.modules()):
|
|
if hasattr(m, '__data_parallel_replicate__'):
|
|
m.__data_parallel_replicate__(ctxs[j], i)
|
|
|
|
|
|
def patch_replication_callback(data_parallel):
|
|
"""
|
|
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
|
Useful when you have customized `DataParallel` implementation.
|
|
|
|
Examples:
|
|
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
|
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
|
> patch_replication_callback(sync_bn)
|
|
# this is equivalent to
|
|
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
|
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
|
"""
|
|
|
|
assert isinstance(data_parallel, DataParallel)
|
|
|
|
old_replicate = data_parallel.replicate
|
|
|
|
@functools.wraps(old_replicate)
|
|
def new_replicate(module, device_ids):
|
|
modules = old_replicate(module, device_ids)
|
|
execute_replication_callbacks(modules)
|
|
return modules
|
|
|
|
data_parallel.replicate = new_replicate
|
|
|