mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
* added random seed control code Co-authored-by: Jack_Chen <jiaozhu@orbbec.com> Co-authored-by: XiangAn <anxiangsir@outlook.com>
127 lines
4.2 KiB
Python
127 lines
4.2 KiB
Python
import math
|
|
import os
|
|
import random
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.utils.data import DistributedSampler as _DistributedSampler
|
|
|
|
|
|
def setup_seed(seed, cuda_deterministic=True):
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
np.random.seed(seed)
|
|
random.seed(seed)
|
|
os.environ["PYTHONHASHSEED"] = str(seed)
|
|
if cuda_deterministic: # slower, more reproducible
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
else: # faster, less reproducible
|
|
torch.backends.cudnn.deterministic = False
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
|
def worker_init_fn(worker_id, num_workers, rank, seed):
|
|
# The seed of each worker equals to
|
|
# num_worker * rank + worker_id + user_seed
|
|
worker_seed = num_workers * rank + worker_id + seed
|
|
np.random.seed(worker_seed)
|
|
random.seed(worker_seed)
|
|
torch.manual_seed(worker_seed)
|
|
|
|
|
|
def get_dist_info():
|
|
if dist.is_available() and dist.is_initialized():
|
|
rank = dist.get_rank()
|
|
world_size = dist.get_world_size()
|
|
else:
|
|
rank = 0
|
|
world_size = 1
|
|
|
|
return rank, world_size
|
|
|
|
|
|
def sync_random_seed(seed=None, device="cuda"):
|
|
"""Make sure different ranks share the same seed.
|
|
All workers must call this function, otherwise it will deadlock.
|
|
This method is generally used in `DistributedSampler`,
|
|
because the seed should be identical across all processes
|
|
in the distributed group.
|
|
In distributed sampling, different ranks should sample non-overlapped
|
|
data in the dataset. Therefore, this function is used to make sure that
|
|
each rank shuffles the data indices in the same order based
|
|
on the same seed. Then different ranks could use different indices
|
|
to select non-overlapped data from the same data list.
|
|
Args:
|
|
seed (int, Optional): The seed. Default to None.
|
|
device (str): The device where the seed will be put on.
|
|
Default to 'cuda'.
|
|
Returns:
|
|
int: Seed to be used.
|
|
"""
|
|
if seed is None:
|
|
seed = np.random.randint(2**31)
|
|
assert isinstance(seed, int)
|
|
|
|
rank, world_size = get_dist_info()
|
|
|
|
if world_size == 1:
|
|
return seed
|
|
|
|
if rank == 0:
|
|
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
|
|
else:
|
|
random_num = torch.tensor(0, dtype=torch.int32, device=device)
|
|
|
|
dist.broadcast(random_num, src=0)
|
|
|
|
return random_num.item()
|
|
|
|
|
|
class DistributedSampler(_DistributedSampler):
|
|
def __init__(
|
|
self,
|
|
dataset,
|
|
num_replicas=None, # world_size
|
|
rank=None, # local_rank
|
|
shuffle=True,
|
|
seed=0,
|
|
):
|
|
|
|
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
|
|
|
# In distributed sampling, different ranks should sample
|
|
# non-overlapped data in the dataset. Therefore, this function
|
|
# is used to make sure that each rank shuffles the data indices
|
|
# in the same order based on the same seed. Then different ranks
|
|
# could use different indices to select non-overlapped data from the
|
|
# same data list.
|
|
self.seed = sync_random_seed(seed)
|
|
|
|
def __iter__(self):
|
|
# deterministically shuffle based on epoch
|
|
if self.shuffle:
|
|
g = torch.Generator()
|
|
# When :attr:`shuffle=True`, this ensures all replicas
|
|
# use a different random ordering for each epoch.
|
|
# Otherwise, the next iteration of this sampler will
|
|
# yield the same ordering.
|
|
g.manual_seed(self.epoch + self.seed)
|
|
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
|
else:
|
|
indices = torch.arange(len(self.dataset)).tolist()
|
|
|
|
# add extra samples to make it evenly divisible
|
|
# in case that indices is shorter than half of total_size
|
|
indices = (indices * math.ceil(self.total_size / len(indices)))[
|
|
: self.total_size
|
|
]
|
|
assert len(indices) == self.total_size
|
|
|
|
# subsample
|
|
indices = indices[self.rank : self.total_size : self.num_replicas]
|
|
assert len(indices) == self.num_samples
|
|
|
|
return iter(indices)
|