diff --git a/recognition/arcface_torch/.gitignore b/recognition/arcface_torch/.gitignore new file mode 100644 index 0000000..e3b9c68 --- /dev/null +++ b/recognition/arcface_torch/.gitignore @@ -0,0 +1,5 @@ +**__pycache__/ +.vscode +bak*/ +work_dirs/ +models/ \ No newline at end of file diff --git a/recognition/arcface_torch/README.md b/recognition/arcface_torch/README.md index 2ee63a8..003b6e2 100644 --- a/recognition/arcface_torch/README.md +++ b/recognition/arcface_torch/README.md @@ -5,50 +5,50 @@ identity on a single server. ## Requirements -- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md). +- Install [PyTorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md). +- (Optional) Install [DALI](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/), our doc for [install_dali.md](docs/install_dali.md). - `pip install -r requirements.txt`. -- Download the dataset - from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_) - . - + ## How to Training -To train a model, run `train.py` with the path to the configs: +To train a model, run `train.py` with the path to the configs. +The example commands below show how to run +distributed training. -### 1. Single node, 8 GPUs: +### 1. To run on a machine with 8 GPUs: ```shell -python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=12581 train.py configs/ms1mv3_r50_lr02 ``` -### 2. Multiple nodes, each node 8 GPUs: +### 2. To run on 2 machines with 8 GPUs each: Node 0: ```shell -python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50 +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=12581 train.py configs/webface42m_r100_lr01_pfc02_bs4k_16gpus ``` Node 1: ```shell -python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50 +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=12581 train.py configs/webface42m_r100_lr01_pfc02_bs4k_16gpus ``` -### 3.Training resnet2060 with 8 GPUs: +## Download Datasets or Prepare Datasets -```shell -python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py -``` +- [MS1MV3](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-retinaface) (93k IDs, 5.2M images) +- [Glint360K](https://github.com/deepinsight/insightface/tree/master/recognition/partial_fc#4-download) (360k IDs, 17.1M images) +- [WebFace42M](docs/prepare_webface42m.md) (2M IDs, 42.5M images) ## Model Zoo - The models are available for non-commercial research purposes only. - All models can be found in here. -- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw -- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d) +- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw +- [OneDrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d) -### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/) +### Performance on IJB-C and [**ICCV2021-MFR**](https://github.com/deepinsight/insightface/blob/master/challenges/mfr/README.md) ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities. @@ -57,44 +57,24 @@ As the result, we can evaluate the FAIR performance for different algorithms. For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The globalised multi-racial testset contains 242,143 identities and 1,624,305 images. -For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4). -Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images. -There are totally 13,928 positive pairs and 96,983,824 negative pairs. - -| Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** | -| :---: | :--- | :--- | :--- |:--- |:--- | -| MS1MV3 | r18 | - | 91 | **47.85** | **68.33** | -| Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** | -| MS1MV3 | r34 | - | 130 | **58.72** | **77.36** | -| Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** | -| MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** | -| Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** | -| MS1MV3 | r100 | - | 248 | **69.09** | **84.31** | -| Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** | -| MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** | -| Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** | - -### Performance on IJB-C and Verification Datasets - -| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log | -| :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- | -| MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)| -| MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)| -| MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)| -| MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)| -| MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)| -| Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)| -| Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)| -| Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)| -| Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)| - -[comment]: <> (More details see [model.md](docs/modelzoo.md) in docs.) -## [Speed Benchmark](docs/speed_benchmark.md) +| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Training Throughout | log | +|:-------------------------|:-----------|:-------------|:------------|:------------|:--------------------|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| MS1MV3 | mobileface | - | - | - | ~13000 | log\|config | +| Glint360K | mobileface | - | - | - | - | log\|config | +| WebFace42M-PartialFC-0.2 | mobileface | 73.80 | 95.40 | 92.64 | (16GPUs)~18583 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_mobilefacenet_pfc02_bs8k_16gpus/training.log)\|[config](configs/webface42m_mobilefacenet_pfc02_bs8k_16gpus.py) | +| MS1MV3 | r100 | 85.39 | 97.00 | 95.36 | ~3400 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r100_lr01/training.log)\|[config](configs/ms1mv3_r100_lr02.py) | +| Glint360K | r100 | - | - | - | ~3400 | log\|[config](configs/glint360k_100_lr02.py) | +| WebFace42M-PartialFC-0.2 | r50 | 93.83 | 97.53 | 96.16 | ~5900 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_bs4k_pfc02/training.log)\|[config](configs/webface42m_r50_lr01_pfc02_bs4k_8gpus.py) | +| WebFace42M-PartialFC-0.2 | r50 | 94.04 | 97.48 | 95.94 | (32GPUs)~17000 | log\|[config](configs/webface42m_r50_lr01_pfc02_bs4k_32gpus.py) | +| WebFace42M-PartialFC-0.2 | r100 | 96.69 | 97.85 | 96.63 | (16GPUs)~5200 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r100_bs4k_pfc02/training.log)\|[config](configs/webface42m_r100_lr01_pfc02_bs4k_16gpus.py) | -**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of -classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same + +## Speed Benchmark + +`arcface_torch` can train large-scale face recognition training set efficiently and quickly. When the number of +classes in training sets is greater than 1 Million, partial fc sampling strategy will get same accuracy with several times faster training performance and smaller GPU memory. Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a @@ -112,39 +92,27 @@ More details see `-` means training failed because of gpu memory limitations. | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | -| :--- | :--- | :--- | :--- | -|125000 | 4681 | 4824 | 5004 | -|1400000 | **1672** | 3043 | 4738 | -|5500000 | **-** | **1389** | 3975 | -|8000000 | **-** | **-** | 3565 | -|16000000 | **-** | **-** | 2679 | -|29000000 | **-** | **-** | **1855** | +|:--------------------------------|:--------------|:---------------|:---------------| +| 125000 | 4681 | 4824 | 5004 | +| 1400000 | **1672** | 3043 | 4738 | +| 5500000 | **-** | **1389** | 3975 | +| 8000000 | **-** | **-** | 3565 | +| 16000000 | **-** | **-** | 2679 | +| 29000000 | **-** | **-** | **1855** | ### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better) | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | -| :--- | :--- | :--- | :--- | -|125000 | 7358 | 5306 | 4868 | -|1400000 | 32252 | 11178 | 6056 | -|5500000 | **-** | 32188 | 9854 | -|8000000 | **-** | **-** | 12310 | -|16000000 | **-** | **-** | 19950 | -|29000000 | **-** | **-** | 32324 | +|:--------------------------------|:--------------|:---------------|:---------------| +| 125000 | 7358 | 5306 | 4868 | +| 1400000 | 32252 | 11178 | 6056 | +| 5500000 | **-** | 32188 | 9854 | +| 8000000 | **-** | **-** | 12310 | +| 16000000 | **-** | **-** | 19950 | +| 29000000 | **-** | **-** | 32324 | -## Evaluation ICCV2021-MFR and IJB-C -More details see [eval.md](docs/eval.md) in docs. - -## Test - -We tested many versions of PyTorch. Please create an issue if you are having trouble. - -- [x] torch 1.6.0 -- [x] torch 1.7.1 -- [x] torch 1.8.0 -- [x] torch 1.9.0 - -## Citation +## Citations ``` @inproceedings{deng2019arcface, diff --git a/recognition/arcface_torch/configs/3millions.py b/recognition/arcface_torch/configs/3millions.py index c9edc2f..559ebe3 100644 --- a/recognition/arcface_torch/configs/3millions.py +++ b/recognition/arcface_torch/configs/3millions.py @@ -3,21 +3,20 @@ from easydict import EasyDict as edict # configs for test speed config = edict() -config.loss = "arcface" +config.loss = "cosface" config.network = "r50" config.resume = False config.output = None config.embedding_size = 512 -config.sample_rate = 1.0 +config.sample_rate = 0.99 config.fp16 = True config.momentum = 0.9 config.weight_decay = 5e-4 -config.batch_size = 128 +config.batch_size = 64 # total_batch_size = batch_size * num_gpus config.lr = 0.1 # batch size is 512 config.rec = "synthetic" config.num_classes = 300 * 10000 config.num_epoch = 30 config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] config.val_targets = [] diff --git a/recognition/arcface_torch/configs/3millions_pfc.py b/recognition/arcface_torch/configs/3millions_pfc.py deleted file mode 100644 index 77caafd..0000000 --- a/recognition/arcface_torch/configs/3millions_pfc.py +++ /dev/null @@ -1,23 +0,0 @@ -from easydict import EasyDict as edict - -# configs for test speed - -config = edict() -config.loss = "arcface" -config.network = "r50" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 0.1 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "synthetic" -config.num_classes = 300 * 10000 -config.num_epoch = 30 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = [] diff --git a/recognition/arcface_torch/configs/base.py b/recognition/arcface_torch/configs/base.py index 78e4b36..5c96d42 100644 --- a/recognition/arcface_torch/configs/base.py +++ b/recognition/arcface_torch/configs/base.py @@ -10,7 +10,6 @@ config.network = "r50" config.resume = False config.output = "ms1mv3_arcface_r50" -config.dataset = "ms1m-retinaface-t1" config.embedding_size = 512 config.sample_rate = 1 config.fp16 = False @@ -18,39 +17,31 @@ config.momentum = 0.9 config.weight_decay = 5e-4 config.batch_size = 128 config.lr = 0.1 # batch size is 512 +config.dali = False +config.verbose = 2000 +config.frequent = 10 +config.score = None -if config.dataset == "emore": - config.rec = "/train_tmp/faces_emore" - config.num_classes = 85742 - config.num_image = 5822653 - config.num_epoch = 16 - config.warmup_epoch = -1 - config.decay_epoch = [8, 14, ] - config.val_targets = ["lfw", ] +# if config.dataset == "emore": +# config.rec = "/train_tmp/faces_emore" +# config.num_classes = 85742 +# config.num_image = 5822653 +# config.num_epoch = 16 +# config.warmup_epoch = -1 +# config.val_targets = ["lfw", ] -elif config.dataset == "ms1m-retinaface-t1": - config.rec = "/train_tmp/ms1m-retinaface-t1" - config.num_classes = 93431 - config.num_image = 5179510 - config.num_epoch = 25 - config.warmup_epoch = -1 - config.decay_epoch = [11, 17, 22] - config.val_targets = ["lfw", "cfp_fp", "agedb_30"] +# elif config.dataset == "ms1m-retinaface-t1": +# config.rec = "/train_tmp/ms1m-retinaface-t1" +# config.num_classes = 93431 +# config.num_image = 5179510 +# config.num_epoch = 25 +# config.warmup_epoch = -1 +# config.val_targets = ["lfw", "cfp_fp", "agedb_30"] -elif config.dataset == "glint360k": - config.rec = "/train_tmp/glint360k" - config.num_classes = 360232 - config.num_image = 17091657 - config.num_epoch = 20 - config.warmup_epoch = -1 - config.decay_epoch = [8, 12, 15, 18] - config.val_targets = ["lfw", "cfp_fp", "agedb_30"] - -elif config.dataset == "webface": - config.rec = "/train_tmp/faces_webface_112x112" - config.num_classes = 10572 - config.num_image = "forget" - config.num_epoch = 34 - config.warmup_epoch = -1 - config.decay_epoch = [20, 28, 32] - config.val_targets = ["lfw", "cfp_fp", "agedb_30"] +# elif config.dataset == "glint360k": +# config.rec = "/train_tmp/glint360k" +# config.num_classes = 360232 +# config.num_image = 17091657 +# config.num_epoch = 20 +# config.warmup_epoch = -1 +# config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/recognition/arcface_torch/configs/glint360k_r100.py b/recognition/arcface_torch/configs/glint360k_100_lr02.py similarity index 73% rename from recognition/arcface_torch/configs/glint360k_r100.py rename to recognition/arcface_torch/configs/glint360k_100_lr02.py index 93d0701..22511fe 100644 --- a/recognition/arcface_torch/configs/glint360k_r100.py +++ b/recognition/arcface_torch/configs/glint360k_100_lr02.py @@ -5,7 +5,7 @@ from easydict import EasyDict as edict # mount -t tmpfs -o size=140G tmpfs /train_tmp config = edict() -config.loss = "cosface" +config.loss = "arcface" config.network = "r100" config.resume = False config.output = None @@ -15,12 +15,13 @@ config.fp16 = True config.momentum = 0.9 config.weight_decay = 5e-4 config.batch_size = 128 -config.lr = 0.1 # batch size is 512 +config.lr = 0.2 +config.verbose = 2000 +config.dali = False config.rec = "/train_tmp/glint360k" config.num_classes = 360232 config.num_image = 17091657 config.num_epoch = 20 -config.warmup_epoch = -1 -config.decay_epoch = [8, 12, 15, 18] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] +config.warmup_epoch = 0 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/recognition/arcface_torch/configs/ms1mv3_mbf.py b/recognition/arcface_torch/configs/ms1mv3_mobileface_lr02.py similarity index 66% rename from recognition/arcface_torch/configs/ms1mv3_mbf.py rename to recognition/arcface_torch/configs/ms1mv3_mobileface_lr02.py index b8a00d6..f5dcaa1 100644 --- a/recognition/arcface_torch/configs/ms1mv3_mbf.py +++ b/recognition/arcface_torch/configs/ms1mv3_mobileface_lr02.py @@ -13,14 +13,15 @@ config.embedding_size = 512 config.sample_rate = 1.0 config.fp16 = True config.momentum = 0.9 -config.weight_decay = 2e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 +config.weight_decay = 1e-4 +config.batch_size = 256 +config.lr = 0.2 +config.verbose = 5000 +config.dali = False config.rec = "/train_tmp/ms1m-retinaface-t1" config.num_classes = 93431 config.num_image = 5179510 -config.num_epoch = 30 -config.warmup_epoch = -1 -config.decay_epoch = [10, 20, 25] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] +config.num_epoch = 40 +config.warmup_epoch = 2 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/recognition/arcface_torch/configs/ms1mv3_r18.py b/recognition/arcface_torch/configs/ms1mv3_r100_lr02.py similarity index 73% rename from recognition/arcface_torch/configs/ms1mv3_r18.py rename to recognition/arcface_torch/configs/ms1mv3_r100_lr02.py index eb4e0d3..9df1a28 100644 --- a/recognition/arcface_torch/configs/ms1mv3_r18.py +++ b/recognition/arcface_torch/configs/ms1mv3_r100_lr02.py @@ -6,7 +6,7 @@ from easydict import EasyDict as edict config = edict() config.loss = "arcface" -config.network = "r18" +config.network = "r100" config.resume = False config.output = None config.embedding_size = 512 @@ -15,12 +15,13 @@ config.fp16 = True config.momentum = 0.9 config.weight_decay = 5e-4 config.batch_size = 128 -config.lr = 0.1 # batch size is 512 +config.lr = 0.2 +config.verbose = 2000 +config.dali = False config.rec = "/train_tmp/ms1m-retinaface-t1" config.num_classes = 93431 config.num_image = 5179510 config.num_epoch = 25 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] +config.warmup_epoch = 2 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/recognition/arcface_torch/configs/ms1mv3_r2060.py b/recognition/arcface_torch/configs/ms1mv3_r2060.py deleted file mode 100644 index 23ad81e..0000000 --- a/recognition/arcface_torch/configs/ms1mv3_r2060.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "arcface" -config.network = "r2060" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 64 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/ms1m-retinaface-t1" -config.num_classes = 93431 -config.num_image = 5179510 -config.num_epoch = 25 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/recognition/arcface_torch/configs/ms1mv3_r34.py b/recognition/arcface_torch/configs/ms1mv3_r34.py deleted file mode 100644 index 5f78337..0000000 --- a/recognition/arcface_torch/configs/ms1mv3_r34.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "arcface" -config.network = "r34" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/ms1m-retinaface-t1" -config.num_classes = 93431 -config.num_image = 5179510 -config.num_epoch = 25 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/recognition/arcface_torch/configs/ms1mv3_r50.py b/recognition/arcface_torch/configs/ms1mv3_r50_lr02.py similarity index 77% rename from recognition/arcface_torch/configs/ms1mv3_r50.py rename to recognition/arcface_torch/configs/ms1mv3_r50_lr02.py index 08ba55d..2eefde4 100644 --- a/recognition/arcface_torch/configs/ms1mv3_r50.py +++ b/recognition/arcface_torch/configs/ms1mv3_r50_lr02.py @@ -15,12 +15,13 @@ config.fp16 = True config.momentum = 0.9 config.weight_decay = 5e-4 config.batch_size = 128 -config.lr = 0.1 # batch size is 512 +config.lr = 0.2 +config.verbose = 2000 +config.dali = False config.rec = "/train_tmp/ms1m-retinaface-t1" config.num_classes = 93431 config.num_image = 5179510 config.num_epoch = 25 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] +config.warmup_epoch = 2 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/recognition/arcface_torch/configs/speed.py b/recognition/arcface_torch/configs/speed.py deleted file mode 100644 index 45e9523..0000000 --- a/recognition/arcface_torch/configs/speed.py +++ /dev/null @@ -1,23 +0,0 @@ -from easydict import EasyDict as edict - -# configs for test speed - -config = edict() -config.loss = "arcface" -config.network = "r50" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "synthetic" -config.num_classes = 100 * 10000 -config.num_epoch = 30 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = [] diff --git a/recognition/arcface_torch/configs/glint360k_mbf.py b/recognition/arcface_torch/configs/webface42m_mobilefacenet_pfc02_bs8k_16gpus.py similarity index 50% rename from recognition/arcface_torch/configs/glint360k_mbf.py rename to recognition/arcface_torch/configs/webface42m_mobilefacenet_pfc02_bs8k_16gpus.py index 46ae777..279f87e 100644 --- a/recognition/arcface_torch/configs/glint360k_mbf.py +++ b/recognition/arcface_torch/configs/webface42m_mobilefacenet_pfc02_bs8k_16gpus.py @@ -10,17 +10,18 @@ config.network = "mbf" config.resume = False config.output = None config.embedding_size = 512 -config.sample_rate = 0.1 +config.sample_rate = 0.2 config.fp16 = True config.momentum = 0.9 -config.weight_decay = 2e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 +config.weight_decay = 1e-4 +config.batch_size = 512 +config.lr = 0.4 +config.verbose = 10000 +config.dali = True -config.rec = "/train_tmp/glint360k" -config.num_classes = 360232 -config.num_image = 17091657 +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 config.num_epoch = 20 -config.warmup_epoch = -1 -config.decay_epoch = [8, 12, 15, 18] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] +config.warmup_epoch = 2 +config.val_targets = [] diff --git a/recognition/arcface_torch/configs/glint360k_r34.py b/recognition/arcface_torch/configs/webface42m_r100_lr01_pfc02_bs4k_16gpus.py similarity index 59% rename from recognition/arcface_torch/configs/glint360k_r34.py rename to recognition/arcface_torch/configs/webface42m_r100_lr01_pfc02_bs4k_16gpus.py index fda2701..d218346 100644 --- a/recognition/arcface_torch/configs/glint360k_r34.py +++ b/recognition/arcface_torch/configs/webface42m_r100_lr01_pfc02_bs4k_16gpus.py @@ -6,21 +6,22 @@ from easydict import EasyDict as edict config = edict() config.loss = "cosface" -config.network = "r34" +config.network = "r100" config.resume = False config.output = None config.embedding_size = 512 -config.sample_rate = 1.0 +config.sample_rate = 0.2 config.fp16 = True config.momentum = 0.9 config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 +config.batch_size = 256 +config.lr = 0.3 +config.verbose = 2000 +config.dali = True -config.rec = "/train_tmp/glint360k" -config.num_classes = 360232 -config.num_image = 17091657 +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 config.num_epoch = 20 -config.warmup_epoch = -1 -config.decay_epoch = [8, 12, 15, 18] +config.warmup_epoch = 1 config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/recognition/arcface_torch/configs/glint360k_r50.py b/recognition/arcface_torch/configs/webface42m_r50_lr01_pfc02_bs4k_32gpus.py similarity index 66% rename from recognition/arcface_torch/configs/glint360k_r50.py rename to recognition/arcface_torch/configs/webface42m_r50_lr01_pfc02_bs4k_32gpus.py index 37e7922..6843dae 100644 --- a/recognition/arcface_torch/configs/glint360k_r50.py +++ b/recognition/arcface_torch/configs/webface42m_r50_lr01_pfc02_bs4k_32gpus.py @@ -10,17 +10,18 @@ config.network = "r50" config.resume = False config.output = None config.embedding_size = 512 -config.sample_rate = 1.0 +config.sample_rate = 0.2 config.fp16 = True config.momentum = 0.9 config.weight_decay = 5e-4 config.batch_size = 128 -config.lr = 0.1 # batch size is 512 +config.lr = 0.4 +config.verbose = 10000 +config.dali = True -config.rec = "/train_tmp/glint360k" -config.num_classes = 360232 -config.num_image = 17091657 +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 config.num_epoch = 20 -config.warmup_epoch = -1 -config.decay_epoch = [8, 12, 15, 18] +config.warmup_epoch = 2 config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/recognition/arcface_torch/configs/glint360k_r18.py b/recognition/arcface_torch/configs/webface42m_r50_lr01_pfc02_bs4k_8gpus.py similarity index 59% rename from recognition/arcface_torch/configs/glint360k_r18.py rename to recognition/arcface_torch/configs/webface42m_r50_lr01_pfc02_bs4k_8gpus.py index 7a8db34..07ed345 100644 --- a/recognition/arcface_torch/configs/glint360k_r18.py +++ b/recognition/arcface_torch/configs/webface42m_r50_lr01_pfc02_bs4k_8gpus.py @@ -6,21 +6,22 @@ from easydict import EasyDict as edict config = edict() config.loss = "cosface" -config.network = "r18" +config.network = "r50" config.resume = False config.output = None config.embedding_size = 512 -config.sample_rate = 1.0 +config.sample_rate = 0.2 config.fp16 = True config.momentum = 0.9 config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 +config.batch_size = 512 +config.lr = 0.4 +config.verbose = 10000 +config.dali = True -config.rec = "/train_tmp/glint360k" -config.num_classes = 360232 -config.num_image = 17091657 +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 config.num_epoch = 20 -config.warmup_epoch = -1 -config.decay_epoch = [8, 12, 15, 18] +config.warmup_epoch = 2 config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/recognition/arcface_torch/dataset.py b/recognition/arcface_torch/dataset.py index 96bbb8b..80d562e 100644 --- a/recognition/arcface_torch/dataset.py +++ b/recognition/arcface_torch/dataset.py @@ -2,13 +2,42 @@ import numbers import os import queue as Queue import threading +from typing import Iterable import mxnet as mx import numpy as np import torch +from torch import distributed from torch.utils.data import DataLoader, Dataset from torchvision import transforms +def get_dataloader( + root_dir: str, + local_rank: int, + batch_size: int, + dali = False) -> Iterable: + if dali and root_dir != "synthetic": + rec = os.path.join(root_dir, 'train.rec') + idx = os.path.join(root_dir, 'train.idx') + return dali_data_iter( + batch_size=batch_size, rec_file=rec, + idx_file=idx, num_threads=2, local_rank=local_rank) + else: + if root_dir == "synthetic": + train_set = SyntheticDataset() + else: + train_set = MXFaceDataset(root_dir=root_dir, local_rank=local_rank) + train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True) + train_loader = DataLoaderX( + local_rank=local_rank, + dataset=train_set, + batch_size=batch_size, + sampler=train_sampler, + num_workers=2, + pin_memory=True, + drop_last=True, + ) + return train_loader class BackgroundGenerator(threading.Thread): def __init__(self, generator, local_rank, max_prefetch=6): @@ -108,7 +137,7 @@ class MXFaceDataset(Dataset): class SyntheticDataset(Dataset): - def __init__(self, local_rank): + def __init__(self): super(SyntheticDataset, self).__init__() img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) img = np.transpose(img, (2, 0, 1)) @@ -122,3 +151,59 @@ class SyntheticDataset(Dataset): def __len__(self): return 1000000 + + +def dali_data_iter( + batch_size: int, rec_file: str, idx_file: str, num_threads: int, + initial_fill=32768, random_shuffle=True, + prefetch_queue_depth=1, local_rank=0, name="reader", + mean=(127.5, 127.5, 127.5), + std=(127.5, 127.5, 127.5)): + """ + Parameters: + ---------- + initial_fill: int + Size of the buffer that is used for shuffling. If random_shuffle is False, this parameter is ignored. + + """ + rank: int = distributed.get_rank() + world_size: int = distributed.get_world_size() + import nvidia.dali.fn as fn + import nvidia.dali.types as types + from nvidia.dali.pipeline import Pipeline + from nvidia.dali.plugin.pytorch import DALIClassificationIterator + + pipe = Pipeline( + batch_size=batch_size, num_threads=num_threads, + device_id=local_rank, prefetch_queue_depth=prefetch_queue_depth, ) + condition_flip = fn.random.coin_flip(probability=0.5) + with pipe: + jpegs, labels = fn.readers.mxnet( + path=rec_file, index_path=idx_file, initial_fill=initial_fill, + num_shards=world_size, shard_id=rank, + random_shuffle=random_shuffle, pad_last_batch=False, name=name) + images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB) + images = fn.crop_mirror_normalize( + images, dtype=types.FLOAT, mean=mean, std=std, mirror=condition_flip) + pipe.set_outputs(images, labels) + pipe.build() + return DALIWarper(DALIClassificationIterator(pipelines=[pipe], reader_name=name, )) + + +@torch.no_grad() +class DALIWarper(object): + def __init__(self, dali_iter): + self.iter = dali_iter + + def __next__(self): + data_dict = self.iter.__next__()[0] + tensor_data = data_dict['data'].cuda() + tensor_label: torch.Tensor = data_dict['label'].cuda().long() + tensor_label.squeeze_() + return tensor_data, tensor_label + + def __iter__(self): + return self + + def reset(self): + self.iter.reset() diff --git a/recognition/arcface_torch/docs/install_dali.md b/recognition/arcface_torch/docs/install_dali.md new file mode 100644 index 0000000..1333ed7 --- /dev/null +++ b/recognition/arcface_torch/docs/install_dali.md @@ -0,0 +1 @@ +TODO diff --git a/recognition/arcface_torch/docs/prepare_webface42m.md b/recognition/arcface_torch/docs/prepare_webface42m.md new file mode 100644 index 0000000..d91bb33 --- /dev/null +++ b/recognition/arcface_torch/docs/prepare_webface42m.md @@ -0,0 +1,22 @@ + + + +## 1. Download Datasets and Unzip + +Download WebFace42M from [https://www.face-benchmark.org/download.html](https://www.face-benchmark.org/download.html). + + +## 2. Create **Pre-shuffle** Rec File for DALI + +Note: preshuffled rec is very important to DALI, and rec without preshuffled can cause performance degradation, origin insightface style rec file +do not support Nvidia DALI, you must follow this command [mxnet.tools.im2rec](https://github.com/apache/incubator-mxnet/blob/master/tools/im2rec.py) to generate a pre-shuffle rec file. + +```shell +# 1) create train.lst using follow command +python -m mxnet.tools.im2rec --list --recursive train "Your WebFace42M Root" + +# 2) create train.rec and train.idx using train.lst using following command +python -m mxnet.tools.im2rec --num-thread 16 --quality 100 train "Your WebFace42M Root" +``` + +Finally, you will get three files: `train.idx`, `train.rec`, `train.idx`. which `train.idx`, `train.rec` are using for training. diff --git a/recognition/arcface_torch/eval/verification.py b/recognition/arcface_torch/eval/verification.py index 253343b..edacf8d 100644 --- a/recognition/arcface_torch/eval/verification.py +++ b/recognition/arcface_torch/eval/verification.py @@ -261,6 +261,8 @@ def test(data_set, backbone, batch_size, nfolds=10): _xnorm_cnt += 1 _xnorm /= _xnorm_cnt + embeddings = embeddings_list[0].copy() + embeddings = sklearn.preprocessing.normalize(embeddings) acc1 = 0.0 std1 = 0.0 embeddings = embeddings_list[0] + embeddings_list[1] diff --git a/recognition/arcface_torch/losses.py b/recognition/arcface_torch/losses.py deleted file mode 100644 index 87aeaa1..0000000 --- a/recognition/arcface_torch/losses.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -from torch import nn - - -def get_loss(name): - if name == "cosface": - return CosFace() - elif name == "arcface": - return ArcFace() - else: - raise ValueError() - - -class CosFace(nn.Module): - def __init__(self, s=64.0, m=0.40): - super(CosFace, self).__init__() - self.s = s - self.m = m - - def forward(self, cosine, label): - index = torch.where(label != -1)[0] - m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) - m_hot.scatter_(1, label[index, None], self.m) - cosine[index] -= m_hot - ret = cosine * self.s - return ret - - -class ArcFace(nn.Module): - def __init__(self, s=64.0, m=0.5): - super(ArcFace, self).__init__() - self.s = s - self.m = m - - def forward(self, cosine: torch.Tensor, label): - index = torch.where(label != -1)[0] - m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) - m_hot.scatter_(1, label[index, None], self.m) - cosine.acos_() - cosine[index] += m_hot - cosine.cos_().mul_(self.s) - return cosine diff --git a/recognition/arcface_torch/lr_scheduler.py b/recognition/arcface_torch/lr_scheduler.py new file mode 100644 index 0000000..4248964 --- /dev/null +++ b/recognition/arcface_torch/lr_scheduler.py @@ -0,0 +1,29 @@ +from torch.optim.lr_scheduler import _LRScheduler + + +class PolyScheduler(_LRScheduler): + def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1): + self.base_lr = base_lr + self.warmup_lr_init = 0.0001 + self.max_steps: int = max_steps + self.warmup_steps: int = warmup_steps + self.power = 2 + super(PolyScheduler, self).__init__(optimizer, last_epoch, False) + + def get_warmup_lr(self): + alpha = float(self.last_epoch) / float(self.warmup_steps) + return [self.base_lr * alpha for _ in self.optimizer.param_groups] + + def get_lr(self): + if self.last_epoch == -1: + return [self.warmup_lr_init for _ in self.optimizer.param_groups] + if self.last_epoch < self.warmup_steps: + return self.get_warmup_lr() + else: + alpha = pow( + 1 + - float(self.last_epoch - self.warmup_steps) + / float(self.max_steps - self.warmup_steps), + self.power, + ) + return [self.base_lr * alpha for _ in self.optimizer.param_groups] diff --git a/recognition/arcface_torch/onnx_ijbc.py b/recognition/arcface_torch/onnx_ijbc.py index 05b50bf..31c491b 100644 --- a/recognition/arcface_torch/onnx_ijbc.py +++ b/recognition/arcface_torch/onnx_ijbc.py @@ -9,9 +9,10 @@ import numpy as np import pandas as pd import prettytable import skimage.transform +import torch from sklearn.metrics import roc_curve from sklearn.preprocessing import normalize - +from torch.utils.data import DataLoader from onnx_helper import ArcFaceORT SRC = np.array( @@ -25,6 +26,7 @@ SRC = np.array( SRC[:, 0] += 8.0 +@torch.no_grad() class AlignedDataSet(mx.gluon.data.Dataset): def __init__(self, root, lines, align=True): self.lines = lines @@ -47,24 +49,23 @@ class AlignedDataSet(mx.gluon.data.Dataset): img_2 = np.expand_dims(np.fliplr(img), 0) output = np.concatenate((img_1, img_2), axis=0).astype(np.float32) output = np.transpose(output, (0, 3, 1, 2)) - output = mx.nd.array(output) - return output + return torch.from_numpy(output) +@torch.no_grad() def extract(model_root, dataset): model = ArcFaceORT(model_path=model_root) model.check() feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim)) - def batchify_fn(data): - return mx.nd.concat(*data, dim=0) + def collate_fn(data): + return torch.cat(data, dim=0) - data_loader = mx.gluon.data.DataLoader( - dataset, 128, last_batch='keep', num_workers=4, - thread_pool=True, prefetch=16, batchify_fn=batchify_fn) + data_loader = DataLoader( + dataset, batch_size=128, drop_last=False, num_workers=4, collate_fn=collate_fn, ) num_iter = 0 for batch in data_loader: - batch = batch.asnumpy() + batch = batch.numpy() batch = (batch - model.input_mean) / model.input_std feat = model.session.run(model.output_names, {model.input_name: batch})[0] feat = np.reshape(feat, (-1, model.feat_dim * 2)) @@ -228,10 +229,12 @@ def main(args): score = verification(template_norm_feats, unique_templates, p1, p2) stop = timeit.default_timer() print('Time: %.2f s. ' % (stop - start)) - save_path = os.path.join(args.result_dir, "{}_result".format(args.target)) + result_dir = args.model_root + + save_path = os.path.join(result_dir, "{}_result".format(args.target)) if not os.path.exists(save_path): os.makedirs(save_path) - score_save_file = os.path.join(save_path, "{}.npy".format(args.model_root)) + score_save_file = os.path.join(save_path, "{}.npy".format(args.target)) np.save(score_save_file, score) files = [score_save_file] methods = [] @@ -261,7 +264,6 @@ if __name__ == '__main__': parser = argparse.ArgumentParser(description='do ijb test') # general parser.add_argument('--model-root', default='', help='path to load model.') - parser.add_argument('--image-path', default='', type=str, help='') - parser.add_argument('--result-dir', default='.', type=str, help='') + parser.add_argument('--image-path', default='/train_tmp/IJB_release/IJBC', type=str, help='') parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB') main(parser.parse_args()) diff --git a/recognition/arcface_torch/partial_fc.py b/recognition/arcface_torch/partial_fc.py index 17e2d25..3f4e4fc 100644 --- a/recognition/arcface_torch/partial_fc.py +++ b/recognition/arcface_torch/partial_fc.py @@ -1,222 +1,384 @@ -import logging -import os +import math +import collections +from typing import Callable import torch -import torch.distributed as dist -from torch.nn import Module -from torch.nn.functional import normalize, linear -from torch.nn.parameter import Parameter +from torch import distributed +from torch.nn.functional import linear, normalize -class PartialFC(Module): +class ArcFace(torch.nn.Module): + """ ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf): + """ + def __init__(self, s=64.0, margin=0.5): + super(ArcFace, self).__init__() + self.scale = s + self.cos_m = math.cos(margin) + self.sin_m = math.sin(margin) + self.theta = math.cos(math.pi - margin) + self.sinmm = math.sin(math.pi - margin) * margin + + + def forward(self, logits: torch.Tensor, labels: torch.Tensor): + index = torch.where(labels != -1)[0] + target_logit = logits[index, labels[index].view(-1)] + + sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2)) + cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin) + if self.easy_margin: + final_target_logit = torch.where( + target_logit > 0, cos_theta_m, target_logit) + else: + final_target_logit = torch.where( + target_logit > self.theta, cos_theta_m, target_logit - self.sinmm) + + logits[index, labels[index].view(-1)] = final_target_logit + logits = logits * self.s + return logits + + +class CosFace(torch.nn.Module): + def __init__(self, s=64.0, m=0.40): + super(CosFace, self).__init__() + self.s = s + self.m = m + + def forward(self, logits: torch.Tensor, labels: torch.Tensor): + index = torch.where(labels != -1)[0] + target_logit = logits[index, labels[index].view(-1)] + final_target_logit = target_logit - self.m + logits[index, labels[index].view(-1)] = final_target_logit + logits = logits * self.s + return logits + + +class PartialFC(torch.nn.Module): """ - Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint, - Partial FC: Training 10 Million Identities on a Single Machine - See the original paper: https://arxiv.org/abs/2010.05222 - """ + A distributed sparsely updating variant of the FC layer, named Partial FC (PFC). - @torch.no_grad() - def __init__(self, rank, local_rank, world_size, batch_size, resume, - margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"): + When sample rate less than 1, in each iteration, positive class centers and a random subset of + negative class centers are selected to compute the margin-based softmax loss, all class + centers are still maintained throughout the whole training process, but only a subset is + selected and updated in each iteration. + + .. note:: + When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1). + + Example: + -------- + >>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2) + >>> for img, labels in data_loader: + >>> embeddings = net(img) + >>> loss = module_pfc(embeddings, labels, optimizer) + >>> loss.backward() + >>> optimizer.step() + """ + _version = 1 + def __init__( + self, + embedding_size, + num_classes, + sample_rate = 1.0, + fp16: bool = False, + margin_loss = "cosface", + ): """ - rank: int - Unique process(GPU) ID from 0 to world_size - 1. - local_rank: int - Unique process(GPU) ID within the server from 0 to 7. - world_size: int - Number of GPU. - batch_size: int - Batch size on current rank(GPU). - resume: bool - Select whether to restore the weight of softmax. - margin_softmax: callable - A function of margin softmax, eg: cosface, arcface. - num_classes: int - The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size, - required. - sample_rate: float - The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling - can greatly speed up training, and reduce a lot of GPU memory, default is 1.0. + Paramenters: + ----------- embedding_size: int - The feature dimension, default is 512. - prefix: str - Path for save checkpoint, default is './'. + The dimension of embedding, required + num_classes: int + Total number of classes, required + sample_rate: float + The rate of negative centers participating in the calculation, default is 1.0. """ super(PartialFC, self).__init__() - # - self.num_classes: int = num_classes - self.rank: int = rank - self.local_rank: int = local_rank - self.device: torch.device = torch.device("cuda:{}".format(self.local_rank)) - self.world_size: int = world_size - self.batch_size: int = batch_size - self.margin_softmax: callable = margin_softmax + assert ( + distributed.is_initialized() + ), "must initialize distributed before create this" + self.rank = distributed.get_rank() + self.world_size = distributed.get_world_size() + + self.dist_cross_entropy = DistCrossEntropy() + self.embedding_size = embedding_size self.sample_rate: float = sample_rate - self.embedding_size: int = embedding_size - self.prefix: str = prefix - self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size) - self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size) + self.fp16 = fp16 + self.num_local: int = num_classes // self.world_size + int( + self.rank < num_classes % self.world_size + ) + self.class_start: int = num_classes // self.world_size * self.rank + min( + self.rank, num_classes % self.world_size + ) self.num_sample: int = int(self.sample_rate * self.num_local) + self.last_batch_size: int = 0 + self.weight: torch.Tensor + self.weight_mom: torch.Tensor + self.weight_activated: torch.nn.Parameter + self.weight_activated_mom: torch.Tensor + self.is_updated: bool = True + self.init_weight_update: bool = True - self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank)) - self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank)) - - if resume: - try: - self.weight: torch.Tensor = torch.load(self.weight_name) - self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name) - if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local: - raise IndexError - logging.info("softmax weight resume successfully!") - logging.info("softmax weight mom resume successfully!") - except (FileNotFoundError, KeyError, IndexError): - self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) - self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) - logging.info("softmax weight init!") - logging.info("softmax weight mom init!") + if self.sample_rate < 1: + self.register_buffer("weight", + tensor=torch.normal(0, 0.01, (self.num_local, embedding_size))) + self.register_buffer("weight_mom", + tensor=torch.zeros_like(self.weight)) + self.register_parameter("weight_activated", + param=torch.nn.Parameter(torch.empty(0, 0))) + self.register_buffer("weight_activated_mom", + tensor=torch.empty(0, 0)) + self.register_buffer("weight_index", + tensor=torch.empty(0, 0)) else: - self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) - self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) - logging.info("softmax weight init successfully!") - logging.info("softmax weight mom init successfully!") - self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank) + self.weight_activated = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size))) - self.index = None - if int(self.sample_rate) == 1: - self.update = lambda: 0 - self.sub_weight = Parameter(self.weight) - self.sub_weight_mom = self.weight_mom + # margin_loss + if isinstance(margin_loss, str): + self.margin_softmax: torch.nn.Module + if margin_loss == "cosface": + self.margin_softmax = CosFace() + elif margin_loss == "arcface": + self.margin_softmax = ArcFace() + else: + raise + elif isinstance(margin_loss, Callable): + self.margin_softmax = margin_loss else: - self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank)) - - def save_params(self): - """ Save softmax weight for each rank on prefix - """ - torch.save(self.weight.data, self.weight_name) - torch.save(self.weight_mom, self.weight_mom_name) + raise @torch.no_grad() - def sample(self, total_label): + def sample(self, + labels: torch.Tensor, + index_positive: torch.Tensor, + optimizer: torch.optim.Optimizer): """ - Sample all positive class centers in each rank, and random select neg class centers to filling a fixed - `num_sample`. + This functions will change the value of labels - total_label: tensor - Label after all gather, which cross all GPUs. + Parameters: + ----------- + labels: torch.Tensor + pass + index_positive: torch.Tensor + pass + optimizer: torch.optim.Optimizer + pass """ - index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local) - total_label[~index_positive] = -1 - total_label[index_positive] -= self.class_start - if int(self.sample_rate) != 1: - positive = torch.unique(total_label[index_positive], sorted=True) - if self.num_sample - positive.size(0) >= 0: - perm = torch.rand(size=[self.num_local], device=self.device) - perm[positive] = 2.0 - index = torch.topk(perm, k=self.num_sample)[1] - index = index.sort()[0] - else: - index = positive - self.index = index - total_label[index_positive] = torch.searchsorted(index, total_label[index_positive]) - self.sub_weight = Parameter(self.weight[index]) - self.sub_weight_mom = self.weight_mom[index] + positive = torch.unique(labels[index_positive], sorted=True).cuda() + if self.num_sample - positive.size(0) >= 0: + perm = torch.rand(size=[self.num_local]).cuda() + perm[positive] = 2.0 + index = torch.topk(perm, k=self.num_sample)[1].cuda() + index = index.sort()[0].cuda() + else: + index = positive + self.weight_index = index - def forward(self, total_features, norm_weight): - """ Partial fc forward, `logits = X * sample(W)` - """ - torch.cuda.current_stream().wait_stream(self.stream) - logits = linear(total_features, norm_weight) - return logits + labels[index_positive] = torch.searchsorted(index, labels[index_positive]) + + self.weight_activated = torch.nn.Parameter(self.weight[self.weight_index]) + self.weight_activated_mom = self.weight_mom[self.weight_index] + + if isinstance(optimizer, torch.optim.SGD): + # TODO the params of partial fc must be last in the params list + optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None) + optimizer.param_groups[-1]["params"][0] = self.weight_activated + optimizer.state[self.weight_activated][ + "momentum_buffer" + ] = self.weight_activated_mom + else: + raise @torch.no_grad() def update(self): - """ Set updated weight and weight_mom to memory bank. + """ partial weight to global """ - self.weight_mom[self.index] = self.sub_weight_mom - self.weight[self.index] = self.sub_weight + if self.init_weight_update: + self.init_weight_update = False + return - def prepare(self, label, optimizer): + if self.sample_rate < 1: + self.weight[self.weight_index] = self.weight_activated + self.weight_mom[self.weight_index] = self.weight_activated_mom + + + def forward( + self, + local_embeddings: torch.Tensor, + local_labels: torch.Tensor, + optimizer: torch.optim.Optimizer, + ): """ - get sampled class centers for cal softmax. - - label: tensor - Label tensor on each rank. - optimizer: opt - Optimizer for partial fc, which need to get weight mom. - """ - with torch.cuda.stream(self.stream): - total_label = torch.zeros( - size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long) - dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label) - self.sample(total_label) - optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None) - optimizer.param_groups[-1]['params'][0] = self.sub_weight - optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom - norm_weight = normalize(self.sub_weight) - return total_label, norm_weight - - def forward_backward(self, label, features, optimizer): - """ - Partial fc forward and backward with model parallel - - label: tensor - Label tensor on each rank(GPU) - features: tensor - Features tensor on each rank(GPU) - optimizer: optimizer - Optimizer for partial fc + Parameters: + ---------- + local_embeddings: torch.Tensor + feature embeddings on each GPU(Rank). + local_labels: torch.Tensor + labels on each GPU(Rank). Returns: - -------- - x_grad: tensor - The gradient of features. - loss_v: tensor - Loss value for cross entropy. + ------- + loss: torch.Tensor + pass """ - total_label, norm_weight = self.prepare(label, optimizer) - total_features = torch.zeros( - size=[self.batch_size * self.world_size, self.embedding_size], device=self.device) - dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data) - total_features.requires_grad = True + local_labels.squeeze_() + local_labels = local_labels.long() + self.update() - logits = self.forward(total_features, norm_weight) - logits = self.margin_softmax(logits, total_label) + batch_size = local_embeddings.size(0) + if self.last_batch_size == 0: + self.last_batch_size = batch_size + assert self.last_batch_size == batch_size, ( + "last batch size do not equal current batch size: {} vs {}".format( + self.last_batch_size, batch_size)) - with torch.no_grad(): - max_fc = torch.max(logits, dim=1, keepdim=True)[0] - dist.all_reduce(max_fc, dist.ReduceOp.MAX) + _gather_embeddings = [ + torch.zeros((batch_size, self.embedding_size)).cuda() + for _ in range(self.world_size) + ] + _gather_labels = [ + torch.zeros(batch_size).long().cuda() for _ in range(self.world_size) + ] + _list_embeddings = AllGather(local_embeddings, *_gather_embeddings) + distributed.all_gather(_gather_labels, local_labels) - # calculate exp(logits) and all-reduce - logits_exp = torch.exp(logits - max_fc) - logits_sum_exp = logits_exp.sum(dim=1, keepdims=True) - dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM) + embeddings = torch.cat(_list_embeddings) + labels = torch.cat(_gather_labels) - # calculate prob - logits_exp.div_(logits_sum_exp) + labels = labels.view(-1, 1) + index_positive = (self.class_start <= labels) & ( + labels < self.class_start + self.num_local + ) + labels[~index_positive] = -1 + labels[index_positive] -= self.class_start - # get one-hot - grad = logits_exp - index = torch.where(total_label != -1)[0] - one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device) - one_hot.scatter_(1, total_label[index, None], 1) + if self.sample_rate < 1: + self.sample(labels, index_positive, optimizer) - # calculate loss - loss = torch.zeros(grad.size()[0], 1, device=grad.device) - loss[index] = grad[index].gather(1, total_label[index, None]) - dist.all_reduce(loss, dist.ReduceOp.SUM) - loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1) + with torch.cuda.amp.autocast(self.fp16): + norm_embeddings = normalize(embeddings) + norm_weight_activated = normalize(self.weight_activated) + logits = linear(norm_embeddings, norm_weight_activated) + if self.fp16: + logits = logits.float() + logits = logits.clamp(-1, 1) - # calculate grad - grad[index] -= one_hot - grad.div_(self.batch_size * self.world_size) + logits = self.margin_softmax(logits, labels) + loss = self.dist_cross_entropy(logits, labels) + return loss - logits.backward(grad) - if total_features.grad is not None: - total_features.grad.detach_() - x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True) - # feature gradient all-reduce - dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0))) - x_grad = x_grad * self.world_size - # backward backbone - return x_grad, loss_v + def state_dict(self, destination=None, prefix="", keep_vars=False): + if destination is None: + destination = collections.OrderedDict() + destination._metadata = collections.OrderedDict() + + for name, module in self._modules.items(): + if module is not None: + module.state_dict(destination, prefix + name + ".", keep_vars=keep_vars) + if self.sample_rate < 1: + destination["weight"] = self.weight.detach() + else: + destination["weight"] = self.weight_activated.data.detach() + return destination + + def load_state_dict(self, state_dict, strict: bool = True): + if self.sample_rate < 1: + self.weight = state_dict["weight"].to(self.weight.device) + self.weight_mom.zero_() + self.weight_activated.data.zero_() + self.weight_activated_mom.zero_() + self.weight_index.zero_() + else: + self.weight_activated.data = state_dict["weight"].to(self.weight_activated.data.device) + +class DistCrossEntropyFunc(torch.autograd.Function): + """ + CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax. + Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf): + """ + + @staticmethod + def forward(ctx, logits: torch.Tensor, label: torch.Tensor): + """ """ + batch_size = logits.size(0) + # for numerical stability + max_logits, _ = torch.max(logits, dim=1, keepdim=True) + # local to global + distributed.all_reduce(max_logits, distributed.ReduceOp.MAX) + logits.sub_(max_logits) + logits.exp_() + sum_logits_exp = torch.sum(logits, dim=1, keepdim=True) + # local to global + distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM) + logits.div_(sum_logits_exp) + index = torch.where(label != -1)[0] + # loss + loss = torch.zeros(batch_size, 1, device=logits.device) + loss[index] = logits[index].gather(1, label[index]) + distributed.all_reduce(loss, distributed.ReduceOp.SUM) + ctx.save_for_backward(index, logits, label) + return loss.clamp_min_(1e-30).log_().mean() * (-1) + + @staticmethod + def backward(ctx, loss_gradient): + """ + Args: + loss_grad (torch.Tensor): gradient backward by last layer + Returns: + gradients for each input in forward function + `None` gradients for one-hot label + """ + ( + index, + logits, + label, + ) = ctx.saved_tensors + batch_size = logits.size(0) + one_hot = torch.zeros( + size=[index.size(0), logits.size(1)], device=logits.device + ) + one_hot.scatter_(1, label[index], 1) + logits[index] -= one_hot + logits.div_(batch_size) + return logits * loss_gradient.item(), None + + +class DistCrossEntropy(torch.nn.Module): + def __init__(self): + super(DistCrossEntropy, self).__init__() + + def forward(self, logit_part, label_part): + return DistCrossEntropyFunc.apply(logit_part, label_part) + + +class AllGatherFunc(torch.autograd.Function): + """AllGather op with gradient backward""" + + @staticmethod + def forward(ctx, tensor, *gather_list): + gather_list = list(gather_list) + distributed.all_gather(gather_list, tensor) + return tuple(gather_list) + + @staticmethod + def backward(ctx, *grads): + grad_list = list(grads) + rank = distributed.get_rank() + grad_out = grad_list[rank] + + dist_ops = [ + distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True) + if i == rank + else distributed.reduce( + grad_list[i], i, distributed.ReduceOp.SUM, async_op=True + ) + for i in range(distributed.get_world_size()) + ] + for _op in dist_ops: + _op.wait() + + grad_out *= len(grad_list) # cooperate with distributed loss function + return (grad_out, *[None for _ in range(len(grad_list))]) + + +AllGather = AllGatherFunc.apply diff --git a/recognition/arcface_torch/run.sh b/recognition/arcface_torch/run.sh index 61af4b4..4069075 100644 --- a/recognition/arcface_torch/run.sh +++ b/recognition/arcface_torch/run.sh @@ -1,2 +1,9 @@ -CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 + +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \ +--nproc_per_node=8 \ +--nnodes=1 \ +--node_rank=0 \ +--master_addr="127.0.0.1" \ +--master_port=12345 train.py $@ + ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh diff --git a/recognition/arcface_torch/torch2onnx.py b/recognition/arcface_torch/torch2onnx.py index fc26ab8..63ce2c5 100644 --- a/recognition/arcface_torch/torch2onnx.py +++ b/recognition/arcface_torch/torch2onnx.py @@ -12,7 +12,7 @@ def convert_onnx(net, path_module, output, opset=11, simplify=False): img = torch.from_numpy(img).unsqueeze(0).float() weight = torch.load(path_module) - net.load_state_dict(weight) + net.load_state_dict(weight, strict=True) net.eval() torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset) model = onnx.load(output) @@ -38,22 +38,16 @@ if __name__ == '__main__': args = parser.parse_args() input_file = args.input if os.path.isdir(input_file): - input_file = os.path.join(input_file, "backbone.pth") + input_file = os.path.join(input_file, "model.pt") assert os.path.exists(input_file) - model_name = os.path.basename(os.path.dirname(input_file)).lower() - params = model_name.split("_") - if len(params) >= 3 and params[1] in ('arcface', 'cosface'): - if args.network is None: - args.network = params[2] + # model_name = os.path.basename(os.path.dirname(input_file)).lower() + # params = model_name.split("_") + # if len(params) >= 3 and params[1] in ('arcface', 'cosface'): + # if args.network is None: + # args.network = params[2] assert args.network is not None print(args) backbone_onnx = get_model(args.network, dropout=0) - - output_path = args.output - if output_path is None: - output_path = os.path.join(os.path.dirname(__file__), 'onnx') - if not os.path.exists(output_path): - os.makedirs(output_path) - assert os.path.isdir(output_path) - output_file = os.path.join(output_path, "%s.onnx" % model_name) - convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify) + if args.output is None: + args.output = os.path.join(os.path.dirname(args.input), "model.onnx") + convert_onnx(backbone_onnx, input_file, args.output, simplify=args.simplify) diff --git a/recognition/arcface_torch/train.py b/recognition/arcface_torch/train.py index 55eca2d..65e1832 100644 --- a/recognition/arcface_torch/train.py +++ b/recognition/arcface_torch/train.py @@ -3,139 +3,146 @@ import logging import os import torch -import torch.distributed as dist -import torch.nn.functional as F -import torch.utils.data.distributed -from torch.nn.utils import clip_grad_norm_ +from torch import distributed +from torch.utils.tensorboard import SummaryWriter -import losses from backbones import get_model -from dataset import MXFaceDataset, SyntheticDataset, DataLoaderX +from dataset import get_dataloader +from torch.utils.data import DataLoader +from lr_scheduler import PolyScheduler from partial_fc import PartialFC -from utils.utils_amp import MaxClipGradScaler -from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint +from utils.utils_callbacks import CallBackLogging, CallBackVerification from utils.utils_config import get_config from utils.utils_logging import AverageMeter, init_logging -def main(args): - cfg = get_config(args.config) - try: - world_size = int(os.environ['WORLD_SIZE']) - rank = int(os.environ['RANK']) - dist.init_process_group('nccl') - except KeyError: - world_size = 1 - rank = 0 - dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:12584", rank=rank, world_size=world_size) +try: + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) + distributed.init_process_group("nccl") +except KeyError: + world_size = 1 + rank = 0 + distributed.init_process_group( + backend="nccl", + init_method="tcp://127.0.0.1:12584", + rank=rank, + world_size=world_size, + ) + + +def main(args): + torch.cuda.set_device(args.local_rank) + cfg = get_config(args.config) - local_rank = args.local_rank - torch.cuda.set_device(local_rank) os.makedirs(cfg.output, exist_ok=True) init_logging(rank, cfg.output) - - if cfg.rec == "synthetic": - train_set = SyntheticDataset(local_rank=local_rank) - else: - train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank) - - train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True) - train_loader = DataLoaderX( - local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size, - sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True) - backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank) - - if cfg.resume: - try: - backbone_pth = os.path.join(cfg.output, "backbone.pth") - backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank))) - if rank == 0: - logging.info("backbone resume successfully!") - except (FileNotFoundError, KeyError, IndexError, RuntimeError): - if rank == 0: - logging.info("resume fail, backbone init successfully!") + summary_writer = ( + SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard")) + if rank == 0 + else None + ) + train_loader = get_dataloader( + cfg.rec, local_rank=args.local_rank, batch_size=cfg.batch_size, dali=cfg.dali) + backbone = get_model( + cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size + ).cuda() backbone = torch.nn.parallel.DistributedDataParallel( - module=backbone, broadcast_buffers=False, device_ids=[local_rank]) + module=backbone, broadcast_buffers=False, device_ids=[args.local_rank]) backbone.train() - margin_softmax = losses.get_loss(cfg.loss) module_partial_fc = PartialFC( - rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume, - batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes, - sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output) + cfg.embedding_size, + cfg.num_classes, + cfg.sample_rate, + cfg.fp16 + ) + module_partial_fc.train().cuda() - opt_backbone = torch.optim.SGD( - params=[{'params': backbone.parameters()}], - lr=cfg.lr / 512 * cfg.batch_size * world_size, - momentum=0.9, weight_decay=cfg.weight_decay) - opt_pfc = torch.optim.SGD( - params=[{'params': module_partial_fc.parameters()}], - lr=cfg.lr / 512 * cfg.batch_size * world_size, - momentum=0.9, weight_decay=cfg.weight_decay) - - num_image = len(train_set) + # TODO the params of partial fc must be last in the params list + opt = torch.optim.SGD( + params=[ + {"params": backbone.parameters(), }, + {"params": module_partial_fc.parameters(), }, + ], + lr=cfg.lr, + momentum=0.9, + weight_decay=cfg.weight_decay + ) total_batch_size = cfg.batch_size * world_size - cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch - cfg.total_step = num_image // total_batch_size * cfg.num_epoch - - def lr_step_func(current_step): - cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch] - if current_step < cfg.warmup_step: - return current_step / cfg.warmup_step - else: - return 0.1 ** len([m for m in cfg.decay_step if m <= current_step]) - - scheduler_backbone = torch.optim.lr_scheduler.LambdaLR( - optimizer=opt_backbone, lr_lambda=lr_step_func) - scheduler_pfc = torch.optim.lr_scheduler.LambdaLR( - optimizer=opt_pfc, lr_lambda=lr_step_func) + cfg.warmup_step = cfg.num_image // total_batch_size * cfg.warmup_epoch + cfg.total_step = cfg.num_image // total_batch_size * cfg.num_epoch + lr_scheduler = PolyScheduler( + optimizer=opt, + base_lr=cfg.lr, + max_steps=cfg.total_step, + warmup_steps=cfg.warmup_step + ) for key, value in cfg.items(): num_space = 25 - len(key) logging.info(": " + key + " " * num_space + str(value)) - val_target = cfg.val_targets - callback_verification = CallBackVerification(2000, rank, val_target, cfg.rec) - callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None) - callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output) + callback_verification = CallBackVerification( + val_targets=cfg.val_targets, rec_prefix=cfg.rec, summary_writer=summary_writer + ) + callback_logging = CallBackLogging( + frequent=cfg.frequent, + total_step=cfg.total_step, + batch_size=cfg.batch_size, + writer=summary_writer + ) - loss = AverageMeter() + loss_am = AverageMeter() start_epoch = 0 global_step = 0 - grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None - for epoch in range(start_epoch, cfg.num_epoch): - train_sampler.set_epoch(epoch) - for step, (img, label) in enumerate(train_loader): - global_step += 1 - features = F.normalize(backbone(img)) - x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc) - if cfg.fp16: - features.backward(grad_amp.scale(x_grad)) - grad_amp.unscale_(opt_backbone) - clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) - grad_amp.step(opt_backbone) - grad_amp.update() - else: - features.backward(x_grad) - clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) - opt_backbone.step() + amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100) - opt_pfc.step() - module_partial_fc.update() - opt_backbone.zero_grad() - opt_pfc.zero_grad() - loss.update(loss_v, 1) - callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp) - callback_verification(global_step, backbone) - scheduler_backbone.step() - scheduler_pfc.step() - callback_checkpoint(global_step, backbone, module_partial_fc) - dist.destroy_process_group() + for epoch in range(start_epoch, cfg.num_epoch): + + if isinstance(train_loader, DataLoader): + train_loader.sampler.set_epoch(epoch) + for _, (img, local_labels) in enumerate(train_loader): + global_step += 1 + local_embeddings = backbone(img) + loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels, opt) + + if cfg.fp16: + amp.scale(loss).backward() + amp.step(opt) + amp.update() + else: + loss.backward() + opt.step() + + opt.zero_grad() + lr_scheduler.step() + + with torch.no_grad(): + loss_am.update(loss.item(), 1) + callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp) + + if global_step % cfg.verbose == 0 and global_step > 200: + callback_verification(global_step, backbone) + + path_pfc = os.path.join(cfg.output, "softmax_fc_gpu_{}.pt".format(rank)) + torch.save(module_partial_fc.state_dict(), path_pfc) + if rank == 0: + path_module = os.path.join(cfg.output, "model.pt") + torch.save(backbone.module.state_dict(), path_module) + + if cfg.dali: + train_loader.reset() + + if rank == 0: + path_module = os.path.join(cfg.output, "model.pt") + torch.save(backbone.module.state_dict(), path_module) + distributed.destroy_process_group() if __name__ == "__main__": torch.backends.cudnn.benchmark = True - parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') - parser.add_argument('config', type=str, help='py config file') - parser.add_argument('--local_rank', type=int, default=0, help='local_rank') + parser = argparse.ArgumentParser(description="Distributed Arcface Training in Pytorch") + parser.add_argument("config", type=str, help="py config file") + parser.add_argument("--local_rank", type=int, default=0, help="local_rank") main(parser.parse_args()) diff --git a/recognition/arcface_torch/utils/plot.py b/recognition/arcface_torch/utils/plot.py index ccc588e..7f1d39d 100644 --- a/recognition/arcface_torch/utils/plot.py +++ b/recognition/arcface_torch/utils/plot.py @@ -1,7 +1,5 @@ -# coding: utf-8 - import os -from pathlib import Path +import sys import matplotlib.pyplot as plt import numpy as np @@ -10,10 +8,11 @@ from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap from prettytable import PrettyTable from sklearn.metrics import roc_curve, auc -image_path = "/data/anxiang/IJB_release/IJBC" -files = [ - "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy" -] +with open(sys.argv[1], "r") as f: + files = f.readlines() + +files = [x.strip() for x in files] +image_path = "/train_tmp/IJB_release/IJBC" def read_template_pair_list(path): @@ -31,7 +30,7 @@ p1, p2, label = read_template_pair_list( methods = [] scores = [] for file in files: - methods.append(file.split('/')[-2]) + methods.append(file) scores.append(np.load(file)) methods = np.array(methods) @@ -53,7 +52,7 @@ for method in methods: label=('[%s (AUC = %0.4f %%)]' % (method.split('-')[-1], roc_auc * 100))) tpr_fpr_row = [] - tpr_fpr_row.append("%s-%s" % (method, "IJBC")) + tpr_fpr_row.append(method) for fpr_iter in np.arange(len(x_labels)): _, min_index = min( list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) diff --git a/recognition/arcface_torch/utils/utils_amp.py b/recognition/arcface_torch/utils/utils_amp.py deleted file mode 100644 index 9ac2a03..0000000 --- a/recognition/arcface_torch/utils/utils_amp.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Dict, List - -import torch - -if torch.__version__ < '1.9': - Iterable = torch._six.container_abcs.Iterable -else: - import collections - - Iterable = collections.abc.Iterable -from torch.cuda.amp import GradScaler - - -class _MultiDeviceReplicator(object): - """ - Lazily serves copies of a tensor to requested devices. Copies are cached per-device. - """ - - def __init__(self, master_tensor: torch.Tensor) -> None: - assert master_tensor.is_cuda - self.master = master_tensor - self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} - - def get(self, device) -> torch.Tensor: - retval = self._per_device_tensors.get(device, None) - if retval is None: - retval = self.master.to(device=device, non_blocking=True, copy=True) - self._per_device_tensors[device] = retval - return retval - - -class MaxClipGradScaler(GradScaler): - def __init__(self, init_scale, max_scale: float, growth_interval=100): - GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval) - self.max_scale = max_scale - - def scale_clip(self): - if self.get_scale() == self.max_scale: - self.set_growth_factor(1) - elif self.get_scale() < self.max_scale: - self.set_growth_factor(2) - elif self.get_scale() > self.max_scale: - self._scale.fill_(self.max_scale) - self.set_growth_factor(1) - - def scale(self, outputs): - """ - Multiplies ('scales') a tensor or list of tensors by the scale factor. - - Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned - unmodified. - - Arguments: - outputs (Tensor or iterable of Tensors): Outputs to scale. - """ - if not self._enabled: - return outputs - self.scale_clip() - # Short-circuit for the common case. - if isinstance(outputs, torch.Tensor): - assert outputs.is_cuda - if self._scale is None: - self._lazy_init_scale_growth_tracker(outputs.device) - assert self._scale is not None - return outputs * self._scale.to(device=outputs.device, non_blocking=True) - - # Invoke the more complex machinery only if we're treating multiple outputs. - stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale - - def apply_scale(val): - if isinstance(val, torch.Tensor): - assert val.is_cuda - if len(stash) == 0: - if self._scale is None: - self._lazy_init_scale_growth_tracker(val.device) - assert self._scale is not None - stash.append(_MultiDeviceReplicator(self._scale)) - return val * stash[0].get(val.device) - elif isinstance(val, Iterable): - iterable = map(apply_scale, val) - if isinstance(val, list) or isinstance(val, tuple): - return type(val)(iterable) - else: - return iterable - else: - raise ValueError("outputs must be a Tensor or an iterable of Tensors") - - return apply_scale(outputs) diff --git a/recognition/arcface_torch/utils/utils_callbacks.py b/recognition/arcface_torch/utils/utils_callbacks.py index bd2f56c..97fe403 100644 --- a/recognition/arcface_torch/utils/utils_callbacks.py +++ b/recognition/arcface_torch/utils/utils_callbacks.py @@ -7,12 +7,14 @@ import torch from eval import verification from utils.utils_logging import AverageMeter +from torch.utils.tensorboard import SummaryWriter +from torch import distributed class CallBackVerification(object): - def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)): - self.frequent: int = frequent - self.rank: int = rank + + def __init__(self, val_targets, rec_prefix, summary_writer=None, image_size=(112, 112)): + self.rank: int = distributed.get_rank() self.highest_acc: float = 0.0 self.highest_acc_list: List[float] = [0.0] * len(val_targets) self.ver_list: List[object] = [] @@ -20,6 +22,8 @@ class CallBackVerification(object): if self.rank is 0: self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size) + self.summary_writer = summary_writer + def ver_test(self, backbone: torch.nn.Module, global_step: int): results = [] for i in range(len(self.ver_list)): @@ -27,6 +31,10 @@ class CallBackVerification(object): self.ver_list[i], backbone, 10, 10) logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm)) logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2)) + + self.summary_writer: SummaryWriter + self.summary_writer.add_scalar(tag=self.ver_name_list[i], scalar_value=acc2, global_step=global_step, ) + if acc2 > self.highest_acc_list[i]: self.highest_acc_list[i] = acc2 logging.info( @@ -42,20 +50,20 @@ class CallBackVerification(object): self.ver_name_list.append(name) def __call__(self, num_update, backbone: torch.nn.Module): - if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0: + if self.rank is 0 and num_update > 0: backbone.eval() self.ver_test(backbone, num_update) backbone.train() class CallBackLogging(object): - def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None): + def __init__(self, frequent, total_step, batch_size, writer=None): self.frequent: int = frequent - self.rank: int = rank + self.rank: int = distributed.get_rank() + self.world_size: int = distributed.get_world_size() self.time_start = time.time() self.total_step: int = total_step self.batch_size: int = batch_size - self.world_size: int = world_size self.writer = writer self.init = False @@ -100,18 +108,3 @@ class CallBackLogging(object): else: self.init = True self.tic = time.time() - - -class CallBackModelCheckpoint(object): - def __init__(self, rank, output="./"): - self.rank: int = rank - self.output: str = output - - def __call__(self, global_step, backbone, partial_fc, ): - if global_step > 100 and self.rank == 0: - path_module = os.path.join(self.output, "backbone.pth") - torch.save(backbone.module.state_dict(), path_module) - logging.info("Pytorch Model Saved in '{}'".format(path_module)) - - if global_step > 100 and partial_fc is not None: - partial_fc.save_params() diff --git a/recognition/arcface_torch/utils/utils_os.py b/recognition/arcface_torch/utils/utils_os.py deleted file mode 100644 index e69de29..0000000