fix some bug

This commit is contained in:
GuoxiaWang
2021-10-14 19:05:59 +08:00
parent 0c8f81369d
commit 7831bcf977
12 changed files with 77 additions and 26 deletions

View File

@@ -67,19 +67,15 @@ If you want to use customed dataset, you can arrange your data according to the
python -m paddle.distributed.launch --gpus=0 tools/train.py \
--config_file configs/ms1mv2_mobileface.py \
--is_static False \
--backbone MobileFaceNet_128 \
--classifier LargeScaleClassifier \
--embedding_size 128 \
--sample_ratio 1.0 \
--loss ArcFace \
--batch_size 1024 \
--batch_size 512 \
--dataset MS1M_v2 \
--num_classes 85742 \
--data_dir MS1M_v2/ \
--label_file MS1M_v2/label.txt \
--fp16 False \
--train_unit 'epoch' \
--output ./MS1M_v2_arcface_MobileFaceNet_128_1.0_fp32
--fp16 False
```
### 4.2 Single node, 8 GPUs:
@@ -156,7 +152,7 @@ sh scripts/inference.sh
| Model structure | lfw | cfp_fp | agedb30 | CPU time cost | GPU time cost | Inference model |
| ------------------------- | ------ | ------- | ------- | -------| -------- |---- |
| MobileFace-Paddle | 0.9945 | 0.9343 | 0.9613 | 4.3ms | 2.3ms | [download link](https://paddle-model-ecology.bj.bcebos.com/model/insight-face/mobileface_v1.0_infer.tar) |
| MobileFace-Paddle | 0.9952 | 0.9280 | 0.9612 | 4.3ms | 2.3ms | [download link](https://paddle-model-ecology.bj.bcebos.com/model/insight-face/mobileface_v1.0_infer.tar) |
| MobileFace-mxnet | 0.9950 | 0.8894 | 0.9591 | 7.3ms | 4.7ms | - |
* Note: MobileFaceNet-Paddle training using MobileFaceNet_128

View File

@@ -0,0 +1,54 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from easydict import EasyDict as edict
config = edict()
config.is_static = False
config.backbone = 'MobileFaceNet_128'
config.classifier = 'LargeScaleClassifier'
config.embedding_size = 128
config.model_parallel = True
config.sample_ratio = 1.0
config.loss = 'ArcFace'
config.dropout = 0.0
config.lr = 0.1 # for global batch size = 512
config.lr_decay = 0.1
config.weight_decay = 5e-4
config.momentum = 0.9
config.train_unit = 'epoch' # 'step' or 'epoch'
config.warmup_num = 0
config.train_num = 25
config.decay_boundaries = [10, 16, 22]
config.use_synthetic_dataset = False
config.dataset = "MS1M_v2"
config.data_dir = "./MS1M_v2"
config.label_file = "./MS1M_v2/label.txt"
config.is_bin = False
config.num_classes = 85742 # 85742 for MS1M_v2, 93431 for MS1M_v3
config.batch_size = 128 # global batch size 1024 of 8 GPU
config.num_workers = 8
config.do_validation_while_train = True
config.validation_interval_step = 2000
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
config.logdir = './log'
config.log_interval_step = 100
config.output = './MS1M_v2_arcface_MobileFaceNet_128_0.1'
config.resume = False
config.checkpoint_dir = None
config.max_num_last_checkpoint = 1

View File

@@ -110,18 +110,18 @@ class MobileFaceNet(nn.Layer):
n = m.weight.shape[1] * m.weight.shape[2] * m.weight.shape[3]
m.weight = paddle.create_parameter(
shape=m.weight.shape,
dtype='float32',
dtype=m.weight.dtype,
default_initializer=nn.initializer.Normal(
mean=0.0, std=math.sqrt(2.0 / n)))
# nn.init.normal_(m.weight, 0, 0.1)
elif isinstance(m, (nn.BatchNorm, nn.BatchNorm2D, nn.GroupNorm)):
m.weight = paddle.create_parameter(
shape=m.weight.shape,
dtype='float32',
dtype=m.weight.dtype,
default_initializer=nn.initializer.Constant(value=1.0))
m.bias = paddle.create_parameter(
shape=m.bias.shape,
dtype='float32',
dtype=m.bias.dtype,
default_initializer=nn.initializer.Constant(value=0.0))
def _make_layer(self, block, setting):

View File

@@ -18,7 +18,6 @@ import os
import paddle
import paddle.nn as nn
class LargeScaleClassifier(nn.Layer):
"""
Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,
@@ -77,10 +76,11 @@ class LargeScaleClassifier(nn.Layer):
self.weight.stop_gradient = True
def step(self, optimizer):
warnings.warn(
"Explicitly call the function paddle._C_ops.sparse_momentum is a temporary manner. "
"We will merge it to optimizer in the future, please don't follow.")
if int(self.sample_ratio) < 1:
warnings.warn(
"Explicitly call the function paddle._C_ops.sparse_momentum is a temporary manner. "
"We will merge it to optimizer in the future, please don't follow.")
found_inf = paddle.logical_not(
paddle.all(paddle.isfinite(self._parameter_list[0].grad)))
if found_inf:

View File

@@ -43,7 +43,7 @@ class LSCGradScaler(GradScaler):
# if self._scale >= self.max_loss_scaling:
# self._scale = paddle.to_tensor([self.max_loss_scaling], dtype='float32')
# unscale the grad
# unscale the grad
self._unscale(optimizer)
if self._found_inf:
@@ -92,6 +92,7 @@ class LSCGradScaler(GradScaler):
self._found_inf = paddle.logical_not(
paddle.all(paddle.isfinite(grad)))
if self._found_inf:
print('Found inf or nan in classifier, dtype is', dtype)
return
for dtype in param_grads_dict:

View File

@@ -143,7 +143,7 @@ class Checkpoint(object):
tensor = paddle.load(path, return_numpy=True)
if dtype:
assert dtype in ['float32', 'float16']
tensor = tensor.astype('float32')
tensor = tensor.astype(dtype)
if 'dist@' in name and '@rank@' in name:
if '.w' in name and 'velocity' not in name:

View File

@@ -16,9 +16,9 @@ python tools/inference.py \
--export_type paddle \
--model_file MS1M_v3_arcface_static_128_fp16_0.1/FresResNet50/exported_model/FresResNet50.pdmodel \
--params_file MS1M_v3_arcface_static_128_fp16_0.1/FresResNet50/exported_model/FresResNet50.pdiparams \
--image_path /wangguoxia/plsc/MS1M_v3/images/00000001.jpg
--image_path MS1M_v3/images/00000001.jpg
python tools/inference.py \
--export_type onnx \
--onnx_file MS1M_v3_arcface_static_128_fp16_0.1/FresResNet50/exported_model/FresResNet50.onnx \
--image_path /wangguoxia/plsc/MS1M_v3/images/00000001.jpg
--image_path MS1M_v3/images/00000001.jpg

View File

@@ -25,8 +25,8 @@ python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 tools/train.py \
--batch_size 128 \
--dataset MS1M_v3 \
--num_classes 93431 \
--data_dir /wangguoxia/plsc/MS1M_v3/ \
--label_file /wangguoxia/plsc/MS1M_v3/label.txt \
--data_dir MS1M_v3/ \
--label_file MS1M_v3/label.txt \
--is_bin False \
--log_interval_step 100 \
--validation_interval_step 2000 \

View File

@@ -25,8 +25,8 @@ python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 tools/train.py \
--batch_size 128 \
--dataset MS1M_v3 \
--num_classes 93431 \
--data_dir /wangguoxia/plsc/MS1M_v3/ \
--label_file /wangguoxia/plsc/MS1M_v3/label.txt \
--data_dir MS1M_v3/ \
--label_file MS1M_v3/label.txt \
--is_bin False \
--log_interval_step 100 \
--validation_interval_step 2000 \

View File

@@ -17,6 +17,6 @@ python tools/validation.py \
--backbone FresResNet50 \
--embedding_size 512 \
--checkpoint_dir MS1M_v3_arcface_dynamic_128_fp16_0.1/FresResNet50/24 \
--data_dir /wangguoxia/plsc/MS1M_v3/ \
--data_dir MS1M_v3/ \
--val_targets lfw,cfp_fp,agedb_30 \
--batch_size 128

View File

@@ -17,6 +17,6 @@ python tools/validation.py \
--backbone FresResNet50 \
--embedding_size 512 \
--checkpoint_dir MS1M_v3_arcface_static_128_fp16_0.1/FresResNet50/24 \
--data_dir /wangguoxia/plsc/MS1M_v3/ \
--data_dir MS1M_v3/ \
--val_targets lfw,cfp_fp,agedb_30 \
--batch_size 128

View File

@@ -120,7 +120,7 @@ class Checkpoint(object):
tensor = paddle.load(path, return_numpy=True)
if dtype:
assert dtype in ['float32', 'float16']
tensor = tensor.astype('float32')
tensor = tensor.astype(dtype)
if 'dist@' in name and '@rank@' in name:
if '.w' in name and 'velocity' not in name: