mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
updated for WebFace42M
updated readability of the code
This commit is contained in:
5
recognition/arcface_torch/.gitignore
vendored
Normal file
5
recognition/arcface_torch/.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
**__pycache__/
|
||||
.vscode
|
||||
bak*/
|
||||
work_dirs/
|
||||
models/
|
||||
@@ -5,50 +5,50 @@ identity on a single server.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
|
||||
- Install [PyTorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
|
||||
- (Optional) Install [DALI](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/), our doc for [install_dali.md](docs/install_dali.md).
|
||||
- `pip install -r requirements.txt`.
|
||||
- Download the dataset
|
||||
from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
|
||||
.
|
||||
|
||||
|
||||
## How to Training
|
||||
|
||||
To train a model, run `train.py` with the path to the configs:
|
||||
To train a model, run `train.py` with the path to the configs.
|
||||
The example commands below show how to run
|
||||
distributed training.
|
||||
|
||||
### 1. Single node, 8 GPUs:
|
||||
### 1. To run on a machine with 8 GPUs:
|
||||
|
||||
```shell
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=12581 train.py configs/ms1mv3_r50_lr02
|
||||
```
|
||||
|
||||
### 2. Multiple nodes, each node 8 GPUs:
|
||||
### 2. To run on 2 machines with 8 GPUs each:
|
||||
|
||||
Node 0:
|
||||
|
||||
```shell
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=12581 train.py configs/webface42m_r100_lr01_pfc02_bs4k_16gpus
|
||||
```
|
||||
|
||||
Node 1:
|
||||
|
||||
```shell
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=12581 train.py configs/webface42m_r100_lr01_pfc02_bs4k_16gpus
|
||||
```
|
||||
|
||||
### 3.Training resnet2060 with 8 GPUs:
|
||||
## Download Datasets or Prepare Datasets
|
||||
|
||||
```shell
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
|
||||
```
|
||||
- [MS1MV3](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-retinaface) (93k IDs, 5.2M images)
|
||||
- [Glint360K](https://github.com/deepinsight/insightface/tree/master/recognition/partial_fc#4-download) (360k IDs, 17.1M images)
|
||||
- [WebFace42M](docs/prepare_webface42m.md) (2M IDs, 42.5M images)
|
||||
|
||||
## Model Zoo
|
||||
|
||||
- The models are available for non-commercial research purposes only.
|
||||
- All models can be found in here.
|
||||
- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
|
||||
- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
|
||||
- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
|
||||
- [OneDrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
|
||||
|
||||
### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
|
||||
### Performance on IJB-C and [**ICCV2021-MFR**](https://github.com/deepinsight/insightface/blob/master/challenges/mfr/README.md)
|
||||
|
||||
ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
|
||||
recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
|
||||
@@ -57,44 +57,24 @@ As the result, we can evaluate the FAIR performance for different algorithms.
|
||||
For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
|
||||
globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
|
||||
|
||||
For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
|
||||
Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
|
||||
There are totally 13,928 positive pairs and 96,983,824 negative pairs.
|
||||
|
||||
| Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
|
||||
| :---: | :--- | :--- | :--- |:--- |:--- |
|
||||
| MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
|
||||
| Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
|
||||
| MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
|
||||
| Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
|
||||
| MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
|
||||
| Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
|
||||
| MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
|
||||
| Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
|
||||
| MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
|
||||
| Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
|
||||
|
||||
### Performance on IJB-C and Verification Datasets
|
||||
|
||||
| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
|
||||
| :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
|
||||
| MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
|
||||
| MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
|
||||
| MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
|
||||
| MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
|
||||
| MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
|
||||
| Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
|
||||
| Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
|
||||
| Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
|
||||
| Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
|
||||
|
||||
[comment]: <> (More details see [model.md](docs/modelzoo.md) in docs.)
|
||||
|
||||
|
||||
## [Speed Benchmark](docs/speed_benchmark.md)
|
||||
| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Training Throughout | log |
|
||||
|:-------------------------|:-----------|:-------------|:------------|:------------|:--------------------|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| MS1MV3 | mobileface | - | - | - | ~13000 | log\|config |
|
||||
| Glint360K | mobileface | - | - | - | - | log\|config |
|
||||
| WebFace42M-PartialFC-0.2 | mobileface | 73.80 | 95.40 | 92.64 | (16GPUs)~18583 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_mobilefacenet_pfc02_bs8k_16gpus/training.log)\|[config](configs/webface42m_mobilefacenet_pfc02_bs8k_16gpus.py) |
|
||||
| MS1MV3 | r100 | 85.39 | 97.00 | 95.36 | ~3400 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r100_lr01/training.log)\|[config](configs/ms1mv3_r100_lr02.py) |
|
||||
| Glint360K | r100 | - | - | - | ~3400 | log\|[config](configs/glint360k_100_lr02.py) |
|
||||
| WebFace42M-PartialFC-0.2 | r50 | 93.83 | 97.53 | 96.16 | ~5900 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_bs4k_pfc02/training.log)\|[config](configs/webface42m_r50_lr01_pfc02_bs4k_8gpus.py) |
|
||||
| WebFace42M-PartialFC-0.2 | r50 | 94.04 | 97.48 | 95.94 | (32GPUs)~17000 | log\|[config](configs/webface42m_r50_lr01_pfc02_bs4k_32gpus.py) |
|
||||
| WebFace42M-PartialFC-0.2 | r100 | 96.69 | 97.85 | 96.63 | (16GPUs)~5200 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r100_bs4k_pfc02/training.log)\|[config](configs/webface42m_r100_lr01_pfc02_bs4k_16gpus.py) |
|
||||
|
||||
**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
|
||||
classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
|
||||
|
||||
## Speed Benchmark
|
||||
|
||||
`arcface_torch` can train large-scale face recognition training set efficiently and quickly. When the number of
|
||||
classes in training sets is greater than 1 Million, partial fc sampling strategy will get same
|
||||
accuracy with several times faster training performance and smaller GPU memory.
|
||||
Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
|
||||
sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
|
||||
@@ -112,39 +92,27 @@ More details see
|
||||
`-` means training failed because of gpu memory limitations.
|
||||
|
||||
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
|
||||
| :--- | :--- | :--- | :--- |
|
||||
|125000 | 4681 | 4824 | 5004 |
|
||||
|1400000 | **1672** | 3043 | 4738 |
|
||||
|5500000 | **-** | **1389** | 3975 |
|
||||
|8000000 | **-** | **-** | 3565 |
|
||||
|16000000 | **-** | **-** | 2679 |
|
||||
|29000000 | **-** | **-** | **1855** |
|
||||
|:--------------------------------|:--------------|:---------------|:---------------|
|
||||
| 125000 | 4681 | 4824 | 5004 |
|
||||
| 1400000 | **1672** | 3043 | 4738 |
|
||||
| 5500000 | **-** | **1389** | 3975 |
|
||||
| 8000000 | **-** | **-** | 3565 |
|
||||
| 16000000 | **-** | **-** | 2679 |
|
||||
| 29000000 | **-** | **-** | **1855** |
|
||||
|
||||
### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
|
||||
|
||||
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
|
||||
| :--- | :--- | :--- | :--- |
|
||||
|125000 | 7358 | 5306 | 4868 |
|
||||
|1400000 | 32252 | 11178 | 6056 |
|
||||
|5500000 | **-** | 32188 | 9854 |
|
||||
|8000000 | **-** | **-** | 12310 |
|
||||
|16000000 | **-** | **-** | 19950 |
|
||||
|29000000 | **-** | **-** | 32324 |
|
||||
|:--------------------------------|:--------------|:---------------|:---------------|
|
||||
| 125000 | 7358 | 5306 | 4868 |
|
||||
| 1400000 | 32252 | 11178 | 6056 |
|
||||
| 5500000 | **-** | 32188 | 9854 |
|
||||
| 8000000 | **-** | **-** | 12310 |
|
||||
| 16000000 | **-** | **-** | 19950 |
|
||||
| 29000000 | **-** | **-** | 32324 |
|
||||
|
||||
## Evaluation ICCV2021-MFR and IJB-C
|
||||
|
||||
More details see [eval.md](docs/eval.md) in docs.
|
||||
|
||||
## Test
|
||||
|
||||
We tested many versions of PyTorch. Please create an issue if you are having trouble.
|
||||
|
||||
- [x] torch 1.6.0
|
||||
- [x] torch 1.7.1
|
||||
- [x] torch 1.8.0
|
||||
- [x] torch 1.9.0
|
||||
|
||||
## Citation
|
||||
## Citations
|
||||
|
||||
```
|
||||
@inproceedings{deng2019arcface,
|
||||
|
||||
@@ -3,21 +3,20 @@ from easydict import EasyDict as edict
|
||||
# configs for test speed
|
||||
|
||||
config = edict()
|
||||
config.loss = "arcface"
|
||||
config.loss = "cosface"
|
||||
config.network = "r50"
|
||||
config.resume = False
|
||||
config.output = None
|
||||
config.embedding_size = 512
|
||||
config.sample_rate = 1.0
|
||||
config.sample_rate = 0.99
|
||||
config.fp16 = True
|
||||
config.momentum = 0.9
|
||||
config.weight_decay = 5e-4
|
||||
config.batch_size = 128
|
||||
config.batch_size = 64 # total_batch_size = batch_size * num_gpus
|
||||
config.lr = 0.1 # batch size is 512
|
||||
|
||||
config.rec = "synthetic"
|
||||
config.num_classes = 300 * 10000
|
||||
config.num_epoch = 30
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [10, 16, 22]
|
||||
config.val_targets = []
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
# configs for test speed
|
||||
|
||||
config = edict()
|
||||
config.loss = "arcface"
|
||||
config.network = "r50"
|
||||
config.resume = False
|
||||
config.output = None
|
||||
config.embedding_size = 512
|
||||
config.sample_rate = 0.1
|
||||
config.fp16 = True
|
||||
config.momentum = 0.9
|
||||
config.weight_decay = 5e-4
|
||||
config.batch_size = 128
|
||||
config.lr = 0.1 # batch size is 512
|
||||
|
||||
config.rec = "synthetic"
|
||||
config.num_classes = 300 * 10000
|
||||
config.num_epoch = 30
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [10, 16, 22]
|
||||
config.val_targets = []
|
||||
@@ -10,7 +10,6 @@ config.network = "r50"
|
||||
config.resume = False
|
||||
config.output = "ms1mv3_arcface_r50"
|
||||
|
||||
config.dataset = "ms1m-retinaface-t1"
|
||||
config.embedding_size = 512
|
||||
config.sample_rate = 1
|
||||
config.fp16 = False
|
||||
@@ -18,39 +17,31 @@ config.momentum = 0.9
|
||||
config.weight_decay = 5e-4
|
||||
config.batch_size = 128
|
||||
config.lr = 0.1 # batch size is 512
|
||||
config.dali = False
|
||||
config.verbose = 2000
|
||||
config.frequent = 10
|
||||
config.score = None
|
||||
|
||||
if config.dataset == "emore":
|
||||
config.rec = "/train_tmp/faces_emore"
|
||||
config.num_classes = 85742
|
||||
config.num_image = 5822653
|
||||
config.num_epoch = 16
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [8, 14, ]
|
||||
config.val_targets = ["lfw", ]
|
||||
# if config.dataset == "emore":
|
||||
# config.rec = "/train_tmp/faces_emore"
|
||||
# config.num_classes = 85742
|
||||
# config.num_image = 5822653
|
||||
# config.num_epoch = 16
|
||||
# config.warmup_epoch = -1
|
||||
# config.val_targets = ["lfw", ]
|
||||
|
||||
elif config.dataset == "ms1m-retinaface-t1":
|
||||
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
||||
config.num_classes = 93431
|
||||
config.num_image = 5179510
|
||||
config.num_epoch = 25
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [11, 17, 22]
|
||||
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
||||
# elif config.dataset == "ms1m-retinaface-t1":
|
||||
# config.rec = "/train_tmp/ms1m-retinaface-t1"
|
||||
# config.num_classes = 93431
|
||||
# config.num_image = 5179510
|
||||
# config.num_epoch = 25
|
||||
# config.warmup_epoch = -1
|
||||
# config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
||||
|
||||
elif config.dataset == "glint360k":
|
||||
config.rec = "/train_tmp/glint360k"
|
||||
config.num_classes = 360232
|
||||
config.num_image = 17091657
|
||||
config.num_epoch = 20
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [8, 12, 15, 18]
|
||||
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
||||
|
||||
elif config.dataset == "webface":
|
||||
config.rec = "/train_tmp/faces_webface_112x112"
|
||||
config.num_classes = 10572
|
||||
config.num_image = "forget"
|
||||
config.num_epoch = 34
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [20, 28, 32]
|
||||
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
||||
# elif config.dataset == "glint360k":
|
||||
# config.rec = "/train_tmp/glint360k"
|
||||
# config.num_classes = 360232
|
||||
# config.num_image = 17091657
|
||||
# config.num_epoch = 20
|
||||
# config.warmup_epoch = -1
|
||||
# config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
||||
|
||||
@@ -5,7 +5,7 @@ from easydict import EasyDict as edict
|
||||
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
||||
|
||||
config = edict()
|
||||
config.loss = "cosface"
|
||||
config.loss = "arcface"
|
||||
config.network = "r100"
|
||||
config.resume = False
|
||||
config.output = None
|
||||
@@ -15,12 +15,13 @@ config.fp16 = True
|
||||
config.momentum = 0.9
|
||||
config.weight_decay = 5e-4
|
||||
config.batch_size = 128
|
||||
config.lr = 0.1 # batch size is 512
|
||||
config.lr = 0.2
|
||||
config.verbose = 2000
|
||||
config.dali = False
|
||||
|
||||
config.rec = "/train_tmp/glint360k"
|
||||
config.num_classes = 360232
|
||||
config.num_image = 17091657
|
||||
config.num_epoch = 20
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [8, 12, 15, 18]
|
||||
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
||||
config.warmup_epoch = 0
|
||||
config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
|
||||
@@ -13,14 +13,15 @@ config.embedding_size = 512
|
||||
config.sample_rate = 1.0
|
||||
config.fp16 = True
|
||||
config.momentum = 0.9
|
||||
config.weight_decay = 2e-4
|
||||
config.batch_size = 128
|
||||
config.lr = 0.1 # batch size is 512
|
||||
config.weight_decay = 1e-4
|
||||
config.batch_size = 256
|
||||
config.lr = 0.2
|
||||
config.verbose = 5000
|
||||
config.dali = False
|
||||
|
||||
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
||||
config.num_classes = 93431
|
||||
config.num_image = 5179510
|
||||
config.num_epoch = 30
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [10, 20, 25]
|
||||
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
||||
config.num_epoch = 40
|
||||
config.warmup_epoch = 2
|
||||
config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
|
||||
@@ -6,7 +6,7 @@ from easydict import EasyDict as edict
|
||||
|
||||
config = edict()
|
||||
config.loss = "arcface"
|
||||
config.network = "r18"
|
||||
config.network = "r100"
|
||||
config.resume = False
|
||||
config.output = None
|
||||
config.embedding_size = 512
|
||||
@@ -15,12 +15,13 @@ config.fp16 = True
|
||||
config.momentum = 0.9
|
||||
config.weight_decay = 5e-4
|
||||
config.batch_size = 128
|
||||
config.lr = 0.1 # batch size is 512
|
||||
config.lr = 0.2
|
||||
config.verbose = 2000
|
||||
config.dali = False
|
||||
|
||||
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
||||
config.num_classes = 93431
|
||||
config.num_image = 5179510
|
||||
config.num_epoch = 25
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [10, 16, 22]
|
||||
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
||||
config.warmup_epoch = 2
|
||||
config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
|
||||
@@ -1,26 +0,0 @@
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
# make training faster
|
||||
# our RAM is 256G
|
||||
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
||||
|
||||
config = edict()
|
||||
config.loss = "arcface"
|
||||
config.network = "r2060"
|
||||
config.resume = False
|
||||
config.output = None
|
||||
config.embedding_size = 512
|
||||
config.sample_rate = 1.0
|
||||
config.fp16 = True
|
||||
config.momentum = 0.9
|
||||
config.weight_decay = 5e-4
|
||||
config.batch_size = 64
|
||||
config.lr = 0.1 # batch size is 512
|
||||
|
||||
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
||||
config.num_classes = 93431
|
||||
config.num_image = 5179510
|
||||
config.num_epoch = 25
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [10, 16, 22]
|
||||
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
||||
@@ -1,26 +0,0 @@
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
# make training faster
|
||||
# our RAM is 256G
|
||||
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
||||
|
||||
config = edict()
|
||||
config.loss = "arcface"
|
||||
config.network = "r34"
|
||||
config.resume = False
|
||||
config.output = None
|
||||
config.embedding_size = 512
|
||||
config.sample_rate = 1.0
|
||||
config.fp16 = True
|
||||
config.momentum = 0.9
|
||||
config.weight_decay = 5e-4
|
||||
config.batch_size = 128
|
||||
config.lr = 0.1 # batch size is 512
|
||||
|
||||
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
||||
config.num_classes = 93431
|
||||
config.num_image = 5179510
|
||||
config.num_epoch = 25
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [10, 16, 22]
|
||||
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
||||
@@ -15,12 +15,13 @@ config.fp16 = True
|
||||
config.momentum = 0.9
|
||||
config.weight_decay = 5e-4
|
||||
config.batch_size = 128
|
||||
config.lr = 0.1 # batch size is 512
|
||||
config.lr = 0.2
|
||||
config.verbose = 2000
|
||||
config.dali = False
|
||||
|
||||
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
||||
config.num_classes = 93431
|
||||
config.num_image = 5179510
|
||||
config.num_epoch = 25
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [10, 16, 22]
|
||||
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
||||
config.warmup_epoch = 2
|
||||
config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
|
||||
@@ -1,23 +0,0 @@
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
# configs for test speed
|
||||
|
||||
config = edict()
|
||||
config.loss = "arcface"
|
||||
config.network = "r50"
|
||||
config.resume = False
|
||||
config.output = None
|
||||
config.embedding_size = 512
|
||||
config.sample_rate = 1.0
|
||||
config.fp16 = True
|
||||
config.momentum = 0.9
|
||||
config.weight_decay = 5e-4
|
||||
config.batch_size = 128
|
||||
config.lr = 0.1 # batch size is 512
|
||||
|
||||
config.rec = "synthetic"
|
||||
config.num_classes = 100 * 10000
|
||||
config.num_epoch = 30
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [10, 16, 22]
|
||||
config.val_targets = []
|
||||
@@ -10,17 +10,18 @@ config.network = "mbf"
|
||||
config.resume = False
|
||||
config.output = None
|
||||
config.embedding_size = 512
|
||||
config.sample_rate = 0.1
|
||||
config.sample_rate = 0.2
|
||||
config.fp16 = True
|
||||
config.momentum = 0.9
|
||||
config.weight_decay = 2e-4
|
||||
config.batch_size = 128
|
||||
config.lr = 0.1 # batch size is 512
|
||||
config.weight_decay = 1e-4
|
||||
config.batch_size = 512
|
||||
config.lr = 0.4
|
||||
config.verbose = 10000
|
||||
config.dali = True
|
||||
|
||||
config.rec = "/train_tmp/glint360k"
|
||||
config.num_classes = 360232
|
||||
config.num_image = 17091657
|
||||
config.rec = "/train_tmp/WebFace42M"
|
||||
config.num_classes = 2059906
|
||||
config.num_image = 42474557
|
||||
config.num_epoch = 20
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [8, 12, 15, 18]
|
||||
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
||||
config.warmup_epoch = 2
|
||||
config.val_targets = []
|
||||
@@ -6,21 +6,22 @@ from easydict import EasyDict as edict
|
||||
|
||||
config = edict()
|
||||
config.loss = "cosface"
|
||||
config.network = "r34"
|
||||
config.network = "r100"
|
||||
config.resume = False
|
||||
config.output = None
|
||||
config.embedding_size = 512
|
||||
config.sample_rate = 1.0
|
||||
config.sample_rate = 0.2
|
||||
config.fp16 = True
|
||||
config.momentum = 0.9
|
||||
config.weight_decay = 5e-4
|
||||
config.batch_size = 128
|
||||
config.lr = 0.1 # batch size is 512
|
||||
config.batch_size = 256
|
||||
config.lr = 0.3
|
||||
config.verbose = 2000
|
||||
config.dali = True
|
||||
|
||||
config.rec = "/train_tmp/glint360k"
|
||||
config.num_classes = 360232
|
||||
config.num_image = 17091657
|
||||
config.rec = "/train_tmp/WebFace42M"
|
||||
config.num_classes = 2059906
|
||||
config.num_image = 42474557
|
||||
config.num_epoch = 20
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [8, 12, 15, 18]
|
||||
config.warmup_epoch = 1
|
||||
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
||||
@@ -10,17 +10,18 @@ config.network = "r50"
|
||||
config.resume = False
|
||||
config.output = None
|
||||
config.embedding_size = 512
|
||||
config.sample_rate = 1.0
|
||||
config.sample_rate = 0.2
|
||||
config.fp16 = True
|
||||
config.momentum = 0.9
|
||||
config.weight_decay = 5e-4
|
||||
config.batch_size = 128
|
||||
config.lr = 0.1 # batch size is 512
|
||||
config.lr = 0.4
|
||||
config.verbose = 10000
|
||||
config.dali = True
|
||||
|
||||
config.rec = "/train_tmp/glint360k"
|
||||
config.num_classes = 360232
|
||||
config.num_image = 17091657
|
||||
config.rec = "/train_tmp/WebFace42M"
|
||||
config.num_classes = 2059906
|
||||
config.num_image = 42474557
|
||||
config.num_epoch = 20
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [8, 12, 15, 18]
|
||||
config.warmup_epoch = 2
|
||||
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
||||
@@ -6,21 +6,22 @@ from easydict import EasyDict as edict
|
||||
|
||||
config = edict()
|
||||
config.loss = "cosface"
|
||||
config.network = "r18"
|
||||
config.network = "r50"
|
||||
config.resume = False
|
||||
config.output = None
|
||||
config.embedding_size = 512
|
||||
config.sample_rate = 1.0
|
||||
config.sample_rate = 0.2
|
||||
config.fp16 = True
|
||||
config.momentum = 0.9
|
||||
config.weight_decay = 5e-4
|
||||
config.batch_size = 128
|
||||
config.lr = 0.1 # batch size is 512
|
||||
config.batch_size = 512
|
||||
config.lr = 0.4
|
||||
config.verbose = 10000
|
||||
config.dali = True
|
||||
|
||||
config.rec = "/train_tmp/glint360k"
|
||||
config.num_classes = 360232
|
||||
config.num_image = 17091657
|
||||
config.rec = "/train_tmp/WebFace42M"
|
||||
config.num_classes = 2059906
|
||||
config.num_image = 42474557
|
||||
config.num_epoch = 20
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [8, 12, 15, 18]
|
||||
config.warmup_epoch = 2
|
||||
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
||||
@@ -2,13 +2,42 @@ import numbers
|
||||
import os
|
||||
import queue as Queue
|
||||
import threading
|
||||
from typing import Iterable
|
||||
|
||||
import mxnet as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import distributed
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
def get_dataloader(
|
||||
root_dir: str,
|
||||
local_rank: int,
|
||||
batch_size: int,
|
||||
dali = False) -> Iterable:
|
||||
if dali and root_dir != "synthetic":
|
||||
rec = os.path.join(root_dir, 'train.rec')
|
||||
idx = os.path.join(root_dir, 'train.idx')
|
||||
return dali_data_iter(
|
||||
batch_size=batch_size, rec_file=rec,
|
||||
idx_file=idx, num_threads=2, local_rank=local_rank)
|
||||
else:
|
||||
if root_dir == "synthetic":
|
||||
train_set = SyntheticDataset()
|
||||
else:
|
||||
train_set = MXFaceDataset(root_dir=root_dir, local_rank=local_rank)
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True)
|
||||
train_loader = DataLoaderX(
|
||||
local_rank=local_rank,
|
||||
dataset=train_set,
|
||||
batch_size=batch_size,
|
||||
sampler=train_sampler,
|
||||
num_workers=2,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
)
|
||||
return train_loader
|
||||
|
||||
class BackgroundGenerator(threading.Thread):
|
||||
def __init__(self, generator, local_rank, max_prefetch=6):
|
||||
@@ -108,7 +137,7 @@ class MXFaceDataset(Dataset):
|
||||
|
||||
|
||||
class SyntheticDataset(Dataset):
|
||||
def __init__(self, local_rank):
|
||||
def __init__(self):
|
||||
super(SyntheticDataset, self).__init__()
|
||||
img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
|
||||
img = np.transpose(img, (2, 0, 1))
|
||||
@@ -122,3 +151,59 @@ class SyntheticDataset(Dataset):
|
||||
|
||||
def __len__(self):
|
||||
return 1000000
|
||||
|
||||
|
||||
def dali_data_iter(
|
||||
batch_size: int, rec_file: str, idx_file: str, num_threads: int,
|
||||
initial_fill=32768, random_shuffle=True,
|
||||
prefetch_queue_depth=1, local_rank=0, name="reader",
|
||||
mean=(127.5, 127.5, 127.5),
|
||||
std=(127.5, 127.5, 127.5)):
|
||||
"""
|
||||
Parameters:
|
||||
----------
|
||||
initial_fill: int
|
||||
Size of the buffer that is used for shuffling. If random_shuffle is False, this parameter is ignored.
|
||||
|
||||
"""
|
||||
rank: int = distributed.get_rank()
|
||||
world_size: int = distributed.get_world_size()
|
||||
import nvidia.dali.fn as fn
|
||||
import nvidia.dali.types as types
|
||||
from nvidia.dali.pipeline import Pipeline
|
||||
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
|
||||
|
||||
pipe = Pipeline(
|
||||
batch_size=batch_size, num_threads=num_threads,
|
||||
device_id=local_rank, prefetch_queue_depth=prefetch_queue_depth, )
|
||||
condition_flip = fn.random.coin_flip(probability=0.5)
|
||||
with pipe:
|
||||
jpegs, labels = fn.readers.mxnet(
|
||||
path=rec_file, index_path=idx_file, initial_fill=initial_fill,
|
||||
num_shards=world_size, shard_id=rank,
|
||||
random_shuffle=random_shuffle, pad_last_batch=False, name=name)
|
||||
images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB)
|
||||
images = fn.crop_mirror_normalize(
|
||||
images, dtype=types.FLOAT, mean=mean, std=std, mirror=condition_flip)
|
||||
pipe.set_outputs(images, labels)
|
||||
pipe.build()
|
||||
return DALIWarper(DALIClassificationIterator(pipelines=[pipe], reader_name=name, ))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
class DALIWarper(object):
|
||||
def __init__(self, dali_iter):
|
||||
self.iter = dali_iter
|
||||
|
||||
def __next__(self):
|
||||
data_dict = self.iter.__next__()[0]
|
||||
tensor_data = data_dict['data'].cuda()
|
||||
tensor_label: torch.Tensor = data_dict['label'].cuda().long()
|
||||
tensor_label.squeeze_()
|
||||
return tensor_data, tensor_label
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
self.iter.reset()
|
||||
|
||||
1
recognition/arcface_torch/docs/install_dali.md
Normal file
1
recognition/arcface_torch/docs/install_dali.md
Normal file
@@ -0,0 +1 @@
|
||||
TODO
|
||||
22
recognition/arcface_torch/docs/prepare_webface42m.md
Normal file
22
recognition/arcface_torch/docs/prepare_webface42m.md
Normal file
@@ -0,0 +1,22 @@
|
||||
|
||||
|
||||
|
||||
## 1. Download Datasets and Unzip
|
||||
|
||||
Download WebFace42M from [https://www.face-benchmark.org/download.html](https://www.face-benchmark.org/download.html).
|
||||
|
||||
|
||||
## 2. Create **Pre-shuffle** Rec File for DALI
|
||||
|
||||
Note: preshuffled rec is very important to DALI, and rec without preshuffled can cause performance degradation, origin insightface style rec file
|
||||
do not support Nvidia DALI, you must follow this command [mxnet.tools.im2rec](https://github.com/apache/incubator-mxnet/blob/master/tools/im2rec.py) to generate a pre-shuffle rec file.
|
||||
|
||||
```shell
|
||||
# 1) create train.lst using follow command
|
||||
python -m mxnet.tools.im2rec --list --recursive train "Your WebFace42M Root"
|
||||
|
||||
# 2) create train.rec and train.idx using train.lst using following command
|
||||
python -m mxnet.tools.im2rec --num-thread 16 --quality 100 train "Your WebFace42M Root"
|
||||
```
|
||||
|
||||
Finally, you will get three files: `train.idx`, `train.rec`, `train.idx`. which `train.idx`, `train.rec` are using for training.
|
||||
@@ -261,6 +261,8 @@ def test(data_set, backbone, batch_size, nfolds=10):
|
||||
_xnorm_cnt += 1
|
||||
_xnorm /= _xnorm_cnt
|
||||
|
||||
embeddings = embeddings_list[0].copy()
|
||||
embeddings = sklearn.preprocessing.normalize(embeddings)
|
||||
acc1 = 0.0
|
||||
std1 = 0.0
|
||||
embeddings = embeddings_list[0] + embeddings_list[1]
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def get_loss(name):
|
||||
if name == "cosface":
|
||||
return CosFace()
|
||||
elif name == "arcface":
|
||||
return ArcFace()
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
|
||||
class CosFace(nn.Module):
|
||||
def __init__(self, s=64.0, m=0.40):
|
||||
super(CosFace, self).__init__()
|
||||
self.s = s
|
||||
self.m = m
|
||||
|
||||
def forward(self, cosine, label):
|
||||
index = torch.where(label != -1)[0]
|
||||
m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
|
||||
m_hot.scatter_(1, label[index, None], self.m)
|
||||
cosine[index] -= m_hot
|
||||
ret = cosine * self.s
|
||||
return ret
|
||||
|
||||
|
||||
class ArcFace(nn.Module):
|
||||
def __init__(self, s=64.0, m=0.5):
|
||||
super(ArcFace, self).__init__()
|
||||
self.s = s
|
||||
self.m = m
|
||||
|
||||
def forward(self, cosine: torch.Tensor, label):
|
||||
index = torch.where(label != -1)[0]
|
||||
m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
|
||||
m_hot.scatter_(1, label[index, None], self.m)
|
||||
cosine.acos_()
|
||||
cosine[index] += m_hot
|
||||
cosine.cos_().mul_(self.s)
|
||||
return cosine
|
||||
29
recognition/arcface_torch/lr_scheduler.py
Normal file
29
recognition/arcface_torch/lr_scheduler.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
class PolyScheduler(_LRScheduler):
|
||||
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1):
|
||||
self.base_lr = base_lr
|
||||
self.warmup_lr_init = 0.0001
|
||||
self.max_steps: int = max_steps
|
||||
self.warmup_steps: int = warmup_steps
|
||||
self.power = 2
|
||||
super(PolyScheduler, self).__init__(optimizer, last_epoch, False)
|
||||
|
||||
def get_warmup_lr(self):
|
||||
alpha = float(self.last_epoch) / float(self.warmup_steps)
|
||||
return [self.base_lr * alpha for _ in self.optimizer.param_groups]
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch == -1:
|
||||
return [self.warmup_lr_init for _ in self.optimizer.param_groups]
|
||||
if self.last_epoch < self.warmup_steps:
|
||||
return self.get_warmup_lr()
|
||||
else:
|
||||
alpha = pow(
|
||||
1
|
||||
- float(self.last_epoch - self.warmup_steps)
|
||||
/ float(self.max_steps - self.warmup_steps),
|
||||
self.power,
|
||||
)
|
||||
return [self.base_lr * alpha for _ in self.optimizer.param_groups]
|
||||
@@ -9,9 +9,10 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import prettytable
|
||||
import skimage.transform
|
||||
import torch
|
||||
from sklearn.metrics import roc_curve
|
||||
from sklearn.preprocessing import normalize
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from onnx_helper import ArcFaceORT
|
||||
|
||||
SRC = np.array(
|
||||
@@ -25,6 +26,7 @@ SRC = np.array(
|
||||
SRC[:, 0] += 8.0
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
class AlignedDataSet(mx.gluon.data.Dataset):
|
||||
def __init__(self, root, lines, align=True):
|
||||
self.lines = lines
|
||||
@@ -47,24 +49,23 @@ class AlignedDataSet(mx.gluon.data.Dataset):
|
||||
img_2 = np.expand_dims(np.fliplr(img), 0)
|
||||
output = np.concatenate((img_1, img_2), axis=0).astype(np.float32)
|
||||
output = np.transpose(output, (0, 3, 1, 2))
|
||||
output = mx.nd.array(output)
|
||||
return output
|
||||
return torch.from_numpy(output)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def extract(model_root, dataset):
|
||||
model = ArcFaceORT(model_path=model_root)
|
||||
model.check()
|
||||
feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim))
|
||||
|
||||
def batchify_fn(data):
|
||||
return mx.nd.concat(*data, dim=0)
|
||||
def collate_fn(data):
|
||||
return torch.cat(data, dim=0)
|
||||
|
||||
data_loader = mx.gluon.data.DataLoader(
|
||||
dataset, 128, last_batch='keep', num_workers=4,
|
||||
thread_pool=True, prefetch=16, batchify_fn=batchify_fn)
|
||||
data_loader = DataLoader(
|
||||
dataset, batch_size=128, drop_last=False, num_workers=4, collate_fn=collate_fn, )
|
||||
num_iter = 0
|
||||
for batch in data_loader:
|
||||
batch = batch.asnumpy()
|
||||
batch = batch.numpy()
|
||||
batch = (batch - model.input_mean) / model.input_std
|
||||
feat = model.session.run(model.output_names, {model.input_name: batch})[0]
|
||||
feat = np.reshape(feat, (-1, model.feat_dim * 2))
|
||||
@@ -228,10 +229,12 @@ def main(args):
|
||||
score = verification(template_norm_feats, unique_templates, p1, p2)
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
save_path = os.path.join(args.result_dir, "{}_result".format(args.target))
|
||||
result_dir = args.model_root
|
||||
|
||||
save_path = os.path.join(result_dir, "{}_result".format(args.target))
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
score_save_file = os.path.join(save_path, "{}.npy".format(args.model_root))
|
||||
score_save_file = os.path.join(save_path, "{}.npy".format(args.target))
|
||||
np.save(score_save_file, score)
|
||||
files = [score_save_file]
|
||||
methods = []
|
||||
@@ -261,7 +264,6 @@ if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='do ijb test')
|
||||
# general
|
||||
parser.add_argument('--model-root', default='', help='path to load model.')
|
||||
parser.add_argument('--image-path', default='', type=str, help='')
|
||||
parser.add_argument('--result-dir', default='.', type=str, help='')
|
||||
parser.add_argument('--image-path', default='/train_tmp/IJB_release/IJBC', type=str, help='')
|
||||
parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')
|
||||
main(parser.parse_args())
|
||||
|
||||
@@ -1,222 +1,384 @@
|
||||
import logging
|
||||
import os
|
||||
import math
|
||||
import collections
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn import Module
|
||||
from torch.nn.functional import normalize, linear
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch import distributed
|
||||
from torch.nn.functional import linear, normalize
|
||||
|
||||
|
||||
class PartialFC(Module):
|
||||
class ArcFace(torch.nn.Module):
|
||||
""" ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
|
||||
"""
|
||||
def __init__(self, s=64.0, margin=0.5):
|
||||
super(ArcFace, self).__init__()
|
||||
self.scale = s
|
||||
self.cos_m = math.cos(margin)
|
||||
self.sin_m = math.sin(margin)
|
||||
self.theta = math.cos(math.pi - margin)
|
||||
self.sinmm = math.sin(math.pi - margin) * margin
|
||||
|
||||
|
||||
def forward(self, logits: torch.Tensor, labels: torch.Tensor):
|
||||
index = torch.where(labels != -1)[0]
|
||||
target_logit = logits[index, labels[index].view(-1)]
|
||||
|
||||
sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
|
||||
cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin)
|
||||
if self.easy_margin:
|
||||
final_target_logit = torch.where(
|
||||
target_logit > 0, cos_theta_m, target_logit)
|
||||
else:
|
||||
final_target_logit = torch.where(
|
||||
target_logit > self.theta, cos_theta_m, target_logit - self.sinmm)
|
||||
|
||||
logits[index, labels[index].view(-1)] = final_target_logit
|
||||
logits = logits * self.s
|
||||
return logits
|
||||
|
||||
|
||||
class CosFace(torch.nn.Module):
|
||||
def __init__(self, s=64.0, m=0.40):
|
||||
super(CosFace, self).__init__()
|
||||
self.s = s
|
||||
self.m = m
|
||||
|
||||
def forward(self, logits: torch.Tensor, labels: torch.Tensor):
|
||||
index = torch.where(labels != -1)[0]
|
||||
target_logit = logits[index, labels[index].view(-1)]
|
||||
final_target_logit = target_logit - self.m
|
||||
logits[index, labels[index].view(-1)] = final_target_logit
|
||||
logits = logits * self.s
|
||||
return logits
|
||||
|
||||
|
||||
class PartialFC(torch.nn.Module):
|
||||
"""
|
||||
Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,
|
||||
Partial FC: Training 10 Million Identities on a Single Machine
|
||||
See the original paper:
|
||||
https://arxiv.org/abs/2010.05222
|
||||
"""
|
||||
A distributed sparsely updating variant of the FC layer, named Partial FC (PFC).
|
||||
|
||||
@torch.no_grad()
|
||||
def __init__(self, rank, local_rank, world_size, batch_size, resume,
|
||||
margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"):
|
||||
When sample rate less than 1, in each iteration, positive class centers and a random subset of
|
||||
negative class centers are selected to compute the margin-based softmax loss, all class
|
||||
centers are still maintained throughout the whole training process, but only a subset is
|
||||
selected and updated in each iteration.
|
||||
|
||||
.. note::
|
||||
When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1).
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2)
|
||||
>>> for img, labels in data_loader:
|
||||
>>> embeddings = net(img)
|
||||
>>> loss = module_pfc(embeddings, labels, optimizer)
|
||||
>>> loss.backward()
|
||||
>>> optimizer.step()
|
||||
"""
|
||||
_version = 1
|
||||
def __init__(
|
||||
self,
|
||||
embedding_size,
|
||||
num_classes,
|
||||
sample_rate = 1.0,
|
||||
fp16: bool = False,
|
||||
margin_loss = "cosface",
|
||||
):
|
||||
"""
|
||||
rank: int
|
||||
Unique process(GPU) ID from 0 to world_size - 1.
|
||||
local_rank: int
|
||||
Unique process(GPU) ID within the server from 0 to 7.
|
||||
world_size: int
|
||||
Number of GPU.
|
||||
batch_size: int
|
||||
Batch size on current rank(GPU).
|
||||
resume: bool
|
||||
Select whether to restore the weight of softmax.
|
||||
margin_softmax: callable
|
||||
A function of margin softmax, eg: cosface, arcface.
|
||||
num_classes: int
|
||||
The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size,
|
||||
required.
|
||||
sample_rate: float
|
||||
The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling
|
||||
can greatly speed up training, and reduce a lot of GPU memory, default is 1.0.
|
||||
Paramenters:
|
||||
-----------
|
||||
embedding_size: int
|
||||
The feature dimension, default is 512.
|
||||
prefix: str
|
||||
Path for save checkpoint, default is './'.
|
||||
The dimension of embedding, required
|
||||
num_classes: int
|
||||
Total number of classes, required
|
||||
sample_rate: float
|
||||
The rate of negative centers participating in the calculation, default is 1.0.
|
||||
"""
|
||||
super(PartialFC, self).__init__()
|
||||
#
|
||||
self.num_classes: int = num_classes
|
||||
self.rank: int = rank
|
||||
self.local_rank: int = local_rank
|
||||
self.device: torch.device = torch.device("cuda:{}".format(self.local_rank))
|
||||
self.world_size: int = world_size
|
||||
self.batch_size: int = batch_size
|
||||
self.margin_softmax: callable = margin_softmax
|
||||
assert (
|
||||
distributed.is_initialized()
|
||||
), "must initialize distributed before create this"
|
||||
self.rank = distributed.get_rank()
|
||||
self.world_size = distributed.get_world_size()
|
||||
|
||||
self.dist_cross_entropy = DistCrossEntropy()
|
||||
self.embedding_size = embedding_size
|
||||
self.sample_rate: float = sample_rate
|
||||
self.embedding_size: int = embedding_size
|
||||
self.prefix: str = prefix
|
||||
self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size)
|
||||
self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size)
|
||||
self.fp16 = fp16
|
||||
self.num_local: int = num_classes // self.world_size + int(
|
||||
self.rank < num_classes % self.world_size
|
||||
)
|
||||
self.class_start: int = num_classes // self.world_size * self.rank + min(
|
||||
self.rank, num_classes % self.world_size
|
||||
)
|
||||
self.num_sample: int = int(self.sample_rate * self.num_local)
|
||||
self.last_batch_size: int = 0
|
||||
self.weight: torch.Tensor
|
||||
self.weight_mom: torch.Tensor
|
||||
self.weight_activated: torch.nn.Parameter
|
||||
self.weight_activated_mom: torch.Tensor
|
||||
self.is_updated: bool = True
|
||||
self.init_weight_update: bool = True
|
||||
|
||||
self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank))
|
||||
self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank))
|
||||
|
||||
if resume:
|
||||
try:
|
||||
self.weight: torch.Tensor = torch.load(self.weight_name)
|
||||
self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)
|
||||
if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local:
|
||||
raise IndexError
|
||||
logging.info("softmax weight resume successfully!")
|
||||
logging.info("softmax weight mom resume successfully!")
|
||||
except (FileNotFoundError, KeyError, IndexError):
|
||||
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
|
||||
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
|
||||
logging.info("softmax weight init!")
|
||||
logging.info("softmax weight mom init!")
|
||||
if self.sample_rate < 1:
|
||||
self.register_buffer("weight",
|
||||
tensor=torch.normal(0, 0.01, (self.num_local, embedding_size)))
|
||||
self.register_buffer("weight_mom",
|
||||
tensor=torch.zeros_like(self.weight))
|
||||
self.register_parameter("weight_activated",
|
||||
param=torch.nn.Parameter(torch.empty(0, 0)))
|
||||
self.register_buffer("weight_activated_mom",
|
||||
tensor=torch.empty(0, 0))
|
||||
self.register_buffer("weight_index",
|
||||
tensor=torch.empty(0, 0))
|
||||
else:
|
||||
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
|
||||
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
|
||||
logging.info("softmax weight init successfully!")
|
||||
logging.info("softmax weight mom init successfully!")
|
||||
self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank)
|
||||
self.weight_activated = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))
|
||||
|
||||
self.index = None
|
||||
if int(self.sample_rate) == 1:
|
||||
self.update = lambda: 0
|
||||
self.sub_weight = Parameter(self.weight)
|
||||
self.sub_weight_mom = self.weight_mom
|
||||
# margin_loss
|
||||
if isinstance(margin_loss, str):
|
||||
self.margin_softmax: torch.nn.Module
|
||||
if margin_loss == "cosface":
|
||||
self.margin_softmax = CosFace()
|
||||
elif margin_loss == "arcface":
|
||||
self.margin_softmax = ArcFace()
|
||||
else:
|
||||
raise
|
||||
elif isinstance(margin_loss, Callable):
|
||||
self.margin_softmax = margin_loss
|
||||
else:
|
||||
self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank))
|
||||
|
||||
def save_params(self):
|
||||
""" Save softmax weight for each rank on prefix
|
||||
"""
|
||||
torch.save(self.weight.data, self.weight_name)
|
||||
torch.save(self.weight_mom, self.weight_mom_name)
|
||||
raise
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, total_label):
|
||||
def sample(self,
|
||||
labels: torch.Tensor,
|
||||
index_positive: torch.Tensor,
|
||||
optimizer: torch.optim.Optimizer):
|
||||
"""
|
||||
Sample all positive class centers in each rank, and random select neg class centers to filling a fixed
|
||||
`num_sample`.
|
||||
This functions will change the value of labels
|
||||
|
||||
total_label: tensor
|
||||
Label after all gather, which cross all GPUs.
|
||||
Parameters:
|
||||
-----------
|
||||
labels: torch.Tensor
|
||||
pass
|
||||
index_positive: torch.Tensor
|
||||
pass
|
||||
optimizer: torch.optim.Optimizer
|
||||
pass
|
||||
"""
|
||||
index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local)
|
||||
total_label[~index_positive] = -1
|
||||
total_label[index_positive] -= self.class_start
|
||||
if int(self.sample_rate) != 1:
|
||||
positive = torch.unique(total_label[index_positive], sorted=True)
|
||||
if self.num_sample - positive.size(0) >= 0:
|
||||
perm = torch.rand(size=[self.num_local], device=self.device)
|
||||
perm[positive] = 2.0
|
||||
index = torch.topk(perm, k=self.num_sample)[1]
|
||||
index = index.sort()[0]
|
||||
else:
|
||||
index = positive
|
||||
self.index = index
|
||||
total_label[index_positive] = torch.searchsorted(index, total_label[index_positive])
|
||||
self.sub_weight = Parameter(self.weight[index])
|
||||
self.sub_weight_mom = self.weight_mom[index]
|
||||
positive = torch.unique(labels[index_positive], sorted=True).cuda()
|
||||
if self.num_sample - positive.size(0) >= 0:
|
||||
perm = torch.rand(size=[self.num_local]).cuda()
|
||||
perm[positive] = 2.0
|
||||
index = torch.topk(perm, k=self.num_sample)[1].cuda()
|
||||
index = index.sort()[0].cuda()
|
||||
else:
|
||||
index = positive
|
||||
self.weight_index = index
|
||||
|
||||
def forward(self, total_features, norm_weight):
|
||||
""" Partial fc forward, `logits = X * sample(W)`
|
||||
"""
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
logits = linear(total_features, norm_weight)
|
||||
return logits
|
||||
labels[index_positive] = torch.searchsorted(index, labels[index_positive])
|
||||
|
||||
self.weight_activated = torch.nn.Parameter(self.weight[self.weight_index])
|
||||
self.weight_activated_mom = self.weight_mom[self.weight_index]
|
||||
|
||||
if isinstance(optimizer, torch.optim.SGD):
|
||||
# TODO the params of partial fc must be last in the params list
|
||||
optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None)
|
||||
optimizer.param_groups[-1]["params"][0] = self.weight_activated
|
||||
optimizer.state[self.weight_activated][
|
||||
"momentum_buffer"
|
||||
] = self.weight_activated_mom
|
||||
else:
|
||||
raise
|
||||
|
||||
@torch.no_grad()
|
||||
def update(self):
|
||||
""" Set updated weight and weight_mom to memory bank.
|
||||
""" partial weight to global
|
||||
"""
|
||||
self.weight_mom[self.index] = self.sub_weight_mom
|
||||
self.weight[self.index] = self.sub_weight
|
||||
if self.init_weight_update:
|
||||
self.init_weight_update = False
|
||||
return
|
||||
|
||||
def prepare(self, label, optimizer):
|
||||
if self.sample_rate < 1:
|
||||
self.weight[self.weight_index] = self.weight_activated
|
||||
self.weight_mom[self.weight_index] = self.weight_activated_mom
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
local_embeddings: torch.Tensor,
|
||||
local_labels: torch.Tensor,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
):
|
||||
"""
|
||||
get sampled class centers for cal softmax.
|
||||
|
||||
label: tensor
|
||||
Label tensor on each rank.
|
||||
optimizer: opt
|
||||
Optimizer for partial fc, which need to get weight mom.
|
||||
"""
|
||||
with torch.cuda.stream(self.stream):
|
||||
total_label = torch.zeros(
|
||||
size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long)
|
||||
dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label)
|
||||
self.sample(total_label)
|
||||
optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None)
|
||||
optimizer.param_groups[-1]['params'][0] = self.sub_weight
|
||||
optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom
|
||||
norm_weight = normalize(self.sub_weight)
|
||||
return total_label, norm_weight
|
||||
|
||||
def forward_backward(self, label, features, optimizer):
|
||||
"""
|
||||
Partial fc forward and backward with model parallel
|
||||
|
||||
label: tensor
|
||||
Label tensor on each rank(GPU)
|
||||
features: tensor
|
||||
Features tensor on each rank(GPU)
|
||||
optimizer: optimizer
|
||||
Optimizer for partial fc
|
||||
Parameters:
|
||||
----------
|
||||
local_embeddings: torch.Tensor
|
||||
feature embeddings on each GPU(Rank).
|
||||
local_labels: torch.Tensor
|
||||
labels on each GPU(Rank).
|
||||
|
||||
Returns:
|
||||
--------
|
||||
x_grad: tensor
|
||||
The gradient of features.
|
||||
loss_v: tensor
|
||||
Loss value for cross entropy.
|
||||
-------
|
||||
loss: torch.Tensor
|
||||
pass
|
||||
"""
|
||||
total_label, norm_weight = self.prepare(label, optimizer)
|
||||
total_features = torch.zeros(
|
||||
size=[self.batch_size * self.world_size, self.embedding_size], device=self.device)
|
||||
dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data)
|
||||
total_features.requires_grad = True
|
||||
local_labels.squeeze_()
|
||||
local_labels = local_labels.long()
|
||||
self.update()
|
||||
|
||||
logits = self.forward(total_features, norm_weight)
|
||||
logits = self.margin_softmax(logits, total_label)
|
||||
batch_size = local_embeddings.size(0)
|
||||
if self.last_batch_size == 0:
|
||||
self.last_batch_size = batch_size
|
||||
assert self.last_batch_size == batch_size, (
|
||||
"last batch size do not equal current batch size: {} vs {}".format(
|
||||
self.last_batch_size, batch_size))
|
||||
|
||||
with torch.no_grad():
|
||||
max_fc = torch.max(logits, dim=1, keepdim=True)[0]
|
||||
dist.all_reduce(max_fc, dist.ReduceOp.MAX)
|
||||
_gather_embeddings = [
|
||||
torch.zeros((batch_size, self.embedding_size)).cuda()
|
||||
for _ in range(self.world_size)
|
||||
]
|
||||
_gather_labels = [
|
||||
torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)
|
||||
]
|
||||
_list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
|
||||
distributed.all_gather(_gather_labels, local_labels)
|
||||
|
||||
# calculate exp(logits) and all-reduce
|
||||
logits_exp = torch.exp(logits - max_fc)
|
||||
logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
|
||||
dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)
|
||||
embeddings = torch.cat(_list_embeddings)
|
||||
labels = torch.cat(_gather_labels)
|
||||
|
||||
# calculate prob
|
||||
logits_exp.div_(logits_sum_exp)
|
||||
labels = labels.view(-1, 1)
|
||||
index_positive = (self.class_start <= labels) & (
|
||||
labels < self.class_start + self.num_local
|
||||
)
|
||||
labels[~index_positive] = -1
|
||||
labels[index_positive] -= self.class_start
|
||||
|
||||
# get one-hot
|
||||
grad = logits_exp
|
||||
index = torch.where(total_label != -1)[0]
|
||||
one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device)
|
||||
one_hot.scatter_(1, total_label[index, None], 1)
|
||||
if self.sample_rate < 1:
|
||||
self.sample(labels, index_positive, optimizer)
|
||||
|
||||
# calculate loss
|
||||
loss = torch.zeros(grad.size()[0], 1, device=grad.device)
|
||||
loss[index] = grad[index].gather(1, total_label[index, None])
|
||||
dist.all_reduce(loss, dist.ReduceOp.SUM)
|
||||
loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)
|
||||
with torch.cuda.amp.autocast(self.fp16):
|
||||
norm_embeddings = normalize(embeddings)
|
||||
norm_weight_activated = normalize(self.weight_activated)
|
||||
logits = linear(norm_embeddings, norm_weight_activated)
|
||||
if self.fp16:
|
||||
logits = logits.float()
|
||||
logits = logits.clamp(-1, 1)
|
||||
|
||||
# calculate grad
|
||||
grad[index] -= one_hot
|
||||
grad.div_(self.batch_size * self.world_size)
|
||||
logits = self.margin_softmax(logits, labels)
|
||||
loss = self.dist_cross_entropy(logits, labels)
|
||||
return loss
|
||||
|
||||
logits.backward(grad)
|
||||
if total_features.grad is not None:
|
||||
total_features.grad.detach_()
|
||||
x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True)
|
||||
# feature gradient all-reduce
|
||||
dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0)))
|
||||
x_grad = x_grad * self.world_size
|
||||
# backward backbone
|
||||
return x_grad, loss_v
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
if destination is None:
|
||||
destination = collections.OrderedDict()
|
||||
destination._metadata = collections.OrderedDict()
|
||||
|
||||
for name, module in self._modules.items():
|
||||
if module is not None:
|
||||
module.state_dict(destination, prefix + name + ".", keep_vars=keep_vars)
|
||||
if self.sample_rate < 1:
|
||||
destination["weight"] = self.weight.detach()
|
||||
else:
|
||||
destination["weight"] = self.weight_activated.data.detach()
|
||||
return destination
|
||||
|
||||
def load_state_dict(self, state_dict, strict: bool = True):
|
||||
if self.sample_rate < 1:
|
||||
self.weight = state_dict["weight"].to(self.weight.device)
|
||||
self.weight_mom.zero_()
|
||||
self.weight_activated.data.zero_()
|
||||
self.weight_activated_mom.zero_()
|
||||
self.weight_index.zero_()
|
||||
else:
|
||||
self.weight_activated.data = state_dict["weight"].to(self.weight_activated.data.device)
|
||||
|
||||
class DistCrossEntropyFunc(torch.autograd.Function):
|
||||
"""
|
||||
CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax.
|
||||
Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, logits: torch.Tensor, label: torch.Tensor):
|
||||
""" """
|
||||
batch_size = logits.size(0)
|
||||
# for numerical stability
|
||||
max_logits, _ = torch.max(logits, dim=1, keepdim=True)
|
||||
# local to global
|
||||
distributed.all_reduce(max_logits, distributed.ReduceOp.MAX)
|
||||
logits.sub_(max_logits)
|
||||
logits.exp_()
|
||||
sum_logits_exp = torch.sum(logits, dim=1, keepdim=True)
|
||||
# local to global
|
||||
distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM)
|
||||
logits.div_(sum_logits_exp)
|
||||
index = torch.where(label != -1)[0]
|
||||
# loss
|
||||
loss = torch.zeros(batch_size, 1, device=logits.device)
|
||||
loss[index] = logits[index].gather(1, label[index])
|
||||
distributed.all_reduce(loss, distributed.ReduceOp.SUM)
|
||||
ctx.save_for_backward(index, logits, label)
|
||||
return loss.clamp_min_(1e-30).log_().mean() * (-1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, loss_gradient):
|
||||
"""
|
||||
Args:
|
||||
loss_grad (torch.Tensor): gradient backward by last layer
|
||||
Returns:
|
||||
gradients for each input in forward function
|
||||
`None` gradients for one-hot label
|
||||
"""
|
||||
(
|
||||
index,
|
||||
logits,
|
||||
label,
|
||||
) = ctx.saved_tensors
|
||||
batch_size = logits.size(0)
|
||||
one_hot = torch.zeros(
|
||||
size=[index.size(0), logits.size(1)], device=logits.device
|
||||
)
|
||||
one_hot.scatter_(1, label[index], 1)
|
||||
logits[index] -= one_hot
|
||||
logits.div_(batch_size)
|
||||
return logits * loss_gradient.item(), None
|
||||
|
||||
|
||||
class DistCrossEntropy(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(DistCrossEntropy, self).__init__()
|
||||
|
||||
def forward(self, logit_part, label_part):
|
||||
return DistCrossEntropyFunc.apply(logit_part, label_part)
|
||||
|
||||
|
||||
class AllGatherFunc(torch.autograd.Function):
|
||||
"""AllGather op with gradient backward"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tensor, *gather_list):
|
||||
gather_list = list(gather_list)
|
||||
distributed.all_gather(gather_list, tensor)
|
||||
return tuple(gather_list)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
grad_list = list(grads)
|
||||
rank = distributed.get_rank()
|
||||
grad_out = grad_list[rank]
|
||||
|
||||
dist_ops = [
|
||||
distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True)
|
||||
if i == rank
|
||||
else distributed.reduce(
|
||||
grad_list[i], i, distributed.ReduceOp.SUM, async_op=True
|
||||
)
|
||||
for i in range(distributed.get_world_size())
|
||||
]
|
||||
for _op in dist_ops:
|
||||
_op.wait()
|
||||
|
||||
grad_out *= len(grad_list) # cooperate with distributed loss function
|
||||
return (grad_out, *[None for _ in range(len(grad_list))])
|
||||
|
||||
|
||||
AllGather = AllGatherFunc.apply
|
||||
|
||||
@@ -1,2 +1,9 @@
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \
|
||||
--nproc_per_node=8 \
|
||||
--nnodes=1 \
|
||||
--node_rank=0 \
|
||||
--master_addr="127.0.0.1" \
|
||||
--master_port=12345 train.py $@
|
||||
|
||||
ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh
|
||||
|
||||
@@ -12,7 +12,7 @@ def convert_onnx(net, path_module, output, opset=11, simplify=False):
|
||||
img = torch.from_numpy(img).unsqueeze(0).float()
|
||||
|
||||
weight = torch.load(path_module)
|
||||
net.load_state_dict(weight)
|
||||
net.load_state_dict(weight, strict=True)
|
||||
net.eval()
|
||||
torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset)
|
||||
model = onnx.load(output)
|
||||
@@ -38,22 +38,16 @@ if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
input_file = args.input
|
||||
if os.path.isdir(input_file):
|
||||
input_file = os.path.join(input_file, "backbone.pth")
|
||||
input_file = os.path.join(input_file, "model.pt")
|
||||
assert os.path.exists(input_file)
|
||||
model_name = os.path.basename(os.path.dirname(input_file)).lower()
|
||||
params = model_name.split("_")
|
||||
if len(params) >= 3 and params[1] in ('arcface', 'cosface'):
|
||||
if args.network is None:
|
||||
args.network = params[2]
|
||||
# model_name = os.path.basename(os.path.dirname(input_file)).lower()
|
||||
# params = model_name.split("_")
|
||||
# if len(params) >= 3 and params[1] in ('arcface', 'cosface'):
|
||||
# if args.network is None:
|
||||
# args.network = params[2]
|
||||
assert args.network is not None
|
||||
print(args)
|
||||
backbone_onnx = get_model(args.network, dropout=0)
|
||||
|
||||
output_path = args.output
|
||||
if output_path is None:
|
||||
output_path = os.path.join(os.path.dirname(__file__), 'onnx')
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
assert os.path.isdir(output_path)
|
||||
output_file = os.path.join(output_path, "%s.onnx" % model_name)
|
||||
convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify)
|
||||
if args.output is None:
|
||||
args.output = os.path.join(os.path.dirname(args.input), "model.onnx")
|
||||
convert_onnx(backbone_onnx, input_file, args.output, simplify=args.simplify)
|
||||
|
||||
@@ -3,139 +3,146 @@ import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data.distributed
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch import distributed
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
import losses
|
||||
from backbones import get_model
|
||||
from dataset import MXFaceDataset, SyntheticDataset, DataLoaderX
|
||||
from dataset import get_dataloader
|
||||
from torch.utils.data import DataLoader
|
||||
from lr_scheduler import PolyScheduler
|
||||
from partial_fc import PartialFC
|
||||
from utils.utils_amp import MaxClipGradScaler
|
||||
from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint
|
||||
from utils.utils_callbacks import CallBackLogging, CallBackVerification
|
||||
from utils.utils_config import get_config
|
||||
from utils.utils_logging import AverageMeter, init_logging
|
||||
|
||||
|
||||
def main(args):
|
||||
cfg = get_config(args.config)
|
||||
try:
|
||||
world_size = int(os.environ['WORLD_SIZE'])
|
||||
rank = int(os.environ['RANK'])
|
||||
dist.init_process_group('nccl')
|
||||
except KeyError:
|
||||
world_size = 1
|
||||
rank = 0
|
||||
dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:12584", rank=rank, world_size=world_size)
|
||||
try:
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
rank = int(os.environ["RANK"])
|
||||
distributed.init_process_group("nccl")
|
||||
except KeyError:
|
||||
world_size = 1
|
||||
rank = 0
|
||||
distributed.init_process_group(
|
||||
backend="nccl",
|
||||
init_method="tcp://127.0.0.1:12584",
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
cfg = get_config(args.config)
|
||||
|
||||
local_rank = args.local_rank
|
||||
torch.cuda.set_device(local_rank)
|
||||
os.makedirs(cfg.output, exist_ok=True)
|
||||
init_logging(rank, cfg.output)
|
||||
|
||||
if cfg.rec == "synthetic":
|
||||
train_set = SyntheticDataset(local_rank=local_rank)
|
||||
else:
|
||||
train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank)
|
||||
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True)
|
||||
train_loader = DataLoaderX(
|
||||
local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size,
|
||||
sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True)
|
||||
backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank)
|
||||
|
||||
if cfg.resume:
|
||||
try:
|
||||
backbone_pth = os.path.join(cfg.output, "backbone.pth")
|
||||
backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank)))
|
||||
if rank == 0:
|
||||
logging.info("backbone resume successfully!")
|
||||
except (FileNotFoundError, KeyError, IndexError, RuntimeError):
|
||||
if rank == 0:
|
||||
logging.info("resume fail, backbone init successfully!")
|
||||
summary_writer = (
|
||||
SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard"))
|
||||
if rank == 0
|
||||
else None
|
||||
)
|
||||
train_loader = get_dataloader(
|
||||
cfg.rec, local_rank=args.local_rank, batch_size=cfg.batch_size, dali=cfg.dali)
|
||||
backbone = get_model(
|
||||
cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size
|
||||
).cuda()
|
||||
|
||||
backbone = torch.nn.parallel.DistributedDataParallel(
|
||||
module=backbone, broadcast_buffers=False, device_ids=[local_rank])
|
||||
module=backbone, broadcast_buffers=False, device_ids=[args.local_rank])
|
||||
backbone.train()
|
||||
margin_softmax = losses.get_loss(cfg.loss)
|
||||
module_partial_fc = PartialFC(
|
||||
rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume,
|
||||
batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes,
|
||||
sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output)
|
||||
cfg.embedding_size,
|
||||
cfg.num_classes,
|
||||
cfg.sample_rate,
|
||||
cfg.fp16
|
||||
)
|
||||
module_partial_fc.train().cuda()
|
||||
|
||||
opt_backbone = torch.optim.SGD(
|
||||
params=[{'params': backbone.parameters()}],
|
||||
lr=cfg.lr / 512 * cfg.batch_size * world_size,
|
||||
momentum=0.9, weight_decay=cfg.weight_decay)
|
||||
opt_pfc = torch.optim.SGD(
|
||||
params=[{'params': module_partial_fc.parameters()}],
|
||||
lr=cfg.lr / 512 * cfg.batch_size * world_size,
|
||||
momentum=0.9, weight_decay=cfg.weight_decay)
|
||||
|
||||
num_image = len(train_set)
|
||||
# TODO the params of partial fc must be last in the params list
|
||||
opt = torch.optim.SGD(
|
||||
params=[
|
||||
{"params": backbone.parameters(), },
|
||||
{"params": module_partial_fc.parameters(), },
|
||||
],
|
||||
lr=cfg.lr,
|
||||
momentum=0.9,
|
||||
weight_decay=cfg.weight_decay
|
||||
)
|
||||
total_batch_size = cfg.batch_size * world_size
|
||||
cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch
|
||||
cfg.total_step = num_image // total_batch_size * cfg.num_epoch
|
||||
|
||||
def lr_step_func(current_step):
|
||||
cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch]
|
||||
if current_step < cfg.warmup_step:
|
||||
return current_step / cfg.warmup_step
|
||||
else:
|
||||
return 0.1 ** len([m for m in cfg.decay_step if m <= current_step])
|
||||
|
||||
scheduler_backbone = torch.optim.lr_scheduler.LambdaLR(
|
||||
optimizer=opt_backbone, lr_lambda=lr_step_func)
|
||||
scheduler_pfc = torch.optim.lr_scheduler.LambdaLR(
|
||||
optimizer=opt_pfc, lr_lambda=lr_step_func)
|
||||
cfg.warmup_step = cfg.num_image // total_batch_size * cfg.warmup_epoch
|
||||
cfg.total_step = cfg.num_image // total_batch_size * cfg.num_epoch
|
||||
lr_scheduler = PolyScheduler(
|
||||
optimizer=opt,
|
||||
base_lr=cfg.lr,
|
||||
max_steps=cfg.total_step,
|
||||
warmup_steps=cfg.warmup_step
|
||||
)
|
||||
|
||||
for key, value in cfg.items():
|
||||
num_space = 25 - len(key)
|
||||
logging.info(": " + key + " " * num_space + str(value))
|
||||
|
||||
val_target = cfg.val_targets
|
||||
callback_verification = CallBackVerification(2000, rank, val_target, cfg.rec)
|
||||
callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None)
|
||||
callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output)
|
||||
callback_verification = CallBackVerification(
|
||||
val_targets=cfg.val_targets, rec_prefix=cfg.rec, summary_writer=summary_writer
|
||||
)
|
||||
callback_logging = CallBackLogging(
|
||||
frequent=cfg.frequent,
|
||||
total_step=cfg.total_step,
|
||||
batch_size=cfg.batch_size,
|
||||
writer=summary_writer
|
||||
)
|
||||
|
||||
loss = AverageMeter()
|
||||
loss_am = AverageMeter()
|
||||
start_epoch = 0
|
||||
global_step = 0
|
||||
grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None
|
||||
for epoch in range(start_epoch, cfg.num_epoch):
|
||||
train_sampler.set_epoch(epoch)
|
||||
for step, (img, label) in enumerate(train_loader):
|
||||
global_step += 1
|
||||
features = F.normalize(backbone(img))
|
||||
x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc)
|
||||
if cfg.fp16:
|
||||
features.backward(grad_amp.scale(x_grad))
|
||||
grad_amp.unscale_(opt_backbone)
|
||||
clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
|
||||
grad_amp.step(opt_backbone)
|
||||
grad_amp.update()
|
||||
else:
|
||||
features.backward(x_grad)
|
||||
clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
|
||||
opt_backbone.step()
|
||||
amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100)
|
||||
|
||||
opt_pfc.step()
|
||||
module_partial_fc.update()
|
||||
opt_backbone.zero_grad()
|
||||
opt_pfc.zero_grad()
|
||||
loss.update(loss_v, 1)
|
||||
callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp)
|
||||
callback_verification(global_step, backbone)
|
||||
scheduler_backbone.step()
|
||||
scheduler_pfc.step()
|
||||
callback_checkpoint(global_step, backbone, module_partial_fc)
|
||||
dist.destroy_process_group()
|
||||
for epoch in range(start_epoch, cfg.num_epoch):
|
||||
|
||||
if isinstance(train_loader, DataLoader):
|
||||
train_loader.sampler.set_epoch(epoch)
|
||||
for _, (img, local_labels) in enumerate(train_loader):
|
||||
global_step += 1
|
||||
local_embeddings = backbone(img)
|
||||
loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels, opt)
|
||||
|
||||
if cfg.fp16:
|
||||
amp.scale(loss).backward()
|
||||
amp.step(opt)
|
||||
amp.update()
|
||||
else:
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
opt.zero_grad()
|
||||
lr_scheduler.step()
|
||||
|
||||
with torch.no_grad():
|
||||
loss_am.update(loss.item(), 1)
|
||||
callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp)
|
||||
|
||||
if global_step % cfg.verbose == 0 and global_step > 200:
|
||||
callback_verification(global_step, backbone)
|
||||
|
||||
path_pfc = os.path.join(cfg.output, "softmax_fc_gpu_{}.pt".format(rank))
|
||||
torch.save(module_partial_fc.state_dict(), path_pfc)
|
||||
if rank == 0:
|
||||
path_module = os.path.join(cfg.output, "model.pt")
|
||||
torch.save(backbone.module.state_dict(), path_module)
|
||||
|
||||
if cfg.dali:
|
||||
train_loader.reset()
|
||||
|
||||
if rank == 0:
|
||||
path_module = os.path.join(cfg.output, "model.pt")
|
||||
torch.save(backbone.module.state_dict(), path_module)
|
||||
distributed.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.backends.cudnn.benchmark = True
|
||||
parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')
|
||||
parser.add_argument('config', type=str, help='py config file')
|
||||
parser.add_argument('--local_rank', type=int, default=0, help='local_rank')
|
||||
parser = argparse.ArgumentParser(description="Distributed Arcface Training in Pytorch")
|
||||
parser.add_argument("config", type=str, help="py config file")
|
||||
parser.add_argument("--local_rank", type=int, default=0, help="local_rank")
|
||||
main(parser.parse_args())
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
# coding: utf-8
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
@@ -10,10 +8,11 @@ from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
|
||||
from prettytable import PrettyTable
|
||||
from sklearn.metrics import roc_curve, auc
|
||||
|
||||
image_path = "/data/anxiang/IJB_release/IJBC"
|
||||
files = [
|
||||
"./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy"
|
||||
]
|
||||
with open(sys.argv[1], "r") as f:
|
||||
files = f.readlines()
|
||||
|
||||
files = [x.strip() for x in files]
|
||||
image_path = "/train_tmp/IJB_release/IJBC"
|
||||
|
||||
|
||||
def read_template_pair_list(path):
|
||||
@@ -31,7 +30,7 @@ p1, p2, label = read_template_pair_list(
|
||||
methods = []
|
||||
scores = []
|
||||
for file in files:
|
||||
methods.append(file.split('/')[-2])
|
||||
methods.append(file)
|
||||
scores.append(np.load(file))
|
||||
|
||||
methods = np.array(methods)
|
||||
@@ -53,7 +52,7 @@ for method in methods:
|
||||
label=('[%s (AUC = %0.4f %%)]' %
|
||||
(method.split('-')[-1], roc_auc * 100)))
|
||||
tpr_fpr_row = []
|
||||
tpr_fpr_row.append("%s-%s" % (method, "IJBC"))
|
||||
tpr_fpr_row.append(method)
|
||||
for fpr_iter in np.arange(len(x_labels)):
|
||||
_, min_index = min(
|
||||
list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
if torch.__version__ < '1.9':
|
||||
Iterable = torch._six.container_abcs.Iterable
|
||||
else:
|
||||
import collections
|
||||
|
||||
Iterable = collections.abc.Iterable
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
|
||||
class _MultiDeviceReplicator(object):
|
||||
"""
|
||||
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
|
||||
"""
|
||||
|
||||
def __init__(self, master_tensor: torch.Tensor) -> None:
|
||||
assert master_tensor.is_cuda
|
||||
self.master = master_tensor
|
||||
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
|
||||
|
||||
def get(self, device) -> torch.Tensor:
|
||||
retval = self._per_device_tensors.get(device, None)
|
||||
if retval is None:
|
||||
retval = self.master.to(device=device, non_blocking=True, copy=True)
|
||||
self._per_device_tensors[device] = retval
|
||||
return retval
|
||||
|
||||
|
||||
class MaxClipGradScaler(GradScaler):
|
||||
def __init__(self, init_scale, max_scale: float, growth_interval=100):
|
||||
GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval)
|
||||
self.max_scale = max_scale
|
||||
|
||||
def scale_clip(self):
|
||||
if self.get_scale() == self.max_scale:
|
||||
self.set_growth_factor(1)
|
||||
elif self.get_scale() < self.max_scale:
|
||||
self.set_growth_factor(2)
|
||||
elif self.get_scale() > self.max_scale:
|
||||
self._scale.fill_(self.max_scale)
|
||||
self.set_growth_factor(1)
|
||||
|
||||
def scale(self, outputs):
|
||||
"""
|
||||
Multiplies ('scales') a tensor or list of tensors by the scale factor.
|
||||
|
||||
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
|
||||
unmodified.
|
||||
|
||||
Arguments:
|
||||
outputs (Tensor or iterable of Tensors): Outputs to scale.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return outputs
|
||||
self.scale_clip()
|
||||
# Short-circuit for the common case.
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
assert outputs.is_cuda
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker(outputs.device)
|
||||
assert self._scale is not None
|
||||
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
|
||||
|
||||
# Invoke the more complex machinery only if we're treating multiple outputs.
|
||||
stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale
|
||||
|
||||
def apply_scale(val):
|
||||
if isinstance(val, torch.Tensor):
|
||||
assert val.is_cuda
|
||||
if len(stash) == 0:
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker(val.device)
|
||||
assert self._scale is not None
|
||||
stash.append(_MultiDeviceReplicator(self._scale))
|
||||
return val * stash[0].get(val.device)
|
||||
elif isinstance(val, Iterable):
|
||||
iterable = map(apply_scale, val)
|
||||
if isinstance(val, list) or isinstance(val, tuple):
|
||||
return type(val)(iterable)
|
||||
else:
|
||||
return iterable
|
||||
else:
|
||||
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
|
||||
|
||||
return apply_scale(outputs)
|
||||
@@ -7,12 +7,14 @@ import torch
|
||||
|
||||
from eval import verification
|
||||
from utils.utils_logging import AverageMeter
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch import distributed
|
||||
|
||||
|
||||
class CallBackVerification(object):
|
||||
def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)):
|
||||
self.frequent: int = frequent
|
||||
self.rank: int = rank
|
||||
|
||||
def __init__(self, val_targets, rec_prefix, summary_writer=None, image_size=(112, 112)):
|
||||
self.rank: int = distributed.get_rank()
|
||||
self.highest_acc: float = 0.0
|
||||
self.highest_acc_list: List[float] = [0.0] * len(val_targets)
|
||||
self.ver_list: List[object] = []
|
||||
@@ -20,6 +22,8 @@ class CallBackVerification(object):
|
||||
if self.rank is 0:
|
||||
self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size)
|
||||
|
||||
self.summary_writer = summary_writer
|
||||
|
||||
def ver_test(self, backbone: torch.nn.Module, global_step: int):
|
||||
results = []
|
||||
for i in range(len(self.ver_list)):
|
||||
@@ -27,6 +31,10 @@ class CallBackVerification(object):
|
||||
self.ver_list[i], backbone, 10, 10)
|
||||
logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm))
|
||||
logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2))
|
||||
|
||||
self.summary_writer: SummaryWriter
|
||||
self.summary_writer.add_scalar(tag=self.ver_name_list[i], scalar_value=acc2, global_step=global_step, )
|
||||
|
||||
if acc2 > self.highest_acc_list[i]:
|
||||
self.highest_acc_list[i] = acc2
|
||||
logging.info(
|
||||
@@ -42,20 +50,20 @@ class CallBackVerification(object):
|
||||
self.ver_name_list.append(name)
|
||||
|
||||
def __call__(self, num_update, backbone: torch.nn.Module):
|
||||
if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0:
|
||||
if self.rank is 0 and num_update > 0:
|
||||
backbone.eval()
|
||||
self.ver_test(backbone, num_update)
|
||||
backbone.train()
|
||||
|
||||
|
||||
class CallBackLogging(object):
|
||||
def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None):
|
||||
def __init__(self, frequent, total_step, batch_size, writer=None):
|
||||
self.frequent: int = frequent
|
||||
self.rank: int = rank
|
||||
self.rank: int = distributed.get_rank()
|
||||
self.world_size: int = distributed.get_world_size()
|
||||
self.time_start = time.time()
|
||||
self.total_step: int = total_step
|
||||
self.batch_size: int = batch_size
|
||||
self.world_size: int = world_size
|
||||
self.writer = writer
|
||||
|
||||
self.init = False
|
||||
@@ -100,18 +108,3 @@ class CallBackLogging(object):
|
||||
else:
|
||||
self.init = True
|
||||
self.tic = time.time()
|
||||
|
||||
|
||||
class CallBackModelCheckpoint(object):
|
||||
def __init__(self, rank, output="./"):
|
||||
self.rank: int = rank
|
||||
self.output: str = output
|
||||
|
||||
def __call__(self, global_step, backbone, partial_fc, ):
|
||||
if global_step > 100 and self.rank == 0:
|
||||
path_module = os.path.join(self.output, "backbone.pth")
|
||||
torch.save(backbone.module.state_dict(), path_module)
|
||||
logging.info("Pytorch Model Saved in '{}'".format(path_module))
|
||||
|
||||
if global_step > 100 and partial_fc is not None:
|
||||
partial_fc.save_params()
|
||||
|
||||
Reference in New Issue
Block a user