updated for WebFace42M

updated readability of the code
This commit is contained in:
AnXiang
2022-01-14 17:43:40 +08:00
parent c2b52f44a6
commit bb221e6e6d
31 changed files with 798 additions and 752 deletions

5
recognition/arcface_torch/.gitignore vendored Normal file
View File

@@ -0,0 +1,5 @@
**__pycache__/
.vscode
bak*/
work_dirs/
models/

View File

@@ -5,50 +5,50 @@ identity on a single server.
## Requirements
- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
- Install [PyTorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
- (Optional) Install [DALI](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/), our doc for [install_dali.md](docs/install_dali.md).
- `pip install -r requirements.txt`.
- Download the dataset
from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
.
## How to Training
To train a model, run `train.py` with the path to the configs:
To train a model, run `train.py` with the path to the configs.
The example commands below show how to run
distributed training.
### 1. Single node, 8 GPUs:
### 1. To run on a machine with 8 GPUs:
```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=12581 train.py configs/ms1mv3_r50_lr02
```
### 2. Multiple nodes, each node 8 GPUs:
### 2. To run on 2 machines with 8 GPUs each:
Node 0:
```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=12581 train.py configs/webface42m_r100_lr01_pfc02_bs4k_16gpus
```
Node 1:
```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=12581 train.py configs/webface42m_r100_lr01_pfc02_bs4k_16gpus
```
### 3.Training resnet2060 with 8 GPUs:
## Download Datasets or Prepare Datasets
```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
```
- [MS1MV3](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-retinaface) (93k IDs, 5.2M images)
- [Glint360K](https://github.com/deepinsight/insightface/tree/master/recognition/partial_fc#4-download) (360k IDs, 17.1M images)
- [WebFace42M](docs/prepare_webface42m.md) (2M IDs, 42.5M images)
## Model Zoo
- The models are available for non-commercial research purposes only.
- All models can be found in here.
- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
- [OneDrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
### Performance on IJB-C and [**ICCV2021-MFR**](https://github.com/deepinsight/insightface/blob/master/challenges/mfr/README.md)
ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
@@ -57,44 +57,24 @@ As the result, we can evaluate the FAIR performance for different algorithms.
For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
There are totally 13,928 positive pairs and 96,983,824 negative pairs.
| Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
| :---: | :--- | :--- | :--- |:--- |:--- |
| MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
| Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
| MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
| Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
| MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
| Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
| MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
| Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
| MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
| Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
### Performance on IJB-C and Verification Datasets
| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
| :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
| MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
| MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
| MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
| MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
| MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
| Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
| Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
| Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
| Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
[comment]: <> (More details see [model.md]&#40;docs/modelzoo.md&#41; in docs.)
## [Speed Benchmark](docs/speed_benchmark.md)
| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Training Throughout | log |
|:-------------------------|:-----------|:-------------|:------------|:------------|:--------------------|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| MS1MV3 | mobileface | - | - | - | ~13000 | log\|config |
| Glint360K | mobileface | - | - | - | - | log\|config |
| WebFace42M-PartialFC-0.2 | mobileface | 73.80 | 95.40 | 92.64 | (16GPUs)~18583 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_mobilefacenet_pfc02_bs8k_16gpus/training.log)\|[config](configs/webface42m_mobilefacenet_pfc02_bs8k_16gpus.py) |
| MS1MV3 | r100 | 85.39 | 97.00 | 95.36 | ~3400 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r100_lr01/training.log)\|[config](configs/ms1mv3_r100_lr02.py) |
| Glint360K | r100 | - | - | - | ~3400 | log\|[config](configs/glint360k_100_lr02.py) |
| WebFace42M-PartialFC-0.2 | r50 | 93.83 | 97.53 | 96.16 | ~5900 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_bs4k_pfc02/training.log)\|[config](configs/webface42m_r50_lr01_pfc02_bs4k_8gpus.py) |
| WebFace42M-PartialFC-0.2 | r50 | 94.04 | 97.48 | 95.94 | (32GPUs)~17000 | log\|[config](configs/webface42m_r50_lr01_pfc02_bs4k_32gpus.py) |
| WebFace42M-PartialFC-0.2 | r100 | 96.69 | 97.85 | 96.63 | (16GPUs)~5200 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r100_bs4k_pfc02/training.log)\|[config](configs/webface42m_r100_lr01_pfc02_bs4k_16gpus.py) |
**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
## Speed Benchmark
`arcface_torch` can train large-scale face recognition training set efficiently and quickly. When the number of
classes in training sets is greater than 1 Million, partial fc sampling strategy will get same
accuracy with several times faster training performance and smaller GPU memory.
Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
@@ -112,39 +92,27 @@ More details see
`-` means training failed because of gpu memory limitations.
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
| :--- | :--- | :--- | :--- |
|125000 | 4681 | 4824 | 5004 |
|1400000 | **1672** | 3043 | 4738 |
|5500000 | **-** | **1389** | 3975 |
|8000000 | **-** | **-** | 3565 |
|16000000 | **-** | **-** | 2679 |
|29000000 | **-** | **-** | **1855** |
|:--------------------------------|:--------------|:---------------|:---------------|
| 125000 | 4681 | 4824 | 5004 |
| 1400000 | **1672** | 3043 | 4738 |
| 5500000 | **-** | **1389** | 3975 |
| 8000000 | **-** | **-** | 3565 |
| 16000000 | **-** | **-** | 2679 |
| 29000000 | **-** | **-** | **1855** |
### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
| :--- | :--- | :--- | :--- |
|125000 | 7358 | 5306 | 4868 |
|1400000 | 32252 | 11178 | 6056 |
|5500000 | **-** | 32188 | 9854 |
|8000000 | **-** | **-** | 12310 |
|16000000 | **-** | **-** | 19950 |
|29000000 | **-** | **-** | 32324 |
|:--------------------------------|:--------------|:---------------|:---------------|
| 125000 | 7358 | 5306 | 4868 |
| 1400000 | 32252 | 11178 | 6056 |
| 5500000 | **-** | 32188 | 9854 |
| 8000000 | **-** | **-** | 12310 |
| 16000000 | **-** | **-** | 19950 |
| 29000000 | **-** | **-** | 32324 |
## Evaluation ICCV2021-MFR and IJB-C
More details see [eval.md](docs/eval.md) in docs.
## Test
We tested many versions of PyTorch. Please create an issue if you are having trouble.
- [x] torch 1.6.0
- [x] torch 1.7.1
- [x] torch 1.8.0
- [x] torch 1.9.0
## Citation
## Citations
```
@inproceedings{deng2019arcface,

View File

@@ -3,21 +3,20 @@ from easydict import EasyDict as edict
# configs for test speed
config = edict()
config.loss = "arcface"
config.loss = "cosface"
config.network = "r50"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.sample_rate = 0.99
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.batch_size = 64 # total_batch_size = batch_size * num_gpus
config.lr = 0.1 # batch size is 512
config.rec = "synthetic"
config.num_classes = 300 * 10000
config.num_epoch = 30
config.warmup_epoch = -1
config.decay_epoch = [10, 16, 22]
config.val_targets = []

View File

@@ -1,23 +0,0 @@
from easydict import EasyDict as edict
# configs for test speed
config = edict()
config.loss = "arcface"
config.network = "r50"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 0.1
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.rec = "synthetic"
config.num_classes = 300 * 10000
config.num_epoch = 30
config.warmup_epoch = -1
config.decay_epoch = [10, 16, 22]
config.val_targets = []

View File

@@ -10,7 +10,6 @@ config.network = "r50"
config.resume = False
config.output = "ms1mv3_arcface_r50"
config.dataset = "ms1m-retinaface-t1"
config.embedding_size = 512
config.sample_rate = 1
config.fp16 = False
@@ -18,39 +17,31 @@ config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.dali = False
config.verbose = 2000
config.frequent = 10
config.score = None
if config.dataset == "emore":
config.rec = "/train_tmp/faces_emore"
config.num_classes = 85742
config.num_image = 5822653
config.num_epoch = 16
config.warmup_epoch = -1
config.decay_epoch = [8, 14, ]
config.val_targets = ["lfw", ]
# if config.dataset == "emore":
# config.rec = "/train_tmp/faces_emore"
# config.num_classes = 85742
# config.num_image = 5822653
# config.num_epoch = 16
# config.warmup_epoch = -1
# config.val_targets = ["lfw", ]
elif config.dataset == "ms1m-retinaface-t1":
config.rec = "/train_tmp/ms1m-retinaface-t1"
config.num_classes = 93431
config.num_image = 5179510
config.num_epoch = 25
config.warmup_epoch = -1
config.decay_epoch = [11, 17, 22]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
# elif config.dataset == "ms1m-retinaface-t1":
# config.rec = "/train_tmp/ms1m-retinaface-t1"
# config.num_classes = 93431
# config.num_image = 5179510
# config.num_epoch = 25
# config.warmup_epoch = -1
# config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
elif config.dataset == "glint360k":
config.rec = "/train_tmp/glint360k"
config.num_classes = 360232
config.num_image = 17091657
config.num_epoch = 20
config.warmup_epoch = -1
config.decay_epoch = [8, 12, 15, 18]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
elif config.dataset == "webface":
config.rec = "/train_tmp/faces_webface_112x112"
config.num_classes = 10572
config.num_image = "forget"
config.num_epoch = 34
config.warmup_epoch = -1
config.decay_epoch = [20, 28, 32]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
# elif config.dataset == "glint360k":
# config.rec = "/train_tmp/glint360k"
# config.num_classes = 360232
# config.num_image = 17091657
# config.num_epoch = 20
# config.warmup_epoch = -1
# config.val_targets = ["lfw", "cfp_fp", "agedb_30"]

View File

@@ -5,7 +5,7 @@ from easydict import EasyDict as edict
# mount -t tmpfs -o size=140G tmpfs /train_tmp
config = edict()
config.loss = "cosface"
config.loss = "arcface"
config.network = "r100"
config.resume = False
config.output = None
@@ -15,12 +15,13 @@ config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.lr = 0.2
config.verbose = 2000
config.dali = False
config.rec = "/train_tmp/glint360k"
config.num_classes = 360232
config.num_image = 17091657
config.num_epoch = 20
config.warmup_epoch = -1
config.decay_epoch = [8, 12, 15, 18]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
config.warmup_epoch = 0
config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]

View File

@@ -13,14 +13,15 @@ config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 2e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.weight_decay = 1e-4
config.batch_size = 256
config.lr = 0.2
config.verbose = 5000
config.dali = False
config.rec = "/train_tmp/ms1m-retinaface-t1"
config.num_classes = 93431
config.num_image = 5179510
config.num_epoch = 30
config.warmup_epoch = -1
config.decay_epoch = [10, 20, 25]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
config.num_epoch = 40
config.warmup_epoch = 2
config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]

View File

@@ -6,7 +6,7 @@ from easydict import EasyDict as edict
config = edict()
config.loss = "arcface"
config.network = "r18"
config.network = "r100"
config.resume = False
config.output = None
config.embedding_size = 512
@@ -15,12 +15,13 @@ config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.lr = 0.2
config.verbose = 2000
config.dali = False
config.rec = "/train_tmp/ms1m-retinaface-t1"
config.num_classes = 93431
config.num_image = 5179510
config.num_epoch = 25
config.warmup_epoch = -1
config.decay_epoch = [10, 16, 22]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
config.warmup_epoch = 2
config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]

View File

@@ -1,26 +0,0 @@
from easydict import EasyDict as edict
# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G tmpfs /train_tmp
config = edict()
config.loss = "arcface"
config.network = "r2060"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 64
config.lr = 0.1 # batch size is 512
config.rec = "/train_tmp/ms1m-retinaface-t1"
config.num_classes = 93431
config.num_image = 5179510
config.num_epoch = 25
config.warmup_epoch = -1
config.decay_epoch = [10, 16, 22]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]

View File

@@ -1,26 +0,0 @@
from easydict import EasyDict as edict
# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G tmpfs /train_tmp
config = edict()
config.loss = "arcface"
config.network = "r34"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.rec = "/train_tmp/ms1m-retinaface-t1"
config.num_classes = 93431
config.num_image = 5179510
config.num_epoch = 25
config.warmup_epoch = -1
config.decay_epoch = [10, 16, 22]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]

View File

@@ -15,12 +15,13 @@ config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.lr = 0.2
config.verbose = 2000
config.dali = False
config.rec = "/train_tmp/ms1m-retinaface-t1"
config.num_classes = 93431
config.num_image = 5179510
config.num_epoch = 25
config.warmup_epoch = -1
config.decay_epoch = [10, 16, 22]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
config.warmup_epoch = 2
config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]

View File

@@ -1,23 +0,0 @@
from easydict import EasyDict as edict
# configs for test speed
config = edict()
config.loss = "arcface"
config.network = "r50"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.rec = "synthetic"
config.num_classes = 100 * 10000
config.num_epoch = 30
config.warmup_epoch = -1
config.decay_epoch = [10, 16, 22]
config.val_targets = []

View File

@@ -10,17 +10,18 @@ config.network = "mbf"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 0.1
config.sample_rate = 0.2
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 2e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.weight_decay = 1e-4
config.batch_size = 512
config.lr = 0.4
config.verbose = 10000
config.dali = True
config.rec = "/train_tmp/glint360k"
config.num_classes = 360232
config.num_image = 17091657
config.rec = "/train_tmp/WebFace42M"
config.num_classes = 2059906
config.num_image = 42474557
config.num_epoch = 20
config.warmup_epoch = -1
config.decay_epoch = [8, 12, 15, 18]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
config.warmup_epoch = 2
config.val_targets = []

View File

@@ -6,21 +6,22 @@ from easydict import EasyDict as edict
config = edict()
config.loss = "cosface"
config.network = "r34"
config.network = "r100"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.sample_rate = 0.2
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.batch_size = 256
config.lr = 0.3
config.verbose = 2000
config.dali = True
config.rec = "/train_tmp/glint360k"
config.num_classes = 360232
config.num_image = 17091657
config.rec = "/train_tmp/WebFace42M"
config.num_classes = 2059906
config.num_image = 42474557
config.num_epoch = 20
config.warmup_epoch = -1
config.decay_epoch = [8, 12, 15, 18]
config.warmup_epoch = 1
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]

View File

@@ -10,17 +10,18 @@ config.network = "r50"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.sample_rate = 0.2
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.lr = 0.4
config.verbose = 10000
config.dali = True
config.rec = "/train_tmp/glint360k"
config.num_classes = 360232
config.num_image = 17091657
config.rec = "/train_tmp/WebFace42M"
config.num_classes = 2059906
config.num_image = 42474557
config.num_epoch = 20
config.warmup_epoch = -1
config.decay_epoch = [8, 12, 15, 18]
config.warmup_epoch = 2
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]

View File

@@ -6,21 +6,22 @@ from easydict import EasyDict as edict
config = edict()
config.loss = "cosface"
config.network = "r18"
config.network = "r50"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.sample_rate = 0.2
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.batch_size = 512
config.lr = 0.4
config.verbose = 10000
config.dali = True
config.rec = "/train_tmp/glint360k"
config.num_classes = 360232
config.num_image = 17091657
config.rec = "/train_tmp/WebFace42M"
config.num_classes = 2059906
config.num_image = 42474557
config.num_epoch = 20
config.warmup_epoch = -1
config.decay_epoch = [8, 12, 15, 18]
config.warmup_epoch = 2
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]

View File

@@ -2,13 +2,42 @@ import numbers
import os
import queue as Queue
import threading
from typing import Iterable
import mxnet as mx
import numpy as np
import torch
from torch import distributed
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
def get_dataloader(
root_dir: str,
local_rank: int,
batch_size: int,
dali = False) -> Iterable:
if dali and root_dir != "synthetic":
rec = os.path.join(root_dir, 'train.rec')
idx = os.path.join(root_dir, 'train.idx')
return dali_data_iter(
batch_size=batch_size, rec_file=rec,
idx_file=idx, num_threads=2, local_rank=local_rank)
else:
if root_dir == "synthetic":
train_set = SyntheticDataset()
else:
train_set = MXFaceDataset(root_dir=root_dir, local_rank=local_rank)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True)
train_loader = DataLoaderX(
local_rank=local_rank,
dataset=train_set,
batch_size=batch_size,
sampler=train_sampler,
num_workers=2,
pin_memory=True,
drop_last=True,
)
return train_loader
class BackgroundGenerator(threading.Thread):
def __init__(self, generator, local_rank, max_prefetch=6):
@@ -108,7 +137,7 @@ class MXFaceDataset(Dataset):
class SyntheticDataset(Dataset):
def __init__(self, local_rank):
def __init__(self):
super(SyntheticDataset, self).__init__()
img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
img = np.transpose(img, (2, 0, 1))
@@ -122,3 +151,59 @@ class SyntheticDataset(Dataset):
def __len__(self):
return 1000000
def dali_data_iter(
batch_size: int, rec_file: str, idx_file: str, num_threads: int,
initial_fill=32768, random_shuffle=True,
prefetch_queue_depth=1, local_rank=0, name="reader",
mean=(127.5, 127.5, 127.5),
std=(127.5, 127.5, 127.5)):
"""
Parameters:
----------
initial_fill: int
Size of the buffer that is used for shuffling. If random_shuffle is False, this parameter is ignored.
"""
rank: int = distributed.get_rank()
world_size: int = distributed.get_world_size()
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
pipe = Pipeline(
batch_size=batch_size, num_threads=num_threads,
device_id=local_rank, prefetch_queue_depth=prefetch_queue_depth, )
condition_flip = fn.random.coin_flip(probability=0.5)
with pipe:
jpegs, labels = fn.readers.mxnet(
path=rec_file, index_path=idx_file, initial_fill=initial_fill,
num_shards=world_size, shard_id=rank,
random_shuffle=random_shuffle, pad_last_batch=False, name=name)
images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB)
images = fn.crop_mirror_normalize(
images, dtype=types.FLOAT, mean=mean, std=std, mirror=condition_flip)
pipe.set_outputs(images, labels)
pipe.build()
return DALIWarper(DALIClassificationIterator(pipelines=[pipe], reader_name=name, ))
@torch.no_grad()
class DALIWarper(object):
def __init__(self, dali_iter):
self.iter = dali_iter
def __next__(self):
data_dict = self.iter.__next__()[0]
tensor_data = data_dict['data'].cuda()
tensor_label: torch.Tensor = data_dict['label'].cuda().long()
tensor_label.squeeze_()
return tensor_data, tensor_label
def __iter__(self):
return self
def reset(self):
self.iter.reset()

View File

@@ -0,0 +1 @@
TODO

View File

@@ -0,0 +1,22 @@
## 1. Download Datasets and Unzip
Download WebFace42M from [https://www.face-benchmark.org/download.html](https://www.face-benchmark.org/download.html).
## 2. Create **Pre-shuffle** Rec File for DALI
Note: preshuffled rec is very important to DALI, and rec without preshuffled can cause performance degradation, origin insightface style rec file
do not support Nvidia DALI, you must follow this command [mxnet.tools.im2rec](https://github.com/apache/incubator-mxnet/blob/master/tools/im2rec.py) to generate a pre-shuffle rec file.
```shell
# 1) create train.lst using follow command
python -m mxnet.tools.im2rec --list --recursive train "Your WebFace42M Root"
# 2) create train.rec and train.idx using train.lst using following command
python -m mxnet.tools.im2rec --num-thread 16 --quality 100 train "Your WebFace42M Root"
```
Finally, you will get three files: `train.idx`, `train.rec`, `train.idx`. which `train.idx`, `train.rec` are using for training.

View File

@@ -261,6 +261,8 @@ def test(data_set, backbone, batch_size, nfolds=10):
_xnorm_cnt += 1
_xnorm /= _xnorm_cnt
embeddings = embeddings_list[0].copy()
embeddings = sklearn.preprocessing.normalize(embeddings)
acc1 = 0.0
std1 = 0.0
embeddings = embeddings_list[0] + embeddings_list[1]

View File

@@ -1,42 +0,0 @@
import torch
from torch import nn
def get_loss(name):
if name == "cosface":
return CosFace()
elif name == "arcface":
return ArcFace()
else:
raise ValueError()
class CosFace(nn.Module):
def __init__(self, s=64.0, m=0.40):
super(CosFace, self).__init__()
self.s = s
self.m = m
def forward(self, cosine, label):
index = torch.where(label != -1)[0]
m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
m_hot.scatter_(1, label[index, None], self.m)
cosine[index] -= m_hot
ret = cosine * self.s
return ret
class ArcFace(nn.Module):
def __init__(self, s=64.0, m=0.5):
super(ArcFace, self).__init__()
self.s = s
self.m = m
def forward(self, cosine: torch.Tensor, label):
index = torch.where(label != -1)[0]
m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
m_hot.scatter_(1, label[index, None], self.m)
cosine.acos_()
cosine[index] += m_hot
cosine.cos_().mul_(self.s)
return cosine

View File

@@ -0,0 +1,29 @@
from torch.optim.lr_scheduler import _LRScheduler
class PolyScheduler(_LRScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1):
self.base_lr = base_lr
self.warmup_lr_init = 0.0001
self.max_steps: int = max_steps
self.warmup_steps: int = warmup_steps
self.power = 2
super(PolyScheduler, self).__init__(optimizer, last_epoch, False)
def get_warmup_lr(self):
alpha = float(self.last_epoch) / float(self.warmup_steps)
return [self.base_lr * alpha for _ in self.optimizer.param_groups]
def get_lr(self):
if self.last_epoch == -1:
return [self.warmup_lr_init for _ in self.optimizer.param_groups]
if self.last_epoch < self.warmup_steps:
return self.get_warmup_lr()
else:
alpha = pow(
1
- float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps),
self.power,
)
return [self.base_lr * alpha for _ in self.optimizer.param_groups]

View File

@@ -9,9 +9,10 @@ import numpy as np
import pandas as pd
import prettytable
import skimage.transform
import torch
from sklearn.metrics import roc_curve
from sklearn.preprocessing import normalize
from torch.utils.data import DataLoader
from onnx_helper import ArcFaceORT
SRC = np.array(
@@ -25,6 +26,7 @@ SRC = np.array(
SRC[:, 0] += 8.0
@torch.no_grad()
class AlignedDataSet(mx.gluon.data.Dataset):
def __init__(self, root, lines, align=True):
self.lines = lines
@@ -47,24 +49,23 @@ class AlignedDataSet(mx.gluon.data.Dataset):
img_2 = np.expand_dims(np.fliplr(img), 0)
output = np.concatenate((img_1, img_2), axis=0).astype(np.float32)
output = np.transpose(output, (0, 3, 1, 2))
output = mx.nd.array(output)
return output
return torch.from_numpy(output)
@torch.no_grad()
def extract(model_root, dataset):
model = ArcFaceORT(model_path=model_root)
model.check()
feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim))
def batchify_fn(data):
return mx.nd.concat(*data, dim=0)
def collate_fn(data):
return torch.cat(data, dim=0)
data_loader = mx.gluon.data.DataLoader(
dataset, 128, last_batch='keep', num_workers=4,
thread_pool=True, prefetch=16, batchify_fn=batchify_fn)
data_loader = DataLoader(
dataset, batch_size=128, drop_last=False, num_workers=4, collate_fn=collate_fn, )
num_iter = 0
for batch in data_loader:
batch = batch.asnumpy()
batch = batch.numpy()
batch = (batch - model.input_mean) / model.input_std
feat = model.session.run(model.output_names, {model.input_name: batch})[0]
feat = np.reshape(feat, (-1, model.feat_dim * 2))
@@ -228,10 +229,12 @@ def main(args):
score = verification(template_norm_feats, unique_templates, p1, p2)
stop = timeit.default_timer()
print('Time: %.2f s. ' % (stop - start))
save_path = os.path.join(args.result_dir, "{}_result".format(args.target))
result_dir = args.model_root
save_path = os.path.join(result_dir, "{}_result".format(args.target))
if not os.path.exists(save_path):
os.makedirs(save_path)
score_save_file = os.path.join(save_path, "{}.npy".format(args.model_root))
score_save_file = os.path.join(save_path, "{}.npy".format(args.target))
np.save(score_save_file, score)
files = [score_save_file]
methods = []
@@ -261,7 +264,6 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser(description='do ijb test')
# general
parser.add_argument('--model-root', default='', help='path to load model.')
parser.add_argument('--image-path', default='', type=str, help='')
parser.add_argument('--result-dir', default='.', type=str, help='')
parser.add_argument('--image-path', default='/train_tmp/IJB_release/IJBC', type=str, help='')
parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')
main(parser.parse_args())

View File

@@ -1,222 +1,384 @@
import logging
import os
import math
import collections
from typing import Callable
import torch
import torch.distributed as dist
from torch.nn import Module
from torch.nn.functional import normalize, linear
from torch.nn.parameter import Parameter
from torch import distributed
from torch.nn.functional import linear, normalize
class PartialFC(Module):
class ArcFace(torch.nn.Module):
""" ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
"""
def __init__(self, s=64.0, margin=0.5):
super(ArcFace, self).__init__()
self.scale = s
self.cos_m = math.cos(margin)
self.sin_m = math.sin(margin)
self.theta = math.cos(math.pi - margin)
self.sinmm = math.sin(math.pi - margin) * margin
def forward(self, logits: torch.Tensor, labels: torch.Tensor):
index = torch.where(labels != -1)[0]
target_logit = logits[index, labels[index].view(-1)]
sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin)
if self.easy_margin:
final_target_logit = torch.where(
target_logit > 0, cos_theta_m, target_logit)
else:
final_target_logit = torch.where(
target_logit > self.theta, cos_theta_m, target_logit - self.sinmm)
logits[index, labels[index].view(-1)] = final_target_logit
logits = logits * self.s
return logits
class CosFace(torch.nn.Module):
def __init__(self, s=64.0, m=0.40):
super(CosFace, self).__init__()
self.s = s
self.m = m
def forward(self, logits: torch.Tensor, labels: torch.Tensor):
index = torch.where(labels != -1)[0]
target_logit = logits[index, labels[index].view(-1)]
final_target_logit = target_logit - self.m
logits[index, labels[index].view(-1)] = final_target_logit
logits = logits * self.s
return logits
class PartialFC(torch.nn.Module):
"""
Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,
Partial FC: Training 10 Million Identities on a Single Machine
See the original paper:
https://arxiv.org/abs/2010.05222
"""
A distributed sparsely updating variant of the FC layer, named Partial FC (PFC).
@torch.no_grad()
def __init__(self, rank, local_rank, world_size, batch_size, resume,
margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"):
When sample rate less than 1, in each iteration, positive class centers and a random subset of
negative class centers are selected to compute the margin-based softmax loss, all class
centers are still maintained throughout the whole training process, but only a subset is
selected and updated in each iteration.
.. note::
When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1).
Example:
--------
>>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2)
>>> for img, labels in data_loader:
>>> embeddings = net(img)
>>> loss = module_pfc(embeddings, labels, optimizer)
>>> loss.backward()
>>> optimizer.step()
"""
_version = 1
def __init__(
self,
embedding_size,
num_classes,
sample_rate = 1.0,
fp16: bool = False,
margin_loss = "cosface",
):
"""
rank: int
Unique process(GPU) ID from 0 to world_size - 1.
local_rank: int
Unique process(GPU) ID within the server from 0 to 7.
world_size: int
Number of GPU.
batch_size: int
Batch size on current rank(GPU).
resume: bool
Select whether to restore the weight of softmax.
margin_softmax: callable
A function of margin softmax, eg: cosface, arcface.
num_classes: int
The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size,
required.
sample_rate: float
The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling
can greatly speed up training, and reduce a lot of GPU memory, default is 1.0.
Paramenters:
-----------
embedding_size: int
The feature dimension, default is 512.
prefix: str
Path for save checkpoint, default is './'.
The dimension of embedding, required
num_classes: int
Total number of classes, required
sample_rate: float
The rate of negative centers participating in the calculation, default is 1.0.
"""
super(PartialFC, self).__init__()
#
self.num_classes: int = num_classes
self.rank: int = rank
self.local_rank: int = local_rank
self.device: torch.device = torch.device("cuda:{}".format(self.local_rank))
self.world_size: int = world_size
self.batch_size: int = batch_size
self.margin_softmax: callable = margin_softmax
assert (
distributed.is_initialized()
), "must initialize distributed before create this"
self.rank = distributed.get_rank()
self.world_size = distributed.get_world_size()
self.dist_cross_entropy = DistCrossEntropy()
self.embedding_size = embedding_size
self.sample_rate: float = sample_rate
self.embedding_size: int = embedding_size
self.prefix: str = prefix
self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size)
self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size)
self.fp16 = fp16
self.num_local: int = num_classes // self.world_size + int(
self.rank < num_classes % self.world_size
)
self.class_start: int = num_classes // self.world_size * self.rank + min(
self.rank, num_classes % self.world_size
)
self.num_sample: int = int(self.sample_rate * self.num_local)
self.last_batch_size: int = 0
self.weight: torch.Tensor
self.weight_mom: torch.Tensor
self.weight_activated: torch.nn.Parameter
self.weight_activated_mom: torch.Tensor
self.is_updated: bool = True
self.init_weight_update: bool = True
self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank))
self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank))
if resume:
try:
self.weight: torch.Tensor = torch.load(self.weight_name)
self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)
if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local:
raise IndexError
logging.info("softmax weight resume successfully!")
logging.info("softmax weight mom resume successfully!")
except (FileNotFoundError, KeyError, IndexError):
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
logging.info("softmax weight init!")
logging.info("softmax weight mom init!")
if self.sample_rate < 1:
self.register_buffer("weight",
tensor=torch.normal(0, 0.01, (self.num_local, embedding_size)))
self.register_buffer("weight_mom",
tensor=torch.zeros_like(self.weight))
self.register_parameter("weight_activated",
param=torch.nn.Parameter(torch.empty(0, 0)))
self.register_buffer("weight_activated_mom",
tensor=torch.empty(0, 0))
self.register_buffer("weight_index",
tensor=torch.empty(0, 0))
else:
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
logging.info("softmax weight init successfully!")
logging.info("softmax weight mom init successfully!")
self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank)
self.weight_activated = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))
self.index = None
if int(self.sample_rate) == 1:
self.update = lambda: 0
self.sub_weight = Parameter(self.weight)
self.sub_weight_mom = self.weight_mom
# margin_loss
if isinstance(margin_loss, str):
self.margin_softmax: torch.nn.Module
if margin_loss == "cosface":
self.margin_softmax = CosFace()
elif margin_loss == "arcface":
self.margin_softmax = ArcFace()
else:
raise
elif isinstance(margin_loss, Callable):
self.margin_softmax = margin_loss
else:
self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank))
def save_params(self):
""" Save softmax weight for each rank on prefix
"""
torch.save(self.weight.data, self.weight_name)
torch.save(self.weight_mom, self.weight_mom_name)
raise
@torch.no_grad()
def sample(self, total_label):
def sample(self,
labels: torch.Tensor,
index_positive: torch.Tensor,
optimizer: torch.optim.Optimizer):
"""
Sample all positive class centers in each rank, and random select neg class centers to filling a fixed
`num_sample`.
This functions will change the value of labels
total_label: tensor
Label after all gather, which cross all GPUs.
Parameters:
-----------
labels: torch.Tensor
pass
index_positive: torch.Tensor
pass
optimizer: torch.optim.Optimizer
pass
"""
index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local)
total_label[~index_positive] = -1
total_label[index_positive] -= self.class_start
if int(self.sample_rate) != 1:
positive = torch.unique(total_label[index_positive], sorted=True)
if self.num_sample - positive.size(0) >= 0:
perm = torch.rand(size=[self.num_local], device=self.device)
perm[positive] = 2.0
index = torch.topk(perm, k=self.num_sample)[1]
index = index.sort()[0]
else:
index = positive
self.index = index
total_label[index_positive] = torch.searchsorted(index, total_label[index_positive])
self.sub_weight = Parameter(self.weight[index])
self.sub_weight_mom = self.weight_mom[index]
positive = torch.unique(labels[index_positive], sorted=True).cuda()
if self.num_sample - positive.size(0) >= 0:
perm = torch.rand(size=[self.num_local]).cuda()
perm[positive] = 2.0
index = torch.topk(perm, k=self.num_sample)[1].cuda()
index = index.sort()[0].cuda()
else:
index = positive
self.weight_index = index
def forward(self, total_features, norm_weight):
""" Partial fc forward, `logits = X * sample(W)`
"""
torch.cuda.current_stream().wait_stream(self.stream)
logits = linear(total_features, norm_weight)
return logits
labels[index_positive] = torch.searchsorted(index, labels[index_positive])
self.weight_activated = torch.nn.Parameter(self.weight[self.weight_index])
self.weight_activated_mom = self.weight_mom[self.weight_index]
if isinstance(optimizer, torch.optim.SGD):
# TODO the params of partial fc must be last in the params list
optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None)
optimizer.param_groups[-1]["params"][0] = self.weight_activated
optimizer.state[self.weight_activated][
"momentum_buffer"
] = self.weight_activated_mom
else:
raise
@torch.no_grad()
def update(self):
""" Set updated weight and weight_mom to memory bank.
""" partial weight to global
"""
self.weight_mom[self.index] = self.sub_weight_mom
self.weight[self.index] = self.sub_weight
if self.init_weight_update:
self.init_weight_update = False
return
def prepare(self, label, optimizer):
if self.sample_rate < 1:
self.weight[self.weight_index] = self.weight_activated
self.weight_mom[self.weight_index] = self.weight_activated_mom
def forward(
self,
local_embeddings: torch.Tensor,
local_labels: torch.Tensor,
optimizer: torch.optim.Optimizer,
):
"""
get sampled class centers for cal softmax.
label: tensor
Label tensor on each rank.
optimizer: opt
Optimizer for partial fc, which need to get weight mom.
"""
with torch.cuda.stream(self.stream):
total_label = torch.zeros(
size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long)
dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label)
self.sample(total_label)
optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None)
optimizer.param_groups[-1]['params'][0] = self.sub_weight
optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom
norm_weight = normalize(self.sub_weight)
return total_label, norm_weight
def forward_backward(self, label, features, optimizer):
"""
Partial fc forward and backward with model parallel
label: tensor
Label tensor on each rank(GPU)
features: tensor
Features tensor on each rank(GPU)
optimizer: optimizer
Optimizer for partial fc
Parameters:
----------
local_embeddings: torch.Tensor
feature embeddings on each GPU(Rank).
local_labels: torch.Tensor
labels on each GPU(Rank).
Returns:
--------
x_grad: tensor
The gradient of features.
loss_v: tensor
Loss value for cross entropy.
-------
loss: torch.Tensor
pass
"""
total_label, norm_weight = self.prepare(label, optimizer)
total_features = torch.zeros(
size=[self.batch_size * self.world_size, self.embedding_size], device=self.device)
dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data)
total_features.requires_grad = True
local_labels.squeeze_()
local_labels = local_labels.long()
self.update()
logits = self.forward(total_features, norm_weight)
logits = self.margin_softmax(logits, total_label)
batch_size = local_embeddings.size(0)
if self.last_batch_size == 0:
self.last_batch_size = batch_size
assert self.last_batch_size == batch_size, (
"last batch size do not equal current batch size: {} vs {}".format(
self.last_batch_size, batch_size))
with torch.no_grad():
max_fc = torch.max(logits, dim=1, keepdim=True)[0]
dist.all_reduce(max_fc, dist.ReduceOp.MAX)
_gather_embeddings = [
torch.zeros((batch_size, self.embedding_size)).cuda()
for _ in range(self.world_size)
]
_gather_labels = [
torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)
]
_list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
distributed.all_gather(_gather_labels, local_labels)
# calculate exp(logits) and all-reduce
logits_exp = torch.exp(logits - max_fc)
logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)
embeddings = torch.cat(_list_embeddings)
labels = torch.cat(_gather_labels)
# calculate prob
logits_exp.div_(logits_sum_exp)
labels = labels.view(-1, 1)
index_positive = (self.class_start <= labels) & (
labels < self.class_start + self.num_local
)
labels[~index_positive] = -1
labels[index_positive] -= self.class_start
# get one-hot
grad = logits_exp
index = torch.where(total_label != -1)[0]
one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device)
one_hot.scatter_(1, total_label[index, None], 1)
if self.sample_rate < 1:
self.sample(labels, index_positive, optimizer)
# calculate loss
loss = torch.zeros(grad.size()[0], 1, device=grad.device)
loss[index] = grad[index].gather(1, total_label[index, None])
dist.all_reduce(loss, dist.ReduceOp.SUM)
loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)
with torch.cuda.amp.autocast(self.fp16):
norm_embeddings = normalize(embeddings)
norm_weight_activated = normalize(self.weight_activated)
logits = linear(norm_embeddings, norm_weight_activated)
if self.fp16:
logits = logits.float()
logits = logits.clamp(-1, 1)
# calculate grad
grad[index] -= one_hot
grad.div_(self.batch_size * self.world_size)
logits = self.margin_softmax(logits, labels)
loss = self.dist_cross_entropy(logits, labels)
return loss
logits.backward(grad)
if total_features.grad is not None:
total_features.grad.detach_()
x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True)
# feature gradient all-reduce
dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0)))
x_grad = x_grad * self.world_size
# backward backbone
return x_grad, loss_v
def state_dict(self, destination=None, prefix="", keep_vars=False):
if destination is None:
destination = collections.OrderedDict()
destination._metadata = collections.OrderedDict()
for name, module in self._modules.items():
if module is not None:
module.state_dict(destination, prefix + name + ".", keep_vars=keep_vars)
if self.sample_rate < 1:
destination["weight"] = self.weight.detach()
else:
destination["weight"] = self.weight_activated.data.detach()
return destination
def load_state_dict(self, state_dict, strict: bool = True):
if self.sample_rate < 1:
self.weight = state_dict["weight"].to(self.weight.device)
self.weight_mom.zero_()
self.weight_activated.data.zero_()
self.weight_activated_mom.zero_()
self.weight_index.zero_()
else:
self.weight_activated.data = state_dict["weight"].to(self.weight_activated.data.device)
class DistCrossEntropyFunc(torch.autograd.Function):
"""
CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax.
Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
"""
@staticmethod
def forward(ctx, logits: torch.Tensor, label: torch.Tensor):
""" """
batch_size = logits.size(0)
# for numerical stability
max_logits, _ = torch.max(logits, dim=1, keepdim=True)
# local to global
distributed.all_reduce(max_logits, distributed.ReduceOp.MAX)
logits.sub_(max_logits)
logits.exp_()
sum_logits_exp = torch.sum(logits, dim=1, keepdim=True)
# local to global
distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM)
logits.div_(sum_logits_exp)
index = torch.where(label != -1)[0]
# loss
loss = torch.zeros(batch_size, 1, device=logits.device)
loss[index] = logits[index].gather(1, label[index])
distributed.all_reduce(loss, distributed.ReduceOp.SUM)
ctx.save_for_backward(index, logits, label)
return loss.clamp_min_(1e-30).log_().mean() * (-1)
@staticmethod
def backward(ctx, loss_gradient):
"""
Args:
loss_grad (torch.Tensor): gradient backward by last layer
Returns:
gradients for each input in forward function
`None` gradients for one-hot label
"""
(
index,
logits,
label,
) = ctx.saved_tensors
batch_size = logits.size(0)
one_hot = torch.zeros(
size=[index.size(0), logits.size(1)], device=logits.device
)
one_hot.scatter_(1, label[index], 1)
logits[index] -= one_hot
logits.div_(batch_size)
return logits * loss_gradient.item(), None
class DistCrossEntropy(torch.nn.Module):
def __init__(self):
super(DistCrossEntropy, self).__init__()
def forward(self, logit_part, label_part):
return DistCrossEntropyFunc.apply(logit_part, label_part)
class AllGatherFunc(torch.autograd.Function):
"""AllGather op with gradient backward"""
@staticmethod
def forward(ctx, tensor, *gather_list):
gather_list = list(gather_list)
distributed.all_gather(gather_list, tensor)
return tuple(gather_list)
@staticmethod
def backward(ctx, *grads):
grad_list = list(grads)
rank = distributed.get_rank()
grad_out = grad_list[rank]
dist_ops = [
distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True)
if i == rank
else distributed.reduce(
grad_list[i], i, distributed.ReduceOp.SUM, async_op=True
)
for i in range(distributed.get_world_size())
]
for _op in dist_ops:
_op.wait()
grad_out *= len(grad_list) # cooperate with distributed loss function
return (grad_out, *[None for _ in range(len(grad_list))])
AllGather = AllGatherFunc.apply

View File

@@ -1,2 +1,9 @@
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \
--nproc_per_node=8 \
--nnodes=1 \
--node_rank=0 \
--master_addr="127.0.0.1" \
--master_port=12345 train.py $@
ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh

View File

@@ -12,7 +12,7 @@ def convert_onnx(net, path_module, output, opset=11, simplify=False):
img = torch.from_numpy(img).unsqueeze(0).float()
weight = torch.load(path_module)
net.load_state_dict(weight)
net.load_state_dict(weight, strict=True)
net.eval()
torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset)
model = onnx.load(output)
@@ -38,22 +38,16 @@ if __name__ == '__main__':
args = parser.parse_args()
input_file = args.input
if os.path.isdir(input_file):
input_file = os.path.join(input_file, "backbone.pth")
input_file = os.path.join(input_file, "model.pt")
assert os.path.exists(input_file)
model_name = os.path.basename(os.path.dirname(input_file)).lower()
params = model_name.split("_")
if len(params) >= 3 and params[1] in ('arcface', 'cosface'):
if args.network is None:
args.network = params[2]
# model_name = os.path.basename(os.path.dirname(input_file)).lower()
# params = model_name.split("_")
# if len(params) >= 3 and params[1] in ('arcface', 'cosface'):
# if args.network is None:
# args.network = params[2]
assert args.network is not None
print(args)
backbone_onnx = get_model(args.network, dropout=0)
output_path = args.output
if output_path is None:
output_path = os.path.join(os.path.dirname(__file__), 'onnx')
if not os.path.exists(output_path):
os.makedirs(output_path)
assert os.path.isdir(output_path)
output_file = os.path.join(output_path, "%s.onnx" % model_name)
convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify)
if args.output is None:
args.output = os.path.join(os.path.dirname(args.input), "model.onnx")
convert_onnx(backbone_onnx, input_file, args.output, simplify=args.simplify)

View File

@@ -3,139 +3,146 @@ import logging
import os
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch.utils.data.distributed
from torch.nn.utils import clip_grad_norm_
from torch import distributed
from torch.utils.tensorboard import SummaryWriter
import losses
from backbones import get_model
from dataset import MXFaceDataset, SyntheticDataset, DataLoaderX
from dataset import get_dataloader
from torch.utils.data import DataLoader
from lr_scheduler import PolyScheduler
from partial_fc import PartialFC
from utils.utils_amp import MaxClipGradScaler
from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint
from utils.utils_callbacks import CallBackLogging, CallBackVerification
from utils.utils_config import get_config
from utils.utils_logging import AverageMeter, init_logging
def main(args):
cfg = get_config(args.config)
try:
world_size = int(os.environ['WORLD_SIZE'])
rank = int(os.environ['RANK'])
dist.init_process_group('nccl')
except KeyError:
world_size = 1
rank = 0
dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:12584", rank=rank, world_size=world_size)
try:
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
distributed.init_process_group("nccl")
except KeyError:
world_size = 1
rank = 0
distributed.init_process_group(
backend="nccl",
init_method="tcp://127.0.0.1:12584",
rank=rank,
world_size=world_size,
)
def main(args):
torch.cuda.set_device(args.local_rank)
cfg = get_config(args.config)
local_rank = args.local_rank
torch.cuda.set_device(local_rank)
os.makedirs(cfg.output, exist_ok=True)
init_logging(rank, cfg.output)
if cfg.rec == "synthetic":
train_set = SyntheticDataset(local_rank=local_rank)
else:
train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True)
train_loader = DataLoaderX(
local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size,
sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True)
backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank)
if cfg.resume:
try:
backbone_pth = os.path.join(cfg.output, "backbone.pth")
backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank)))
if rank == 0:
logging.info("backbone resume successfully!")
except (FileNotFoundError, KeyError, IndexError, RuntimeError):
if rank == 0:
logging.info("resume fail, backbone init successfully!")
summary_writer = (
SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard"))
if rank == 0
else None
)
train_loader = get_dataloader(
cfg.rec, local_rank=args.local_rank, batch_size=cfg.batch_size, dali=cfg.dali)
backbone = get_model(
cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size
).cuda()
backbone = torch.nn.parallel.DistributedDataParallel(
module=backbone, broadcast_buffers=False, device_ids=[local_rank])
module=backbone, broadcast_buffers=False, device_ids=[args.local_rank])
backbone.train()
margin_softmax = losses.get_loss(cfg.loss)
module_partial_fc = PartialFC(
rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume,
batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes,
sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output)
cfg.embedding_size,
cfg.num_classes,
cfg.sample_rate,
cfg.fp16
)
module_partial_fc.train().cuda()
opt_backbone = torch.optim.SGD(
params=[{'params': backbone.parameters()}],
lr=cfg.lr / 512 * cfg.batch_size * world_size,
momentum=0.9, weight_decay=cfg.weight_decay)
opt_pfc = torch.optim.SGD(
params=[{'params': module_partial_fc.parameters()}],
lr=cfg.lr / 512 * cfg.batch_size * world_size,
momentum=0.9, weight_decay=cfg.weight_decay)
num_image = len(train_set)
# TODO the params of partial fc must be last in the params list
opt = torch.optim.SGD(
params=[
{"params": backbone.parameters(), },
{"params": module_partial_fc.parameters(), },
],
lr=cfg.lr,
momentum=0.9,
weight_decay=cfg.weight_decay
)
total_batch_size = cfg.batch_size * world_size
cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch
cfg.total_step = num_image // total_batch_size * cfg.num_epoch
def lr_step_func(current_step):
cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch]
if current_step < cfg.warmup_step:
return current_step / cfg.warmup_step
else:
return 0.1 ** len([m for m in cfg.decay_step if m <= current_step])
scheduler_backbone = torch.optim.lr_scheduler.LambdaLR(
optimizer=opt_backbone, lr_lambda=lr_step_func)
scheduler_pfc = torch.optim.lr_scheduler.LambdaLR(
optimizer=opt_pfc, lr_lambda=lr_step_func)
cfg.warmup_step = cfg.num_image // total_batch_size * cfg.warmup_epoch
cfg.total_step = cfg.num_image // total_batch_size * cfg.num_epoch
lr_scheduler = PolyScheduler(
optimizer=opt,
base_lr=cfg.lr,
max_steps=cfg.total_step,
warmup_steps=cfg.warmup_step
)
for key, value in cfg.items():
num_space = 25 - len(key)
logging.info(": " + key + " " * num_space + str(value))
val_target = cfg.val_targets
callback_verification = CallBackVerification(2000, rank, val_target, cfg.rec)
callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None)
callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output)
callback_verification = CallBackVerification(
val_targets=cfg.val_targets, rec_prefix=cfg.rec, summary_writer=summary_writer
)
callback_logging = CallBackLogging(
frequent=cfg.frequent,
total_step=cfg.total_step,
batch_size=cfg.batch_size,
writer=summary_writer
)
loss = AverageMeter()
loss_am = AverageMeter()
start_epoch = 0
global_step = 0
grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None
for epoch in range(start_epoch, cfg.num_epoch):
train_sampler.set_epoch(epoch)
for step, (img, label) in enumerate(train_loader):
global_step += 1
features = F.normalize(backbone(img))
x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc)
if cfg.fp16:
features.backward(grad_amp.scale(x_grad))
grad_amp.unscale_(opt_backbone)
clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
grad_amp.step(opt_backbone)
grad_amp.update()
else:
features.backward(x_grad)
clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
opt_backbone.step()
amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100)
opt_pfc.step()
module_partial_fc.update()
opt_backbone.zero_grad()
opt_pfc.zero_grad()
loss.update(loss_v, 1)
callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp)
callback_verification(global_step, backbone)
scheduler_backbone.step()
scheduler_pfc.step()
callback_checkpoint(global_step, backbone, module_partial_fc)
dist.destroy_process_group()
for epoch in range(start_epoch, cfg.num_epoch):
if isinstance(train_loader, DataLoader):
train_loader.sampler.set_epoch(epoch)
for _, (img, local_labels) in enumerate(train_loader):
global_step += 1
local_embeddings = backbone(img)
loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels, opt)
if cfg.fp16:
amp.scale(loss).backward()
amp.step(opt)
amp.update()
else:
loss.backward()
opt.step()
opt.zero_grad()
lr_scheduler.step()
with torch.no_grad():
loss_am.update(loss.item(), 1)
callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp)
if global_step % cfg.verbose == 0 and global_step > 200:
callback_verification(global_step, backbone)
path_pfc = os.path.join(cfg.output, "softmax_fc_gpu_{}.pt".format(rank))
torch.save(module_partial_fc.state_dict(), path_pfc)
if rank == 0:
path_module = os.path.join(cfg.output, "model.pt")
torch.save(backbone.module.state_dict(), path_module)
if cfg.dali:
train_loader.reset()
if rank == 0:
path_module = os.path.join(cfg.output, "model.pt")
torch.save(backbone.module.state_dict(), path_module)
distributed.destroy_process_group()
if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')
parser.add_argument('config', type=str, help='py config file')
parser.add_argument('--local_rank', type=int, default=0, help='local_rank')
parser = argparse.ArgumentParser(description="Distributed Arcface Training in Pytorch")
parser.add_argument("config", type=str, help="py config file")
parser.add_argument("--local_rank", type=int, default=0, help="local_rank")
main(parser.parse_args())

View File

@@ -1,7 +1,5 @@
# coding: utf-8
import os
from pathlib import Path
import sys
import matplotlib.pyplot as plt
import numpy as np
@@ -10,10 +8,11 @@ from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
from prettytable import PrettyTable
from sklearn.metrics import roc_curve, auc
image_path = "/data/anxiang/IJB_release/IJBC"
files = [
"./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy"
]
with open(sys.argv[1], "r") as f:
files = f.readlines()
files = [x.strip() for x in files]
image_path = "/train_tmp/IJB_release/IJBC"
def read_template_pair_list(path):
@@ -31,7 +30,7 @@ p1, p2, label = read_template_pair_list(
methods = []
scores = []
for file in files:
methods.append(file.split('/')[-2])
methods.append(file)
scores.append(np.load(file))
methods = np.array(methods)
@@ -53,7 +52,7 @@ for method in methods:
label=('[%s (AUC = %0.4f %%)]' %
(method.split('-')[-1], roc_auc * 100)))
tpr_fpr_row = []
tpr_fpr_row.append("%s-%s" % (method, "IJBC"))
tpr_fpr_row.append(method)
for fpr_iter in np.arange(len(x_labels)):
_, min_index = min(
list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))

View File

@@ -1,88 +0,0 @@
from typing import Dict, List
import torch
if torch.__version__ < '1.9':
Iterable = torch._six.container_abcs.Iterable
else:
import collections
Iterable = collections.abc.Iterable
from torch.cuda.amp import GradScaler
class _MultiDeviceReplicator(object):
"""
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
"""
def __init__(self, master_tensor: torch.Tensor) -> None:
assert master_tensor.is_cuda
self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
def get(self, device) -> torch.Tensor:
retval = self._per_device_tensors.get(device, None)
if retval is None:
retval = self.master.to(device=device, non_blocking=True, copy=True)
self._per_device_tensors[device] = retval
return retval
class MaxClipGradScaler(GradScaler):
def __init__(self, init_scale, max_scale: float, growth_interval=100):
GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval)
self.max_scale = max_scale
def scale_clip(self):
if self.get_scale() == self.max_scale:
self.set_growth_factor(1)
elif self.get_scale() < self.max_scale:
self.set_growth_factor(2)
elif self.get_scale() > self.max_scale:
self._scale.fill_(self.max_scale)
self.set_growth_factor(1)
def scale(self, outputs):
"""
Multiplies ('scales') a tensor or list of tensors by the scale factor.
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
unmodified.
Arguments:
outputs (Tensor or iterable of Tensors): Outputs to scale.
"""
if not self._enabled:
return outputs
self.scale_clip()
# Short-circuit for the common case.
if isinstance(outputs, torch.Tensor):
assert outputs.is_cuda
if self._scale is None:
self._lazy_init_scale_growth_tracker(outputs.device)
assert self._scale is not None
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
# Invoke the more complex machinery only if we're treating multiple outputs.
stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale
def apply_scale(val):
if isinstance(val, torch.Tensor):
assert val.is_cuda
if len(stash) == 0:
if self._scale is None:
self._lazy_init_scale_growth_tracker(val.device)
assert self._scale is not None
stash.append(_MultiDeviceReplicator(self._scale))
return val * stash[0].get(val.device)
elif isinstance(val, Iterable):
iterable = map(apply_scale, val)
if isinstance(val, list) or isinstance(val, tuple):
return type(val)(iterable)
else:
return iterable
else:
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)

View File

@@ -7,12 +7,14 @@ import torch
from eval import verification
from utils.utils_logging import AverageMeter
from torch.utils.tensorboard import SummaryWriter
from torch import distributed
class CallBackVerification(object):
def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)):
self.frequent: int = frequent
self.rank: int = rank
def __init__(self, val_targets, rec_prefix, summary_writer=None, image_size=(112, 112)):
self.rank: int = distributed.get_rank()
self.highest_acc: float = 0.0
self.highest_acc_list: List[float] = [0.0] * len(val_targets)
self.ver_list: List[object] = []
@@ -20,6 +22,8 @@ class CallBackVerification(object):
if self.rank is 0:
self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size)
self.summary_writer = summary_writer
def ver_test(self, backbone: torch.nn.Module, global_step: int):
results = []
for i in range(len(self.ver_list)):
@@ -27,6 +31,10 @@ class CallBackVerification(object):
self.ver_list[i], backbone, 10, 10)
logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm))
logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2))
self.summary_writer: SummaryWriter
self.summary_writer.add_scalar(tag=self.ver_name_list[i], scalar_value=acc2, global_step=global_step, )
if acc2 > self.highest_acc_list[i]:
self.highest_acc_list[i] = acc2
logging.info(
@@ -42,20 +50,20 @@ class CallBackVerification(object):
self.ver_name_list.append(name)
def __call__(self, num_update, backbone: torch.nn.Module):
if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0:
if self.rank is 0 and num_update > 0:
backbone.eval()
self.ver_test(backbone, num_update)
backbone.train()
class CallBackLogging(object):
def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None):
def __init__(self, frequent, total_step, batch_size, writer=None):
self.frequent: int = frequent
self.rank: int = rank
self.rank: int = distributed.get_rank()
self.world_size: int = distributed.get_world_size()
self.time_start = time.time()
self.total_step: int = total_step
self.batch_size: int = batch_size
self.world_size: int = world_size
self.writer = writer
self.init = False
@@ -100,18 +108,3 @@ class CallBackLogging(object):
else:
self.init = True
self.tic = time.time()
class CallBackModelCheckpoint(object):
def __init__(self, rank, output="./"):
self.rank: int = rank
self.output: str = output
def __call__(self, global_step, backbone, partial_fc, ):
if global_step > 100 and self.rank == 0:
path_module = os.path.join(self.output, "backbone.pth")
torch.save(backbone.module.state_dict(), path_module)
logging.info("Pytorch Model Saved in '{}'".format(path_module))
if global_step > 100 and partial_fc is not None:
partial_fc.save_params()