add EasyFace
129
.gitignore
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
/package
|
||||
/temp
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
|
||||
.vscode
|
||||
.idea
|
||||
|
||||
# custom
|
||||
*.pkl
|
||||
*.pkl.json
|
||||
*.log.json
|
||||
*.whl
|
||||
*.tar.gz
|
||||
*.swp
|
||||
*.log
|
||||
*.tar.gz
|
||||
source.sh
|
||||
tensorboard.sh
|
||||
.DS_Store
|
||||
replace.sh
|
||||
result.png
|
||||
result.jpg
|
||||
|
||||
# Pytorch
|
||||
*.pth
|
||||
*.pt
|
||||
203
LICENSE
Normal file
@@ -0,0 +1,203 @@
|
||||
Copyright 2022-2023 Alibaba ModelScope. All rights reserved.
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2020-2022 Alibaba ModelScope.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
296
README.md
Normal file
@@ -0,0 +1,296 @@
|
||||
<div align="center">
|
||||
<img src="demo/modelscope.gif" width="40%" height="40%" />
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||
<!-- [](https://easy-cv.readthedocs.io/en/latest/) -->
|
||||
[](https://github.com/modelscope/modelscope/blob/master/LICENSE)
|
||||
</div>
|
||||
|
||||
|
||||
<h4 align="center">
|
||||
<a href=#EasyFace> 特性 </a> |
|
||||
<a href=#安装> 安装 </a> |
|
||||
<a href=#单模型推理> 单模型推理</a> |
|
||||
<a href=#单模型训练和微调> 单模型训练/微调</a> |
|
||||
<a href=#单模型选型和对比> 单模型选型/对比</a>
|
||||
<!--- <a href=#人脸识别系统多模块一键选型/对比> 人脸识别系统多模块一键选型/对比</a> -->
|
||||
</h4>
|
||||
|
||||
## EasyFace
|
||||
|
||||
**EasyFace**旨在快速选型/了解/对比/体验人脸相关sota模型,依托于[**Modelscope**](https://modelscope.cn/home)开发库和[**Pytorch**](https://pytorch.org)框架,EasyFace具有以下特性:
|
||||
- 快速体验/对比/选型Sota的人脸相关模型, 涉及人脸检测,人脸识别,人脸关键点,人脸表情识别,人脸活体检测等领域,目前支持人脸检测相关sota模型。
|
||||
- 5行代码即可进行模型推理,10行代码进行模型训练/Finetune, 20行代码对比不同模型在自建/公开数据集上的精度以及可视化结果。
|
||||
- 基于现有模型快速搭建[**创空间**](https://modelscope.cn/studios/damo/face_album/summary)应用。
|
||||
|
||||
## News 📢
|
||||
|
||||
<!--- 🔥 **`2023-03-20`**:新增DamoFR人脸识别模型,基于Vit Backbone 围绕data-centric以及patch-level hard example mining策略重新设计了Transformer-based Small/Medium/Large 人脸识别backbone,效果sota,已release不同算力下的sota人脸识别,口罩人脸识别DamoFR模型,[**paper**]() and [**project**]();-->
|
||||
|
||||
🔥 **`2023-03-10`**:新增DamoFD(ICLR23)人脸检测关键点模型,基于SCRFD框架进一步搜索了FD-friendly backbone结构。 在0.5/2.5/10/34 GFlops VGA分辨率的算力约束条件下性能均超过SCRFD。其中提出的$轻量级的检测器DDSAR-0.5G在VGA分辨率0.5GFlops条件下WiderFace上hard集精度为71.03(超过SCRFD 2.5个点),欢迎大家一键使用(支持训练和推理),[**paper**](https://openreview.net/forum?id=NkJOhtNKX91)。
|
||||
|
||||
🔥 **`2023-03-10`**:新增4个人脸检测模型,包括DamoFD,MogFace,RetinaFace,Mtcnn。
|
||||
|
||||
## 支持模型列表
|
||||
`**对应模型的推理和训练单元测试放在face_project目录下**`
|
||||
|
||||
### 推理
|
||||
|
||||
🔥 **`人脸检测`**:DamoFD,MogFace,RetinaFace,Mtcnn。
|
||||
|
||||
### 训练
|
||||
🔥 **`人脸检测`**:DamoFD。
|
||||
|
||||
## 安装
|
||||
```
|
||||
conda create --offline -n EasyFace python=3.8
|
||||
conda activate EasyFace
|
||||
# pytorch >= 1.3.0
|
||||
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 --extra-index-url https://download.pytorch.org/whl/cu102
|
||||
git clone https://github.com/ly19965/FaceMaas
|
||||
cd FaceMaas
|
||||
pip install -r requirements.txt
|
||||
mim install mmcv-full
|
||||
```
|
||||
|
||||
## 单模型推理
|
||||
从支持推理的模型列表里选择想体验的模型, e.g.人脸检测模型DamoFD_0.5g
|
||||
|
||||
### 单张图片推理
|
||||
```python
|
||||
import cv2
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
face_detection = pipeline(task=Tasks.face_detection, model='damo/cv_ddsar_face-detection_iclr23-damofd')
|
||||
# 支持 url image and abs dir image path
|
||||
img_path = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_detection2.jpeg'
|
||||
result = face_detection(img_path)
|
||||
|
||||
# 提供可视化结果
|
||||
from modelscope.utils.cv.image_utils import draw_face_detection_result
|
||||
from modelscope.preprocessors.image import LoadImage
|
||||
img = LoadImage.convert_to_ndarray(img_path)
|
||||
cv2.imwrite('srcImg.jpg', img)
|
||||
img_draw = draw_face_detection_result('srcImg.jpg', result)
|
||||
import matplotlib.pyplot as plt
|
||||
plt.imshow(img_draw)
|
||||
```
|
||||
|
||||
### Mini公开数据集推理
|
||||
```python
|
||||
import os.path as osp
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import voc_ap, image_eval,img_pr_info, gen_gt_info, dataset_pr_info, bbox_overlap
|
||||
|
||||
model_id = 'damo/cv_ddsar_face-detection_iclr23-damofd'
|
||||
val_set = MsDataset.load('widerface_mini_train_val', namespace='ly261666', split='validation')#, download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
img_base_path = next(iter(val_set))[1]
|
||||
img_dir = osp.join(img_base_path, 'val_data')
|
||||
img_gt = osp.join(img_base_path, 'val_label.txt')
|
||||
gt_info = gen_gt_info(img_gt)
|
||||
pred_info = {}
|
||||
iou_th = 0.5
|
||||
thresh_num = 1000
|
||||
face_detection_func = pipeline(Tasks.face_detection, model=model_id)
|
||||
count_face = 0
|
||||
pr_curve = np.zeros((thresh_num, 2)).astype('float')
|
||||
for img_name in os.listdir(img_dir):
|
||||
abs_img_name = osp.join(img_dir, img_name)
|
||||
result = face_detection_func(abs_img_name)
|
||||
pred_info = np.concatenate([result['boxes'], np.array(result['scores'])[:,np.newaxis]], axis=1)
|
||||
gt_box = np.array(gt_info[img_name])
|
||||
pred_recall, proposal_list = image_eval(pred_info, gt_box, iou_th)
|
||||
_img_pr_info, fp = img_pr_info(thresh_num, pred_info, proposal_list, pred_recall)
|
||||
pr_curve += _img_pr_info
|
||||
count_face += gt_box.shape[0]
|
||||
|
||||
pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face)
|
||||
propose = pr_curve[:, 0]
|
||||
recall = pr_curve[:, 1]
|
||||
for srecall in np.arange(0.1, 1.0001, 0.1):
|
||||
rindex = len(np.where(recall<=srecall)[0])-1
|
||||
rthresh = 1.0 - float(rindex)/thresh_num
|
||||
print('Recall-Precision-Thresh:', recall[rindex], propose[rindex], rthresh)
|
||||
ap = voc_ap(recall, propose)
|
||||
print('ap: %.5f, iou_th: %.2f'%(ap, iou_th))
|
||||
```
|
||||
|
||||
## 单模型训练和微调
|
||||
从支持训练的模型列表里选择想体验的模型, e.g.人脸检测模型DamoFD_0.5g
|
||||
|
||||
### 训练
|
||||
|
||||
```python
|
||||
import os
|
||||
import tempfile
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
model_id = 'damo/cv_ddsar_face-detection_iclr23-damofd'
|
||||
ms_ds_widerface = MsDataset.load('WIDER_FACE_mini', namespace='shaoxuan') # remove '_mini' for full dataset
|
||||
|
||||
data_path = ms_ds_widerface.config_kwargs['split_config']
|
||||
train_dir = data_path['train']
|
||||
val_dir = data_path['validation']
|
||||
|
||||
def get_name(dir_name):
|
||||
names = [i for i in os.listdir(dir_name) if not i.startswith('_')]
|
||||
return names[0]
|
||||
|
||||
train_root = train_dir + '/' + get_name(train_dir) + '/'
|
||||
val_root = val_dir + '/' + get_name(val_dir) + '/'
|
||||
cache_path = snapshot_download(model_id)
|
||||
tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(tmp_dir):
|
||||
os.makedirs(tmp_dir)
|
||||
|
||||
def _cfg_modify_fn(cfg):
|
||||
cfg.checkpoint_config.interval = 1
|
||||
cfg.log_config.interval = 10
|
||||
cfg.evaluation.interval = 1
|
||||
cfg.data.workers_per_gpu = 1
|
||||
cfg.data.samples_per_gpu = 4
|
||||
return cfg
|
||||
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(cache_path, 'DamoFD_lms.py'),
|
||||
work_dir=tmp_dir,
|
||||
train_root=train_root,
|
||||
val_root=val_root,
|
||||
total_epochs=1, # run #epochs
|
||||
cfg_modify_fn=_cfg_modify_fn)
|
||||
|
||||
trainer = build_trainer(name=Trainers.face_detection_scrfd, default_args=kwargs)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### 模型微调
|
||||
|
||||
```python
|
||||
import os
|
||||
import tempfile
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.utils.constant import ModelFile
|
||||
|
||||
model_id = 'damo/cv_ddsar_face-detection_iclr23-damofd'
|
||||
ms_ds_widerface = MsDataset.load('WIDER_FACE_mini', namespace='shaoxuan') # remove '_mini' for full dataset
|
||||
|
||||
data_path = ms_ds_widerface.config_kwargs['split_config']
|
||||
train_dir = data_path['train']
|
||||
val_dir = data_path['validation']
|
||||
|
||||
def get_name(dir_name):
|
||||
names = [i for i in os.listdir(dir_name) if not i.startswith('_')]
|
||||
return names[0]
|
||||
|
||||
train_root = train_dir + '/' + get_name(train_dir) + '/'
|
||||
val_root = val_dir + '/' + get_name(val_dir) + '/'
|
||||
cache_path = snapshot_download(model_id)
|
||||
tmp_dir = tempfile.TemporaryDirectory().name
|
||||
pretrain_epochs = 640
|
||||
ft_epochs = 1
|
||||
total_epochs = pretrain_epochs + ft_epochs
|
||||
if not os.path.exists(tmp_dir):
|
||||
os.makedirs(tmp_dir)
|
||||
|
||||
def _cfg_modify_fn(cfg):
|
||||
cfg.checkpoint_config.interval = 1
|
||||
cfg.log_config.interval = 10
|
||||
cfg.evaluation.interval = 1
|
||||
cfg.data.workers_per_gpu = 1
|
||||
cfg.data.samples_per_gpu = 4
|
||||
return cfg
|
||||
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(cache_path, 'DamoFD_lms.py'),
|
||||
work_dir=tmp_dir,
|
||||
train_root=train_root,
|
||||
val_root=val_root,
|
||||
resume_from=os.path.join(cache_path, ModelFile.TORCH_MODEL_FILE),
|
||||
total_epochs=total_epochs, # run #epochs
|
||||
cfg_modify_fn=_cfg_modify_fn)
|
||||
|
||||
trainer = build_trainer(name=Trainers.face_detection_scrfd, default_args=kwargs)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## 单模型选型和对比
|
||||
```python
|
||||
import os.path as osp
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import voc_ap, image_eval,img_pr_info, gen_gt_info, dataset_pr_info, bbox_overlap
|
||||
|
||||
model_id_list = ['damo/cv_ddsar_face-detection_iclr23-damofd', 'damo/cv_resnet101_face-detection_cvpr22papermogface', 'damo/cv_resnet50_face-detection_retinaface', 'damo/cv_manual_face-detection_mtcnn']
|
||||
val_set = MsDataset.load('widerface_mini_train_val', namespace='ly261666', split='validation')#, download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
img_base_path = next(iter(val_set))[1]
|
||||
img_dir = osp.join(img_base_path, 'val_data')
|
||||
img_gt = osp.join(img_base_path, 'val_label.txt')
|
||||
gt_info = gen_gt_info(img_gt)
|
||||
pred_info = {}
|
||||
iou_th = 0.5
|
||||
thresh_num = 1000
|
||||
count_face = 0
|
||||
conf_th = 0.01
|
||||
final_info = ""
|
||||
pr_curve = np.zeros((thresh_num, 2)).astype('float')
|
||||
for model_id in model_id_list:
|
||||
pr_curve = np.zeros((thresh_num, 2)).astype('float')
|
||||
count_face = 0
|
||||
if 'mtcnn' in model_id:
|
||||
face_detection_func = pipeline(Tasks.face_detection, model=model_id, conf_th=0.7) # Mtcnn only support high conf threshold
|
||||
elif 'damofd' in model_id:
|
||||
face_detection_func = pipeline(Tasks.face_detection, model=model_id) # Revise conf_th in DamoFD_lms.py
|
||||
else:
|
||||
face_detection_func = pipeline(Tasks.face_detection, model=model_id, conf_th=0.01)
|
||||
for idx, img_name in enumerate(os.listdir(img_dir)):
|
||||
print ('model_id: {}, inference img: {} {}/{}'.format(model_id, img_name, idx+1, len(os.listdir(img_dir))))
|
||||
abs_img_name = osp.join(img_dir, img_name)
|
||||
result = face_detection_func(abs_img_name)
|
||||
pred_info = np.concatenate([result['boxes'], np.array(result['scores'])[:,np.newaxis]], axis=1)
|
||||
gt_box = np.array(gt_info[img_name])
|
||||
pred_recall, proposal_list = image_eval(pred_info, gt_box, iou_th)
|
||||
_img_pr_info, fp = img_pr_info(thresh_num, pred_info, proposal_list, pred_recall)
|
||||
pr_curve += _img_pr_info
|
||||
count_face += gt_box.shape[0]
|
||||
|
||||
pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face)
|
||||
propose = pr_curve[:, 0]
|
||||
recall = pr_curve[:, 1]
|
||||
for srecall in np.arange(0.1, 1.0001, 0.1):
|
||||
rindex = len(np.where(recall<=srecall)[0])-1
|
||||
rthresh = 1.0 - float(rindex)/thresh_num
|
||||
print('Recall-Precision-Thresh:', recall[rindex], propose[rindex], rthresh)
|
||||
ap = voc_ap(recall, propose)
|
||||
result_info = 'model_id: {}, ap: {:.5f}, iou_th: {:.2f}'.format(model_id, ap, iou_th)
|
||||
print(result_info)
|
||||
final_info += result_info + '\n'
|
||||
print("Overall Result:")
|
||||
print(final_info)
|
||||
```
|
||||
|
||||
|
||||
<!--- ## 人脸识别系统多模块一键选型/对比 -->
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
BIN
data/test/images/face_detection.png
Normal file
|
After Width: | Height: | Size: 638 KiB |
BIN
data/test/images/face_detection2.jpeg
Normal file
|
After Width: | Height: | Size: 48 KiB |
BIN
data/test/images/face_liveness_ir.jpg
Normal file
|
After Width: | Height: | Size: 46 KiB |
BIN
data/test/images/face_liveness_rgb.png
Normal file
|
After Width: | Height: | Size: 22 KiB |
BIN
data/test/images/face_liveness_xc.png
Normal file
|
After Width: | Height: | Size: 3.0 MiB |
BIN
data/test/images/face_recognition_1.png
Normal file
|
After Width: | Height: | Size: 452 KiB |
BIN
data/test/images/face_recognition_2.png
Normal file
|
After Width: | Height: | Size: 349 KiB |
BIN
data/test/images/face_reconstruction.jpg
Normal file
|
After Width: | Height: | Size: 1.8 MiB |
BIN
data/test/images/facial_expression_recognition.jpg
Normal file
|
After Width: | Height: | Size: 158 KiB |
BIN
data/test/images/ir_face_recognition_1.png
Normal file
|
After Width: | Height: | Size: 20 KiB |
BIN
data/test/images/ir_face_recognition_2.png
Normal file
|
After Width: | Height: | Size: 20 KiB |
BIN
data/test/images/mask_face_recognition_1.jpg
Normal file
|
After Width: | Height: | Size: 114 KiB |
BIN
data/test/images/mask_face_recognition_2.jpg
Normal file
|
After Width: | Height: | Size: 114 KiB |
BIN
data/test/images/mog_face_detection.jpg
Normal file
|
After Width: | Height: | Size: 85 KiB |
BIN
data/test/images/mtcnn_face_detection.jpg
Normal file
|
After Width: | Height: | Size: 85 KiB |
BIN
data/test/images/retina_face_detection.jpg
Normal file
|
After Width: | Height: | Size: 85 KiB |
BIN
data/test/images/ulfd_face_detection.jpg
Normal file
|
After Width: | Height: | Size: 85 KiB |
BIN
demo/modelscope.gif
Normal file
|
After Width: | Height: | Size: 35 KiB |
276
face_project/face_detection/DamoFD/README.md
Normal file
@@ -0,0 +1,276 @@
|
||||
|
||||
<div align="center">
|
||||
<img src="demo/DamoFD.jpg" width="100%" height="10%" />
|
||||
</div>
|
||||
<h4 align="center">
|
||||
<a href=#DamoFD模型介绍> 模型介绍 </a> |
|
||||
<a href=#快速使用> 快速使用 </a> |
|
||||
<a href=#单图片推理> 单图片推理 </a> |
|
||||
<a href=#多图片推理和评测> 多图片推理/评测 </a> |
|
||||
<a href=#模型训练> 模型训练 </a> |
|
||||
<a href=#模型微调> 模型微调 </a>
|
||||
</h4>
|
||||
|
||||
# DamoFD模型介绍
|
||||
人脸检测关键点模型DamoFD,被ICLR2023录取([论文地址](https://openreview.net/forum?id=NkJOhtNKX91)), 这个项目中开源的模型是在DamoFD增加了关键点分支,论文原文代码见[项目地址](),论文解析详见[解析]()。
|
||||
|
||||
## 快速使用
|
||||
|
||||
DamoFD为当前SOTA的人脸检测关键点方法,论文已被ICLR23录取([论文地址](https://openreview.net/forum?id=NkJOhtNKX91))。DamoFD提供了family-based 人脸检测关键点模型,分别为`DamoFD-0.5G, DamoFD-2.5G, DamoFD-10G, DamoFD-34G`,性能均明显超过[SCRFD](https://arxiv.org/abs/2105.04714)。在这个界面中,我们提供几个有关`推理/评测/训练/微调`脚本帮助大家迅速/一键使用DamoFD, 代码范例中的实例均集成在如下几个unit test脚本里:
|
||||
- `DamoFD-0.5G: 训练,微调`:train_damofd_500m.py; 推理,评测:test_damofd_500m.py
|
||||
- `DamoFD-2.5G: 训练,微调`:train_damofd_2500m.py; 推理,评测:test_damofd_2500m.py
|
||||
- `DamoFD-10G: 训练,微调`:train_damofd_10g.py; 推理,评测:test_damofd_10g.py
|
||||
- `DamoFD-34G: 训练,微调`:train_damofd_34g.py; 推理,评测:test_damofd_34g.py
|
||||
- `Usage on DamoFD-0.5G`:
|
||||
```python
|
||||
PYTHONPATH=. python face_project/face_detection/DamoFD/train_damofd_500m.py
|
||||
PYTHONPATH=. python face_project/face_detection/DamoFD/test_damofd_500m.py
|
||||
```
|
||||
|
||||
## 代码范例
|
||||
我们以DamoFD-0.5G为例,提供了推理/评测/训练/微调代码范例和解析:
|
||||
|
||||
### 单图片推理
|
||||
```python
|
||||
import cv2
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
face_detection = pipeline(task=Tasks.face_detection, model='damo/cv_ddsar_face-detection_iclr23-damofd')
|
||||
# 支持 url image and abs dir image path
|
||||
img_path = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_detection2.jpeg'
|
||||
result = face_detection(img_path)
|
||||
|
||||
# 提供可视化结果
|
||||
from modelscope.utils.cv.image_utils import draw_face_detection_result
|
||||
from modelscope.preprocessors.image import LoadImage
|
||||
img = LoadImage.convert_to_ndarray(img_path)
|
||||
cv2.imwrite('srcImg.jpg', img)
|
||||
img_draw = draw_face_detection_result('srcImg.jpg', result)
|
||||
import matplotlib.pyplot as plt
|
||||
plt.imshow(img_draw)
|
||||
```
|
||||
|
||||
### 多图片推理和评测
|
||||
- 我们提供了100张测试图片,可运行下面代码一键使用(下载数据集+推理);
|
||||
- 也支持测试自建数据集,需要按如下格式建立数据集:
|
||||
```
|
||||
img_base_path/
|
||||
val_data/
|
||||
test_1.jpg
|
||||
...
|
||||
test_N.jpg
|
||||
val_label.txt
|
||||
## val_label.txt format
|
||||
test_1.jpg
|
||||
x0 x1 w h
|
||||
x0 x1 w h
|
||||
...
|
||||
test_N.jpg
|
||||
x0 x1 w h
|
||||
x0 x1 w h
|
||||
...
|
||||
```
|
||||
|
||||
```python
|
||||
import os.path as osp
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import voc_ap, image_eval,img_pr_info, gen_gt_info, dataset_pr_info, bbox_overlap
|
||||
|
||||
model_id = 'damo/cv_ddsar_face-detection_iclr23-damofd'
|
||||
val_set = MsDataset.load('widerface_mini_train_val', namespace='ly261666', split='validation')#, download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
img_base_path = next(iter(val_set))[1]
|
||||
img_dir = osp.join(img_base_path, 'val_data')
|
||||
img_gt = osp.join(img_base_path, 'val_label.txt')
|
||||
gt_info = gen_gt_info(img_gt)
|
||||
pred_info = {}
|
||||
iou_th = 0.5
|
||||
thresh_num = 1000
|
||||
face_detection_func = pipeline(Tasks.face_detection, model=model_id)
|
||||
count_face = 0
|
||||
pr_curve = np.zeros((thresh_num, 2)).astype('float')
|
||||
for img_name in os.listdir(img_dir):
|
||||
abs_img_name = osp.join(img_dir, img_name)
|
||||
result = face_detection_func(abs_img_name)
|
||||
pred_info = np.concatenate([result['boxes'], np.array(result['scores'])[:,np.newaxis]], axis=1)
|
||||
gt_box = np.array(gt_info[img_name])
|
||||
pred_recall, proposal_list = image_eval(pred_info, gt_box, iou_th)
|
||||
_img_pr_info, fp = img_pr_info(thresh_num, pred_info, proposal_list, pred_recall)
|
||||
pr_curve += _img_pr_info
|
||||
count_face += gt_box.shape[0]
|
||||
|
||||
pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face)
|
||||
propose = pr_curve[:, 0]
|
||||
recall = pr_curve[:, 1]
|
||||
for srecall in np.arange(0.1, 1.0001, 0.1):
|
||||
rindex = len(np.where(recall<=srecall)[0])-1
|
||||
rthresh = 1.0 - float(rindex)/thresh_num
|
||||
print('Recall-Precision-Thresh:', recall[rindex], propose[rindex], rthresh)
|
||||
ap = voc_ap(recall, propose)
|
||||
print('ap: %.5f, iou_th: %.2f'%(ap, iou_th))
|
||||
```
|
||||
Result:
|
||||
```
|
||||
Recall-Precision-Thresh: 0.09902038655017209 1.0 0.746
|
||||
Recall-Precision-Thresh: 0.19989409584326184 0.993421052631579 0.632
|
||||
Recall-Precision-Thresh: 0.2991792427852793 0.9519797809604044 0.499
|
||||
Recall-Precision-Thresh: 0.39925867090283296 0.8308539944903581 0.367
|
||||
Recall-Precision-Thresh: 0.4495631453534551 0.7237851662404092 0.0010000000000000009
|
||||
Recall-Precision-Thresh: 0.4495631453534551 0.7237851662404092 0.0010000000000000009
|
||||
Recall-Precision-Thresh: 0.4495631453534551 0.7237851662404092 0.0010000000000000009
|
||||
Recall-Precision-Thresh: 0.4495631453534551 0.7237851662404092 0.0010000000000000009
|
||||
Recall-Precision-Thresh: 0.4495631453534551 0.7237851662404092 0.0010000000000000009
|
||||
Recall-Precision-Thresh: 0.4495631453534551 0.7237851662404092 0.0010000000000000009
|
||||
ap: 0.42606, iou_th: 0.50
|
||||
```
|
||||
|
||||
### 模型训练
|
||||
- 我们提供了Wider Face 和 Wider Face mini的训练集,可运行下面代码一键使用(下载数据集+训练);
|
||||
- 也支持训练自建数据集,需要按如下格式建立数据集:
|
||||
```
|
||||
# <image_path> image_width image_height
|
||||
bbox_x1 bbox_y1 bbox_x2 bbox_y2 (<keypoint,3>*N)
|
||||
...
|
||||
...
|
||||
# <image_path> image_width image_height
|
||||
bbox_x1 bbox_y1 bbox_x2 bbox_y2 (<keypoint,3>*N)
|
||||
...
|
||||
...
|
||||
```
|
||||
|
||||
```python
|
||||
import os
|
||||
import tempfile
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
model_id = 'damo/cv_ddsar_face-detection_iclr23-damofd'
|
||||
ms_ds_widerface = MsDataset.load('WIDER_FACE_mini', namespace='shaoxuan') # remove '_mini' for full dataset
|
||||
|
||||
data_path = ms_ds_widerface.config_kwargs['split_config']
|
||||
train_dir = data_path['train']
|
||||
val_dir = data_path['validation']
|
||||
|
||||
def get_name(dir_name):
|
||||
names = [i for i in os.listdir(dir_name) if not i.startswith('_')]
|
||||
return names[0]
|
||||
|
||||
train_root = train_dir + '/' + get_name(train_dir) + '/'
|
||||
val_root = val_dir + '/' + get_name(val_dir) + '/'
|
||||
cache_path = snapshot_download(model_id)
|
||||
tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(tmp_dir):
|
||||
os.makedirs(tmp_dir)
|
||||
|
||||
def _cfg_modify_fn(cfg):
|
||||
cfg.checkpoint_config.interval = 1
|
||||
cfg.log_config.interval = 10
|
||||
cfg.evaluation.interval = 1
|
||||
cfg.data.workers_per_gpu = 1
|
||||
cfg.data.samples_per_gpu = 4
|
||||
return cfg
|
||||
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(cache_path, 'DamoFD_lms.py'),
|
||||
work_dir=tmp_dir,
|
||||
train_root=train_root,
|
||||
val_root=val_root,
|
||||
total_epochs=1, # run #epochs
|
||||
cfg_modify_fn=_cfg_modify_fn)
|
||||
|
||||
trainer = build_trainer(name=Trainers.face_detection_scrfd, default_args=kwargs)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### 模型微调
|
||||
- 我们提供了Wider Face 和 Wider Face mini的训练集,可运行下面代码一键使用(下载数据集+训练);
|
||||
- 网络结构在'modelscope/modelscope/models/cv/face_detection/scrfd/damofd_detect.py', 训练细节在'trainers/cv/face_detection_scrfd_trainer.py'。可以修改这两个文件中的
|
||||
- 也支持微调自建数据集,需要按如下格式建立数据集:
|
||||
```
|
||||
# <image_path> image_width image_height
|
||||
bbox_x1 bbox_y1 bbox_x2 bbox_y2 (<keypoint,3>*N)
|
||||
...
|
||||
...
|
||||
# <image_path> image_width image_height
|
||||
bbox_x1 bbox_y1 bbox_x2 bbox_y2 (<keypoint,3>*N)
|
||||
...
|
||||
...
|
||||
```
|
||||
|
||||
```python
|
||||
import os
|
||||
import tempfile
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.utils.constant import ModelFile
|
||||
|
||||
model_id = 'damo/cv_ddsar_face-detection_iclr23-damofd'
|
||||
ms_ds_widerface = MsDataset.load('WIDER_FACE_mini', namespace='shaoxuan') # remove '_mini' for full dataset
|
||||
|
||||
data_path = ms_ds_widerface.config_kwargs['split_config']
|
||||
train_dir = data_path['train']
|
||||
val_dir = data_path['validation']
|
||||
|
||||
def get_name(dir_name):
|
||||
names = [i for i in os.listdir(dir_name) if not i.startswith('_')]
|
||||
return names[0]
|
||||
|
||||
train_root = train_dir + '/' + get_name(train_dir) + '/'
|
||||
val_root = val_dir + '/' + get_name(val_dir) + '/'
|
||||
cache_path = snapshot_download(model_id)
|
||||
tmp_dir = tempfile.TemporaryDirectory().name
|
||||
pretrain_epochs = 640
|
||||
ft_epochs = 1
|
||||
total_epochs = pretrain_epochs + ft_epochs
|
||||
if not os.path.exists(tmp_dir):
|
||||
os.makedirs(tmp_dir)
|
||||
|
||||
def _cfg_modify_fn(cfg):
|
||||
cfg.checkpoint_config.interval = 1
|
||||
cfg.log_config.interval = 10
|
||||
cfg.evaluation.interval = 1
|
||||
cfg.data.workers_per_gpu = 1
|
||||
cfg.data.samples_per_gpu = 4
|
||||
return cfg
|
||||
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(cache_path, 'DamoFD_lms.py'),
|
||||
work_dir=tmp_dir,
|
||||
train_root=train_root,
|
||||
val_root=val_root,
|
||||
resume_from=os.path.join(cache_path, ModelFile.TORCH_MODEL_FILE),
|
||||
total_epochs=total_epochs, # run #epochs
|
||||
cfg_modify_fn=_cfg_modify_fn)
|
||||
|
||||
trainer = build_trainer(name=Trainers.face_detection_scrfd, default_args=kwargs)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
|
||||
## 模型效果
|
||||

|
||||
|
||||
<!---
|
||||
## 引用
|
||||
如果你觉得这个该模型对有所帮助,请考虑引用下面的相关的论文:
|
||||
|
||||
```BibTeX
|
||||
@inproceedings{liu2022mogface,
|
||||
title={MogFace: Towards a Deeper Appreciation on Face Detection},
|
||||
author={Liu, Yang and Wang, Fei and Deng, Jiankang and Zhou, Zhipeng and Sun, Baigui and Li, Hao},
|
||||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
||||
pages={4093--4102},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
-->
|
||||
|
||||
BIN
face_project/face_detection/DamoFD/demo/DamoFD.jpg
Normal file
|
After Width: | Height: | Size: 68 KiB |
BIN
face_project/face_detection/DamoFD/demo/DamoFD_ap.jpg
Normal file
|
After Width: | Height: | Size: 158 KiB |
74
face_project/face_detection/DamoFD/test_damofd_10g.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import draw_face_detection_result
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
from modelscope.utils.constant import DownloadMode
|
||||
from modelscope.utils.cv.image_utils import voc_ap, image_eval,img_pr_info, gen_gt_info, dataset_pr_info, bbox_overlap
|
||||
|
||||
|
||||
class DamoFDFaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.face_detection
|
||||
self.model_id = 'damo/cv_ddsar_face-detection_iclr23-damofd-10G'
|
||||
self.img_path = 'data/test/images/mog_face_detection.jpg'
|
||||
|
||||
def show_result(self, img_path, detection_result):
|
||||
img = draw_face_detection_result(img_path, detection_result)
|
||||
cv2.imwrite('result.png', img)
|
||||
print(f'output written to {osp.abspath("result.png")}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
face_detection = pipeline(Tasks.face_detection, model=self.model_id)
|
||||
|
||||
result = face_detection(self.img_path)
|
||||
self.show_result(self.img_path, result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_dataset(self):
|
||||
val_set = MsDataset.load('widerface_mini_train_val', namespace='ly261666', split='validation')#, download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
img_base_path = next(iter(val_set))[1]
|
||||
img_dir = osp.join(img_base_path, 'val_data')
|
||||
img_gt = osp.join(img_base_path, 'val_label.txt')
|
||||
gt_info = gen_gt_info(img_gt)
|
||||
pred_info = {}
|
||||
iou_th = 0.5
|
||||
thresh_num = 1000
|
||||
face_detection_func = pipeline(Tasks.face_detection, model=self.model_id)
|
||||
count_face = 0
|
||||
pr_curve = np.zeros((thresh_num, 2)).astype('float')
|
||||
for img_name in os.listdir(img_dir):
|
||||
abs_img_name = osp.join(img_dir, img_name)
|
||||
result = face_detection_func(abs_img_name)
|
||||
pred_info = np.concatenate([result['boxes'], np.array(result['scores'])[:,np.newaxis]], axis=1)
|
||||
gt_box = np.array(gt_info[img_name])
|
||||
pred_recall, proposal_list = image_eval(pred_info, gt_box, iou_th)
|
||||
_img_pr_info, fp = img_pr_info(thresh_num, pred_info, proposal_list, pred_recall)
|
||||
pr_curve += _img_pr_info
|
||||
count_face += gt_box.shape[0]
|
||||
|
||||
pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face)
|
||||
propose = pr_curve[:, 0]
|
||||
recall = pr_curve[:, 1]
|
||||
for srecall in np.arange(0.1, 1.0001, 0.1):
|
||||
rindex = len(np.where(recall<=srecall)[0])-1
|
||||
rthresh = 1.0 - float(rindex)/thresh_num
|
||||
print('Recall-Precision-Thresh:', recall[rindex], propose[rindex], rthresh)
|
||||
ap = voc_ap(recall, propose)
|
||||
print('ap: %.5f, iou_th: %.2f'%(ap, iou_th))
|
||||
self.show_result(abs_img_name, result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
78
face_project/face_detection/DamoFD/test_damofd_2500m.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import draw_face_detection_result
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
from modelscope.utils.constant import DownloadMode
|
||||
from modelscope.utils.cv.image_utils import voc_ap, image_eval,img_pr_info, gen_gt_info, dataset_pr_info, bbox_overlap
|
||||
|
||||
|
||||
class DamoFDFaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.face_detection
|
||||
self.model_id = 'damo/cv_ddsar_face-detection_iclr23-damofd-2.5G'
|
||||
self.img_path = 'data/test/images/mog_face_detection.jpg'
|
||||
|
||||
def show_result(self, img_path, detection_result):
|
||||
img = draw_face_detection_result(img_path, detection_result)
|
||||
cv2.imwrite('result.png', img)
|
||||
print(f'output written to {osp.abspath("result.png")}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_dataset(self):
|
||||
val_set = MsDataset.load('widerface_mini_train_val', namespace='ly261666', split='validation')#, download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
img_base_path = next(iter(val_set))[1]
|
||||
img_dir = osp.join(img_base_path, 'val_data')
|
||||
img_gt = osp.join(img_base_path, 'val_label.txt')
|
||||
gt_info = gen_gt_info(img_gt)
|
||||
pred_info = {}
|
||||
iou_th = 0.5
|
||||
thresh_num = 1000
|
||||
face_detection_func = pipeline(Tasks.face_detection, model=self.model_id)
|
||||
count_face = 0
|
||||
pr_curve = np.zeros((thresh_num, 2)).astype('float')
|
||||
for img_name in os.listdir(img_dir):
|
||||
abs_img_name = osp.join(img_dir, img_name)
|
||||
result = face_detection_func(abs_img_name)
|
||||
pred_info = np.concatenate([result['boxes'], np.array(result['scores'])[:,np.newaxis]], axis=1)
|
||||
gt_box = np.array(gt_info[img_name])
|
||||
pred_recall, proposal_list = image_eval(pred_info, gt_box, iou_th)
|
||||
_img_pr_info, fp = img_pr_info(thresh_num, pred_info, proposal_list, pred_recall)
|
||||
pr_curve += _img_pr_info
|
||||
count_face += gt_box.shape[0]
|
||||
|
||||
pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face)
|
||||
propose = pr_curve[:, 0]
|
||||
recall = pr_curve[:, 1]
|
||||
for srecall in np.arange(0.1, 1.0001, 0.1):
|
||||
rindex = len(np.where(recall<=srecall)[0])-1
|
||||
rthresh = 1.0 - float(rindex)/thresh_num
|
||||
print('Recall-Precision-Thresh:', recall[rindex], propose[rindex], rthresh)
|
||||
ap = voc_ap(recall, propose)
|
||||
print('ap: %.5f, iou_th: %.2f'%(ap, iou_th))
|
||||
self.show_result(abs_img_name, result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
face_detection = pipeline(Tasks.face_detection, model=self.model_id)
|
||||
|
||||
result = face_detection(self.img_path)
|
||||
self.show_result(self.img_path, result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_demo_compatibility(self):
|
||||
self.compatibility_check()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
78
face_project/face_detection/DamoFD/test_damofd_34g.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import draw_face_detection_result
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
from modelscope.utils.constant import DownloadMode
|
||||
from modelscope.utils.cv.image_utils import voc_ap, image_eval,img_pr_info, gen_gt_info, dataset_pr_info, bbox_overlap
|
||||
|
||||
|
||||
class DamoFDFaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.face_detection
|
||||
self.model_id = 'damo/cv_ddsar_face-detection_iclr23-damofd-34G'
|
||||
self.img_path = 'data/test/images/mog_face_detection.jpg'
|
||||
|
||||
def show_result(self, img_path, detection_result):
|
||||
img = draw_face_detection_result(img_path, detection_result)
|
||||
cv2.imwrite('result.png', img)
|
||||
print(f'output written to {osp.abspath("result.png")}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_dataset(self):
|
||||
val_set = MsDataset.load('widerface_mini_train_val', namespace='ly261666', split='validation')#, download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
img_base_path = next(iter(val_set))[1]
|
||||
img_dir = osp.join(img_base_path, 'val_data')
|
||||
img_gt = osp.join(img_base_path, 'val_label.txt')
|
||||
gt_info = gen_gt_info(img_gt)
|
||||
pred_info = {}
|
||||
iou_th = 0.5
|
||||
thresh_num = 1000
|
||||
face_detection_func = pipeline(Tasks.face_detection, model=self.model_id)
|
||||
count_face = 0
|
||||
pr_curve = np.zeros((thresh_num, 2)).astype('float')
|
||||
for img_name in os.listdir(img_dir):
|
||||
abs_img_name = osp.join(img_dir, img_name)
|
||||
result = face_detection_func(abs_img_name)
|
||||
pred_info = np.concatenate([result['boxes'], np.array(result['scores'])[:,np.newaxis]], axis=1)
|
||||
gt_box = np.array(gt_info[img_name])
|
||||
pred_recall, proposal_list = image_eval(pred_info, gt_box, iou_th)
|
||||
_img_pr_info, fp = img_pr_info(thresh_num, pred_info, proposal_list, pred_recall)
|
||||
pr_curve += _img_pr_info
|
||||
count_face += gt_box.shape[0]
|
||||
|
||||
pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face)
|
||||
propose = pr_curve[:, 0]
|
||||
recall = pr_curve[:, 1]
|
||||
for srecall in np.arange(0.1, 1.0001, 0.1):
|
||||
rindex = len(np.where(recall<=srecall)[0])-1
|
||||
rthresh = 1.0 - float(rindex)/thresh_num
|
||||
print('Recall-Precision-Thresh:', recall[rindex], propose[rindex], rthresh)
|
||||
ap = voc_ap(recall, propose)
|
||||
print('ap: %.5f, iou_th: %.2f'%(ap, iou_th))
|
||||
self.show_result(abs_img_name, result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
face_detection = pipeline(Tasks.face_detection, model=self.model_id)
|
||||
|
||||
result = face_detection(self.img_path)
|
||||
self.show_result(self.img_path, result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_demo_compatibility(self):
|
||||
self.compatibility_check()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
78
face_project/face_detection/DamoFD/test_damofd_500m.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import draw_face_detection_result
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
from modelscope.utils.constant import DownloadMode
|
||||
from modelscope.utils.cv.image_utils import voc_ap, image_eval,img_pr_info, gen_gt_info, dataset_pr_info, bbox_overlap
|
||||
|
||||
|
||||
class DamoFDFaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.face_detection
|
||||
self.model_id = 'damo/cv_ddsar_face-detection_iclr23-damofd'
|
||||
self.img_path = 'data/test/images/mog_face_detection.jpg'
|
||||
|
||||
def show_result(self, img_path, detection_result):
|
||||
img = draw_face_detection_result(img_path, detection_result)
|
||||
cv2.imwrite('result.png', img)
|
||||
print(f'output written to {osp.abspath("result.png")}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_dataset(self):
|
||||
val_set = MsDataset.load('widerface_mini_train_val', namespace='ly261666', split='validation')#, download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
img_base_path = next(iter(val_set))[1]
|
||||
img_dir = osp.join(img_base_path, 'val_data')
|
||||
img_gt = osp.join(img_base_path, 'val_label.txt')
|
||||
gt_info = gen_gt_info(img_gt)
|
||||
pred_info = {}
|
||||
iou_th = 0.5
|
||||
thresh_num = 1000
|
||||
face_detection_func = pipeline(Tasks.face_detection, model=self.model_id)
|
||||
count_face = 0
|
||||
pr_curve = np.zeros((thresh_num, 2)).astype('float')
|
||||
for img_name in os.listdir(img_dir):
|
||||
abs_img_name = osp.join(img_dir, img_name)
|
||||
result = face_detection_func(abs_img_name)
|
||||
pred_info = np.concatenate([result['boxes'], np.array(result['scores'])[:,np.newaxis]], axis=1)
|
||||
gt_box = np.array(gt_info[img_name])
|
||||
pred_recall, proposal_list = image_eval(pred_info, gt_box, iou_th)
|
||||
_img_pr_info, fp = img_pr_info(thresh_num, pred_info, proposal_list, pred_recall)
|
||||
pr_curve += _img_pr_info
|
||||
count_face += gt_box.shape[0]
|
||||
|
||||
pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face)
|
||||
propose = pr_curve[:, 0]
|
||||
recall = pr_curve[:, 1]
|
||||
for srecall in np.arange(0.1, 1.0001, 0.1):
|
||||
rindex = len(np.where(recall<=srecall)[0])-1
|
||||
rthresh = 1.0 - float(rindex)/thresh_num
|
||||
print('Recall-Precision-Thresh:', recall[rindex], propose[rindex], rthresh)
|
||||
ap = voc_ap(recall, propose)
|
||||
print('ap: %.5f, iou_th: %.2f'%(ap, iou_th))
|
||||
self.show_result(abs_img_name, result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
face_detection = pipeline(Tasks.face_detection, model=self.model_id)
|
||||
|
||||
result = face_detection(self.img_path)
|
||||
self.show_result(self.img_path, result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_demo_compatibility(self):
|
||||
self.compatibility_check()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
150
face_project/face_detection/DamoFD/train_damofd_10g.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.test_utils import DistributedTestCase, test_level
|
||||
|
||||
|
||||
def _setup():
|
||||
model_id = 'damo/cv_ddsar_face-detection_iclr23-damofd-10G'
|
||||
# mini dataset only for unit test, remove '_mini' for full dataset.
|
||||
ms_ds_widerface = MsDataset.load('WIDER_FACE_mini', namespace='shaoxuan')
|
||||
|
||||
data_path = ms_ds_widerface.config_kwargs['split_config']
|
||||
train_dir = data_path['train']
|
||||
val_dir = data_path['validation']
|
||||
train_root = train_dir + '/' + os.listdir(train_dir)[0] + '/'
|
||||
val_root = val_dir + '/' + os.listdir(val_dir)[0] + '/'
|
||||
max_epochs = 1 # run epochs in unit test
|
||||
|
||||
cache_path = snapshot_download(model_id)
|
||||
|
||||
tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(tmp_dir):
|
||||
os.makedirs(tmp_dir)
|
||||
return train_root, val_root, max_epochs, cache_path, tmp_dir
|
||||
|
||||
|
||||
def train_func(**kwargs):
|
||||
trainer = build_trainer(
|
||||
name=Trainers.face_detection_scrfd, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
|
||||
class TestFaceDetectionDamofdTrainerSingleGPU(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('SingleGPU Testing %s.%s' %
|
||||
(type(self).__name__, self._testMethodName)))
|
||||
self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup(
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
def _cfg_modify_fn(self, cfg):
|
||||
cfg.checkpoint_config.interval = 1
|
||||
cfg.log_config.interval = 10
|
||||
cfg.evaluation.interval = 1
|
||||
cfg.data.workers_per_gpu = 3
|
||||
cfg.data.samples_per_gpu = 4 # batch size
|
||||
return cfg
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer_from_scratch(self):
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(self.cache_path, 'DamoFD_lms.py'),
|
||||
work_dir=self.tmp_dir,
|
||||
train_root=self.train_root,
|
||||
val_root=self.val_root,
|
||||
total_epochs=self.max_epochs,
|
||||
cfg_modify_fn=self._cfg_modify_fn)
|
||||
|
||||
trainer = build_trainer(
|
||||
name=Trainers.face_detection_scrfd, default_args=kwargs)
|
||||
trainer.train()
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
for i in range(self.max_epochs):
|
||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_trainer_finetune(self):
|
||||
pretrain_epoch = 640
|
||||
self.max_epochs += pretrain_epoch
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(self.cache_path, 'DamoFD_lms.py'),
|
||||
work_dir=self.tmp_dir,
|
||||
train_root=self.train_root,
|
||||
val_root=self.val_root,
|
||||
total_epochs=self.max_epochs,
|
||||
resume_from=os.path.join(self.cache_path,
|
||||
ModelFile.TORCH_MODEL_FILE),
|
||||
cfg_modify_fn=self._cfg_modify_fn)
|
||||
|
||||
trainer = build_trainer(
|
||||
name=Trainers.face_detection_scrfd, default_args=kwargs)
|
||||
trainer.train()
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
for i in range(pretrain_epoch, self.max_epochs):
|
||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available()
|
||||
or torch.cuda.device_count() <= 1, 'distributed unittest')
|
||||
class TestFaceDetectionDamofdTrainerMultiGpus(DistributedTestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('MultiGPUs Testing %s.%s' %
|
||||
(type(self).__name__, self._testMethodName)))
|
||||
self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup(
|
||||
)
|
||||
cfg_file_path = os.path.join(self.cache_path, 'DamoFD_lms.py')
|
||||
cfg = Config.from_file(cfg_file_path)
|
||||
cfg.checkpoint_config.interval = 1
|
||||
cfg.log_config.interval = 10
|
||||
cfg.evaluation.interval = 1
|
||||
cfg.data.workers_per_gpu = 3
|
||||
cfg.data.samples_per_gpu = 4
|
||||
cfg.dump(cfg_file_path)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_multi_gpus_finetune(self):
|
||||
pretrain_epoch = 640
|
||||
self.max_epochs += pretrain_epoch
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(self.cache_path, 'DamoFD_lms.py'),
|
||||
work_dir=self.tmp_dir,
|
||||
train_root=self.train_root,
|
||||
val_root=self.val_root,
|
||||
total_epochs=self.max_epochs,
|
||||
resume_from=os.path.join(self.cache_path,
|
||||
ModelFile.TORCH_MODEL_FILE),
|
||||
launcher='pytorch')
|
||||
self.start(train_func, num_gpus=2, **kwargs)
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json'))
|
||||
self.assertEqual(len(json_files), 1)
|
||||
for i in range(pretrain_epoch, self.max_epochs):
|
||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
150
face_project/face_detection/DamoFD/train_damofd_2500m.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.test_utils import DistributedTestCase, test_level
|
||||
|
||||
|
||||
def _setup():
|
||||
model_id = 'damo/cv_ddsar_face-detection_iclr23-damofd-2.5G'
|
||||
# mini dataset only for unit test, remove '_mini' for full dataset.
|
||||
ms_ds_widerface = MsDataset.load('WIDER_FACE_mini', namespace='shaoxuan')
|
||||
|
||||
data_path = ms_ds_widerface.config_kwargs['split_config']
|
||||
train_dir = data_path['train']
|
||||
val_dir = data_path['validation']
|
||||
train_root = train_dir + '/' + os.listdir(train_dir)[0] + '/'
|
||||
val_root = val_dir + '/' + os.listdir(val_dir)[0] + '/'
|
||||
max_epochs = 1 # run epochs in unit test
|
||||
|
||||
cache_path = snapshot_download(model_id)
|
||||
|
||||
tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(tmp_dir):
|
||||
os.makedirs(tmp_dir)
|
||||
return train_root, val_root, max_epochs, cache_path, tmp_dir
|
||||
|
||||
|
||||
def train_func(**kwargs):
|
||||
trainer = build_trainer(
|
||||
name=Trainers.face_detection_scrfd, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
|
||||
class TestFaceDetectionDamofdTrainerSingleGPU(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('SingleGPU Testing %s.%s' %
|
||||
(type(self).__name__, self._testMethodName)))
|
||||
self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup(
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
def _cfg_modify_fn(self, cfg):
|
||||
cfg.checkpoint_config.interval = 1
|
||||
cfg.log_config.interval = 10
|
||||
cfg.evaluation.interval = 1
|
||||
cfg.data.workers_per_gpu = 3
|
||||
cfg.data.samples_per_gpu = 4 # batch size
|
||||
return cfg
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer_from_scratch(self):
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(self.cache_path, 'DamoFD_lms.py'),
|
||||
work_dir=self.tmp_dir,
|
||||
train_root=self.train_root,
|
||||
val_root=self.val_root,
|
||||
total_epochs=self.max_epochs,
|
||||
cfg_modify_fn=self._cfg_modify_fn)
|
||||
|
||||
trainer = build_trainer(
|
||||
name=Trainers.face_detection_scrfd, default_args=kwargs)
|
||||
trainer.train()
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
for i in range(self.max_epochs):
|
||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_trainer_finetune(self):
|
||||
pretrain_epoch = 640
|
||||
self.max_epochs += pretrain_epoch
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(self.cache_path, 'DamoFD_lms.py'),
|
||||
work_dir=self.tmp_dir,
|
||||
train_root=self.train_root,
|
||||
val_root=self.val_root,
|
||||
total_epochs=self.max_epochs,
|
||||
resume_from=os.path.join(self.cache_path,
|
||||
ModelFile.TORCH_MODEL_FILE),
|
||||
cfg_modify_fn=self._cfg_modify_fn)
|
||||
|
||||
trainer = build_trainer(
|
||||
name=Trainers.face_detection_scrfd, default_args=kwargs)
|
||||
trainer.train()
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
for i in range(pretrain_epoch, self.max_epochs):
|
||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available()
|
||||
or torch.cuda.device_count() <= 1, 'distributed unittest')
|
||||
class TestFaceDetectionDamofdTrainerMultiGpus(DistributedTestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('MultiGPUs Testing %s.%s' %
|
||||
(type(self).__name__, self._testMethodName)))
|
||||
self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup(
|
||||
)
|
||||
cfg_file_path = os.path.join(self.cache_path, 'DamoFD_lms.py')
|
||||
cfg = Config.from_file(cfg_file_path)
|
||||
cfg.checkpoint_config.interval = 1
|
||||
cfg.log_config.interval = 10
|
||||
cfg.evaluation.interval = 1
|
||||
cfg.data.workers_per_gpu = 3
|
||||
cfg.data.samples_per_gpu = 4
|
||||
cfg.dump(cfg_file_path)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_multi_gpus_finetune(self):
|
||||
pretrain_epoch = 640
|
||||
self.max_epochs += pretrain_epoch
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(self.cache_path, 'DamoFD_lms.py'),
|
||||
work_dir=self.tmp_dir,
|
||||
train_root=self.train_root,
|
||||
val_root=self.val_root,
|
||||
total_epochs=self.max_epochs,
|
||||
resume_from=os.path.join(self.cache_path,
|
||||
ModelFile.TORCH_MODEL_FILE),
|
||||
launcher='pytorch')
|
||||
self.start(train_func, num_gpus=2, **kwargs)
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json'))
|
||||
self.assertEqual(len(json_files), 1)
|
||||
for i in range(pretrain_epoch, self.max_epochs):
|
||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
150
face_project/face_detection/DamoFD/train_damofd_34g.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.test_utils import DistributedTestCase, test_level
|
||||
|
||||
|
||||
def _setup():
|
||||
model_id = 'damo/cv_ddsar_face-detection_iclr23-damofd-34G'
|
||||
# mini dataset only for unit test, remove '_mini' for full dataset.
|
||||
ms_ds_widerface = MsDataset.load('WIDER_FACE_mini', namespace='shaoxuan')
|
||||
|
||||
data_path = ms_ds_widerface.config_kwargs['split_config']
|
||||
train_dir = data_path['train']
|
||||
val_dir = data_path['validation']
|
||||
train_root = train_dir + '/' + os.listdir(train_dir)[0] + '/'
|
||||
val_root = val_dir + '/' + os.listdir(val_dir)[0] + '/'
|
||||
max_epochs = 1 # run epochs in unit test
|
||||
|
||||
cache_path = snapshot_download(model_id)
|
||||
|
||||
tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(tmp_dir):
|
||||
os.makedirs(tmp_dir)
|
||||
return train_root, val_root, max_epochs, cache_path, tmp_dir
|
||||
|
||||
|
||||
def train_func(**kwargs):
|
||||
trainer = build_trainer(
|
||||
name=Trainers.face_detection_scrfd, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
|
||||
class TestFaceDetectionDamofdTrainerSingleGPU(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('SingleGPU Testing %s.%s' %
|
||||
(type(self).__name__, self._testMethodName)))
|
||||
self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup(
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
def _cfg_modify_fn(self, cfg):
|
||||
cfg.checkpoint_config.interval = 1
|
||||
cfg.log_config.interval = 10
|
||||
cfg.evaluation.interval = 1
|
||||
cfg.data.workers_per_gpu = 3
|
||||
cfg.data.samples_per_gpu = 4 # batch size
|
||||
return cfg
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer_from_scratch(self):
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(self.cache_path, 'DamoFD_lms.py'),
|
||||
work_dir=self.tmp_dir,
|
||||
train_root=self.train_root,
|
||||
val_root=self.val_root,
|
||||
total_epochs=self.max_epochs,
|
||||
cfg_modify_fn=self._cfg_modify_fn)
|
||||
|
||||
trainer = build_trainer(
|
||||
name=Trainers.face_detection_scrfd, default_args=kwargs)
|
||||
trainer.train()
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
for i in range(self.max_epochs):
|
||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_trainer_finetune(self):
|
||||
pretrain_epoch = 640
|
||||
self.max_epochs += pretrain_epoch
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(self.cache_path, 'DamoFD_lms.py'),
|
||||
work_dir=self.tmp_dir,
|
||||
train_root=self.train_root,
|
||||
val_root=self.val_root,
|
||||
total_epochs=self.max_epochs,
|
||||
resume_from=os.path.join(self.cache_path,
|
||||
ModelFile.TORCH_MODEL_FILE),
|
||||
cfg_modify_fn=self._cfg_modify_fn)
|
||||
|
||||
trainer = build_trainer(
|
||||
name=Trainers.face_detection_scrfd, default_args=kwargs)
|
||||
trainer.train()
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
for i in range(pretrain_epoch, self.max_epochs):
|
||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available()
|
||||
or torch.cuda.device_count() <= 1, 'distributed unittest')
|
||||
class TestFaceDetectionDamofdTrainerMultiGpus(DistributedTestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('MultiGPUs Testing %s.%s' %
|
||||
(type(self).__name__, self._testMethodName)))
|
||||
self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup(
|
||||
)
|
||||
cfg_file_path = os.path.join(self.cache_path, 'DamoFD_lms.py')
|
||||
cfg = Config.from_file(cfg_file_path)
|
||||
cfg.checkpoint_config.interval = 1
|
||||
cfg.log_config.interval = 10
|
||||
cfg.evaluation.interval = 1
|
||||
cfg.data.workers_per_gpu = 3
|
||||
cfg.data.samples_per_gpu = 4
|
||||
cfg.dump(cfg_file_path)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_multi_gpus_finetune(self):
|
||||
pretrain_epoch = 640
|
||||
self.max_epochs += pretrain_epoch
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(self.cache_path, 'DamoFD_lms.py'),
|
||||
work_dir=self.tmp_dir,
|
||||
train_root=self.train_root,
|
||||
val_root=self.val_root,
|
||||
total_epochs=self.max_epochs,
|
||||
resume_from=os.path.join(self.cache_path,
|
||||
ModelFile.TORCH_MODEL_FILE),
|
||||
launcher='pytorch')
|
||||
self.start(train_func, num_gpus=2, **kwargs)
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json'))
|
||||
self.assertEqual(len(json_files), 1)
|
||||
for i in range(pretrain_epoch, self.max_epochs):
|
||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
150
face_project/face_detection/DamoFD/train_damofd_500m.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.test_utils import DistributedTestCase, test_level
|
||||
|
||||
|
||||
def _setup():
|
||||
model_id = 'damo/cv_ddsar_face-detection_iclr23-damofd'
|
||||
# mini dataset only for unit test, remove '_mini' for full dataset.
|
||||
ms_ds_widerface = MsDataset.load('WIDER_FACE_mini', namespace='shaoxuan')
|
||||
|
||||
data_path = ms_ds_widerface.config_kwargs['split_config']
|
||||
train_dir = data_path['train']
|
||||
val_dir = data_path['validation']
|
||||
train_root = train_dir + '/' + os.listdir(train_dir)[0] + '/'
|
||||
val_root = val_dir + '/' + os.listdir(val_dir)[0] + '/'
|
||||
max_epochs = 1 # run epochs in unit test
|
||||
|
||||
cache_path = snapshot_download(model_id)
|
||||
|
||||
tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(tmp_dir):
|
||||
os.makedirs(tmp_dir)
|
||||
return train_root, val_root, max_epochs, cache_path, tmp_dir
|
||||
|
||||
|
||||
def train_func(**kwargs):
|
||||
trainer = build_trainer(
|
||||
name=Trainers.face_detection_scrfd, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
|
||||
class TestFaceDetectionDamofdTrainerSingleGPU(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('SingleGPU Testing %s.%s' %
|
||||
(type(self).__name__, self._testMethodName)))
|
||||
self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup(
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
def _cfg_modify_fn(self, cfg):
|
||||
cfg.checkpoint_config.interval = 1
|
||||
cfg.log_config.interval = 10
|
||||
cfg.evaluation.interval = 1
|
||||
cfg.data.workers_per_gpu = 3
|
||||
cfg.data.samples_per_gpu = 4 # batch size
|
||||
return cfg
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer_from_scratch(self):
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(self.cache_path, 'DamoFD_lms.py'),
|
||||
work_dir=self.tmp_dir,
|
||||
train_root=self.train_root,
|
||||
val_root=self.val_root,
|
||||
total_epochs=self.max_epochs,
|
||||
cfg_modify_fn=self._cfg_modify_fn)
|
||||
|
||||
trainer = build_trainer(
|
||||
name=Trainers.face_detection_scrfd, default_args=kwargs)
|
||||
trainer.train()
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
for i in range(self.max_epochs):
|
||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_trainer_finetune(self):
|
||||
pretrain_epoch = 640
|
||||
self.max_epochs += pretrain_epoch
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(self.cache_path, 'DamoFD_lms.py'),
|
||||
work_dir=self.tmp_dir,
|
||||
train_root=self.train_root,
|
||||
val_root=self.val_root,
|
||||
total_epochs=self.max_epochs,
|
||||
resume_from=os.path.join(self.cache_path,
|
||||
ModelFile.TORCH_MODEL_FILE),
|
||||
cfg_modify_fn=self._cfg_modify_fn)
|
||||
|
||||
trainer = build_trainer(
|
||||
name=Trainers.face_detection_scrfd, default_args=kwargs)
|
||||
trainer.train()
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
for i in range(pretrain_epoch, self.max_epochs):
|
||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available()
|
||||
or torch.cuda.device_count() <= 1, 'distributed unittest')
|
||||
class TestFaceDetectionDamofdTrainerMultiGpus(DistributedTestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('MultiGPUs Testing %s.%s' %
|
||||
(type(self).__name__, self._testMethodName)))
|
||||
self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup(
|
||||
)
|
||||
cfg_file_path = os.path.join(self.cache_path, 'DamoFD_lms.py')
|
||||
cfg = Config.from_file(cfg_file_path)
|
||||
cfg.checkpoint_config.interval = 1
|
||||
cfg.log_config.interval = 10
|
||||
cfg.evaluation.interval = 1
|
||||
cfg.data.workers_per_gpu = 3
|
||||
cfg.data.samples_per_gpu = 4
|
||||
cfg.dump(cfg_file_path)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_multi_gpus_finetune(self):
|
||||
pretrain_epoch = 640
|
||||
self.max_epochs += pretrain_epoch
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(self.cache_path, 'DamoFD_lms.py'),
|
||||
work_dir=self.tmp_dir,
|
||||
train_root=self.train_root,
|
||||
val_root=self.val_root,
|
||||
total_epochs=self.max_epochs,
|
||||
resume_from=os.path.join(self.cache_path,
|
||||
ModelFile.TORCH_MODEL_FILE),
|
||||
launcher='pytorch')
|
||||
self.start(train_func, num_gpus=2, **kwargs)
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json'))
|
||||
self.assertEqual(len(json_files), 1)
|
||||
for i in range(pretrain_epoch, self.max_epochs):
|
||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
146
face_project/face_detection/MogFace/README.md
Normal file
@@ -0,0 +1,146 @@
|
||||
|
||||
<div align="center">
|
||||
<img src="demo/MogFace.jpg" width="100%" height="10%" />
|
||||
</div>
|
||||
<h4 align="center">
|
||||
<a href=#MogFace模型介绍> 模型介绍 </a> |
|
||||
<a href=#快速使用> 快速使用 </a> |
|
||||
<a href=#单图片推理> 单图片推理 </a> |
|
||||
<a href=#多图片推理和评测> 多图片推理/评测 </a>
|
||||
</h4>
|
||||
|
||||
# MogFace模型介绍
|
||||
MogFace为当前SOTA的人脸检测方法,已在Wider Face六项榜单上霸榜一年以上,后续被CVPR2022录取([论文地址](https://openaccess.thecvf.com/content/CVPR2022/papers/Liu_MogFace_Towards_a_Deeper_Appreciation_on_Face_Detection_CVPR_2022_paper.pdf)),该方法的主要贡献是从下面三个角度提升人脸检测器:
|
||||
- Scale-level Data Augmentation (SSE):SSE是第一个从maximize pyramid layer 表征的角度来控制数据集中gt的尺度分布,而不是intuitive的假想检测器的学习能力,因此会在不同场景下都很鲁棒。
|
||||
- Adaptive Online Anchor Mining Strategy(Ali-AMS):减少对超参的依赖, 简单且有效的adpative label assign 方法。
|
||||
- Hierarchical Context-aware Module (HCAM): 减少误检是real world人脸检测器面对的最大挑战,HCAM是最近几年第一次在算法侧给出solid solution。
|
||||
|
||||
## 快速使用
|
||||
|
||||
在这个界面中,我们提供几个有关`推理/评测`脚本帮助大家迅速/一键使用MogFace, 代码范例中的实例均集成在test_mog_face_detection.py
|
||||
- `Usage`:
|
||||
```python
|
||||
PYTHONPATH=. python face_project/face_detection/MogFace/test_mog_face_detection.py
|
||||
```
|
||||
|
||||
## 代码范例
|
||||
|
||||
### 单图片推理
|
||||
```python
|
||||
import cv2
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
face_detection = pipeline(task=Tasks.face_detection, model='damo/cv_resnet101_face-detection_cvpr22papermogface')
|
||||
# 支持 url image and abs dir image path
|
||||
img_path = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_detection2.jpeg'
|
||||
result = face_detection(img_path)
|
||||
|
||||
# 提供可视化结果
|
||||
from modelscope.utils.cv.image_utils import draw_face_detection_result
|
||||
from modelscope.preprocessors.image import LoadImage
|
||||
img = LoadImage.convert_to_ndarray(img_path)
|
||||
cv2.imwrite('srcImg.jpg', img)
|
||||
img_draw = draw_face_detection_result('srcImg.jpg', result)
|
||||
import matplotlib.pyplot as plt
|
||||
plt.imshow(img_draw)
|
||||
```
|
||||
|
||||
### 多图片推理和评测
|
||||
- 我们提供了100张测试图片,可运行下面代码一键使用(下载数据集+推理);
|
||||
- 也支持测试自建数据集,需要按如下格式建立数据集:
|
||||
```
|
||||
img_base_path/
|
||||
val_data/
|
||||
test_1.jpg
|
||||
...
|
||||
test_N.jpg
|
||||
val_label.txt
|
||||
## val_label.txt format
|
||||
test_1.jpg
|
||||
x0 x1 w h
|
||||
x0 x1 w h
|
||||
...
|
||||
test_N.jpg
|
||||
x0 x1 w h
|
||||
x0 x1 w h
|
||||
...
|
||||
```
|
||||
|
||||
```python
|
||||
import os.path as osp
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import voc_ap, image_eval,img_pr_info, gen_gt_info, dataset_pr_info, bbox_overlap
|
||||
|
||||
model_id = 'damo/cv_resnet101_face-detection_cvpr22papermogface'
|
||||
val_set = MsDataset.load('widerface_mini_train_val', namespace='ly261666', split='validation')#, download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
img_base_path = next(iter(val_set))[1]
|
||||
img_dir = osp.join(img_base_path, 'val_data')
|
||||
img_gt = osp.join(img_base_path, 'val_label.txt')
|
||||
gt_info = gen_gt_info(img_gt)
|
||||
pred_info = {}
|
||||
iou_th = 0.5
|
||||
thresh_num = 1000
|
||||
face_detection_func = pipeline(Tasks.face_detection, model=self.model_id, conf_th=0.01)
|
||||
count_face = 0
|
||||
pr_curve = np.zeros((thresh_num, 2)).astype('float')
|
||||
for img_name in os.listdir(img_dir):
|
||||
print ('inference img: {} {}/{}'.format(img_name, idx+1, len(os.listdir(img_dir))))
|
||||
abs_img_name = osp.join(img_dir, img_name)
|
||||
result = face_detection_func(abs_img_name)
|
||||
pred_info = np.concatenate([result['boxes'], np.array(result['scores'])[:,np.newaxis]], axis=1)
|
||||
gt_box = np.array(gt_info[img_name])
|
||||
pred_recall, proposal_list = image_eval(pred_info, gt_box, iou_th)
|
||||
_img_pr_info, fp = img_pr_info(thresh_num, pred_info, proposal_list, pred_recall)
|
||||
pr_curve += _img_pr_info
|
||||
count_face += gt_box.shape[0]
|
||||
|
||||
pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face)
|
||||
propose = pr_curve[:, 0]
|
||||
recall = pr_curve[:, 1]
|
||||
for srecall in np.arange(0.1, 1.0001, 0.1):
|
||||
rindex = len(np.where(recall<=srecall)[0])-1
|
||||
rthresh = 1.0 - float(rindex)/thresh_num
|
||||
print('Recall-Precision-Thresh:', recall[rindex], propose[rindex], rthresh)
|
||||
ap = voc_ap(recall, propose)
|
||||
print('ap: %.5f, iou_th: %.2f'%(ap, iou_th))
|
||||
```
|
||||
Result:
|
||||
```
|
||||
Recall-Precision-Thresh: 0.09928514694201747 1.0 0.914
|
||||
Recall-Precision-Thresh: 0.19962933545141648 0.9986754966887417 0.841
|
||||
Recall-Precision-Thresh: 0.29864972200158857 0.9964664310954063 0.749
|
||||
Recall-Precision-Thresh: 0.39899391051098754 0.9947194719471947 0.6619999999999999
|
||||
Recall-Precision-Thresh: 0.4996028594122319 0.9823008849557522 0.565
|
||||
Recall-Precision-Thresh: 0.598623245962404 0.9548141891891891 0.471
|
||||
Recall-Precision-Thresh: 0.6997617156473391 0.9091847265221878 0.384
|
||||
Recall-Precision-Thresh: 0.7995763833730474 0.8055481461723126 0.274
|
||||
Recall-Precision-Thresh: 0.8988615303150649 0.05734797297297297 0.0010000000000000009
|
||||
Recall-Precision-Thresh: 0.8988615303150649 0.05734797297297297 0.0010000000000000009
|
||||
ap: 0.83243, iou_th: 0.50
|
||||
```
|
||||
|
||||
## 模型精度
|
||||

|
||||
|
||||
## 来源说明
|
||||
本模型及代码来自达摩院自研技术
|
||||
|
||||
## 引用
|
||||
如果你觉得这个该模型对有所帮助,请考虑引用下面的相关的论文:
|
||||
|
||||
```BibTeX
|
||||
@inproceedings{liu2022mogface,
|
||||
title={MogFace: Towards a Deeper Appreciation on Face Detection},
|
||||
author={Liu, Yang and Wang, Fei and Deng, Jiankang and Zhou, Zhipeng and Sun, Baigui and Li, Hao},
|
||||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
||||
pages={4093--4102},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
|
||||
BIN
face_project/face_detection/MogFace/demo/MogFace.jpg
Normal file
|
After Width: | Height: | Size: 135 KiB |
BIN
face_project/face_detection/MogFace/demo/MogFace_result.jpg
Normal file
|
After Width: | Height: | Size: 341 KiB |
@@ -0,0 +1,73 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import draw_face_detection_no_lm_result
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
from modelscope.utils.constant import DownloadMode
|
||||
from modelscope.utils.cv.image_utils import voc_ap, image_eval,img_pr_info, gen_gt_info, dataset_pr_info, bbox_overlap
|
||||
|
||||
class MogFaceDetectionTest(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'damo/cv_resnet101_face-detection_cvpr22papermogface'
|
||||
self.img_path = 'data/test/images/mog_face_detection.jpg'
|
||||
|
||||
def show_result(self, img_path, detection_result):
|
||||
img = draw_face_detection_no_lm_result(img_path, detection_result)
|
||||
cv2.imwrite('result.png', img)
|
||||
print(f'output written to {osp.abspath("result.png")}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
face_detection = pipeline(Tasks.face_detection, model=self.model_id)
|
||||
|
||||
result = face_detection(self.img_path)
|
||||
self.show_result(self.img_path, result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_dataset(self):
|
||||
val_set = MsDataset.load('widerface_mini_train_val', namespace='ly261666', split='validation')#, download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
img_base_path = next(iter(val_set))[1]
|
||||
img_dir = osp.join(img_base_path, 'val_data')
|
||||
img_gt = osp.join(img_base_path, 'val_label.txt')
|
||||
gt_info = gen_gt_info(img_gt)
|
||||
pred_info = {}
|
||||
iou_th = 0.5
|
||||
thresh_num = 1000
|
||||
face_detection_func = pipeline(Tasks.face_detection, model=self.model_id, conf_th=0.01)
|
||||
count_face = 0
|
||||
pr_curve = np.zeros((thresh_num, 2)).astype('float')
|
||||
for idx, img_name in enumerate(os.listdir(img_dir)):
|
||||
print ('inference img: {} {}/{}'.format(img_name, idx+1, len(os.listdir(img_dir))))
|
||||
abs_img_name = osp.join(img_dir, img_name)
|
||||
result = face_detection_func(abs_img_name)
|
||||
pred_info = np.concatenate([result['boxes'], np.array(result['scores'])[:,np.newaxis]], axis=1)
|
||||
gt_box = np.array(gt_info[img_name])
|
||||
pred_recall, proposal_list = image_eval(pred_info, gt_box, iou_th)
|
||||
_img_pr_info, fp = img_pr_info(thresh_num, pred_info, proposal_list, pred_recall)
|
||||
pr_curve += _img_pr_info
|
||||
count_face += gt_box.shape[0]
|
||||
|
||||
pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face)
|
||||
propose = pr_curve[:, 0]
|
||||
recall = pr_curve[:, 1]
|
||||
for srecall in np.arange(0.1, 1.0001, 0.1):
|
||||
rindex = len(np.where(recall<=srecall)[0])-1
|
||||
rthresh = 1.0 - float(rindex)/thresh_num
|
||||
print('Recall-Precision-Thresh:', recall[rindex], propose[rindex], rthresh)
|
||||
ap = voc_ap(recall, propose)
|
||||
print('ap: %.5f, iou_th: %.2f'%(ap, iou_th))
|
||||
self.show_result(abs_img_name, result)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
143
face_project/face_detection/Mtcnn/README.md
Normal file
@@ -0,0 +1,143 @@
|
||||
<div align="center">
|
||||
<img src="demo/Mtcnn.jpg" width="100%" height="10%" />
|
||||
</div>
|
||||
<h4 align="center">
|
||||
<a href=#Mtcnn模型介绍> 模型介绍 </a> |
|
||||
<a href=#快速使用> 快速使用 </a> |
|
||||
<a href=#单图片推理> 单图片推理 </a> |
|
||||
<a href=#多图片推理和评测> 多图片推理/评测 </a>
|
||||
</h4>
|
||||
|
||||
# Mtcnn模型介绍
|
||||
MTCNN是工业界广泛应用的检测关键点二合一模型, ([论文地址](https://arxiv.org/abs/1604.02878), [代码地址](https://github.com/TropComplique/mtcnn-pytorch)),该方法包含下面4个模块:
|
||||
- Image Pyramid: 首先将图像进行不同尺度的变换,构建图像金字塔,以适应不同大小的人脸的进行检测;
|
||||
- Proposal Network: 其基本的构造是一个全卷积网络。对上一步构建完成的图像金字塔,通过一个FCN进行初步特征提取与标定边框,并进行Bounding-Box Regression调整窗口与NMS进行大部分窗口的过滤。
|
||||
- Refine Network: 其基本的构造是一个卷积神经网络,相对于第一层的P-Net来说,增加了一个全连接层,因此对于输入数据的筛选会更加严格。在图片经过P-Net后,会留下许多预测窗口,我们将所有的预测窗口送入R-Net,这个网络会滤除大量效果比较差的候选框,最后对选定的候选框进行Bounding-Box Regression和NMS进一步优化预测结果;
|
||||
- Output Network: 基本结构是一个较为复杂的卷积神经网络,相对于R-Net来说多了一个卷积层。O-Net的效果与R-Net的区别在于这一层结构会通过更多的监督来识别面部的区域,而且会对人的面部特征点进行回归,最终输出五个人脸面部特征点。
|
||||
|
||||
## 快速使用
|
||||
|
||||
在这个界面中,我们提供几个有关`推理/评测`脚本帮助大家迅速/一键使用Mtcnn, 代码范例中的实例均集成在test_mtcnn_face_detection.py
|
||||
- `Usage`:
|
||||
```python
|
||||
PYTHONPATH=. python face_project/face_detection/Mtcnn/test_mtcnn_face_detection.py
|
||||
```
|
||||
|
||||
## 代码范例
|
||||
|
||||
### 单图片推理
|
||||
```python
|
||||
import cv2
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
model_id = 'damo/cv_manual_face-detection_mtcnn'
|
||||
face_detection = pipeline(task=Tasks.face_detection, model=model_id)
|
||||
# 支持 url image and abs dir image path
|
||||
img_path = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_detection2.jpeg'
|
||||
result = face_detection(img_path)
|
||||
|
||||
# 提供可视化结果
|
||||
from modelscope.utils.cv.image_utils import draw_face_detection_result
|
||||
from modelscope.preprocessors.image import LoadImage
|
||||
img = LoadImage.convert_to_ndarray(img_path)
|
||||
cv2.imwrite('srcImg.jpg', img)
|
||||
img_draw = draw_face_detection_result('srcImg.jpg', result)
|
||||
import matplotlib.pyplot as plt
|
||||
plt.imshow(img_draw)
|
||||
```
|
||||
|
||||
### 多图片推理和评测
|
||||
- 我们提供了100张测试图片,可运行下面代码一键使用(下载数据集+推理);
|
||||
- 也支持测试自建数据集,需要按如下格式建立数据集:
|
||||
```
|
||||
img_base_path/
|
||||
val_data/
|
||||
test_1.jpg
|
||||
...
|
||||
test_N.jpg
|
||||
val_label.txt
|
||||
## val_label.txt format
|
||||
test_1.jpg
|
||||
x0 x1 w h
|
||||
x0 x1 w h
|
||||
...
|
||||
test_N.jpg
|
||||
x0 x1 w h
|
||||
x0 x1 w h
|
||||
...
|
||||
```
|
||||
|
||||
```python
|
||||
import os.path as osp
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import voc_ap, image_eval,img_pr_info, gen_gt_info, dataset_pr_info, bbox_overlap
|
||||
|
||||
model_id = 'damo/cv_manual_face-detection_mtcnn'
|
||||
val_set = MsDataset.load('widerface_mini_train_val', namespace='ly261666', split='validation')#, download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
img_base_path = next(iter(val_set))[1]
|
||||
img_dir = osp.join(img_base_path, 'val_data')
|
||||
img_gt = osp.join(img_base_path, 'val_label.txt')
|
||||
gt_info = gen_gt_info(img_gt)
|
||||
pred_info = {}
|
||||
iou_th = 0.5
|
||||
thresh_num = 1000
|
||||
face_detection_func = pipeline(Tasks.face_detection, model=model_id, conf_th=0.7)
|
||||
count_face = 0
|
||||
pr_curve = np.zeros((thresh_num, 2)).astype('float')
|
||||
for idx, img_name in enumerate(os.listdir(img_dir)):
|
||||
print ('inference img: {} {}/{}'.format(img_name, idx+1, len(os.listdir(img_dir))))
|
||||
abs_img_name = osp.join(img_dir, img_name)
|
||||
result = face_detection_func(abs_img_name)
|
||||
pred_info = np.concatenate([result['boxes'], np.array(result['scores'])[:,np.newaxis]], axis=1)
|
||||
gt_box = np.array(gt_info[img_name])
|
||||
pred_recall, proposal_list = image_eval(pred_info, gt_box, iou_th)
|
||||
_img_pr_info, fp = img_pr_info(thresh_num, pred_info, proposal_list, pred_recall)
|
||||
pr_curve += _img_pr_info
|
||||
count_face += gt_box.shape[0]
|
||||
|
||||
pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face)
|
||||
propose = pr_curve[:, 0]
|
||||
recall = pr_curve[:, 1]
|
||||
for srecall in np.arange(0.1, 1.0001, 0.1):
|
||||
rindex = len(np.where(recall<=srecall)[0])-1
|
||||
rthresh = 1.0 - float(rindex)/thresh_num
|
||||
print('Recall-Precision-Thresh:', recall[rindex], propose[rindex], rthresh)
|
||||
ap = voc_ap(recall, propose)
|
||||
print('ap: %.5f, iou_th: %.2f'%(ap, iou_th))
|
||||
```
|
||||
Result:
|
||||
```
|
||||
Recall-Precision-Thresh: 0.4945724119671697 0.013132364106746154 1.001
|
||||
Recall-Precision-Thresh: 0.19909981466772572 0.9791666666666666 0.997
|
||||
Recall-Precision-Thresh: 0.2997087635689701 0.827485380116959 0.95
|
||||
Recall-Precision-Thresh: 0.3995234312946783 0.26216122307157746 0.6579999999999999
|
||||
Recall-Precision-Thresh: 0.4945724119671697 0.013132364106746154 0.0010000000000000009
|
||||
Recall-Precision-Thresh: 0.4945724119671697 0.013132364106746154 0.0010000000000000009
|
||||
Recall-Precision-Thresh: 0.4945724119671697 0.013132364106746154 0.0010000000000000009
|
||||
Recall-Precision-Thresh: 0.4945724119671697 0.013132364106746154 0.0010000000000000009
|
||||
Recall-Precision-Thresh: 0.4945724119671697 0.013132364106746154 0.0010000000000000009
|
||||
Recall-Precision-Thresh: 0.4945724119671697 0.013132364106746154 0.0010000000000000009
|
||||
ap: 0.35710, iou_th: 0.50
|
||||
```
|
||||
|
||||
|
||||
## 引用
|
||||
如果你觉得这个该模型对有所帮助,请考虑引用下面的相关的论文:
|
||||
|
||||
```BibTeX
|
||||
@inproceedings{xiang2017joint,
|
||||
title={Joint face detection and facial expression recognition with MTCNN},
|
||||
author={Xiang, Jia and Zhu, Gengming},
|
||||
booktitle={2017 4th international conference on information science and control engineering (ICISCE)},
|
||||
pages={424--427},
|
||||
year={2017},
|
||||
organization={IEEE}
|
||||
}
|
||||
```
|
||||
|
||||
BIN
face_project/face_detection/Mtcnn/demo/Mtcnn.jpg
Normal file
|
After Width: | Height: | Size: 103 KiB |
@@ -0,0 +1,73 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import draw_face_detection_no_lm_result
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
from modelscope.utils.constant import DownloadMode
|
||||
from modelscope.utils.cv.image_utils import voc_ap, image_eval,img_pr_info, gen_gt_info, dataset_pr_info, bbox_overlap
|
||||
|
||||
class RetinaFaceDetectionTest(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'damo/cv_manual_face-detection_mtcnn'
|
||||
self.img_path = 'data/test/images/mog_face_detection.jpg'
|
||||
|
||||
def show_result(self, img_path, detection_result):
|
||||
img = draw_face_detection_no_lm_result(img_path, detection_result)
|
||||
cv2.imwrite('result.png', img)
|
||||
print(f'output written to {osp.abspath("result.png")}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
face_detection = pipeline(Tasks.face_detection, model=self.model_id)
|
||||
|
||||
result = face_detection(self.img_path)
|
||||
self.show_result(self.img_path, result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_dataset(self):
|
||||
val_set = MsDataset.load('widerface_mini_train_val', namespace='ly261666', split='validation')#, download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
img_base_path = next(iter(val_set))[1]
|
||||
img_dir = osp.join(img_base_path, 'val_data')
|
||||
img_gt = osp.join(img_base_path, 'val_label.txt')
|
||||
gt_info = gen_gt_info(img_gt)
|
||||
pred_info = {}
|
||||
iou_th = 0.5
|
||||
thresh_num = 1000
|
||||
face_detection_func = pipeline(Tasks.face_detection, model=self.model_id, conf_th=0.01)
|
||||
count_face = 0
|
||||
pr_curve = np.zeros((thresh_num, 2)).astype('float')
|
||||
for idx, img_name in enumerate(os.listdir(img_dir)):
|
||||
print ('inference img: {} {}/{}'.format(img_name, idx+1, len(os.listdir(img_dir))))
|
||||
abs_img_name = osp.join(img_dir, img_name)
|
||||
result = face_detection_func(abs_img_name)
|
||||
pred_info = np.concatenate([result['boxes'], np.array(result['scores'])[:,np.newaxis]], axis=1)
|
||||
gt_box = np.array(gt_info[img_name])
|
||||
pred_recall, proposal_list = image_eval(pred_info, gt_box, iou_th)
|
||||
_img_pr_info, fp = img_pr_info(thresh_num, pred_info, proposal_list, pred_recall)
|
||||
pr_curve += _img_pr_info
|
||||
count_face += gt_box.shape[0]
|
||||
|
||||
pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face)
|
||||
propose = pr_curve[:, 0]
|
||||
recall = pr_curve[:, 1]
|
||||
for srecall in np.arange(0.1, 1.0001, 0.1):
|
||||
rindex = len(np.where(recall<=srecall)[0])-1
|
||||
rthresh = 1.0 - float(rindex)/thresh_num
|
||||
print('Recall-Precision-Thresh:', recall[rindex], propose[rindex], rthresh)
|
||||
ap = voc_ap(recall, propose)
|
||||
print('ap: %.5f, iou_th: %.2f'%(ap, iou_th))
|
||||
self.show_result(abs_img_name, result)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
143
face_project/face_detection/RetinaFace/README.md
Normal file
@@ -0,0 +1,143 @@
|
||||
|
||||
<div align="center">
|
||||
<img src="demo/RetinaFace.jpg" width="100%" height="10%" />
|
||||
</div>
|
||||
<h4 align="center">
|
||||
<a href=#RetinaFace模型介绍> 模型介绍 </a> |
|
||||
<a href=#快速使用> 快速使用 </a> |
|
||||
<a href=#单图片推理> 单图片推理 </a> |
|
||||
<a href=#多图片推理和评测> 多图片推理/评测 </a>
|
||||
</h4>
|
||||
|
||||
# RetinaFace模型介绍
|
||||
RetinaFace为当前学术界和工业界精度较高的人脸检测和人脸关键点定位二合一的方法,被CVPR 2020 录取([论文地址](https://arxiv.org/abs/1905.00641), [代码地址](https://github.com/biubug6/Pytorch_Retinaface))),该方法的主要贡献是:
|
||||
- 引入关键点分支,可以在训练阶段引入关键点预测分支进行多任务学习,提供额外的互补特征,inference去掉关键点分支即可,并不会引入额外的计算量。
|
||||
|
||||
## 快速使用
|
||||
|
||||
在这个界面中,我们提供几个有关`推理/评测`脚本帮助大家迅速/一键使用RetinaFace, 代码范例中的实例均集成在test_retina_face_detection.py
|
||||
- `Usage`:
|
||||
```python
|
||||
PYTHONPATH=. python face_project/face_detection/RetinaFace/test_retina_face_detection.py
|
||||
```
|
||||
|
||||
## 代码范例
|
||||
|
||||
### 单图片推理
|
||||
```python
|
||||
import cv2
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
model_id = 'damo/cv_resnet50_face-detection_retinaface'
|
||||
face_detection = pipeline(task=Tasks.face_detection, model=model_id)
|
||||
# 支持 url image and abs dir image path
|
||||
img_path = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_detection2.jpeg'
|
||||
result = face_detection(img_path)
|
||||
|
||||
# 提供可视化结果
|
||||
from modelscope.utils.cv.image_utils import draw_face_detection_result
|
||||
from modelscope.preprocessors.image import LoadImage
|
||||
img = LoadImage.convert_to_ndarray(img_path)
|
||||
cv2.imwrite('srcImg.jpg', img)
|
||||
img_draw = draw_face_detection_result('srcImg.jpg', result)
|
||||
import matplotlib.pyplot as plt
|
||||
plt.imshow(img_draw)
|
||||
```
|
||||
|
||||
### 多图片推理和评测
|
||||
- 我们提供了100张测试图片,可运行下面代码一键使用(下载数据集+推理);
|
||||
- 也支持测试自建数据集,需要按如下格式建立数据集:
|
||||
```
|
||||
img_base_path/
|
||||
val_data/
|
||||
test_1.jpg
|
||||
...
|
||||
test_N.jpg
|
||||
val_label.txt
|
||||
## val_label.txt format
|
||||
test_1.jpg
|
||||
x0 x1 w h
|
||||
x0 x1 w h
|
||||
...
|
||||
test_N.jpg
|
||||
x0 x1 w h
|
||||
x0 x1 w h
|
||||
...
|
||||
```
|
||||
|
||||
```python
|
||||
import os.path as osp
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import voc_ap, image_eval,img_pr_info, gen_gt_info, dataset_pr_info, bbox_overlap
|
||||
|
||||
model_id = 'damo/cv_resnet50_face-detection_retinaface'
|
||||
val_set = MsDataset.load('widerface_mini_train_val', namespace='ly261666', split='validation')#, download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
img_base_path = next(iter(val_set))[1]
|
||||
img_dir = osp.join(img_base_path, 'val_data')
|
||||
img_gt = osp.join(img_base_path, 'val_label.txt')
|
||||
gt_info = gen_gt_info(img_gt)
|
||||
pred_info = {}
|
||||
iou_th = 0.5
|
||||
thresh_num = 1000
|
||||
face_detection_func = pipeline(Tasks.face_detection, model=model_id, conf_th=0.01)
|
||||
count_face = 0
|
||||
pr_curve = np.zeros((thresh_num, 2)).astype('float')
|
||||
for img_name in os.listdir(img_dir):
|
||||
print ('inference img: {} {}/{}'.format(img_name, idx+1, len(os.listdir(img_dir))))
|
||||
abs_img_name = osp.join(img_dir, img_name)
|
||||
result = face_detection_func(abs_img_name)
|
||||
pred_info = np.concatenate([result['boxes'], np.array(result['scores'])[:,np.newaxis]], axis=1)
|
||||
gt_box = np.array(gt_info[img_name])
|
||||
pred_recall, proposal_list = image_eval(pred_info, gt_box, iou_th)
|
||||
_img_pr_info, fp = img_pr_info(thresh_num, pred_info, proposal_list, pred_recall)
|
||||
pr_curve += _img_pr_info
|
||||
count_face += gt_box.shape[0]
|
||||
|
||||
pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face)
|
||||
propose = pr_curve[:, 0]
|
||||
recall = pr_curve[:, 1]
|
||||
for srecall in np.arange(0.1, 1.0001, 0.1):
|
||||
rindex = len(np.where(recall<=srecall)[0])-1
|
||||
rthresh = 1.0 - float(rindex)/thresh_num
|
||||
print('Recall-Precision-Thresh:', recall[rindex], propose[rindex], rthresh)
|
||||
ap = voc_ap(recall, propose)
|
||||
print('ap: %.5f, iou_th: %.2f'%(ap, iou_th))
|
||||
```
|
||||
Result:
|
||||
```
|
||||
Recall-Precision-Thresh: 0.09981466772570824 0.9973544973544973 0.979
|
||||
Recall-Precision-Thresh: 0.19962933545141648 0.989501312335958 0.855
|
||||
Recall-Precision-Thresh: 0.2994440031771247 0.9576629974597799 0.486
|
||||
Recall-Precision-Thresh: 0.3995234312946783 0.7038246268656716 0.11099999999999999
|
||||
Recall-Precision-Thresh: 0.4980142970611596 0.3608286974870516 0.029000000000000026
|
||||
Recall-Precision-Thresh: 0.5837966640190627 0.17127543886903837 0.0010000000000000009
|
||||
Recall-Precision-Thresh: 0.5837966640190627 0.17127543886903837 0.0010000000000000009
|
||||
Recall-Precision-Thresh: 0.5837966640190627 0.17127543886903837 0.0010000000000000009
|
||||
Recall-Precision-Thresh: 0.5837966640190627 0.17127543886903837 0.0010000000000000009
|
||||
Recall-Precision-Thresh: 0.5837966640190627 0.17127543886903837 0.0010000000000000009
|
||||
ap: 0.45492, iou_th: 0.50
|
||||
```
|
||||
|
||||
## 模型精度
|
||||

|
||||
|
||||
|
||||
## 引用
|
||||
如果你觉得这个该模型对有所帮助,请考虑引用下面的相关的论文:
|
||||
|
||||
```BibTeX
|
||||
@inproceedings{deng2020retinaface,
|
||||
title={Retinaface: Single-shot multi-level face localisation in the wild},
|
||||
author={Deng, Jiankang and Guo, Jia and Ververas, Evangelos and Kotsia, Irene and Zafeiriou, Stefanos},
|
||||
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
|
||||
pages={5203--5212},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
|
||||
BIN
face_project/face_detection/RetinaFace/demo/RetinaFace.jpg
Normal file
|
After Width: | Height: | Size: 137 KiB |
@@ -0,0 +1,73 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import draw_face_detection_no_lm_result
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
from modelscope.utils.constant import DownloadMode
|
||||
from modelscope.utils.cv.image_utils import voc_ap, image_eval,img_pr_info, gen_gt_info, dataset_pr_info, bbox_overlap
|
||||
|
||||
class RetinaFaceDetectionTest(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'damo/cv_resnet50_face-detection_retinaface'
|
||||
self.img_path = 'data/test/images/mog_face_detection.jpg'
|
||||
|
||||
def show_result(self, img_path, detection_result):
|
||||
img = draw_face_detection_no_lm_result(img_path, detection_result)
|
||||
cv2.imwrite('result.png', img)
|
||||
print(f'output written to {osp.abspath("result.png")}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
face_detection = pipeline(Tasks.face_detection, model=self.model_id)
|
||||
|
||||
result = face_detection(self.img_path)
|
||||
self.show_result(self.img_path, result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_dataset(self):
|
||||
val_set = MsDataset.load('widerface_mini_train_val', namespace='ly261666', split='validation')#, download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
img_base_path = next(iter(val_set))[1]
|
||||
img_dir = osp.join(img_base_path, 'val_data')
|
||||
img_gt = osp.join(img_base_path, 'val_label.txt')
|
||||
gt_info = gen_gt_info(img_gt)
|
||||
pred_info = {}
|
||||
iou_th = 0.5
|
||||
thresh_num = 1000
|
||||
face_detection_func = pipeline(Tasks.face_detection, model=self.model_id, conf_th=0.7)
|
||||
count_face = 0
|
||||
pr_curve = np.zeros((thresh_num, 2)).astype('float')
|
||||
for idx, img_name in enumerate(os.listdir(img_dir)):
|
||||
print ('inference img: {} {}/{}'.format(img_name, idx+1, len(os.listdir(img_dir))))
|
||||
abs_img_name = osp.join(img_dir, img_name)
|
||||
result = face_detection_func(abs_img_name)
|
||||
pred_info = np.concatenate([result['boxes'], np.array(result['scores'])[:,np.newaxis]], axis=1)
|
||||
gt_box = np.array(gt_info[img_name])
|
||||
pred_recall, proposal_list = image_eval(pred_info, gt_box, iou_th)
|
||||
_img_pr_info, fp = img_pr_info(thresh_num, pred_info, proposal_list, pred_recall)
|
||||
pr_curve += _img_pr_info
|
||||
count_face += gt_box.shape[0]
|
||||
|
||||
pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face)
|
||||
propose = pr_curve[:, 0]
|
||||
recall = pr_curve[:, 1]
|
||||
for srecall in np.arange(0.1, 1.0001, 0.1):
|
||||
rindex = len(np.where(recall<=srecall)[0])-1
|
||||
rthresh = 1.0 - float(rindex)/thresh_num
|
||||
print('Recall-Precision-Thresh:', recall[rindex], propose[rindex], rthresh)
|
||||
ap = voc_ap(recall, propose)
|
||||
print('ap: %.5f, iou_th: %.2f'%(ap, iou_th))
|
||||
self.show_result(abs_img_name, result)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
4
modelscope/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .version import __release_datetime__, __version__
|
||||
|
||||
__all__ = ['__version__', '__release_datetime__']
|
||||
4
modelscope/fileio/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .file import File, LocalStorage
|
||||
from .io import dump, dumps, load
|
||||
324
modelscope/fileio/file.py
Normal file
@@ -0,0 +1,324 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import tempfile
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Generator, Union
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class Storage(metaclass=ABCMeta):
|
||||
"""Abstract class of storage.
|
||||
|
||||
All backends need to implement two apis: ``read()`` and ``read_text()``.
|
||||
``read()`` reads the file as a byte stream and ``read_text()`` reads
|
||||
the file as texts.
|
||||
"""
|
||||
@abstractmethod
|
||||
def read(self, filepath: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def read_text(self, filepath: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def write_text(self,
|
||||
obj: str,
|
||||
filepath: Union[str, Path],
|
||||
encoding: str = 'utf-8') -> None:
|
||||
pass
|
||||
|
||||
|
||||
class LocalStorage(Storage):
|
||||
"""Local hard disk storage"""
|
||||
def read(self, filepath: Union[str, Path]) -> bytes:
|
||||
"""Read data from a given ``filepath`` with 'rb' mode.
|
||||
|
||||
Args:
|
||||
filepath (str or Path): Path to read data.
|
||||
|
||||
Returns:
|
||||
bytes: Expected bytes object.
|
||||
"""
|
||||
with open(filepath, 'rb') as f:
|
||||
content = f.read()
|
||||
return content
|
||||
|
||||
def read_text(self,
|
||||
filepath: Union[str, Path],
|
||||
encoding: str = 'utf-8') -> str:
|
||||
"""Read data from a given ``filepath`` with 'r' mode.
|
||||
|
||||
Args:
|
||||
filepath (str or Path): Path to read data.
|
||||
encoding (str): The encoding format used to open the ``filepath``.
|
||||
Default: 'utf-8'.
|
||||
|
||||
Returns:
|
||||
str: Expected text reading from ``filepath``.
|
||||
"""
|
||||
with open(filepath, 'r', encoding=encoding) as f:
|
||||
value_buf = f.read()
|
||||
return value_buf
|
||||
|
||||
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
||||
"""Write data to a given ``filepath`` with 'wb' mode.
|
||||
|
||||
Note:
|
||||
``write`` will create a directory if the directory of ``filepath``
|
||||
does not exist.
|
||||
|
||||
Args:
|
||||
obj (bytes): Data to be written.
|
||||
filepath (str or Path): Path to write data.
|
||||
"""
|
||||
dirname = os.path.dirname(filepath)
|
||||
if dirname and not os.path.exists(dirname):
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
|
||||
with open(filepath, 'wb') as f:
|
||||
f.write(obj)
|
||||
|
||||
def write_text(self,
|
||||
obj: str,
|
||||
filepath: Union[str, Path],
|
||||
encoding: str = 'utf-8') -> None:
|
||||
"""Write data to a given ``filepath`` with 'w' mode.
|
||||
|
||||
Note:
|
||||
``write_text`` will create a directory if the directory of
|
||||
``filepath`` does not exist.
|
||||
|
||||
Args:
|
||||
obj (str): Data to be written.
|
||||
filepath (str or Path): Path to write data.
|
||||
encoding (str): The encoding format used to open the ``filepath``.
|
||||
Default: 'utf-8'.
|
||||
"""
|
||||
dirname = os.path.dirname(filepath)
|
||||
if dirname and not os.path.exists(dirname):
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
|
||||
with open(filepath, 'w', encoding=encoding) as f:
|
||||
f.write(obj)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def as_local_path(
|
||||
self,
|
||||
filepath: Union[str,
|
||||
Path]) -> Generator[Union[str, Path], None, None]:
|
||||
"""Only for unified API and do nothing."""
|
||||
yield filepath
|
||||
|
||||
|
||||
class HTTPStorage(Storage):
|
||||
"""HTTP and HTTPS storage."""
|
||||
def read(self, url):
|
||||
# TODO @wenmeng.zwm add progress bar if file is too large
|
||||
r = requests.get(url)
|
||||
r.raise_for_status()
|
||||
return r.content
|
||||
|
||||
def read_text(self, url):
|
||||
r = requests.get(url)
|
||||
r.raise_for_status()
|
||||
return r.text
|
||||
|
||||
@contextlib.contextmanager
|
||||
def as_local_path(
|
||||
self, filepath: str) -> Generator[Union[str, Path], None, None]:
|
||||
"""Download a file from ``filepath``.
|
||||
|
||||
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
|
||||
can be called with ``with`` statement, and when exists from the
|
||||
``with`` statement, the temporary path will be released.
|
||||
|
||||
Args:
|
||||
filepath (str): Download a file from ``filepath``.
|
||||
|
||||
Examples:
|
||||
>>> storage = HTTPStorage()
|
||||
>>> # After existing from the ``with`` clause,
|
||||
>>> # the path will be removed
|
||||
>>> with storage.get_local_path('http://path/to/file') as path:
|
||||
... # do something here
|
||||
"""
|
||||
try:
|
||||
f = tempfile.NamedTemporaryFile(delete=False)
|
||||
f.write(self.read(filepath))
|
||||
f.close()
|
||||
yield f.name
|
||||
finally:
|
||||
os.remove(f.name)
|
||||
|
||||
def write(self, obj: bytes, url: Union[str, Path]) -> None:
|
||||
raise NotImplementedError('write is not supported by HTTP Storage')
|
||||
|
||||
def write_text(self,
|
||||
obj: str,
|
||||
url: Union[str, Path],
|
||||
encoding: str = 'utf-8') -> None:
|
||||
raise NotImplementedError(
|
||||
'write_text is not supported by HTTP Storage')
|
||||
|
||||
|
||||
class OSSStorage(Storage):
|
||||
"""OSS storage."""
|
||||
def __init__(self, oss_config_file=None):
|
||||
# read from config file or env var
|
||||
raise NotImplementedError(
|
||||
'OSSStorage.__init__ to be implemented in the future')
|
||||
|
||||
def read(self, filepath):
|
||||
raise NotImplementedError(
|
||||
'OSSStorage.read to be implemented in the future')
|
||||
|
||||
def read_text(self, filepath, encoding='utf-8'):
|
||||
raise NotImplementedError(
|
||||
'OSSStorage.read_text to be implemented in the future')
|
||||
|
||||
@contextlib.contextmanager
|
||||
def as_local_path(
|
||||
self, filepath: str) -> Generator[Union[str, Path], None, None]:
|
||||
"""Download a file from ``filepath``.
|
||||
|
||||
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
|
||||
can be called with ``with`` statement, and when exists from the
|
||||
``with`` statement, the temporary path will be released.
|
||||
|
||||
Args:
|
||||
filepath (str): Download a file from ``filepath``.
|
||||
|
||||
Examples:
|
||||
>>> storage = OSSStorage()
|
||||
>>> # After existing from the ``with`` clause,
|
||||
>>> # the path will be removed
|
||||
>>> with storage.get_local_path('http://path/to/file') as path:
|
||||
... # do something here
|
||||
"""
|
||||
try:
|
||||
f = tempfile.NamedTemporaryFile(delete=False)
|
||||
f.write(self.read(filepath))
|
||||
f.close()
|
||||
yield f.name
|
||||
finally:
|
||||
os.remove(f.name)
|
||||
|
||||
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
||||
raise NotImplementedError(
|
||||
'OSSStorage.write to be implemented in the future')
|
||||
|
||||
def write_text(self,
|
||||
obj: str,
|
||||
filepath: Union[str, Path],
|
||||
encoding: str = 'utf-8') -> None:
|
||||
raise NotImplementedError(
|
||||
'OSSStorage.write_text to be implemented in the future')
|
||||
|
||||
|
||||
G_STORAGES = {}
|
||||
|
||||
|
||||
class File(object):
|
||||
_prefix_to_storage: dict = {
|
||||
'oss': OSSStorage,
|
||||
'http': HTTPStorage,
|
||||
'https': HTTPStorage,
|
||||
'local': LocalStorage,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_storage(uri):
|
||||
assert isinstance(uri,
|
||||
str), f'uri should be str type, but got {type(uri)}'
|
||||
|
||||
if '://' not in uri:
|
||||
# local path
|
||||
storage_type = 'local'
|
||||
else:
|
||||
prefix, _ = uri.split('://')
|
||||
storage_type = prefix
|
||||
|
||||
assert storage_type in File._prefix_to_storage, \
|
||||
f'Unsupported uri {uri}, valid prefixs: '\
|
||||
f'{list(File._prefix_to_storage.keys())}'
|
||||
|
||||
if storage_type not in G_STORAGES:
|
||||
G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
|
||||
|
||||
return G_STORAGES[storage_type]
|
||||
|
||||
@staticmethod
|
||||
def read(uri: str) -> bytes:
|
||||
"""Read data from a given ``filepath`` with 'rb' mode.
|
||||
|
||||
Args:
|
||||
filepath (str or Path): Path to read data.
|
||||
|
||||
Returns:
|
||||
bytes: Expected bytes object.
|
||||
"""
|
||||
storage = File._get_storage(uri)
|
||||
return storage.read(uri)
|
||||
|
||||
@staticmethod
|
||||
def read_text(uri: Union[str, Path], encoding: str = 'utf-8') -> str:
|
||||
"""Read data from a given ``filepath`` with 'r' mode.
|
||||
|
||||
Args:
|
||||
filepath (str or Path): Path to read data.
|
||||
encoding (str): The encoding format used to open the ``filepath``.
|
||||
Default: 'utf-8'.
|
||||
|
||||
Returns:
|
||||
str: Expected text reading from ``filepath``.
|
||||
"""
|
||||
storage = File._get_storage(uri)
|
||||
return storage.read_text(uri)
|
||||
|
||||
@staticmethod
|
||||
def write(obj: bytes, uri: Union[str, Path]) -> None:
|
||||
"""Write data to a given ``filepath`` with 'wb' mode.
|
||||
|
||||
Note:
|
||||
``write`` will create a directory if the directory of ``filepath``
|
||||
does not exist.
|
||||
|
||||
Args:
|
||||
obj (bytes): Data to be written.
|
||||
filepath (str or Path): Path to write data.
|
||||
"""
|
||||
storage = File._get_storage(uri)
|
||||
return storage.write(obj, uri)
|
||||
|
||||
@staticmethod
|
||||
def write_text(obj: str, uri: str, encoding: str = 'utf-8') -> None:
|
||||
"""Write data to a given ``filepath`` with 'w' mode.
|
||||
|
||||
Note:
|
||||
``write_text`` will create a directory if the directory of
|
||||
``filepath`` does not exist.
|
||||
|
||||
Args:
|
||||
obj (str): Data to be written.
|
||||
filepath (str or Path): Path to write data.
|
||||
encoding (str): The encoding format used to open the ``filepath``.
|
||||
Default: 'utf-8'.
|
||||
"""
|
||||
storage = File._get_storage(uri)
|
||||
return storage.write_text(obj, uri)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
|
||||
"""Only for unified API and do nothing."""
|
||||
storage = File._get_storage(uri)
|
||||
with storage.as_local_path(uri) as local_path:
|
||||
yield local_path
|
||||
5
modelscope/fileio/format/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .base import FormatHandler
|
||||
from .json import JsonHandler
|
||||
from .yaml import YamlHandler
|
||||
20
modelscope/fileio/format/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class FormatHandler(metaclass=ABCMeta):
|
||||
# if `text_format` is True, file
|
||||
# should use text mode otherwise binary mode
|
||||
text_mode = True
|
||||
|
||||
@abstractmethod
|
||||
def load(self, file, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def dump(self, obj, file, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def dumps(self, obj, **kwargs):
|
||||
pass
|
||||
35
modelscope/fileio/format/json.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numpy as np
|
||||
|
||||
from .base import FormatHandler
|
||||
|
||||
|
||||
def set_default(obj):
|
||||
"""Set default json values for non-serializable values.
|
||||
|
||||
It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
|
||||
It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
|
||||
etc.) into plain numbers of plain python built-in types.
|
||||
"""
|
||||
if isinstance(obj, (set, range)):
|
||||
return list(obj)
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
elif isinstance(obj, np.generic):
|
||||
return obj.item()
|
||||
raise TypeError(f'{type(obj)} is unsupported for json dump')
|
||||
|
||||
|
||||
class JsonHandler(FormatHandler):
|
||||
"""Use jsonplus, serialization of Python types to JSON that "just works"."""
|
||||
def load(self, file):
|
||||
import jsonplus
|
||||
return jsonplus.loads(file.read())
|
||||
|
||||
def dump(self, obj, file, **kwargs):
|
||||
file.write(self.dumps(obj, **kwargs))
|
||||
|
||||
def dumps(self, obj, **kwargs):
|
||||
import jsonplus
|
||||
kwargs.setdefault('default', set_default)
|
||||
return jsonplus.dumps(obj, **kwargs)
|
||||
24
modelscope/fileio/format/yaml.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import yaml
|
||||
|
||||
try:
|
||||
from yaml import CDumper as Dumper
|
||||
from yaml import CLoader as Loader
|
||||
except ImportError:
|
||||
from yaml import Loader, Dumper # type: ignore
|
||||
|
||||
from .base import FormatHandler # isort:skip
|
||||
|
||||
|
||||
class YamlHandler(FormatHandler):
|
||||
def load(self, file, **kwargs):
|
||||
kwargs.setdefault('Loader', Loader)
|
||||
return yaml.load(file, **kwargs)
|
||||
|
||||
def dump(self, obj, file, **kwargs):
|
||||
kwargs.setdefault('Dumper', Dumper)
|
||||
yaml.dump(obj, file, **kwargs)
|
||||
|
||||
def dumps(self, obj, **kwargs):
|
||||
kwargs.setdefault('Dumper', Dumper)
|
||||
return yaml.dump(obj, **kwargs)
|
||||
127
modelscope/fileio/io.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from io import BytesIO, StringIO
|
||||
from pathlib import Path
|
||||
|
||||
from .file import File
|
||||
from .format import JsonHandler, YamlHandler
|
||||
|
||||
format_handlers = {
|
||||
'json': JsonHandler(),
|
||||
'yaml': YamlHandler(),
|
||||
'yml': YamlHandler(),
|
||||
}
|
||||
|
||||
|
||||
def load(file, file_format=None, **kwargs):
|
||||
"""Load data from json/yaml/pickle files.
|
||||
|
||||
This method provides a unified api for loading data from serialized files.
|
||||
|
||||
Args:
|
||||
file (str or :obj:`Path` or file-like object): Filename or a file-like
|
||||
object.
|
||||
file_format (str, optional): If not specified, the file format will be
|
||||
inferred from the file extension, otherwise use the specified one.
|
||||
Currently supported formats include "json", "yaml/yml".
|
||||
|
||||
Examples:
|
||||
>>> load('/path/of/your/file') # file is stored in disk
|
||||
>>> load('https://path/of/your/file') # file is stored on internet
|
||||
>>> load('oss://path/of/your/file') # file is stored in petrel
|
||||
|
||||
Returns:
|
||||
The content from the file.
|
||||
"""
|
||||
if isinstance(file, Path):
|
||||
file = str(file)
|
||||
if file_format is None and isinstance(file, str):
|
||||
file_format = file.split('.')[-1]
|
||||
if file_format not in format_handlers:
|
||||
raise TypeError(f'Unsupported format: {file_format}')
|
||||
|
||||
handler = format_handlers[file_format]
|
||||
if isinstance(file, str):
|
||||
if handler.text_mode:
|
||||
with StringIO(File.read_text(file)) as f:
|
||||
obj = handler.load(f, **kwargs)
|
||||
else:
|
||||
with BytesIO(File.read(file)) as f:
|
||||
obj = handler.load(f, **kwargs)
|
||||
elif hasattr(file, 'read'):
|
||||
obj = handler.load(file, **kwargs)
|
||||
else:
|
||||
raise TypeError('"file" must be a filepath str or a file-object')
|
||||
return obj
|
||||
|
||||
|
||||
def dump(obj, file=None, file_format=None, **kwargs):
|
||||
"""Dump data to json/yaml strings or files.
|
||||
|
||||
This method provides a unified api for dumping data as strings or to files.
|
||||
|
||||
Args:
|
||||
obj (any): The python object to be dumped.
|
||||
file (str or :obj:`Path` or file-like object, optional): If not
|
||||
specified, then the object is dumped to a str, otherwise to a file
|
||||
specified by the filename or file-like object.
|
||||
file_format (str, optional): Same as :func:`load`.
|
||||
|
||||
Examples:
|
||||
>>> dump('hello world', '/path/of/your/file') # disk
|
||||
>>> dump('hello world', 'oss://path/of/your/file') # oss
|
||||
|
||||
Returns:
|
||||
bool: True for success, False otherwise.
|
||||
"""
|
||||
if isinstance(file, Path):
|
||||
file = str(file)
|
||||
if file_format is None:
|
||||
if isinstance(file, str):
|
||||
file_format = file.split('.')[-1]
|
||||
elif file is None:
|
||||
raise ValueError(
|
||||
'file_format must be specified since file is None')
|
||||
if file_format not in format_handlers:
|
||||
raise TypeError(f'Unsupported format: {file_format}')
|
||||
|
||||
handler = format_handlers[file_format]
|
||||
if file is None:
|
||||
return handler.dump_to_str(obj, **kwargs)
|
||||
elif isinstance(file, str):
|
||||
if handler.text_mode:
|
||||
with StringIO() as f:
|
||||
handler.dump(obj, f, **kwargs)
|
||||
File.write_text(f.getvalue(), file)
|
||||
else:
|
||||
with BytesIO() as f:
|
||||
handler.dump(obj, f, **kwargs)
|
||||
File.write(f.getvalue(), file)
|
||||
elif hasattr(file, 'write'):
|
||||
handler.dump(obj, file, **kwargs)
|
||||
else:
|
||||
raise TypeError('"file" must be a filename str or a file-object')
|
||||
|
||||
|
||||
def dumps(obj, format, **kwargs):
|
||||
"""Dump data to json/yaml strings or files.
|
||||
|
||||
This method provides a unified api for dumping data as strings or to files.
|
||||
|
||||
Args:
|
||||
obj (any): The python object to be dumped.
|
||||
format (str, optional): Same as file_format :func:`load`.
|
||||
|
||||
Examples:
|
||||
>>> dumps('hello world', 'json') # json
|
||||
>>> dumps('hello world', 'yaml') # yaml
|
||||
|
||||
Returns:
|
||||
bool: True for success, False otherwise.
|
||||
"""
|
||||
if format not in format_handlers:
|
||||
raise TypeError(f'Unsupported format: {format}')
|
||||
|
||||
handler = format_handlers[format]
|
||||
return handler.dumps(obj, **kwargs)
|
||||
0
modelscope/hub/__init__.py
Normal file
906
modelscope/hub/api.py
Normal file
@@ -0,0 +1,906 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# yapf: disable
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import os
|
||||
import pickle
|
||||
import platform
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
from http.cookiejar import CookieJar
|
||||
from os.path import expanduser
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from requests import Session
|
||||
from requests.adapters import HTTPAdapter, Retry
|
||||
|
||||
from modelscope import __version__
|
||||
from modelscope.hub.constants import (API_HTTP_CLIENT_TIMEOUT,
|
||||
API_RESPONSE_FIELD_DATA,
|
||||
API_RESPONSE_FIELD_EMAIL,
|
||||
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN,
|
||||
API_RESPONSE_FIELD_MESSAGE,
|
||||
API_RESPONSE_FIELD_USERNAME,
|
||||
DEFAULT_CREDENTIALS_PATH,
|
||||
MODELSCOPE_CLOUD_ENVIRONMENT,
|
||||
MODELSCOPE_CLOUD_USERNAME,
|
||||
ONE_YEAR_SECONDS,
|
||||
REQUESTS_API_HTTP_METHOD, Licenses,
|
||||
ModelVisibility)
|
||||
from modelscope.hub.errors import (InvalidParameter, NotExistError,
|
||||
NotLoginException, NoValidRevisionError,
|
||||
RequestError, datahub_raise_on_error,
|
||||
handle_http_post_error,
|
||||
handle_http_response, is_ok,
|
||||
raise_for_http_status, raise_on_error)
|
||||
from modelscope.hub.git import GitCommandWrapper
|
||||
from modelscope.hub.repository import Repository
|
||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
DEFAULT_MODEL_REVISION,
|
||||
DEFAULT_REPOSITORY_REVISION,
|
||||
MASTER_MODEL_BRANCH, DatasetFormations,
|
||||
DatasetMetaFormats,
|
||||
DatasetVisibilityMap, DownloadChannel,
|
||||
ModelFile)
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
from .utils.utils import (get_endpoint, get_release_datetime,
|
||||
model_id_to_group_owner_name)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class HubApi:
|
||||
"""Model hub api interface.
|
||||
"""
|
||||
def __init__(self, endpoint: Optional[str] = None):
|
||||
"""The ModelScope HubApi。
|
||||
|
||||
Args:
|
||||
endpoint (str, optional): The modelscope server http|https address. Defaults to None.
|
||||
"""
|
||||
self.endpoint = endpoint if endpoint is not None else get_endpoint()
|
||||
self.headers = {'user-agent': ModelScopeConfig.get_user_agent()}
|
||||
self.session = Session()
|
||||
retry = Retry(
|
||||
total=2,
|
||||
read=2,
|
||||
connect=2,
|
||||
backoff_factor=1,
|
||||
status_forcelist=(500, 502, 503, 504),
|
||||
)
|
||||
adapter = HTTPAdapter(max_retries=retry)
|
||||
self.session.mount('http://', adapter)
|
||||
self.session.mount('https://', adapter)
|
||||
# set http timeout
|
||||
for method in REQUESTS_API_HTTP_METHOD:
|
||||
setattr(
|
||||
self.session, method,
|
||||
functools.partial(
|
||||
getattr(self.session, method),
|
||||
timeout=API_HTTP_CLIENT_TIMEOUT))
|
||||
|
||||
def login(
|
||||
self,
|
||||
access_token: str,
|
||||
) -> tuple():
|
||||
"""Login with your SDK access token, which can be obtained from
|
||||
https://www.modelscope.cn user center.
|
||||
|
||||
Args:
|
||||
access_token (str): user access token on modelscope.
|
||||
|
||||
Returns:
|
||||
cookies: to authenticate yourself to ModelScope open-api
|
||||
git_token: token to access your git repository.
|
||||
|
||||
Note:
|
||||
You only have to login once within 30 days.
|
||||
"""
|
||||
path = f'{self.endpoint}/api/v1/login'
|
||||
r = self.session.post(
|
||||
path, json={'AccessToken': access_token}, headers=self.headers)
|
||||
raise_for_http_status(r)
|
||||
d = r.json()
|
||||
raise_on_error(d)
|
||||
|
||||
token = d[API_RESPONSE_FIELD_DATA][API_RESPONSE_FIELD_GIT_ACCESS_TOKEN]
|
||||
cookies = r.cookies
|
||||
|
||||
# save token and cookie
|
||||
ModelScopeConfig.save_token(token)
|
||||
ModelScopeConfig.save_cookies(cookies)
|
||||
ModelScopeConfig.save_user_info(
|
||||
d[API_RESPONSE_FIELD_DATA][API_RESPONSE_FIELD_USERNAME],
|
||||
d[API_RESPONSE_FIELD_DATA][API_RESPONSE_FIELD_EMAIL])
|
||||
|
||||
return d[API_RESPONSE_FIELD_DATA][
|
||||
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN], cookies
|
||||
|
||||
def create_model(self,
|
||||
model_id: str,
|
||||
visibility: Optional[int] = ModelVisibility.PUBLIC,
|
||||
license: Optional[str] = Licenses.APACHE_V2,
|
||||
chinese_name: Optional[str] = None) -> str:
|
||||
"""Create model repo at ModelScopeHub.
|
||||
|
||||
Args:
|
||||
model_id (str): The model id
|
||||
visibility (int, optional): visibility of the model(1-private, 5-public), default 5.
|
||||
license (str, optional): license of the model, default none.
|
||||
chinese_name (str, optional): chinese name of the model.
|
||||
|
||||
Returns:
|
||||
Name of the model created
|
||||
|
||||
Raises:
|
||||
InvalidParameter: If model_id is invalid.
|
||||
ValueError: If not login.
|
||||
|
||||
Note:
|
||||
model_id = {owner}/{name}
|
||||
"""
|
||||
if model_id is None:
|
||||
raise InvalidParameter('model_id is required!')
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise ValueError('Token does not exist, please login first.')
|
||||
|
||||
path = f'{self.endpoint}/api/v1/models'
|
||||
owner_or_group, name = model_id_to_group_owner_name(model_id)
|
||||
body = {
|
||||
'Path': owner_or_group,
|
||||
'Name': name,
|
||||
'ChineseName': chinese_name,
|
||||
'Visibility': visibility, # server check
|
||||
'License': license
|
||||
}
|
||||
r = self.session.post(
|
||||
path, json=body, cookies=cookies, headers=self.headers)
|
||||
handle_http_post_error(r, path, body)
|
||||
raise_on_error(r.json())
|
||||
model_repo_url = f'{get_endpoint()}/{model_id}'
|
||||
return model_repo_url
|
||||
|
||||
def delete_model(self, model_id: str):
|
||||
"""Delete model_id from ModelScope.
|
||||
|
||||
Args:
|
||||
model_id (str): The model id.
|
||||
|
||||
Raises:
|
||||
ValueError: If not login.
|
||||
|
||||
Note:
|
||||
model_id = {owner}/{name}
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise ValueError('Token does not exist, please login first.')
|
||||
path = f'{self.endpoint}/api/v1/models/{model_id}'
|
||||
|
||||
r = self.session.delete(path, cookies=cookies, headers=self.headers)
|
||||
raise_for_http_status(r)
|
||||
raise_on_error(r.json())
|
||||
|
||||
def get_model_url(self, model_id: str):
|
||||
return f'{self.endpoint}/api/v1/models/{model_id}.git'
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
) -> str:
|
||||
"""Get model information at ModelScope
|
||||
|
||||
Args:
|
||||
model_id (str): The model id.
|
||||
revision (str optional): revision of model.
|
||||
|
||||
Returns:
|
||||
The model detail information.
|
||||
|
||||
Raises:
|
||||
NotExistError: If the model is not exist, will throw NotExistError
|
||||
|
||||
Note:
|
||||
model_id = {owner}/{name}
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
owner_or_group, name = model_id_to_group_owner_name(model_id)
|
||||
if revision:
|
||||
path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}'
|
||||
else:
|
||||
path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}'
|
||||
|
||||
r = self.session.get(path, cookies=cookies, headers=self.headers)
|
||||
handle_http_response(r, logger, cookies, model_id)
|
||||
if r.status_code == HTTPStatus.OK:
|
||||
if is_ok(r.json()):
|
||||
return r.json()[API_RESPONSE_FIELD_DATA]
|
||||
else:
|
||||
raise NotExistError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
||||
else:
|
||||
raise_for_http_status(r)
|
||||
|
||||
def push_model(self,
|
||||
model_id: str,
|
||||
model_dir: str,
|
||||
visibility: Optional[int] = ModelVisibility.PUBLIC,
|
||||
license: Optional[str] = Licenses.APACHE_V2,
|
||||
chinese_name: Optional[str] = None,
|
||||
commit_message: Optional[str] = 'upload model',
|
||||
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION):
|
||||
"""Upload model from a given directory to given repository. A valid model directory
|
||||
must contain a configuration.json file.
|
||||
|
||||
This function upload the files in given directory to given repository. If the
|
||||
given repository is not exists in remote, it will automatically create it with
|
||||
given visibility, license and chinese_name parameters. If the revision is also
|
||||
not exists in remote repository, it will create a new branch for it.
|
||||
|
||||
This function must be called before calling HubApi's login with a valid token
|
||||
which can be obtained from ModelScope's website.
|
||||
|
||||
Args:
|
||||
model_id (str):
|
||||
The model id to be uploaded, caller must have write permission for it.
|
||||
model_dir(str):
|
||||
The Absolute Path of the finetune result.
|
||||
visibility(int, optional):
|
||||
Visibility of the new created model(1-private, 5-public). If the model is
|
||||
not exists in ModelScope, this function will create a new model with this
|
||||
visibility and this parameter is required. You can ignore this parameter
|
||||
if you make sure the model's existence.
|
||||
license(`str`, defaults to `None`):
|
||||
License of the new created model(see License). If the model is not exists
|
||||
in ModelScope, this function will create a new model with this license
|
||||
and this parameter is required. You can ignore this parameter if you
|
||||
make sure the model's existence.
|
||||
chinese_name(`str`, *optional*, defaults to `None`):
|
||||
chinese name of the new created model.
|
||||
commit_message(`str`, *optional*, defaults to `None`):
|
||||
commit message of the push request.
|
||||
revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION):
|
||||
which branch to push. If the branch is not exists, It will create a new
|
||||
branch and push to it.
|
||||
|
||||
Raises:
|
||||
InvalidParameter: Parameter invalid.
|
||||
NotLoginException: Not login
|
||||
ValueError: No configuration.json
|
||||
Exception: Create failed.
|
||||
"""
|
||||
if model_id is None:
|
||||
raise InvalidParameter('model_id cannot be empty!')
|
||||
if model_dir is None:
|
||||
raise InvalidParameter('model_dir cannot be empty!')
|
||||
if not os.path.exists(model_dir) or os.path.isfile(model_dir):
|
||||
raise InvalidParameter('model_dir must be a valid directory.')
|
||||
cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
|
||||
if not os.path.exists(cfg_file):
|
||||
raise ValueError(f'{model_dir} must contain a configuration.json.')
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise NotLoginException('Must login before upload!')
|
||||
files_to_save = os.listdir(model_dir)
|
||||
try:
|
||||
self.get_model(model_id=model_id)
|
||||
except Exception:
|
||||
if visibility is None or license is None:
|
||||
raise InvalidParameter(
|
||||
'visibility and license cannot be empty if want to create new repo'
|
||||
)
|
||||
logger.info('Create new model %s' % model_id)
|
||||
self.create_model(
|
||||
model_id=model_id,
|
||||
visibility=visibility,
|
||||
license=license,
|
||||
chinese_name=chinese_name)
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
git_wrapper = GitCommandWrapper()
|
||||
try:
|
||||
repo = Repository(model_dir=tmp_dir, clone_from=model_id)
|
||||
branches = git_wrapper.get_remote_branches(tmp_dir)
|
||||
if revision not in branches:
|
||||
logger.info('Create new branch %s' % revision)
|
||||
git_wrapper.new_branch(tmp_dir, revision)
|
||||
git_wrapper.checkout(tmp_dir, revision)
|
||||
files_in_repo = os.listdir(tmp_dir)
|
||||
for f in files_in_repo:
|
||||
if f[0] != '.':
|
||||
src = os.path.join(tmp_dir, f)
|
||||
if os.path.isfile(src):
|
||||
os.remove(src)
|
||||
else:
|
||||
shutil.rmtree(src, ignore_errors=True)
|
||||
for f in files_to_save:
|
||||
if f[0] != '.':
|
||||
src = os.path.join(model_dir, f)
|
||||
if os.path.isdir(src):
|
||||
shutil.copytree(src, os.path.join(tmp_dir, f))
|
||||
else:
|
||||
shutil.copy(src, tmp_dir)
|
||||
if not commit_message:
|
||||
date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
|
||||
commit_message = '[automsg] push model %s to hub at %s' % (
|
||||
model_id, date)
|
||||
repo.push(
|
||||
commit_message=commit_message,
|
||||
local_branch=revision,
|
||||
remote_branch=revision)
|
||||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
def list_models(self,
|
||||
owner_or_group: str,
|
||||
page_number: Optional[int] = 1,
|
||||
page_size: Optional[int] = 10) -> dict:
|
||||
"""List models in owner or group.
|
||||
|
||||
Args:
|
||||
owner_or_group(str): owner or group.
|
||||
page_number(int, optional): The page number, default: 1
|
||||
page_size(int, optional): The page size, default: 10
|
||||
|
||||
Raises:
|
||||
RequestError: The request error.
|
||||
|
||||
Returns:
|
||||
dict: {"models": "list of models", "TotalCount": total_number_of_models_in_owner_or_group}
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
path = f'{self.endpoint}/api/v1/models/'
|
||||
r = self.session.put(
|
||||
path,
|
||||
data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' %
|
||||
(owner_or_group, page_number, page_size),
|
||||
cookies=cookies,
|
||||
headers=self.headers)
|
||||
handle_http_response(r, logger, cookies, 'list_model')
|
||||
if r.status_code == HTTPStatus.OK:
|
||||
if is_ok(r.json()):
|
||||
data = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return data
|
||||
else:
|
||||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
||||
else:
|
||||
raise_for_http_status(r)
|
||||
return None
|
||||
|
||||
def _check_cookie(self,
|
||||
use_cookies: Union[bool,
|
||||
CookieJar] = False) -> CookieJar:
|
||||
cookies = None
|
||||
if isinstance(use_cookies, CookieJar):
|
||||
cookies = use_cookies
|
||||
elif use_cookies:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise ValueError('Token does not exist, please login first.')
|
||||
return cookies
|
||||
|
||||
def list_model_revisions(
|
||||
self,
|
||||
model_id: str,
|
||||
cutoff_timestamp: Optional[int] = None,
|
||||
use_cookies: Union[bool, CookieJar] = False) -> List[str]:
|
||||
"""Get model branch and tags.
|
||||
|
||||
Args:
|
||||
model_id (str): The model id
|
||||
cutoff_timestamp (int): Tags created before the cutoff will be included.
|
||||
The timestamp is represented by the seconds elapsed from the epoch time.
|
||||
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True,
|
||||
will load cookie from local. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], List[str]]: Return list of branch name and tags
|
||||
"""
|
||||
cookies = self._check_cookie(use_cookies)
|
||||
if cutoff_timestamp is None:
|
||||
cutoff_timestamp = get_release_datetime()
|
||||
path = f'{self.endpoint}/api/v1/models/{model_id}/revisions?EndTime=%s' % cutoff_timestamp
|
||||
r = self.session.get(path, cookies=cookies, headers=self.headers)
|
||||
handle_http_response(r, logger, cookies, model_id)
|
||||
d = r.json()
|
||||
raise_on_error(d)
|
||||
info = d[API_RESPONSE_FIELD_DATA]
|
||||
# tags returned from backend are guaranteed to be ordered by create-time
|
||||
tags = [x['Revision'] for x in info['RevisionMap']['Tags']
|
||||
] if info['RevisionMap']['Tags'] else []
|
||||
return tags
|
||||
|
||||
def get_valid_revision(self,
|
||||
model_id: str,
|
||||
revision=None,
|
||||
cookies: Optional[CookieJar] = None):
|
||||
release_timestamp = get_release_datetime()
|
||||
current_timestamp = int(round(datetime.datetime.now().timestamp()))
|
||||
# for active development in library codes (non-release-branches), release_timestamp
|
||||
# is set to be a far-away-time-in-the-future, to ensure that we shall
|
||||
# get the master-HEAD version from model repo by default (when no revision is provided)
|
||||
if release_timestamp > current_timestamp + ONE_YEAR_SECONDS:
|
||||
branches, tags = self.get_model_branches_and_tags(
|
||||
model_id, use_cookies=False if cookies is None else cookies)
|
||||
if revision is None:
|
||||
revision = MASTER_MODEL_BRANCH
|
||||
logger.info(
|
||||
'Model revision not specified, use default: %s in development mode'
|
||||
% revision)
|
||||
if revision not in branches and revision not in tags:
|
||||
raise NotExistError('The model: %s has no revision : %s .' % (model_id, revision))
|
||||
logger.info('Development mode use revision: %s' % revision)
|
||||
else:
|
||||
if revision is None: # user not specified revision, use latest revision before release time
|
||||
revisions = self.list_model_revisions(
|
||||
model_id,
|
||||
cutoff_timestamp=release_timestamp,
|
||||
use_cookies=False if cookies is None else cookies)
|
||||
if len(revisions) == 0:
|
||||
raise NoValidRevisionError(
|
||||
'The model: %s has no valid revision!' % model_id)
|
||||
# tags (revisions) returned from backend are guaranteed to be ordered by create-time
|
||||
# we shall obtain the latest revision created earlier than release version of this branch
|
||||
revision = revisions[0]
|
||||
logger.info(
|
||||
'Model revision not specified, use the latest revision: %s'
|
||||
% revision)
|
||||
else:
|
||||
# use user-specified revision
|
||||
revisions = self.list_model_revisions(
|
||||
model_id,
|
||||
cutoff_timestamp=current_timestamp,
|
||||
use_cookies=False if cookies is None else cookies)
|
||||
if revision not in revisions:
|
||||
raise NotExistError('The model: %s has no revision: %s !' %
|
||||
(model_id, revision))
|
||||
logger.info('Use user-specified model revision: %s' % revision)
|
||||
return revision
|
||||
|
||||
def get_model_branches_and_tags(
|
||||
self,
|
||||
model_id: str,
|
||||
use_cookies: Union[bool, CookieJar] = False,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""Get model branch and tags.
|
||||
|
||||
Args:
|
||||
model_id (str): The model id
|
||||
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True,
|
||||
will load cookie from local. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], List[str]]: Return list of branch name and tags
|
||||
"""
|
||||
cookies = self._check_cookie(use_cookies)
|
||||
|
||||
path = f'{self.endpoint}/api/v1/models/{model_id}/revisions'
|
||||
r = self.session.get(path, cookies=cookies, headers=self.headers)
|
||||
handle_http_response(r, logger, cookies, model_id)
|
||||
d = r.json()
|
||||
raise_on_error(d)
|
||||
info = d[API_RESPONSE_FIELD_DATA]
|
||||
branches = [x['Revision'] for x in info['RevisionMap']['Branches']
|
||||
] if info['RevisionMap']['Branches'] else []
|
||||
tags = [x['Revision'] for x in info['RevisionMap']['Tags']
|
||||
] if info['RevisionMap']['Tags'] else []
|
||||
return branches, tags
|
||||
|
||||
def get_model_files(self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
root: Optional[str] = None,
|
||||
recursive: Optional[str] = False,
|
||||
use_cookies: Union[bool, CookieJar] = False,
|
||||
headers: Optional[dict] = {}) -> List[dict]:
|
||||
"""List the models files.
|
||||
|
||||
Args:
|
||||
model_id (str): The model id
|
||||
revision (Optional[str], optional): The branch or tag name.
|
||||
root (Optional[str], optional): The root path. Defaults to None.
|
||||
recursive (Optional[str], optional): Is recursive list files. Defaults to False.
|
||||
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True,
|
||||
will load cookie from local. Defaults to False.
|
||||
headers: request headers
|
||||
|
||||
Returns:
|
||||
List[dict]: Model file list.
|
||||
"""
|
||||
if revision:
|
||||
path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s' % (
|
||||
self.endpoint, model_id, revision, recursive)
|
||||
else:
|
||||
path = '%s/api/v1/models/%s/repo/files?Recursive=%s' % (
|
||||
self.endpoint, model_id, recursive)
|
||||
cookies = self._check_cookie(use_cookies)
|
||||
if root is not None:
|
||||
path = path + f'&Root={root}'
|
||||
headers = self.headers if headers is None else headers
|
||||
r = self.session.get(
|
||||
path, cookies=cookies, headers=headers)
|
||||
|
||||
handle_http_response(r, logger, cookies, model_id)
|
||||
d = r.json()
|
||||
raise_on_error(d)
|
||||
|
||||
files = []
|
||||
for file in d[API_RESPONSE_FIELD_DATA]['Files']:
|
||||
if file['Name'] == '.gitignore' or file['Name'] == '.gitattributes':
|
||||
continue
|
||||
|
||||
files.append(file)
|
||||
return files
|
||||
|
||||
def list_datasets(self):
|
||||
path = f'{self.endpoint}/api/v1/datasets'
|
||||
params = {}
|
||||
r = self.session.get(path, params=params, headers=self.headers)
|
||||
raise_for_http_status(r)
|
||||
dataset_list = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return [x['Name'] for x in dataset_list]
|
||||
|
||||
def get_dataset_id_and_type(self, dataset_name: str, namespace: str):
|
||||
""" Get the dataset id and type. """
|
||||
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}'
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
r = self.session.get(datahub_url, cookies=cookies)
|
||||
resp = r.json()
|
||||
datahub_raise_on_error(datahub_url, resp)
|
||||
dataset_id = resp['Data']['Id']
|
||||
dataset_type = resp['Data']['Type']
|
||||
return dataset_id, dataset_type
|
||||
|
||||
def get_dataset_meta_file_list(self, dataset_name: str, namespace: str, dataset_id: str, revision: str):
|
||||
""" Get the meta file-list of the dataset. """
|
||||
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}'
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
r = self.session.get(datahub_url, cookies=cookies, headers=self.headers)
|
||||
r = self.session.get(
|
||||
datahub_url, cookies=cookies, headers=self.headers)
|
||||
resp = r.json()
|
||||
datahub_raise_on_error(datahub_url, resp)
|
||||
file_list = resp['Data']
|
||||
if file_list is None:
|
||||
raise NotExistError(
|
||||
f'The modelscope dataset [dataset_name = {dataset_name}, namespace = {namespace}, '
|
||||
f'version = {revision}] dose not exist')
|
||||
|
||||
file_list = file_list['Files']
|
||||
return file_list
|
||||
|
||||
def get_dataset_meta_files_local_paths(self, dataset_name: str,
|
||||
namespace: str,
|
||||
revision: str,
|
||||
meta_cache_dir: str, dataset_type: int, file_list: list):
|
||||
local_paths = defaultdict(list)
|
||||
dataset_formation = DatasetFormations(dataset_type)
|
||||
dataset_meta_format = DatasetMetaFormats[dataset_formation]
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
|
||||
# Dump the data_type as a local file
|
||||
dataset_type_file_path = os.path.join(meta_cache_dir,
|
||||
f'{str(dataset_type)}{DatasetFormations.formation_mark_ext.value}')
|
||||
with open(dataset_type_file_path, 'w') as fp:
|
||||
fp.write('*** Automatically-generated file, do not modify ***')
|
||||
|
||||
for file_info in file_list:
|
||||
file_path = file_info['Path']
|
||||
extension = os.path.splitext(file_path)[-1]
|
||||
if extension in dataset_meta_format:
|
||||
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \
|
||||
f'Revision={revision}&FilePath={file_path}'
|
||||
r = self.session.get(datahub_url, cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
local_path = os.path.join(meta_cache_dir, file_path)
|
||||
if os.path.exists(local_path):
|
||||
logger.warning(
|
||||
f"Reusing dataset {dataset_name}'s python file ({local_path})"
|
||||
)
|
||||
local_paths[extension].append(local_path)
|
||||
continue
|
||||
with open(local_path, 'wb') as f:
|
||||
f.write(r.content)
|
||||
local_paths[extension].append(local_path)
|
||||
|
||||
return local_paths, dataset_formation
|
||||
|
||||
def fetch_single_csv_script(self, script_url: str):
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
resp = self.session.get(script_url, cookies=cookies, headers=self.headers)
|
||||
if not resp or not resp.text:
|
||||
raise 'The meta-csv file cannot be empty when the meta-args `big_data` is true.'
|
||||
text_list = resp.text.strip().split('\n')
|
||||
text_headers = text_list[0]
|
||||
text_content = text_list[1:]
|
||||
|
||||
return text_headers, text_content
|
||||
|
||||
def get_dataset_file_url(
|
||||
self,
|
||||
file_name: str,
|
||||
dataset_name: str,
|
||||
namespace: str,
|
||||
revision: Optional[str] = DEFAULT_DATASET_REVISION):
|
||||
if file_name.endswith('.csv'):
|
||||
file_name = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \
|
||||
f'Revision={revision}&FilePath={file_name}'
|
||||
return file_name
|
||||
|
||||
def get_dataset_access_config(
|
||||
self,
|
||||
dataset_name: str,
|
||||
namespace: str,
|
||||
revision: Optional[str] = DEFAULT_DATASET_REVISION):
|
||||
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \
|
||||
f'ststoken?Revision={revision}'
|
||||
return self.datahub_remote_call(datahub_url)
|
||||
|
||||
def get_dataset_access_config_session(
|
||||
self,
|
||||
dataset_name: str,
|
||||
namespace: str,
|
||||
check_cookie: bool,
|
||||
revision: Optional[str] = DEFAULT_DATASET_REVISION):
|
||||
|
||||
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \
|
||||
f'ststoken?Revision={revision}'
|
||||
if check_cookie:
|
||||
cookies = self._check_cookie(use_cookies=True)
|
||||
else:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
r = self.session.get(url=datahub_url, cookies=cookies, headers=self.headers)
|
||||
|
||||
r = self.session.get(
|
||||
url=datahub_url, cookies=cookies, headers=self.headers)
|
||||
resp = r.json()
|
||||
raise_on_error(resp)
|
||||
return resp['Data']
|
||||
|
||||
def get_dataset_access_config_for_unzipped(self,
|
||||
dataset_name: str,
|
||||
namespace: str,
|
||||
revision: str,
|
||||
zip_file_name: str):
|
||||
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}'
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
r = self.session.get(url=datahub_url, cookies=cookies, headers=self.headers)
|
||||
resp = r.json()
|
||||
# get visibility of the dataset
|
||||
raise_on_error(resp)
|
||||
data = resp['Data']
|
||||
visibility = DatasetVisibilityMap.get(data['Visibility'])
|
||||
|
||||
datahub_sts_url = f'{datahub_url}/ststoken?Revision={revision}'
|
||||
r_sts = self.session.get(url=datahub_sts_url, cookies=cookies, headers=self.headers)
|
||||
resp_sts = r_sts.json()
|
||||
raise_on_error(resp_sts)
|
||||
data_sts = resp_sts['Data']
|
||||
file_dir = visibility + '-unzipped' + '/' + namespace + '_' + dataset_name + '_' + zip_file_name
|
||||
data_sts['Dir'] = file_dir
|
||||
return data_sts
|
||||
|
||||
def list_oss_dataset_objects(self, dataset_name, namespace, max_limit,
|
||||
is_recursive, is_filter_dir, revision):
|
||||
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \
|
||||
f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}'
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
resp = self.session.get(url=url, cookies=cookies, timeout=1800)
|
||||
resp = resp.json()
|
||||
raise_on_error(resp)
|
||||
resp = resp['Data']
|
||||
return resp
|
||||
|
||||
def delete_oss_dataset_object(self, object_name: str, dataset_name: str,
|
||||
namespace: str, revision: str) -> str:
|
||||
if not object_name or not dataset_name or not namespace or not revision:
|
||||
raise ValueError('Args cannot be empty!')
|
||||
|
||||
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss?Path={object_name}&Revision={revision}'
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
resp = self.session.delete(url=url, cookies=cookies)
|
||||
resp = resp.json()
|
||||
raise_on_error(resp)
|
||||
resp = resp['Message']
|
||||
return resp
|
||||
|
||||
def delete_oss_dataset_dir(self, object_name: str, dataset_name: str,
|
||||
namespace: str, revision: str) -> str:
|
||||
if not object_name or not dataset_name or not namespace or not revision:
|
||||
raise ValueError('Args cannot be empty!')
|
||||
|
||||
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/prefix?Prefix={object_name}/' \
|
||||
f'&Revision={revision}'
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
resp = self.session.delete(url=url, cookies=cookies)
|
||||
resp = resp.json()
|
||||
raise_on_error(resp)
|
||||
resp = resp['Message']
|
||||
return resp
|
||||
|
||||
def datahub_remote_call(self, url):
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
r = self.session.get(
|
||||
url,
|
||||
cookies=cookies,
|
||||
headers={'user-agent': ModelScopeConfig.get_user_agent()})
|
||||
resp = r.json()
|
||||
datahub_raise_on_error(url, resp)
|
||||
return resp['Data']
|
||||
|
||||
def dataset_download_statistics(self, dataset_name: str, namespace: str, use_streaming: bool) -> None:
|
||||
is_ci_test = os.getenv('CI_TEST') == 'True'
|
||||
if dataset_name and namespace and not is_ci_test and not use_streaming:
|
||||
try:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
|
||||
# Download count
|
||||
download_count_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase'
|
||||
download_count_resp = self.session.post(download_count_url, cookies=cookies, headers=self.headers)
|
||||
raise_for_http_status(download_count_resp)
|
||||
|
||||
# Download uv
|
||||
channel = DownloadChannel.LOCAL.value
|
||||
user_name = ''
|
||||
if MODELSCOPE_CLOUD_ENVIRONMENT in os.environ:
|
||||
channel = os.environ[MODELSCOPE_CLOUD_ENVIRONMENT]
|
||||
if MODELSCOPE_CLOUD_USERNAME in os.environ:
|
||||
user_name = os.environ[MODELSCOPE_CLOUD_USERNAME]
|
||||
download_uv_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/uv/' \
|
||||
f'{channel}?user={user_name}'
|
||||
download_uv_resp = self.session.post(download_uv_url, cookies=cookies, headers=self.headers)
|
||||
download_uv_resp = download_uv_resp.json()
|
||||
raise_on_error(download_uv_resp)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
|
||||
class ModelScopeConfig:
|
||||
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
|
||||
COOKIES_FILE_NAME = 'cookies'
|
||||
GIT_TOKEN_FILE_NAME = 'git_token'
|
||||
USER_INFO_FILE_NAME = 'user'
|
||||
USER_SESSION_ID_FILE_NAME = 'session'
|
||||
|
||||
@staticmethod
|
||||
def make_sure_credential_path_exist():
|
||||
os.makedirs(ModelScopeConfig.path_credential, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def save_cookies(cookies: CookieJar):
|
||||
ModelScopeConfig.make_sure_credential_path_exist()
|
||||
with open(
|
||||
os.path.join(ModelScopeConfig.path_credential,
|
||||
ModelScopeConfig.COOKIES_FILE_NAME), 'wb+') as f:
|
||||
pickle.dump(cookies, f)
|
||||
|
||||
@staticmethod
|
||||
def get_cookies():
|
||||
cookies_path = os.path.join(ModelScopeConfig.path_credential,
|
||||
ModelScopeConfig.COOKIES_FILE_NAME)
|
||||
if os.path.exists(cookies_path):
|
||||
with open(cookies_path, 'rb') as f:
|
||||
cookies = pickle.load(f)
|
||||
for cookie in cookies:
|
||||
if cookie.is_expired():
|
||||
logger.warning(
|
||||
'Authentication has expired, '
|
||||
'please re-login if you need to access private models or datasets.')
|
||||
return None
|
||||
return cookies
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_user_session_id():
|
||||
session_path = os.path.join(ModelScopeConfig.path_credential,
|
||||
ModelScopeConfig.USER_SESSION_ID_FILE_NAME)
|
||||
session_id = ''
|
||||
if os.path.exists(session_path):
|
||||
with open(session_path, 'rb') as f:
|
||||
session_id = str(f.readline().strip(), encoding='utf-8')
|
||||
return session_id
|
||||
if session_id == '' or len(session_id) != 32:
|
||||
session_id = str(uuid.uuid4().hex)
|
||||
ModelScopeConfig.make_sure_credential_path_exist()
|
||||
with open(session_path, 'w+') as wf:
|
||||
wf.write(session_id)
|
||||
|
||||
return session_id
|
||||
|
||||
@staticmethod
|
||||
def save_token(token: str):
|
||||
ModelScopeConfig.make_sure_credential_path_exist()
|
||||
with open(
|
||||
os.path.join(ModelScopeConfig.path_credential,
|
||||
ModelScopeConfig.GIT_TOKEN_FILE_NAME), 'w+') as f:
|
||||
f.write(token)
|
||||
|
||||
@staticmethod
|
||||
def save_user_info(user_name: str, user_email: str):
|
||||
ModelScopeConfig.make_sure_credential_path_exist()
|
||||
with open(
|
||||
os.path.join(ModelScopeConfig.path_credential,
|
||||
ModelScopeConfig.USER_INFO_FILE_NAME), 'w+') as f:
|
||||
f.write('%s:%s' % (user_name, user_email))
|
||||
|
||||
@staticmethod
|
||||
def get_user_info() -> Tuple[str, str]:
|
||||
try:
|
||||
with open(
|
||||
os.path.join(ModelScopeConfig.path_credential,
|
||||
ModelScopeConfig.USER_INFO_FILE_NAME),
|
||||
'r',
|
||||
encoding='utf-8') as f:
|
||||
info = f.read()
|
||||
return info.split(':')[0], info.split(':')[1]
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return None, None
|
||||
|
||||
@staticmethod
|
||||
def get_token() -> Optional[str]:
|
||||
"""
|
||||
Get token or None if not existent.
|
||||
|
||||
Returns:
|
||||
`str` or `None`: The token, `None` if it doesn't exist.
|
||||
|
||||
"""
|
||||
token = None
|
||||
try:
|
||||
with open(
|
||||
os.path.join(ModelScopeConfig.path_credential,
|
||||
ModelScopeConfig.GIT_TOKEN_FILE_NAME),
|
||||
'r',
|
||||
encoding='utf-8') as f:
|
||||
token = f.read()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:
|
||||
"""Formats a user-agent string with basic info about a request.
|
||||
|
||||
Args:
|
||||
user_agent (`str`, `dict`, *optional*):
|
||||
The user agent info in the form of a dictionary or a single string.
|
||||
|
||||
Returns:
|
||||
The formatted user-agent string.
|
||||
"""
|
||||
|
||||
# include some more telemetrics when executing in dedicated
|
||||
# cloud containers
|
||||
env = 'custom'
|
||||
if MODELSCOPE_CLOUD_ENVIRONMENT in os.environ:
|
||||
env = os.environ[MODELSCOPE_CLOUD_ENVIRONMENT]
|
||||
user_name = 'unknown'
|
||||
if MODELSCOPE_CLOUD_USERNAME in os.environ:
|
||||
user_name = os.environ[MODELSCOPE_CLOUD_USERNAME]
|
||||
|
||||
ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s; user/%s' % (
|
||||
__version__,
|
||||
platform.python_version(),
|
||||
ModelScopeConfig.get_user_session_id(),
|
||||
platform.platform(),
|
||||
platform.processor(),
|
||||
env,
|
||||
user_name,
|
||||
)
|
||||
if isinstance(user_agent, dict):
|
||||
ua += '; ' + '; '.join(f'{k}/{v}' for k, v in user_agent.items())
|
||||
elif isinstance(user_agent, str):
|
||||
ua += '; ' + user_agent
|
||||
return ua
|
||||
93
modelscope/hub/check_model.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Dict, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.hub.constants import FILE_HASH
|
||||
from modelscope.hub.git import GitCommandWrapper
|
||||
from modelscope.hub.utils.caching import ModelFileSystemCache
|
||||
from modelscope.hub.utils.utils import compute_hash
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def check_local_model_is_latest(
|
||||
model_root_path: str,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
):
|
||||
"""Check local model repo is latest.
|
||||
Check local model repo is same as hub latest version.
|
||||
"""
|
||||
model_cache = None
|
||||
# download with git
|
||||
if os.path.exists(os.path.join(model_root_path, '.git')):
|
||||
git_cmd_wrapper = GitCommandWrapper()
|
||||
git_url = git_cmd_wrapper.get_repo_remote_url(model_root_path)
|
||||
if git_url.endswith('.git'):
|
||||
git_url = git_url[:-4]
|
||||
u_parse = urlparse(git_url)
|
||||
model_id = u_parse.path[1:]
|
||||
else: # snapshot_download
|
||||
model_cache = ModelFileSystemCache(model_root_path)
|
||||
model_id = model_cache.get_model_id()
|
||||
|
||||
try:
|
||||
# make headers
|
||||
headers = {
|
||||
'user-agent':
|
||||
ModelScopeConfig.get_user_agent(user_agent=user_agent, )
|
||||
}
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
|
||||
snapshot_header = headers if 'CI_TEST' in os.environ else {
|
||||
**headers,
|
||||
**{
|
||||
'Snapshot': 'True'
|
||||
}
|
||||
}
|
||||
_api = HubApi()
|
||||
try:
|
||||
_, revisions = _api.get_model_branches_and_tags(
|
||||
model_id=model_id, use_cookies=cookies)
|
||||
if len(revisions) > 0:
|
||||
latest_revision = revisions[0]
|
||||
else:
|
||||
latest_revision = 'master'
|
||||
except: # noqa: E722
|
||||
latest_revision = 'master'
|
||||
|
||||
model_files = _api.get_model_files(
|
||||
model_id=model_id,
|
||||
revision=latest_revision,
|
||||
recursive=True,
|
||||
headers=snapshot_header,
|
||||
use_cookies=cookies,
|
||||
)
|
||||
for model_file in model_files:
|
||||
if model_file['Type'] == 'tree':
|
||||
continue
|
||||
# check model_file updated
|
||||
if model_cache is not None:
|
||||
if model_cache.exists(model_file):
|
||||
continue
|
||||
else:
|
||||
logger.info(
|
||||
'Model is updated from modelscope hub, you can verify from https://www.modelscope.cn.'
|
||||
)
|
||||
break
|
||||
else:
|
||||
if FILE_HASH in model_file:
|
||||
local_file_hash = compute_hash(
|
||||
os.path.join(model_root_path, model_file['Path']))
|
||||
if local_file_hash == model_file[FILE_HASH]:
|
||||
continue
|
||||
else:
|
||||
logger.info(
|
||||
'Model is updated from modelscope hub, you can verify from https://www.modelscope.cn.'
|
||||
)
|
||||
break
|
||||
except: # noqa: E722
|
||||
pass # ignore
|
||||
46
modelscope/hub/constants.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
MODELSCOPE_URL_SCHEME = 'http://'
|
||||
DEFAULT_MODELSCOPE_DOMAIN = 'www.modelscope.cn'
|
||||
DEFAULT_MODELSCOPE_DATA_ENDPOINT = MODELSCOPE_URL_SCHEME + DEFAULT_MODELSCOPE_DOMAIN
|
||||
|
||||
DEFAULT_MODELSCOPE_GROUP = 'damo'
|
||||
MODEL_ID_SEPARATOR = '/'
|
||||
FILE_HASH = 'Sha256'
|
||||
LOGGER_NAME = 'ModelScopeHub'
|
||||
DEFAULT_CREDENTIALS_PATH = Path.home().joinpath('.modelscope', 'credentials')
|
||||
REQUESTS_API_HTTP_METHOD = ['get', 'head', 'post', 'put', 'patch', 'delete']
|
||||
API_HTTP_CLIENT_TIMEOUT = 60
|
||||
API_RESPONSE_FIELD_DATA = 'Data'
|
||||
API_FILE_DOWNLOAD_RETRY_TIMES = 5
|
||||
API_FILE_DOWNLOAD_TIMEOUT = 60 * 5
|
||||
API_FILE_DOWNLOAD_CHUNK_SIZE = 4096
|
||||
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN = 'AccessToken'
|
||||
API_RESPONSE_FIELD_USERNAME = 'Username'
|
||||
API_RESPONSE_FIELD_EMAIL = 'Email'
|
||||
API_RESPONSE_FIELD_MESSAGE = 'Message'
|
||||
MODELSCOPE_CLOUD_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT'
|
||||
MODELSCOPE_CLOUD_USERNAME = 'MODELSCOPE_USERNAME'
|
||||
MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG'
|
||||
ONE_YEAR_SECONDS = 24 * 365 * 60 * 60
|
||||
MODEL_META_FILE_NAME = '.mdl'
|
||||
MODEL_META_MODEL_ID = 'id'
|
||||
|
||||
|
||||
class Licenses(object):
|
||||
APACHE_V2 = 'Apache License 2.0'
|
||||
GPL_V2 = 'GPL-2.0'
|
||||
GPL_V3 = 'GPL-3.0'
|
||||
LGPL_V2_1 = 'LGPL-2.1'
|
||||
LGPL_V3 = 'LGPL-3.0'
|
||||
AFL_V3 = 'AFL-3.0'
|
||||
ECL_V2 = 'ECL-2.0'
|
||||
MIT = 'MIT'
|
||||
|
||||
|
||||
class ModelVisibility(object):
|
||||
PRIVATE = 1
|
||||
INTERNAL = 3
|
||||
PUBLIC = 5
|
||||
338
modelscope/hub/deploy.py
Normal file
@@ -0,0 +1,338 @@
|
||||
import json
|
||||
import urllib
|
||||
from abc import ABC
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from attrs import asdict, define, field, validators
|
||||
|
||||
from modelscope.hub.api import ModelScopeConfig
|
||||
from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA,
|
||||
API_RESPONSE_FIELD_MESSAGE)
|
||||
from modelscope.hub.errors import (NotLoginException, NotSupportError,
|
||||
RequestError, handle_http_response, is_ok,
|
||||
raise_for_http_status)
|
||||
from modelscope.hub.utils.utils import get_endpoint
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
# yapf: enable
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class Accelerator(object):
|
||||
CPU = 'cpu'
|
||||
GPU = 'gpu'
|
||||
|
||||
|
||||
class Vendor(object):
|
||||
EAS = 'eas'
|
||||
|
||||
|
||||
class EASRegion(object):
|
||||
beijing = 'cn-beijing'
|
||||
hangzhou = 'cn-hangzhou'
|
||||
|
||||
|
||||
class EASCpuInstanceType(object):
|
||||
"""EAS Cpu Instance Type, ref(https://help.aliyun.com/document_detail/144261.html)
|
||||
"""
|
||||
tiny = 'ecs.c6.2xlarge'
|
||||
small = 'ecs.c6.4xlarge'
|
||||
medium = 'ecs.c6.6xlarge'
|
||||
large = 'ecs.c6.8xlarge'
|
||||
|
||||
|
||||
class EASGpuInstanceType(object):
|
||||
"""EAS Gpu Instance Type, ref(https://help.aliyun.com/document_detail/144261.html)
|
||||
"""
|
||||
tiny = 'ecs.gn5-c28g1.7xlarge'
|
||||
small = 'ecs.gn5-c8g1.4xlarge'
|
||||
medium = 'ecs.gn6i-c24g1.12xlarge'
|
||||
large = 'ecs.gn6e-c12g1.3xlarge'
|
||||
|
||||
|
||||
def min_smaller_than_max(instance, attribute, value):
|
||||
if value > instance.max_replica:
|
||||
raise ValueError(
|
||||
"'min_replica' value: %s has to be smaller than 'max_replica' value: %s!"
|
||||
% (value, instance.max_replica))
|
||||
|
||||
|
||||
@define
|
||||
class ServiceScalingConfig(object):
|
||||
"""Resource scaling config
|
||||
Currently we ignore max_replica
|
||||
Args:
|
||||
max_replica: maximum replica
|
||||
min_replica: minimum replica
|
||||
"""
|
||||
max_replica: int = field(default=1, validator=validators.ge(1))
|
||||
min_replica: int = field(
|
||||
default=1, validator=[validators.ge(1), min_smaller_than_max])
|
||||
|
||||
|
||||
@define
|
||||
class ServiceResourceConfig(object):
|
||||
"""Eas Resource request.
|
||||
|
||||
Args:
|
||||
accelerator: the accelerator(cpu|gpu)
|
||||
instance_type: the instance type.
|
||||
scaling: The instance scaling config.
|
||||
"""
|
||||
instance_type: str
|
||||
scaling: ServiceScalingConfig
|
||||
accelerator: str = field(default=Accelerator.CPU,
|
||||
validator=validators.in_(
|
||||
[Accelerator.CPU, Accelerator.GPU]))
|
||||
|
||||
|
||||
@define
|
||||
class ServiceProviderParameters(ABC):
|
||||
pass
|
||||
|
||||
|
||||
@define
|
||||
class EASDeployParameters(ServiceProviderParameters):
|
||||
"""Parameters for EAS Deployment.
|
||||
|
||||
Args:
|
||||
resource_group: the resource group to deploy, current default.
|
||||
region: The eas instance region(eg: cn-hangzhou).
|
||||
access_key_id: The eas account access key id.
|
||||
access_key_secret: The eas account access key secret.
|
||||
vendor: must be 'eas'
|
||||
"""
|
||||
region: str
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
resource_group: Optional[str] = None
|
||||
vendor: str = field(default=Vendor.EAS,
|
||||
validator=validators.in_([Vendor.EAS]))
|
||||
|
||||
|
||||
@define
|
||||
class EASListParameters(ServiceProviderParameters):
|
||||
"""EAS instance list parameters.
|
||||
|
||||
Args:
|
||||
resource_group: the resource group to deploy, current default.
|
||||
region: The eas instance region(eg: cn-hangzhou).
|
||||
access_key_id: The eas account access key id.
|
||||
access_key_secret: The eas account access key secret.
|
||||
vendor: must be 'eas'
|
||||
"""
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
region: str = None
|
||||
resource_group: str = None
|
||||
vendor: str = field(default=Vendor.EAS,
|
||||
validator=validators.in_([Vendor.EAS]))
|
||||
|
||||
|
||||
@define
|
||||
class DeployServiceParameters(object):
|
||||
"""Deploy service parameters
|
||||
|
||||
Args:
|
||||
instance_name: the name of the service.
|
||||
model_id: the modelscope model_id
|
||||
revision: the modelscope model revision
|
||||
resource: the resource requirement.
|
||||
provider: the cloud service provider.
|
||||
"""
|
||||
instance_name: str
|
||||
model_id: str
|
||||
revision: str
|
||||
resource: ServiceResourceConfig
|
||||
provider: ServiceProviderParameters
|
||||
|
||||
|
||||
class AttrsToQueryString(ABC):
|
||||
"""Convert the attrs class to json string.
|
||||
|
||||
Args:
|
||||
"""
|
||||
def to_query_str(self):
|
||||
self_dict = asdict(self.provider,
|
||||
filter=lambda attr, value: value is not None)
|
||||
json_str = json.dumps(self_dict)
|
||||
print(json_str)
|
||||
safe_str = urllib.parse.quote_plus(json_str)
|
||||
print(safe_str)
|
||||
query_param = 'provider=%s' % safe_str
|
||||
return query_param
|
||||
|
||||
|
||||
@define
|
||||
class ListServiceParameters(AttrsToQueryString):
|
||||
provider: ServiceProviderParameters
|
||||
skip: int = 0
|
||||
limit: int = 100
|
||||
|
||||
|
||||
@define
|
||||
class GetServiceParameters(AttrsToQueryString):
|
||||
provider: ServiceProviderParameters
|
||||
|
||||
|
||||
@define
|
||||
class DeleteServiceParameters(AttrsToQueryString):
|
||||
provider: ServiceProviderParameters
|
||||
|
||||
|
||||
class ServiceDeployer(object):
|
||||
"""Facilitate model deployment on to supported service provider(s).
|
||||
"""
|
||||
def __init__(self, endpoint=None):
|
||||
self.endpoint = endpoint if endpoint is not None else get_endpoint()
|
||||
self.headers = {'user-agent': ModelScopeConfig.get_user_agent()}
|
||||
self.cookies = ModelScopeConfig.get_cookies()
|
||||
if self.cookies is None:
|
||||
raise NotLoginException(
|
||||
'Token does not exist, please login with HubApi first.')
|
||||
|
||||
# deploy_model
|
||||
def create(self, model_id: str, revision: str, instance_name: str,
|
||||
resource: ServiceResourceConfig,
|
||||
provider: ServiceProviderParameters):
|
||||
"""Deploy model to cloud, current we only support PAI EAS, this is an async API ,
|
||||
and the deployment could take a while to finish remotely. Please check deploy instance
|
||||
status separately via checking the status.
|
||||
|
||||
Args:
|
||||
model_id (str): The deployed model id
|
||||
revision (str): The model revision
|
||||
instance_name (str): The deployed model instance name.
|
||||
resource (ServiceResourceConfig): The service resource information.
|
||||
provider (ServiceProviderParameters): The service provider parameter
|
||||
|
||||
Raises:
|
||||
NotSupportError: Not supported platform.
|
||||
RequestError: The server return error.
|
||||
|
||||
Returns:
|
||||
ServiceInstanceInfo: The information of the deployed service instance.
|
||||
"""
|
||||
if provider.vendor != Vendor.EAS:
|
||||
raise NotSupportError(
|
||||
'Not support vendor: %s ,only support EAS current.' %
|
||||
(provider.vendor))
|
||||
create_params = DeployServiceParameters(instance_name=instance_name,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
resource=resource,
|
||||
provider=provider)
|
||||
path = f'{self.endpoint}/api/v1/deployer/endpoint'
|
||||
body = asdict(create_params)
|
||||
r = requests.post(path,
|
||||
json=body,
|
||||
cookies=self.cookies,
|
||||
headers=self.headers)
|
||||
handle_http_response(r, logger, self.cookies, 'create_service')
|
||||
if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES:
|
||||
if is_ok(r.json()):
|
||||
data = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return data
|
||||
else:
|
||||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
||||
else:
|
||||
raise_for_http_status(r)
|
||||
return None
|
||||
|
||||
def get(self, instance_name: str, provider: ServiceProviderParameters):
|
||||
"""Query the specified instance information.
|
||||
|
||||
Args:
|
||||
instance_name (str): The deployed instance name.
|
||||
provider (ServiceProviderParameters): The cloud provider information, for eas
|
||||
need region(eg: ch-hangzhou), access_key_id and access_key_secret.
|
||||
|
||||
Raises:
|
||||
RequestError: The request is failed from server.
|
||||
|
||||
Returns:
|
||||
Dict: The information of the requested service instance.
|
||||
"""
|
||||
params = GetServiceParameters(provider=provider)
|
||||
path = '%s/api/v1/deployer/endpoint/%s?%s' % (
|
||||
self.endpoint, instance_name, params.to_query_str())
|
||||
r = requests.get(path, cookies=self.cookies, headers=self.headers)
|
||||
handle_http_response(r, logger, self.cookies, 'get_service')
|
||||
if r.status_code == HTTPStatus.OK:
|
||||
if is_ok(r.json()):
|
||||
data = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return data
|
||||
else:
|
||||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
||||
else:
|
||||
raise_for_http_status(r)
|
||||
return None
|
||||
|
||||
def delete(self, instance_name: str, provider: ServiceProviderParameters):
|
||||
"""Delete deployed model, this api send delete command and return, it will take
|
||||
some to delete, please check through the cloud console.
|
||||
|
||||
Args:
|
||||
instance_name (str): The instance name you want to delete.
|
||||
provider (ServiceProviderParameters): The cloud provider information, for eas
|
||||
need region(eg: ch-hangzhou), access_key_id and access_key_secret.
|
||||
|
||||
Raises:
|
||||
RequestError: The request is failed.
|
||||
|
||||
Returns:
|
||||
Dict: The deleted instance information.
|
||||
"""
|
||||
params = DeleteServiceParameters(provider=provider)
|
||||
path = '%s/api/v1/deployer/endpoint/%s?%s' % (
|
||||
self.endpoint, instance_name, params.to_query_str())
|
||||
r = requests.delete(path, cookies=self.cookies, headers=self.headers)
|
||||
handle_http_response(r, logger, self.cookies, 'delete_service')
|
||||
if r.status_code == HTTPStatus.OK:
|
||||
if is_ok(r.json()):
|
||||
data = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return data
|
||||
else:
|
||||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
||||
else:
|
||||
raise_for_http_status(r)
|
||||
return None
|
||||
|
||||
def list(self,
|
||||
provider: ServiceProviderParameters,
|
||||
skip: Optional[int] = 0,
|
||||
limit: Optional[int] = 100):
|
||||
"""List deployed model instances.
|
||||
|
||||
Args:
|
||||
provider (ServiceProviderParameters): The cloud service provider parameter,
|
||||
for eas, need access_key_id and access_key_secret.
|
||||
skip (int, optional): start of the list, current not support.
|
||||
limit (int, optional): maximum number of instances return, current not support
|
||||
|
||||
Raises:
|
||||
RequestError: The request is failed from server.
|
||||
|
||||
Returns:
|
||||
List: List of instance information
|
||||
"""
|
||||
|
||||
params = ListServiceParameters(provider=provider,
|
||||
skip=skip,
|
||||
limit=limit)
|
||||
path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint,
|
||||
params.to_query_str())
|
||||
r = requests.get(path, cookies=self.cookies, headers=self.headers)
|
||||
handle_http_response(r, logger, self.cookies, 'list_service_instances')
|
||||
if r.status_code == HTTPStatus.OK:
|
||||
if is_ok(r.json()):
|
||||
data = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return data
|
||||
else:
|
||||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
||||
else:
|
||||
raise_for_http_status(r)
|
||||
return None
|
||||
153
modelscope/hub/errors.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from http import HTTPStatus
|
||||
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class NotSupportError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NoValidRevisionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NotExistError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RequestError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class GitError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidParameter(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NotLoginException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FileIntegrityError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FileDownloadError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def is_ok(rsp):
|
||||
""" Check the request is ok
|
||||
|
||||
Args:
|
||||
rsp (Response): The request response body
|
||||
|
||||
Returns:
|
||||
bool: `True` if success otherwise `False`.
|
||||
"""
|
||||
return rsp['Code'] == HTTPStatus.OK and rsp['Success']
|
||||
|
||||
|
||||
def handle_http_post_error(response, url, request_body):
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except HTTPError as error:
|
||||
logger.error('Request %s with body: %s exception' %
|
||||
(url, request_body))
|
||||
logger.error('Response details: %s' % response.content)
|
||||
raise error
|
||||
|
||||
|
||||
def handle_http_response(response, logger, cookies, model_id):
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except HTTPError as error:
|
||||
if cookies is None: # code in [403] and
|
||||
logger.error(
|
||||
f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \
|
||||
private. Please login first.')
|
||||
logger.error('Response details: %s' % response.content)
|
||||
raise error
|
||||
|
||||
|
||||
def raise_on_error(rsp):
|
||||
"""If response error, raise exception
|
||||
|
||||
Args:
|
||||
rsp (_type_): The server response
|
||||
|
||||
Raises:
|
||||
RequestError: the response error message.
|
||||
|
||||
Returns:
|
||||
bool: True if request is OK, otherwise raise `RequestError` exception.
|
||||
"""
|
||||
if rsp['Code'] == HTTPStatus.OK:
|
||||
return True
|
||||
else:
|
||||
raise RequestError(rsp['Message'])
|
||||
|
||||
|
||||
def datahub_raise_on_error(url, rsp):
|
||||
"""If response error, raise exception
|
||||
|
||||
Args:
|
||||
url (str): The request url
|
||||
rsp (HTTPResponse): The server response.
|
||||
|
||||
Raises:
|
||||
RequestError: the http request error.
|
||||
|
||||
Returns:
|
||||
bool: `True` if request is OK, otherwise raise `RequestError` exception.
|
||||
"""
|
||||
if rsp.get('Code') == HTTPStatus.OK:
|
||||
return True
|
||||
else:
|
||||
raise RequestError(
|
||||
f"Url = {url}, Message = {rsp.get('Message')}, Please specify correct dataset_name and namespace."
|
||||
)
|
||||
|
||||
|
||||
def raise_for_http_status(rsp):
|
||||
"""Attempt to decode utf-8 first since some servers
|
||||
localize reason strings, for invalid utf-8, fall back
|
||||
to decoding with iso-8859-1.
|
||||
|
||||
Args:
|
||||
rsp: The http response.
|
||||
|
||||
Raises:
|
||||
HTTPError: The http error info.
|
||||
"""
|
||||
http_error_msg = ''
|
||||
if isinstance(rsp.reason, bytes):
|
||||
try:
|
||||
reason = rsp.reason.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
reason = rsp.reason.decode('iso-8859-1')
|
||||
else:
|
||||
reason = rsp.reason
|
||||
|
||||
if 400 <= rsp.status_code < 500:
|
||||
http_error_msg = u'%s Client Error: %s for url: %s' % (rsp.status_code,
|
||||
reason, rsp.url)
|
||||
|
||||
elif 500 <= rsp.status_code < 600:
|
||||
http_error_msg = u'%s Server Error: %s for url: %s' % (rsp.status_code,
|
||||
reason, rsp.url)
|
||||
|
||||
if http_error_msg:
|
||||
req = rsp.request
|
||||
if req.method == 'POST':
|
||||
http_error_msg = u'%s, body: %s' % (http_error_msg, req.body)
|
||||
raise HTTPError(http_error_msg, response=rsp)
|
||||
261
modelscope/hub/file_download.py
Normal file
@@ -0,0 +1,261 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import copy
|
||||
import os
|
||||
import tempfile
|
||||
from functools import partial
|
||||
from http.cookiejar import CookieJar
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import requests
|
||||
from requests.adapters import Retry
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope import __version__
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.hub.constants import (API_FILE_DOWNLOAD_CHUNK_SIZE,
|
||||
API_FILE_DOWNLOAD_RETRY_TIMES,
|
||||
API_FILE_DOWNLOAD_TIMEOUT, FILE_HASH)
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
from .errors import FileDownloadError, NotExistError
|
||||
from .utils.caching import ModelFileSystemCache
|
||||
from .utils.utils import (file_integrity_validation, get_cache_dir,
|
||||
get_endpoint, model_id_to_group_owner_name)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def model_file_download(
|
||||
model_id: str,
|
||||
file_path: str,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
cache_dir: Optional[str] = None,
|
||||
user_agent: Union[Dict, str, None] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
) -> Optional[str]: # pragma: no cover
|
||||
"""Download from a given URL and cache it if it's not already present in the local cache.
|
||||
|
||||
Given a URL, this function looks for the corresponding file in the local
|
||||
cache. If it's not there, download it. Then return the path to the cached
|
||||
file.
|
||||
|
||||
Args:
|
||||
model_id (str): The model to whom the file to be downloaded belongs.
|
||||
file_path(str): Path of the file to be downloaded, relative to the root of model repo.
|
||||
revision(str, optional): revision of the model file to be downloaded.
|
||||
Can be any of a branch, tag or commit hash.
|
||||
cache_dir (str, Path, optional): Path to the folder where cached files are stored.
|
||||
user_agent (dict, str, optional): The user-agent info in the form of a dictionary or a string.
|
||||
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
|
||||
local cached file if it exists. if `False`, download the file anyway even it exists.
|
||||
cookies (CookieJar, optional): The cookie of download request.
|
||||
|
||||
Returns:
|
||||
string: string of local file or if networking is off, last version of
|
||||
file cached on disk.
|
||||
|
||||
Raises:
|
||||
NotExistError: The file is not exist.
|
||||
ValueError: The request parameter error.
|
||||
|
||||
Note:
|
||||
Raises the following errors:
|
||||
|
||||
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
||||
if `use_auth_token=True` and the token cannot be found.
|
||||
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
|
||||
if ETag cannot be determined.
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = get_cache_dir()
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
temporary_cache_dir = os.path.join(cache_dir, 'temp')
|
||||
os.makedirs(temporary_cache_dir, exist_ok=True)
|
||||
|
||||
group_or_owner, name = model_id_to_group_owner_name(model_id)
|
||||
|
||||
cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
|
||||
|
||||
# if local_files_only is `True` and the file already exists in cached_path
|
||||
# return the cached path
|
||||
if local_files_only:
|
||||
cached_file_path = cache.get_file_by_path(file_path)
|
||||
if cached_file_path is not None:
|
||||
logger.warning(
|
||||
"File exists in local cache, but we're not sure it's up to date"
|
||||
)
|
||||
return cached_file_path
|
||||
else:
|
||||
raise ValueError(
|
||||
'Cannot find the requested files in the cached path and outgoing'
|
||||
' traffic has been disabled. To enable model look-ups and downloads'
|
||||
" online, set 'local_files_only' to False.")
|
||||
|
||||
_api = HubApi()
|
||||
headers = {
|
||||
'user-agent': ModelScopeConfig.get_user_agent(user_agent=user_agent, )
|
||||
}
|
||||
if cookies is None:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
|
||||
revision = _api.get_valid_revision(model_id,
|
||||
revision=revision,
|
||||
cookies=cookies)
|
||||
file_to_download_info = None
|
||||
# we need to confirm the version is up-to-date
|
||||
# we need to get the file list to check if the latest version is cached, if so return, otherwise download
|
||||
model_files = _api.get_model_files(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
recursive=True,
|
||||
use_cookies=False if cookies is None else cookies)
|
||||
|
||||
for model_file in model_files:
|
||||
if model_file['Type'] == 'tree':
|
||||
continue
|
||||
|
||||
if model_file['Path'] == file_path:
|
||||
if cache.exists(model_file):
|
||||
logger.info(
|
||||
f'File {model_file["Name"]} already in cache, skip downloading!'
|
||||
)
|
||||
return cache.get_file_by_info(model_file)
|
||||
else:
|
||||
file_to_download_info = model_file
|
||||
break
|
||||
|
||||
if file_to_download_info is None:
|
||||
raise NotExistError('The file path: %s not exist in: %s' %
|
||||
(file_path, model_id))
|
||||
|
||||
# we need to download again
|
||||
url_to_download = get_file_download_url(model_id, file_path, revision)
|
||||
file_to_download_info = {
|
||||
'Path': file_path,
|
||||
'Revision': file_to_download_info['Revision'],
|
||||
FILE_HASH: file_to_download_info[FILE_HASH]
|
||||
}
|
||||
|
||||
temp_file_name = next(tempfile._get_candidate_names())
|
||||
http_get_file(url_to_download,
|
||||
temporary_cache_dir,
|
||||
temp_file_name,
|
||||
headers=headers,
|
||||
cookies=None if cookies is None else cookies.get_dict())
|
||||
temp_file_path = os.path.join(temporary_cache_dir, temp_file_name)
|
||||
# for download with commit we can't get Sha256
|
||||
if file_to_download_info[FILE_HASH] is not None:
|
||||
file_integrity_validation(temp_file_path,
|
||||
file_to_download_info[FILE_HASH])
|
||||
return cache.put_file(file_to_download_info,
|
||||
os.path.join(temporary_cache_dir, temp_file_name))
|
||||
|
||||
|
||||
def get_file_download_url(model_id: str, file_path: str, revision: str):
|
||||
"""Format file download url according to `model_id`, `revision` and `file_path`.
|
||||
e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`,
|
||||
the resulted download url is: https://modelscope.co/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md
|
||||
|
||||
Args:
|
||||
model_id (str): The model_id.
|
||||
file_path (str): File path
|
||||
revision (str): File revision.
|
||||
|
||||
Returns:
|
||||
str: The file url.
|
||||
"""
|
||||
download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}'
|
||||
return download_url_template.format(
|
||||
endpoint=get_endpoint(),
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
file_path=file_path,
|
||||
)
|
||||
|
||||
|
||||
def http_get_file(
|
||||
url: str,
|
||||
local_dir: str,
|
||||
file_name: str,
|
||||
cookies: CookieJar,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Download remote file, will retry 5 times before giving up on errors.
|
||||
|
||||
Args:
|
||||
url(str):
|
||||
actual download url of the file
|
||||
local_dir(str):
|
||||
local directory where the downloaded file stores
|
||||
file_name(str):
|
||||
name of the file stored in `local_dir`
|
||||
cookies(CookieJar):
|
||||
cookies used to authentication the user, which is used for downloading private repos
|
||||
headers(Dict[str, str], optional):
|
||||
http headers to carry necessary info when requesting the remote file
|
||||
|
||||
Raises:
|
||||
FileDownloadError: File download failed.
|
||||
|
||||
"""
|
||||
total = -1
|
||||
temp_file_manager = partial(tempfile.NamedTemporaryFile,
|
||||
mode='wb',
|
||||
dir=local_dir,
|
||||
delete=False)
|
||||
get_headers = {} if headers is None else copy.deepcopy(headers)
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.info('downloading %s to %s', url, temp_file.name)
|
||||
# retry sleep 0.5s, 1s, 2s, 4s
|
||||
retry = Retry(total=API_FILE_DOWNLOAD_RETRY_TIMES,
|
||||
backoff_factor=1,
|
||||
allowed_methods=['GET'])
|
||||
while True:
|
||||
try:
|
||||
downloaded_size = temp_file.tell()
|
||||
get_headers['Range'] = 'bytes=%d-' % downloaded_size
|
||||
r = requests.get(url,
|
||||
stream=True,
|
||||
headers=get_headers,
|
||||
cookies=cookies,
|
||||
timeout=API_FILE_DOWNLOAD_TIMEOUT)
|
||||
r.raise_for_status()
|
||||
content_length = r.headers.get('Content-Length')
|
||||
total = int(
|
||||
content_length) if content_length is not None else None
|
||||
progress = tqdm(
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=total,
|
||||
initial=downloaded_size,
|
||||
desc='Downloading',
|
||||
)
|
||||
for chunk in r.iter_content(
|
||||
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
progress.update(len(chunk))
|
||||
temp_file.write(chunk)
|
||||
progress.close()
|
||||
break
|
||||
except (Exception) as e: # no matter what happen, we will retry.
|
||||
retry = retry.increment('GET', url, error=e)
|
||||
retry.sleep()
|
||||
|
||||
logger.info('storing %s in cache at %s', url, local_dir)
|
||||
downloaded_length = os.path.getsize(temp_file.name)
|
||||
if total != downloaded_length:
|
||||
os.remove(temp_file.name)
|
||||
msg = 'File %s download incomplete, content_length: %s but the \
|
||||
file downloaded length: %s, please download again' % (
|
||||
file_name, total, downloaded_length)
|
||||
logger.error(msg)
|
||||
raise FileDownloadError(msg)
|
||||
os.replace(temp_file.name, os.path.join(local_dir, file_name))
|
||||
260
modelscope/hub/git.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from typing import List, Optional
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
from ..utils.constant import MASTER_MODEL_BRANCH
|
||||
from .errors import GitError
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class Singleton(type):
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(Singleton,
|
||||
cls).__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
class GitCommandWrapper(metaclass=Singleton):
|
||||
"""Some git operation wrapper
|
||||
"""
|
||||
default_git_path = 'git' # The default git command line
|
||||
|
||||
def __init__(self, path: str = None):
|
||||
self.git_path = path or self.default_git_path
|
||||
|
||||
def _run_git_command(self, *args) -> subprocess.CompletedProcess:
|
||||
"""Run git command, if command return 0, return subprocess.response
|
||||
otherwise raise GitError, message is stdout and stderr.
|
||||
|
||||
Args:
|
||||
args: List of command args.
|
||||
|
||||
Raises:
|
||||
GitError: Exception with stdout and stderr.
|
||||
|
||||
Returns:
|
||||
subprocess.CompletedProcess: the command response
|
||||
"""
|
||||
logger.debug(' '.join(args))
|
||||
git_env = os.environ.copy()
|
||||
git_env['GIT_TERMINAL_PROMPT'] = '0'
|
||||
response = subprocess.run(
|
||||
[self.git_path, *args],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
env=git_env,
|
||||
) # compatible for python3.6
|
||||
try:
|
||||
response.check_returncode()
|
||||
return response
|
||||
except subprocess.CalledProcessError as error:
|
||||
if response.returncode == 1:
|
||||
logger.info('Nothing to commit.')
|
||||
return response
|
||||
else:
|
||||
logger.error(
|
||||
'There are error run git command, you may need to login first.'
|
||||
)
|
||||
raise GitError('stdout: %s, stderr: %s' %
|
||||
(response.stdout.decode('utf8'),
|
||||
error.stderr.decode('utf8')))
|
||||
|
||||
def config_auth_token(self, repo_dir, auth_token):
|
||||
url = self.get_repo_remote_url(repo_dir)
|
||||
if '//oauth2' not in url:
|
||||
auth_url = self._add_token(auth_token, url)
|
||||
cmd_args = '-C %s remote set-url origin %s' % (repo_dir, auth_url)
|
||||
cmd_args = cmd_args.split(' ')
|
||||
rsp = self._run_git_command(*cmd_args)
|
||||
logger.debug(rsp.stdout.decode('utf8'))
|
||||
|
||||
def _add_token(self, token: str, url: str):
|
||||
if token:
|
||||
if '//oauth2' not in url:
|
||||
url = url.replace('//', '//oauth2:%s@' % token)
|
||||
return url
|
||||
|
||||
def remove_token_from_url(self, url: str):
|
||||
if url and '//oauth2' in url:
|
||||
start_index = url.find('oauth2')
|
||||
end_index = url.find('@')
|
||||
url = url[:start_index] + url[end_index + 1:]
|
||||
return url
|
||||
|
||||
def is_lfs_installed(self):
|
||||
cmd = ['lfs', 'env']
|
||||
try:
|
||||
self._run_git_command(*cmd)
|
||||
return True
|
||||
except GitError:
|
||||
return False
|
||||
|
||||
def git_lfs_install(self, repo_dir):
|
||||
cmd = ['-C', repo_dir, 'lfs', 'install']
|
||||
try:
|
||||
self._run_git_command(*cmd)
|
||||
return True
|
||||
except GitError:
|
||||
return False
|
||||
|
||||
def clone(self,
|
||||
repo_base_dir: str,
|
||||
token: str,
|
||||
url: str,
|
||||
repo_name: str,
|
||||
branch: Optional[str] = None):
|
||||
""" git clone command wrapper.
|
||||
For public project, token can None, private repo, there must token.
|
||||
|
||||
Args:
|
||||
repo_base_dir (str): The local base dir, the repository will be clone to local_dir/repo_name
|
||||
token (str): The git token, must be provided for private project.
|
||||
url (str): The remote url
|
||||
repo_name (str): The local repository path name.
|
||||
branch (str, optional): _description_. Defaults to None.
|
||||
|
||||
Returns:
|
||||
The popen response.
|
||||
"""
|
||||
url = self._add_token(token, url)
|
||||
if branch:
|
||||
clone_args = '-C %s clone %s %s --branch %s' % (repo_base_dir, url,
|
||||
repo_name, branch)
|
||||
else:
|
||||
clone_args = '-C %s clone %s' % (repo_base_dir, url)
|
||||
logger.debug(clone_args)
|
||||
clone_args = clone_args.split(' ')
|
||||
response = self._run_git_command(*clone_args)
|
||||
logger.debug(response.stdout.decode('utf8'))
|
||||
return response
|
||||
|
||||
def add_user_info(self, repo_base_dir, repo_name):
|
||||
from modelscope.hub.api import ModelScopeConfig
|
||||
user_name, user_email = ModelScopeConfig.get_user_info()
|
||||
if user_name and user_email:
|
||||
# config user.name and user.email if exist
|
||||
config_user_name_args = '-C %s/%s config user.name %s' % (
|
||||
repo_base_dir, repo_name, user_name)
|
||||
response = self._run_git_command(*config_user_name_args.split(' '))
|
||||
logger.debug(response.stdout.decode('utf8'))
|
||||
config_user_email_args = '-C %s/%s config user.email %s' % (
|
||||
repo_base_dir, repo_name, user_email)
|
||||
response = self._run_git_command(
|
||||
*config_user_email_args.split(' '))
|
||||
logger.debug(response.stdout.decode('utf8'))
|
||||
|
||||
def add(self,
|
||||
repo_dir: str,
|
||||
files: List[str] = list(),
|
||||
all_files: bool = False):
|
||||
if all_files:
|
||||
add_args = '-C %s add -A' % repo_dir
|
||||
elif len(files) > 0:
|
||||
files_str = ' '.join(files)
|
||||
add_args = '-C %s add %s' % (repo_dir, files_str)
|
||||
add_args = add_args.split(' ')
|
||||
rsp = self._run_git_command(*add_args)
|
||||
logger.debug(rsp.stdout.decode('utf8'))
|
||||
return rsp
|
||||
|
||||
def commit(self, repo_dir: str, message: str):
|
||||
"""Run git commit command
|
||||
|
||||
Args:
|
||||
repo_dir (str): the repository directory.
|
||||
message (str): commit message.
|
||||
|
||||
Returns:
|
||||
The command popen response.
|
||||
"""
|
||||
commit_args = ['-C', '%s' % repo_dir, 'commit', '-m', "'%s'" % message]
|
||||
rsp = self._run_git_command(*commit_args)
|
||||
logger.info(rsp.stdout.decode('utf8'))
|
||||
return rsp
|
||||
|
||||
def checkout(self, repo_dir: str, revision: str):
|
||||
cmds = ['-C', '%s' % repo_dir, 'checkout', '%s' % revision]
|
||||
return self._run_git_command(*cmds)
|
||||
|
||||
def new_branch(self, repo_dir: str, revision: str):
|
||||
cmds = ['-C', '%s' % repo_dir, 'checkout', '-b', revision]
|
||||
return self._run_git_command(*cmds)
|
||||
|
||||
def get_remote_branches(self, repo_dir: str):
|
||||
cmds = ['-C', '%s' % repo_dir, 'branch', '-r']
|
||||
rsp = self._run_git_command(*cmds)
|
||||
info = [
|
||||
line.strip()
|
||||
for line in rsp.stdout.decode('utf8').strip().split(os.linesep)
|
||||
]
|
||||
if len(info) == 1:
|
||||
return ['/'.join(info[0].split('/')[1:])]
|
||||
else:
|
||||
return ['/'.join(line.split('/')[1:]) for line in info[1:]]
|
||||
|
||||
def pull(self, repo_dir: str):
|
||||
cmds = ['-C', repo_dir, 'pull']
|
||||
return self._run_git_command(*cmds)
|
||||
|
||||
def push(self,
|
||||
repo_dir: str,
|
||||
token: str,
|
||||
url: str,
|
||||
local_branch: str,
|
||||
remote_branch: str,
|
||||
force: bool = False):
|
||||
url = self._add_token(token, url)
|
||||
|
||||
push_args = '-C %s push %s %s:%s' % (repo_dir, url, local_branch,
|
||||
remote_branch)
|
||||
if force:
|
||||
push_args += ' -f'
|
||||
push_args = push_args.split(' ')
|
||||
rsp = self._run_git_command(*push_args)
|
||||
logger.debug(rsp.stdout.decode('utf8'))
|
||||
return rsp
|
||||
|
||||
def get_repo_remote_url(self, repo_dir: str):
|
||||
cmd_args = '-C %s config --get remote.origin.url' % repo_dir
|
||||
cmd_args = cmd_args.split(' ')
|
||||
rsp = self._run_git_command(*cmd_args)
|
||||
url = rsp.stdout.decode('utf8')
|
||||
return url.strip()
|
||||
|
||||
def list_lfs_files(self, repo_dir: str):
|
||||
cmd_args = '-C %s lfs ls-files' % repo_dir
|
||||
cmd_args = cmd_args.split(' ')
|
||||
rsp = self._run_git_command(*cmd_args)
|
||||
out = rsp.stdout.decode('utf8').strip()
|
||||
files = []
|
||||
for line in out.split(os.linesep):
|
||||
files.append(line.split(' ')[-1])
|
||||
|
||||
return files
|
||||
|
||||
def tag(self,
|
||||
repo_dir: str,
|
||||
tag_name: str,
|
||||
message: str,
|
||||
ref: str = MASTER_MODEL_BRANCH):
|
||||
cmd_args = [
|
||||
'-C', repo_dir, 'tag', tag_name, '-m',
|
||||
'"%s"' % message, ref
|
||||
]
|
||||
rsp = self._run_git_command(*cmd_args)
|
||||
logger.debug(rsp.stdout.decode('utf8'))
|
||||
return rsp
|
||||
|
||||
def push_tag(self, repo_dir: str, tag_name):
|
||||
cmd_args = ['-C', repo_dir, 'push', 'origin', tag_name]
|
||||
rsp = self._run_git_command(*cmd_args)
|
||||
logger.debug(rsp.stdout.decode('utf8'))
|
||||
return rsp
|
||||
290
modelscope/hub/repository.py
Normal file
@@ -0,0 +1,290 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException
|
||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
DEFAULT_REPOSITORY_REVISION,
|
||||
MASTER_MODEL_BRANCH)
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
from .git import GitCommandWrapper
|
||||
from .utils.utils import get_endpoint
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class Repository:
|
||||
"""A local representation of the model git repository.
|
||||
"""
|
||||
def __init__(self,
|
||||
model_dir: str,
|
||||
clone_from: str,
|
||||
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
||||
auth_token: Optional[str] = None,
|
||||
git_path: Optional[str] = None):
|
||||
"""Instantiate a Repository object by cloning the remote ModelScopeHub repo
|
||||
|
||||
Args:
|
||||
model_dir (str): The model root directory.
|
||||
clone_from (str): model id in ModelScope-hub from which git clone
|
||||
revision (str, optional): revision of the model you want to clone from.
|
||||
Can be any of a branch, tag or commit hash
|
||||
auth_token (str, optional): token obtained when calling `HubApi.login()`.
|
||||
Usually you can safely ignore the parameter as the token is already
|
||||
saved when you login the first time, if None, we will use saved token.
|
||||
git_path (str, optional): The git command line path, if None, we use 'git'
|
||||
|
||||
Raises:
|
||||
InvalidParameter: revision is None.
|
||||
"""
|
||||
self.model_dir = model_dir
|
||||
self.model_base_dir = os.path.dirname(model_dir)
|
||||
self.model_repo_name = os.path.basename(model_dir)
|
||||
|
||||
if not revision:
|
||||
err_msg = 'a non-default value of revision cannot be empty.'
|
||||
raise InvalidParameter(err_msg)
|
||||
|
||||
from modelscope.hub.api import ModelScopeConfig
|
||||
if auth_token:
|
||||
self.auth_token = auth_token
|
||||
else:
|
||||
self.auth_token = ModelScopeConfig.get_token()
|
||||
|
||||
git_wrapper = GitCommandWrapper()
|
||||
if not git_wrapper.is_lfs_installed():
|
||||
logger.error('git lfs is not installed, please install.')
|
||||
|
||||
self.git_wrapper = GitCommandWrapper(git_path)
|
||||
os.makedirs(self.model_dir, exist_ok=True)
|
||||
url = self._get_model_id_url(clone_from)
|
||||
if os.listdir(self.model_dir): # directory not empty.
|
||||
remote_url = self._get_remote_url()
|
||||
remote_url = self.git_wrapper.remove_token_from_url(remote_url)
|
||||
if remote_url and remote_url == url: # need not clone again
|
||||
return
|
||||
self.git_wrapper.clone(self.model_base_dir, self.auth_token, url,
|
||||
self.model_repo_name, revision)
|
||||
|
||||
if git_wrapper.is_lfs_installed():
|
||||
git_wrapper.git_lfs_install(self.model_dir) # init repo lfs
|
||||
|
||||
# add user info if login
|
||||
self.git_wrapper.add_user_info(self.model_base_dir,
|
||||
self.model_repo_name)
|
||||
if self.auth_token: # config remote with auth token
|
||||
self.git_wrapper.config_auth_token(self.model_dir, self.auth_token)
|
||||
|
||||
def _get_model_id_url(self, model_id):
|
||||
url = f'{get_endpoint()}/{model_id}.git'
|
||||
return url
|
||||
|
||||
def _get_remote_url(self):
|
||||
try:
|
||||
remote = self.git_wrapper.get_repo_remote_url(self.model_dir)
|
||||
except GitError:
|
||||
remote = None
|
||||
return remote
|
||||
|
||||
def push(self,
|
||||
commit_message: str,
|
||||
local_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
||||
remote_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
||||
force: Optional[bool] = False):
|
||||
"""Push local files to remote, this method will do.
|
||||
Execute git pull, git add, git commit, git push in order.
|
||||
|
||||
Args:
|
||||
commit_message (str): commit message
|
||||
local_branch(str, optional): The local branch, default master.
|
||||
remote_branch (str, optional): The remote branch to push, default master.
|
||||
force (bool, optional): whether to use forced-push.
|
||||
|
||||
Raises:
|
||||
InvalidParameter: no commit message.
|
||||
NotLoginException: no auth token.
|
||||
"""
|
||||
if commit_message is None or not isinstance(commit_message, str):
|
||||
msg = 'commit_message must be provided!'
|
||||
raise InvalidParameter(msg)
|
||||
if not isinstance(force, bool):
|
||||
raise InvalidParameter('force must be bool')
|
||||
|
||||
if not self.auth_token:
|
||||
raise NotLoginException('Must login to push, please login first.')
|
||||
|
||||
self.git_wrapper.config_auth_token(self.model_dir, self.auth_token)
|
||||
self.git_wrapper.add_user_info(self.model_base_dir,
|
||||
self.model_repo_name)
|
||||
|
||||
url = self.git_wrapper.get_repo_remote_url(self.model_dir)
|
||||
self.git_wrapper.pull(self.model_dir)
|
||||
|
||||
self.git_wrapper.add(self.model_dir, all_files=True)
|
||||
self.git_wrapper.commit(self.model_dir, commit_message)
|
||||
self.git_wrapper.push(repo_dir=self.model_dir,
|
||||
token=self.auth_token,
|
||||
url=url,
|
||||
local_branch=local_branch,
|
||||
remote_branch=remote_branch)
|
||||
|
||||
def tag(self,
|
||||
tag_name: str,
|
||||
message: str,
|
||||
ref: Optional[str] = MASTER_MODEL_BRANCH):
|
||||
"""Create a new tag.
|
||||
|
||||
Args:
|
||||
tag_name (str): The name of the tag
|
||||
message (str): The tag message.
|
||||
ref (str, optional): The tag reference, can be commit id or branch.
|
||||
|
||||
Raises:
|
||||
InvalidParameter: no commit message.
|
||||
"""
|
||||
if tag_name is None or tag_name == '':
|
||||
msg = 'We use tag-based revision, therefore tag_name cannot be None or empty.'
|
||||
raise InvalidParameter(msg)
|
||||
if message is None or message == '':
|
||||
msg = 'We use annotated tag, therefore message cannot None or empty.'
|
||||
raise InvalidParameter(msg)
|
||||
self.git_wrapper.tag(repo_dir=self.model_dir,
|
||||
tag_name=tag_name,
|
||||
message=message,
|
||||
ref=ref)
|
||||
|
||||
def tag_and_push(self,
|
||||
tag_name: str,
|
||||
message: str,
|
||||
ref: Optional[str] = MASTER_MODEL_BRANCH):
|
||||
"""Create tag and push to remote
|
||||
|
||||
Args:
|
||||
tag_name (str): The name of the tag
|
||||
message (str): The tag message.
|
||||
ref (str, optional): The tag ref, can be commit id or branch. Defaults to MASTER_MODEL_BRANCH.
|
||||
"""
|
||||
self.tag(tag_name, message, ref)
|
||||
|
||||
self.git_wrapper.push_tag(repo_dir=self.model_dir, tag_name=tag_name)
|
||||
|
||||
|
||||
class DatasetRepository:
|
||||
"""A local representation of the dataset (metadata) git repository.
|
||||
"""
|
||||
def __init__(self,
|
||||
repo_work_dir: str,
|
||||
dataset_id: str,
|
||||
revision: Optional[str] = DEFAULT_DATASET_REVISION,
|
||||
auth_token: Optional[str] = None,
|
||||
git_path: Optional[str] = None):
|
||||
"""
|
||||
Instantiate a Dataset Repository object by cloning the remote ModelScope dataset repo
|
||||
|
||||
Args:
|
||||
repo_work_dir (str): The dataset repo root directory.
|
||||
dataset_id (str): dataset id in ModelScope from which git clone
|
||||
revision (str, optional): revision of the dataset you want to clone from.
|
||||
Can be any of a branch, tag or commit hash
|
||||
auth_token (str, optional): token obtained when calling `HubApi.login()`.
|
||||
Usually you can safely ignore the parameter as the token is
|
||||
already saved when you login the first time, if None, we will use saved token.
|
||||
git_path (str, optional): The git command line path, if None, we use 'git'
|
||||
|
||||
Raises:
|
||||
InvalidParameter: parameter invalid.
|
||||
"""
|
||||
self.dataset_id = dataset_id
|
||||
if not repo_work_dir or not isinstance(repo_work_dir, str):
|
||||
err_msg = 'dataset_work_dir must be provided!'
|
||||
raise InvalidParameter(err_msg)
|
||||
self.repo_work_dir = repo_work_dir.rstrip('/')
|
||||
if not self.repo_work_dir:
|
||||
err_msg = 'dataset_work_dir can not be root dir!'
|
||||
raise InvalidParameter(err_msg)
|
||||
self.repo_base_dir = os.path.dirname(self.repo_work_dir)
|
||||
self.repo_name = os.path.basename(self.repo_work_dir)
|
||||
|
||||
if not revision:
|
||||
err_msg = 'a non-default value of revision cannot be empty.'
|
||||
raise InvalidParameter(err_msg)
|
||||
self.revision = revision
|
||||
from modelscope.hub.api import ModelScopeConfig
|
||||
if auth_token:
|
||||
self.auth_token = auth_token
|
||||
else:
|
||||
self.auth_token = ModelScopeConfig.get_token()
|
||||
|
||||
self.git_wrapper = GitCommandWrapper(git_path)
|
||||
os.makedirs(self.repo_work_dir, exist_ok=True)
|
||||
self.repo_url = self._get_repo_url(dataset_id=dataset_id)
|
||||
|
||||
def clone(self) -> str:
|
||||
# check local repo dir, directory not empty.
|
||||
if os.listdir(self.repo_work_dir):
|
||||
remote_url = self._get_remote_url()
|
||||
remote_url = self.git_wrapper.remove_token_from_url(remote_url)
|
||||
# no need clone again
|
||||
if remote_url and remote_url == self.repo_url:
|
||||
return ''
|
||||
|
||||
logger.info('Cloning repo from {} '.format(self.repo_url))
|
||||
self.git_wrapper.clone(self.repo_base_dir, self.auth_token,
|
||||
self.repo_url, self.repo_name, self.revision)
|
||||
return self.repo_work_dir
|
||||
|
||||
def push(self,
|
||||
commit_message: str,
|
||||
branch: Optional[str] = DEFAULT_DATASET_REVISION,
|
||||
force: Optional[bool] = False):
|
||||
"""Push local files to remote, this method will do.
|
||||
git pull
|
||||
git add
|
||||
git commit
|
||||
git push
|
||||
|
||||
Args:
|
||||
commit_message (str): commit message
|
||||
branch (str, optional): which branch to push.
|
||||
force (bool, optional): whether to use forced-push.
|
||||
|
||||
Raises:
|
||||
InvalidParameter: no commit message.
|
||||
NotLoginException: no access token.
|
||||
"""
|
||||
if commit_message is None or not isinstance(commit_message, str):
|
||||
msg = 'commit_message must be provided!'
|
||||
raise InvalidParameter(msg)
|
||||
|
||||
if not isinstance(force, bool):
|
||||
raise InvalidParameter('force must be bool')
|
||||
|
||||
if not self.auth_token:
|
||||
raise NotLoginException('Must login to push, please login first.')
|
||||
|
||||
self.git_wrapper.config_auth_token(self.repo_work_dir, self.auth_token)
|
||||
self.git_wrapper.add_user_info(self.repo_base_dir, self.repo_name)
|
||||
|
||||
remote_url = self._get_remote_url()
|
||||
remote_url = self.git_wrapper.remove_token_from_url(remote_url)
|
||||
|
||||
self.git_wrapper.pull(self.repo_work_dir)
|
||||
self.git_wrapper.add(self.repo_work_dir, all_files=True)
|
||||
self.git_wrapper.commit(self.repo_work_dir, commit_message)
|
||||
self.git_wrapper.push(repo_dir=self.repo_work_dir,
|
||||
token=self.auth_token,
|
||||
url=remote_url,
|
||||
local_branch=branch,
|
||||
remote_branch=branch)
|
||||
|
||||
def _get_repo_url(self, dataset_id):
|
||||
return f'{get_endpoint()}/datasets/{dataset_id}.git'
|
||||
|
||||
def _get_remote_url(self):
|
||||
try:
|
||||
remote = self.git_wrapper.get_repo_remote_url(self.repo_work_dir)
|
||||
except GitError:
|
||||
remote = None
|
||||
return remote
|
||||
151
modelscope/hub/snapshot_download.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from http.cookiejar import CookieJar
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
from .constants import FILE_HASH
|
||||
from .file_download import get_file_download_url, http_get_file
|
||||
from .utils.caching import ModelFileSystemCache
|
||||
from .utils.utils import (file_integrity_validation, get_cache_dir,
|
||||
model_id_to_group_owner_name)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def snapshot_download(model_id: str,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
ignore_file_pattern: List = None) -> str:
|
||||
"""Download all files of a repo.
|
||||
Downloads a whole snapshot of a repo's files at the specified revision. This
|
||||
is useful when you want all files from a repo, because you don't know which
|
||||
ones you will need a priori. All files are nested inside a folder in order
|
||||
to keep their actual filename relative to that folder.
|
||||
|
||||
An alternative would be to just clone a repo but this would require that the
|
||||
user always has git and git-lfs installed, and properly configured.
|
||||
|
||||
Args:
|
||||
model_id (str): A user or an organization name and a repo name separated by a `/`.
|
||||
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash. NOTE: currently only branch and tag name is supported
|
||||
cache_dir (str, Path, optional): Path to the folder where cached files are stored.
|
||||
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
|
||||
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
|
||||
local cached file if it exists.
|
||||
cookies (CookieJar, optional): The cookie of the request, default None.
|
||||
ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
|
||||
Any file pattern to be ignored in downloading, like exact file names or file extensions.
|
||||
Raises:
|
||||
ValueError: the value details.
|
||||
|
||||
Returns:
|
||||
str: Local folder path (string) of repo snapshot
|
||||
|
||||
Note:
|
||||
Raises the following errors:
|
||||
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
||||
if `use_auth_token=True` and the token cannot be found.
|
||||
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
|
||||
ETag cannot be determined.
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
|
||||
if cache_dir is None:
|
||||
cache_dir = get_cache_dir()
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
temporary_cache_dir = os.path.join(cache_dir, 'temp')
|
||||
os.makedirs(temporary_cache_dir, exist_ok=True)
|
||||
|
||||
group_or_owner, name = model_id_to_group_owner_name(model_id)
|
||||
|
||||
cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
|
||||
if local_files_only:
|
||||
if len(cache.cached_files) == 0:
|
||||
raise ValueError(
|
||||
'Cannot find the requested files in the cached path and outgoing'
|
||||
' traffic has been disabled. To enable model look-ups and downloads'
|
||||
" online, set 'local_files_only' to False.")
|
||||
logger.warning(
|
||||
'We can not confirm the cached file is for revision: %s' %
|
||||
revision)
|
||||
return cache.get_root_location(
|
||||
) # we can not confirm the cached file is for snapshot 'revision'
|
||||
else:
|
||||
# make headers
|
||||
headers = {
|
||||
'user-agent':
|
||||
ModelScopeConfig.get_user_agent(user_agent=user_agent, )
|
||||
}
|
||||
_api = HubApi()
|
||||
if cookies is None:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
revision = _api.get_valid_revision(model_id,
|
||||
revision=revision,
|
||||
cookies=cookies)
|
||||
|
||||
snapshot_header = headers if 'CI_TEST' in os.environ else {
|
||||
**headers,
|
||||
**{
|
||||
'Snapshot': 'True'
|
||||
}
|
||||
}
|
||||
model_files = _api.get_model_files(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
recursive=True,
|
||||
use_cookies=False if cookies is None else cookies,
|
||||
headers=snapshot_header,
|
||||
)
|
||||
|
||||
if ignore_file_pattern is None:
|
||||
ignore_file_pattern = []
|
||||
if isinstance(ignore_file_pattern, str):
|
||||
ignore_file_pattern = [ignore_file_pattern]
|
||||
|
||||
with tempfile.TemporaryDirectory(
|
||||
dir=temporary_cache_dir) as temp_cache_dir:
|
||||
for model_file in model_files:
|
||||
if model_file['Type'] == 'tree' or \
|
||||
any([re.search(pattern, model_file['Name']) is not None for pattern in ignore_file_pattern]):
|
||||
continue
|
||||
# check model_file is exist in cache, if existed, skip download, otherwise download
|
||||
if cache.exists(model_file):
|
||||
file_name = os.path.basename(model_file['Name'])
|
||||
logger.info(
|
||||
f'File {file_name} already in cache, skip downloading!'
|
||||
)
|
||||
continue
|
||||
|
||||
# get download url
|
||||
url = get_file_download_url(model_id=model_id,
|
||||
file_path=model_file['Path'],
|
||||
revision=revision)
|
||||
|
||||
# First download to /tmp
|
||||
http_get_file(url=url,
|
||||
local_dir=temp_cache_dir,
|
||||
file_name=model_file['Name'],
|
||||
headers=headers,
|
||||
cookies=cookies)
|
||||
# check file integrity
|
||||
temp_file = os.path.join(temp_cache_dir, model_file['Name'])
|
||||
if FILE_HASH in model_file:
|
||||
file_integrity_validation(temp_file, model_file[FILE_HASH])
|
||||
# put file to cache
|
||||
cache.put_file(model_file, temp_file)
|
||||
|
||||
return os.path.join(cache.get_root_location())
|
||||
0
modelscope/hub/utils/__init__.py
Normal file
290
modelscope/hub/utils/caching.py
Normal file
@@ -0,0 +1,290 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
from shutil import move, rmtree
|
||||
|
||||
from modelscope.hub.constants import MODEL_META_FILE_NAME, MODEL_META_MODEL_ID
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
"""Implements caching functionality, used internally only
|
||||
"""
|
||||
|
||||
|
||||
class FileSystemCache(object):
|
||||
KEY_FILE_NAME = '.msc'
|
||||
"""Local file cache.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
cache_root_location: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""Base file system cache interface.
|
||||
|
||||
Args:
|
||||
cache_root_location (str): The root location to store files.
|
||||
kwargs(dict): The keyword arguments.
|
||||
"""
|
||||
os.makedirs(cache_root_location, exist_ok=True)
|
||||
self.cache_root_location = cache_root_location
|
||||
self.load_cache()
|
||||
|
||||
def get_root_location(self):
|
||||
return self.cache_root_location
|
||||
|
||||
def load_cache(self):
|
||||
self.cached_files = []
|
||||
cache_keys_file_path = os.path.join(self.cache_root_location,
|
||||
FileSystemCache.KEY_FILE_NAME)
|
||||
if os.path.exists(cache_keys_file_path):
|
||||
with open(cache_keys_file_path, 'rb') as f:
|
||||
self.cached_files = pickle.load(f)
|
||||
|
||||
def save_cached_files(self):
|
||||
"""Save cache metadata."""
|
||||
# save new meta to tmp and move to KEY_FILE_NAME
|
||||
cache_keys_file_path = os.path.join(self.cache_root_location,
|
||||
FileSystemCache.KEY_FILE_NAME)
|
||||
# TODO: Sync file write
|
||||
fd, fn = tempfile.mkstemp()
|
||||
with open(fd, 'wb') as f:
|
||||
pickle.dump(self.cached_files, f)
|
||||
move(fn, cache_keys_file_path)
|
||||
|
||||
def get_file(self, key):
|
||||
"""Check the key is in the cache, if exist, return the file, otherwise return None.
|
||||
|
||||
Args:
|
||||
key(str): The cache key.
|
||||
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
pass
|
||||
|
||||
def put_file(self, key, location):
|
||||
"""Put file to the cache.
|
||||
|
||||
Args:
|
||||
key (str): The cache key
|
||||
location (str): Location of the file, we will move the file to cache.
|
||||
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
pass
|
||||
|
||||
def remove_key(self, key):
|
||||
"""Remove cache key in index, The file is removed manually
|
||||
|
||||
Args:
|
||||
key (dict): The cache key.
|
||||
"""
|
||||
if key in self.cached_files:
|
||||
self.cached_files.remove(key)
|
||||
self.save_cached_files()
|
||||
|
||||
def exists(self, key):
|
||||
for cache_file in self.cached_files:
|
||||
if cache_file == key:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def clear_cache(self):
|
||||
"""Remove all files and metadata from the cache
|
||||
In the case of multiple cache locations, this clears only the last one,
|
||||
which is assumed to be the read/write one.
|
||||
"""
|
||||
rmtree(self.cache_root_location)
|
||||
self.load_cache()
|
||||
|
||||
def hash_name(self, key):
|
||||
return hashlib.sha256(key.encode()).hexdigest()
|
||||
|
||||
|
||||
class ModelFileSystemCache(FileSystemCache):
|
||||
"""Local cache file layout
|
||||
cache_root/owner/model_name/individual cached files and cache index file '.mcs'
|
||||
Save only one version for each file.
|
||||
"""
|
||||
def __init__(self, cache_root, owner=None, name=None):
|
||||
"""Put file to the cache
|
||||
Args:
|
||||
cache_root(`str`): The modelscope local cache root(default: ~/.cache/modelscope/)
|
||||
owner(`str`): The model owner.
|
||||
name('str'): The name of the model
|
||||
Returns:
|
||||
Raises:
|
||||
None
|
||||
<Tip>
|
||||
model_id = {owner}/{name}
|
||||
</Tip>
|
||||
"""
|
||||
if owner is None or name is None:
|
||||
# get model meta from
|
||||
super().__init__(os.path.join(cache_root))
|
||||
self.load_model_meta()
|
||||
else:
|
||||
super().__init__(os.path.join(cache_root, owner, name))
|
||||
self.model_meta = {MODEL_META_MODEL_ID: '%s/%s' % (owner, name)}
|
||||
self.save_model_meta()
|
||||
|
||||
def load_model_meta(self):
|
||||
meta_file_path = os.path.join(self.cache_root_location,
|
||||
MODEL_META_FILE_NAME)
|
||||
if os.path.exists(meta_file_path):
|
||||
with open(meta_file_path, 'rb') as f:
|
||||
self.model_meta = pickle.load(f)
|
||||
else:
|
||||
self.model_meta = {MODEL_META_MODEL_ID: 'unknown'}
|
||||
|
||||
def get_model_id(self):
|
||||
return self.model_meta[MODEL_META_MODEL_ID]
|
||||
|
||||
def save_model_meta(self):
|
||||
meta_file_path = os.path.join(self.cache_root_location,
|
||||
MODEL_META_FILE_NAME)
|
||||
with open(meta_file_path, 'wb') as f:
|
||||
pickle.dump(self.model_meta, f)
|
||||
|
||||
def get_file_by_path(self, file_path):
|
||||
"""Retrieve the cache if there is file match the path.
|
||||
|
||||
Args:
|
||||
file_path (str): The file path in the model.
|
||||
|
||||
Returns:
|
||||
path: the full path of the file.
|
||||
"""
|
||||
for cached_file in self.cached_files:
|
||||
if file_path == cached_file['Path']:
|
||||
cached_file_path = os.path.join(self.cache_root_location,
|
||||
cached_file['Path'])
|
||||
if os.path.exists(cached_file_path):
|
||||
return cached_file_path
|
||||
else:
|
||||
self.remove_key(cached_file)
|
||||
|
||||
return None
|
||||
|
||||
def get_file_by_path_and_commit_id(self, file_path, commit_id):
|
||||
"""Retrieve the cache if there is file match the path.
|
||||
|
||||
Args:
|
||||
file_path (str): The file path in the model.
|
||||
commit_id (str): The commit id of the file
|
||||
|
||||
Returns:
|
||||
path: the full path of the file.
|
||||
"""
|
||||
for cached_file in self.cached_files:
|
||||
if file_path == cached_file['Path'] and \
|
||||
(cached_file['Revision'].startswith(commit_id) or commit_id.startswith(cached_file['Revision'])):
|
||||
cached_file_path = os.path.join(self.cache_root_location,
|
||||
cached_file['Path'])
|
||||
if os.path.exists(cached_file_path):
|
||||
return cached_file_path
|
||||
else:
|
||||
self.remove_key(cached_file)
|
||||
|
||||
return None
|
||||
|
||||
def get_file_by_info(self, model_file_info):
|
||||
"""Check if exist cache file.
|
||||
|
||||
Args:
|
||||
model_file_info (ModelFileInfo): The file information of the file.
|
||||
|
||||
Returns:
|
||||
str: The file path.
|
||||
"""
|
||||
cache_key = self.__get_cache_key(model_file_info)
|
||||
for cached_file in self.cached_files:
|
||||
if cached_file == cache_key:
|
||||
orig_path = os.path.join(self.cache_root_location,
|
||||
cached_file['Path'])
|
||||
if os.path.exists(orig_path):
|
||||
return orig_path
|
||||
else:
|
||||
self.remove_key(cached_file)
|
||||
break
|
||||
|
||||
return None
|
||||
|
||||
def __get_cache_key(self, model_file_info):
|
||||
cache_key = {
|
||||
'Path': model_file_info['Path'],
|
||||
'Revision': model_file_info['Revision'], # commit id
|
||||
}
|
||||
return cache_key
|
||||
|
||||
def exists(self, model_file_info):
|
||||
"""Check the file is cached or not.
|
||||
|
||||
Args:
|
||||
model_file_info (CachedFileInfo): The cached file info
|
||||
|
||||
Returns:
|
||||
bool: If exists return True otherwise False
|
||||
"""
|
||||
key = self.__get_cache_key(model_file_info)
|
||||
is_exists = False
|
||||
for cached_key in self.cached_files:
|
||||
if cached_key['Path'] == key['Path'] and (
|
||||
cached_key['Revision'].startswith(key['Revision'])
|
||||
or key['Revision'].startswith(cached_key['Revision'])):
|
||||
is_exists = True
|
||||
break
|
||||
file_path = os.path.join(self.cache_root_location,
|
||||
model_file_info['Path'])
|
||||
if is_exists:
|
||||
if os.path.exists(file_path):
|
||||
return True
|
||||
else:
|
||||
self.remove_key(
|
||||
model_file_info) # someone may manual delete the file
|
||||
return False
|
||||
|
||||
def remove_if_exists(self, model_file_info):
|
||||
"""We in cache, remove it.
|
||||
|
||||
Args:
|
||||
model_file_info (ModelFileInfo): The model file information from server.
|
||||
"""
|
||||
for cached_file in self.cached_files:
|
||||
if cached_file['Path'] == model_file_info['Path']:
|
||||
self.remove_key(cached_file)
|
||||
file_path = os.path.join(self.cache_root_location,
|
||||
cached_file['Path'])
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
break
|
||||
|
||||
def put_file(self, model_file_info, model_file_location):
|
||||
"""Put model on model_file_location to cache, the model first download to /tmp, and move to cache.
|
||||
|
||||
Args:
|
||||
model_file_info (str): The file description returned by get_model_files.
|
||||
model_file_location (str): The location of the temporary file.
|
||||
|
||||
Returns:
|
||||
str: The location of the cached file.
|
||||
"""
|
||||
self.remove_if_exists(model_file_info) # backup old revision
|
||||
cache_key = self.__get_cache_key(model_file_info)
|
||||
cache_full_path = os.path.join(
|
||||
self.cache_root_location,
|
||||
cache_key['Path']) # Branch and Tag do not have same name.
|
||||
cache_file_dir = os.path.dirname(cache_full_path)
|
||||
if not os.path.exists(cache_file_dir):
|
||||
os.makedirs(cache_file_dir, exist_ok=True)
|
||||
# We can't make operation transaction
|
||||
move(model_file_location, cache_full_path)
|
||||
self.cached_files.append(cache_key)
|
||||
self.save_cached_files()
|
||||
return cache_full_path
|
||||
94
modelscope/hub/utils/utils.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
|
||||
DEFAULT_MODELSCOPE_GROUP,
|
||||
MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG,
|
||||
MODELSCOPE_URL_SCHEME)
|
||||
from modelscope.hub.errors import FileIntegrityError
|
||||
from modelscope.utils.file_utils import get_default_cache_dir
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def model_id_to_group_owner_name(model_id):
|
||||
if MODEL_ID_SEPARATOR in model_id:
|
||||
group_or_owner = model_id.split(MODEL_ID_SEPARATOR)[0]
|
||||
name = model_id.split(MODEL_ID_SEPARATOR)[1]
|
||||
else:
|
||||
group_or_owner = DEFAULT_MODELSCOPE_GROUP
|
||||
name = model_id
|
||||
return group_or_owner, name
|
||||
|
||||
|
||||
def get_cache_dir(model_id: Optional[str] = None):
|
||||
"""cache dir precedence:
|
||||
function parameter > environment > ~/.cache/modelscope/hub
|
||||
|
||||
Args:
|
||||
model_id (str, optional): The model id.
|
||||
|
||||
Returns:
|
||||
str: the model_id dir if model_id not None, otherwise cache root dir.
|
||||
"""
|
||||
default_cache_dir = get_default_cache_dir()
|
||||
base_path = os.getenv('MODELSCOPE_CACHE',
|
||||
os.path.join(default_cache_dir, 'hub'))
|
||||
return base_path if model_id is None else os.path.join(
|
||||
base_path, model_id + '/')
|
||||
|
||||
|
||||
def get_release_datetime():
|
||||
if MODELSCOPE_SDK_DEBUG in os.environ:
|
||||
rt = int(round(datetime.now().timestamp()))
|
||||
else:
|
||||
from modelscope import version
|
||||
rt = int(
|
||||
round(
|
||||
datetime.strptime(version.__release_datetime__,
|
||||
'%Y-%m-%d %H:%M:%S').timestamp()))
|
||||
return rt
|
||||
|
||||
|
||||
def get_endpoint():
|
||||
modelscope_domain = os.getenv('MODELSCOPE_DOMAIN',
|
||||
DEFAULT_MODELSCOPE_DOMAIN)
|
||||
return MODELSCOPE_URL_SCHEME + modelscope_domain
|
||||
|
||||
|
||||
def compute_hash(file_path):
|
||||
BUFFER_SIZE = 1024 * 64 # 64k buffer size
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, 'rb') as f:
|
||||
while True:
|
||||
data = f.read(BUFFER_SIZE)
|
||||
if not data:
|
||||
break
|
||||
sha256_hash.update(data)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
|
||||
def file_integrity_validation(file_path, expected_sha256):
|
||||
"""Validate the file hash is expected, if not, delete the file
|
||||
|
||||
Args:
|
||||
file_path (str): The file to validate
|
||||
expected_sha256 (str): The expected sha256 hash
|
||||
|
||||
Raises:
|
||||
FileIntegrityError: If file_path hash is not expected.
|
||||
|
||||
"""
|
||||
file_sha256 = compute_hash(file_path)
|
||||
if not file_sha256 == expected_sha256:
|
||||
os.remove(file_path)
|
||||
msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path
|
||||
logger.error(msg)
|
||||
raise FileIntegrityError(msg)
|
||||
299
modelscope/metainfo.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from modelscope.utils.constant import Fields, Tasks
|
||||
|
||||
|
||||
class Models(object):
|
||||
""" Names for different models.
|
||||
|
||||
Holds the standard model name to use for identifying different model.
|
||||
This should be used to register models.
|
||||
|
||||
Model name should only contain model information but not task information.
|
||||
"""
|
||||
tinynas_damoyolo = 'tinynas-damoyolo'
|
||||
# face models
|
||||
scrfd = 'scrfd'
|
||||
face_2d_keypoints = 'face-2d-keypoints'
|
||||
fer = 'fer'
|
||||
fairface = 'fairface'
|
||||
retinaface = 'retinaface'
|
||||
mogface = 'mogface'
|
||||
mtcnn = 'mtcnn'
|
||||
ulfd = 'ulfd'
|
||||
rts = 'rts'
|
||||
flir = 'flir'
|
||||
arcface = 'arcface'
|
||||
facemask = 'facemask'
|
||||
flc = 'flc'
|
||||
tinymog = 'tinymog'
|
||||
damofd = 'damofd'
|
||||
|
||||
|
||||
class TaskModels(object):
|
||||
pass
|
||||
class Heads(object):
|
||||
pass
|
||||
|
||||
class Pipelines(object):
|
||||
""" Names for different pipelines.
|
||||
|
||||
Holds the standard pipline name to use for identifying different pipeline.
|
||||
This should be used to register pipelines.
|
||||
|
||||
For pipeline which support different models and implements the common function, we
|
||||
should use task name for this pipeline.
|
||||
For pipeline which suuport only one model, we should use ${Model}-${Task} as its name.
|
||||
"""
|
||||
# vision tasks
|
||||
face_2d_keypoints = 'mobilenet_face-2d-keypoints_alignment'
|
||||
salient_detection = 'u2net-salient-detection'
|
||||
salient_boudary_detection = 'res2net-salient-detection'
|
||||
camouflaged_detection = 'res2net-camouflaged-detection'
|
||||
image_demoire = 'uhdm-image-demoireing'
|
||||
image_classification = 'image-classification'
|
||||
face_detection = 'resnet-face-detection-scrfd10gkps'
|
||||
face_liveness_ir = 'manual-face-liveness-flir'
|
||||
face_liveness_rgb = 'manual-face-liveness-flir'
|
||||
face_liveness_xc = 'manual-face-liveness-flxc'
|
||||
card_detection = 'resnet-card-detection-scrfd34gkps'
|
||||
ulfd_face_detection = 'manual-face-detection-ulfd'
|
||||
tinymog_face_detection = 'manual-face-detection-tinymog'
|
||||
facial_expression_recognition = 'vgg19-facial-expression-recognition-fer'
|
||||
facial_landmark_confidence = 'manual-facial-landmark-confidence-flcm'
|
||||
face_attribute_recognition = 'resnet34-face-attribute-recognition-fairface'
|
||||
retina_face_detection = 'resnet50-face-detection-retinaface'
|
||||
mog_face_detection = 'resnet101-face-detection-cvpr22papermogface'
|
||||
mtcnn_face_detection = 'manual-face-detection-mtcnn'
|
||||
face_recognition = 'ir101-face-recognition-cfglint'
|
||||
face_recognition_ood = 'ir-face-recognition-ood-rts'
|
||||
face_quality_assessment = 'manual-face-quality-assessment-fqa'
|
||||
face_recognition_ood = 'ir-face-recognition-rts'
|
||||
face_recognition_onnx_ir = 'manual-face-recognition-frir'
|
||||
face_recognition_onnx_fm = 'manual-face-recognition-frfm'
|
||||
arc_face_recognition = 'ir50-face-recognition-arcface'
|
||||
mask_face_recognition = 'resnet-face-recognition-facemask'
|
||||
|
||||
|
||||
DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
# TaskName: (pipeline_module_name, model_repo)
|
||||
Tasks.face_detection:
|
||||
(Pipelines.mog_face_detection,
|
||||
'damo/cv_resnet101_face-detection_cvpr22papermogface'),
|
||||
Tasks.face_liveness: (Pipelines.face_liveness_ir,
|
||||
'damo/cv_manual_face-liveness_flir'),
|
||||
Tasks.face_recognition: (Pipelines.face_recognition,
|
||||
'damo/cv_ir101_facerecognition_cfglint'),
|
||||
Tasks.facial_expression_recognition:
|
||||
(Pipelines.facial_expression_recognition,
|
||||
'damo/cv_vgg19_facial-expression-recognition_fer'),
|
||||
Tasks.face_attribute_recognition:
|
||||
(Pipelines.face_attribute_recognition,
|
||||
'damo/cv_resnet34_face-attribute-recognition_fairface'),
|
||||
Tasks.face_2d_keypoints: (Pipelines.face_2d_keypoints,
|
||||
'damo/cv_mobilenet_face-2d-keypoints_alignment'),
|
||||
Tasks.face_quality_assessment:
|
||||
(Pipelines.face_quality_assessment,
|
||||
'damo/cv_manual_face-quality-assessment_fqa'),
|
||||
}
|
||||
class CVTrainers(object):
|
||||
face_detection_scrfd = 'face-detection-scrfd'
|
||||
|
||||
|
||||
class Trainers(CVTrainers):
|
||||
""" Names for different trainer.
|
||||
|
||||
Holds the standard trainer name to use for identifying different trainer.
|
||||
This should be used to register trainers.
|
||||
|
||||
For a general Trainer, you can use EpochBasedTrainer.
|
||||
For a model specific Trainer, you can use ${ModelName}-${Task}-trainer.
|
||||
"""
|
||||
|
||||
default = 'trainer'
|
||||
easycv = 'easycv'
|
||||
tinynas_damoyolo = 'tinynas-damoyolo'
|
||||
|
||||
@staticmethod
|
||||
def get_trainer_domain(attribute_or_value):
|
||||
if attribute_or_value in vars(
|
||||
CVTrainers) or attribute_or_value in vars(CVTrainers).values():
|
||||
return Fields.cv
|
||||
elif attribute_or_value in vars(
|
||||
NLPTrainers) or attribute_or_value in vars(
|
||||
NLPTrainers).values():
|
||||
return Fields.nlp
|
||||
elif attribute_or_value in vars(
|
||||
AudioTrainers) or attribute_or_value in vars(
|
||||
AudioTrainers).values():
|
||||
return Fields.audio
|
||||
elif attribute_or_value in vars(
|
||||
MultiModalTrainers) or attribute_or_value in vars(
|
||||
MultiModalTrainers).values():
|
||||
return Fields.multi_modal
|
||||
elif attribute_or_value == Trainers.default:
|
||||
return Trainers.default
|
||||
elif attribute_or_value == Trainers.easycv:
|
||||
return Trainers.easycv
|
||||
else:
|
||||
return 'unknown'
|
||||
|
||||
|
||||
class Preprocessors(object):
|
||||
""" Names for different preprocessor.
|
||||
|
||||
Holds the standard preprocessor name to use for identifying different preprocessor.
|
||||
This should be used to register preprocessors.
|
||||
|
||||
For a general preprocessor, just use the function name as preprocessor name such as
|
||||
resize-image, random-crop
|
||||
For a model-specific preprocessor, use ${modelname}-${fuction}
|
||||
"""
|
||||
|
||||
# cv preprocessor
|
||||
load_image = 'load-image'
|
||||
image_denoise_preprocessor = 'image-denoise-preprocessor'
|
||||
image_deblur_preprocessor = 'image-deblur-preprocessor'
|
||||
object_detection_tinynas_preprocessor = 'object-detection-tinynas-preprocessor'
|
||||
image_classification_mmcv_preprocessor = 'image-classification-mmcv-preprocessor'
|
||||
image_color_enhance_preprocessor = 'image-color-enhance-preprocessor'
|
||||
image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor'
|
||||
image_driving_perception_preprocessor = 'image-driving-perception-preprocessor'
|
||||
image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor'
|
||||
image_quality_assessment_mos_preprocessor = 'image-quality_assessment-mos-preprocessor'
|
||||
video_summarization_preprocessor = 'video-summarization-preprocessor'
|
||||
movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor'
|
||||
image_classification_bypass_preprocessor = 'image-classification-bypass-preprocessor'
|
||||
object_detection_scrfd = 'object-detection-scrfd'
|
||||
|
||||
|
||||
|
||||
class Metrics(object):
|
||||
""" Names for different metrics.
|
||||
"""
|
||||
|
||||
# accuracy
|
||||
accuracy = 'accuracy'
|
||||
|
||||
multi_average_precision = 'mAP'
|
||||
audio_noise_metric = 'audio-noise-metric'
|
||||
PPL = 'ppl'
|
||||
|
||||
# text gen
|
||||
BLEU = 'bleu'
|
||||
|
||||
# metrics for image denoise task
|
||||
image_denoise_metric = 'image-denoise-metric'
|
||||
# metrics for video frame-interpolation task
|
||||
video_frame_interpolation_metric = 'video-frame-interpolation-metric'
|
||||
# metrics for real-world video super-resolution task
|
||||
video_super_resolution_metric = 'video-super-resolution-metric'
|
||||
|
||||
# metric for image instance segmentation task
|
||||
image_ins_seg_coco_metric = 'image-ins-seg-coco-metric'
|
||||
# metrics for sequence classification task
|
||||
seq_cls_metric = 'seq-cls-metric'
|
||||
# loss metric
|
||||
loss_metric = 'loss-metric'
|
||||
# metrics for token-classification task
|
||||
token_cls_metric = 'token-cls-metric'
|
||||
# metrics for text-generation task
|
||||
text_gen_metric = 'text-gen-metric'
|
||||
# file saving wrapper
|
||||
prediction_saving_wrapper = 'prediction-saving-wrapper'
|
||||
# metrics for image-color-enhance task
|
||||
image_color_enhance_metric = 'image-color-enhance-metric'
|
||||
# metrics for image-portrait-enhancement task
|
||||
image_portrait_enhancement_metric = 'image-portrait-enhancement-metric'
|
||||
video_summarization_metric = 'video-summarization-metric'
|
||||
# metric for movie-scene-segmentation task
|
||||
movie_scene_segmentation_metric = 'movie-scene-segmentation-metric'
|
||||
# metric for inpainting task
|
||||
image_inpainting_metric = 'image-inpainting-metric'
|
||||
# metric for ocr
|
||||
NED = 'ned'
|
||||
# metric for cross-modal retrieval
|
||||
inbatch_recall = 'inbatch_recall'
|
||||
# metric for referring-video-object-segmentation task
|
||||
referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric'
|
||||
# metric for video stabilization task
|
||||
video_stabilization_metric = 'video-stabilization-metric'
|
||||
# metirc for image-quality-assessment-mos task
|
||||
image_quality_assessment_mos_metric = 'image-quality-assessment-mos-metric'
|
||||
# metirc for image-quality-assessment-degradation task
|
||||
image_quality_assessment_degradation_metric = 'image-quality-assessment-degradation-metric'
|
||||
# metric for text-ranking task
|
||||
text_ranking_metric = 'text-ranking-metric'
|
||||
|
||||
|
||||
class Optimizers(object):
|
||||
""" Names for different OPTIMIZER.
|
||||
|
||||
Holds the standard optimizer name to use for identifying different optimizer.
|
||||
This should be used to register optimizer.
|
||||
"""
|
||||
|
||||
default = 'optimizer'
|
||||
|
||||
SGD = 'SGD'
|
||||
|
||||
|
||||
class Hooks(object):
|
||||
""" Names for different hooks.
|
||||
|
||||
All kinds of hooks are defined here
|
||||
"""
|
||||
# lr
|
||||
LrSchedulerHook = 'LrSchedulerHook'
|
||||
PlateauLrSchedulerHook = 'PlateauLrSchedulerHook'
|
||||
NoneLrSchedulerHook = 'NoneLrSchedulerHook'
|
||||
|
||||
# optimizer
|
||||
OptimizerHook = 'OptimizerHook'
|
||||
TorchAMPOptimizerHook = 'TorchAMPOptimizerHook'
|
||||
ApexAMPOptimizerHook = 'ApexAMPOptimizerHook'
|
||||
NoneOptimizerHook = 'NoneOptimizerHook'
|
||||
|
||||
# checkpoint
|
||||
CheckpointHook = 'CheckpointHook'
|
||||
BestCkptSaverHook = 'BestCkptSaverHook'
|
||||
LoadCheckpointHook = 'LoadCheckpointHook'
|
||||
|
||||
# logger
|
||||
TextLoggerHook = 'TextLoggerHook'
|
||||
TensorboardHook = 'TensorboardHook'
|
||||
|
||||
IterTimerHook = 'IterTimerHook'
|
||||
EvaluationHook = 'EvaluationHook'
|
||||
|
||||
# Compression
|
||||
SparsityHook = 'SparsityHook'
|
||||
|
||||
# CLIP logit_scale clamp
|
||||
ClipClampLogitScaleHook = 'ClipClampLogitScaleHook'
|
||||
|
||||
# train
|
||||
EarlyStopHook = 'EarlyStopHook'
|
||||
DeepspeedHook = 'DeepspeedHook'
|
||||
|
||||
|
||||
class LR_Schedulers(object):
|
||||
"""learning rate scheduler is defined here
|
||||
|
||||
"""
|
||||
LinearWarmup = 'LinearWarmup'
|
||||
ConstantWarmup = 'ConstantWarmup'
|
||||
ExponentialWarmup = 'ExponentialWarmup'
|
||||
|
||||
|
||||
class Datasets(object):
|
||||
""" Names for different datasets.
|
||||
"""
|
||||
ClsDataset = 'ClsDataset'
|
||||
Face2dKeypointsDataset = 'FaceKeypointDataset'
|
||||
HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset'
|
||||
HumanWholeBodyKeypointDataset = 'WholeBodyCocoTopDownDataset'
|
||||
SegDataset = 'SegDataset'
|
||||
DetDataset = 'DetDataset'
|
||||
DetImagesMixDataset = 'DetImagesMixDataset'
|
||||
PanopticDataset = 'PanopticDataset'
|
||||
PairedDataset = 'PairedDataset'
|
||||
14
modelscope/models/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.error import (AUDIO_IMPORT_ERROR,
|
||||
TENSORFLOW_IMPORT_WARNING)
|
||||
from modelscope.utils.import_utils import is_torch_available
|
||||
|
||||
from . import cv
|
||||
from .base import Head, Model
|
||||
from .builder import BACKBONES, HEADS, MODELS, build_model
|
||||
|
||||
if is_torch_available():
|
||||
from .base.base_torch_model import TorchModel
|
||||
from .base.base_torch_head import TorchHead
|
||||
10
modelscope/models/base/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.utils.import_utils import is_torch_available
|
||||
|
||||
from .base_head import * # noqa F403
|
||||
from .base_model import * # noqa F403
|
||||
|
||||
if is_torch_available():
|
||||
from .base_torch_model import TorchModel
|
||||
from .base_torch_head import TorchHead
|
||||
39
modelscope/models/base/base_head.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from modelscope.models.base.base_model import Model
|
||||
from modelscope.utils.config import ConfigDict
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
Tensor = Union['torch.Tensor', 'tf.Tensor']
|
||||
Input = Union[Dict[str, Tensor], Model]
|
||||
|
||||
|
||||
class Head(ABC):
|
||||
"""The head base class is for the tasks head method definition
|
||||
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
self.config = ConfigDict(kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
This method will use the output from backbone model to do any
|
||||
downstream tasks. Receive The output from backbone model.
|
||||
|
||||
Returns (Dict[str, Any]): The output from downstream task.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
compute loss for head during the finetuning.
|
||||
|
||||
Returns (Dict[str, Any]): The loss dict
|
||||
"""
|
||||
pass
|
||||
167
modelscope/models/base/base_model.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import os.path as osp
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from modelscope.hub.check_model import check_local_model_is_latest
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models.builder import build_model
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile
|
||||
from modelscope.utils.device import verify_device
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
Tensor = Union['torch.Tensor', 'tf.Tensor']
|
||||
|
||||
|
||||
class Model(ABC):
|
||||
"""Base model interface.
|
||||
"""
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
self.model_dir = model_dir
|
||||
device_name = kwargs.get('device', 'gpu')
|
||||
verify_device(device_name)
|
||||
self._device_name = device_name
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
return self.postprocess(self.forward(*args, **kwargs))
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Run the forward pass for a model.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: output from the model forward pass
|
||||
"""
|
||||
pass
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
""" Model specific postprocess and convert model output to
|
||||
standard model outputs.
|
||||
|
||||
Args:
|
||||
inputs: input data
|
||||
|
||||
Return:
|
||||
dict of results: a dict containing outputs of model, each
|
||||
output should have the standard output name.
|
||||
"""
|
||||
return inputs
|
||||
|
||||
@classmethod
|
||||
def _instantiate(cls, **kwargs):
|
||||
""" Define the instantiation method of a model,default method is by
|
||||
calling the constructor. Note that in the case of no loading model
|
||||
process in constructor of a task model, a load_model method is
|
||||
added, and thus this method is overloaded
|
||||
"""
|
||||
return cls(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
model_name_or_path: str,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
cfg_dict: Config = None,
|
||||
device: str = None,
|
||||
**kwargs):
|
||||
"""Instantiate a model from local directory or remote model repo. Note
|
||||
that when loading from remote, the model revision can be specified.
|
||||
|
||||
Args:
|
||||
model_name_or_path(str): A model dir or a model id to be loaded
|
||||
revision(str, `optional`): The revision used when the model_name_or_path is
|
||||
a model id of the remote hub. default `master`.
|
||||
cfg_dict(Config, `optional`): An optional model config. If provided, it will replace
|
||||
the config read out of the `model_name_or_path`
|
||||
device(str, `optional`): The device to load the model.
|
||||
**kwargs:
|
||||
task(str, `optional`): The `Tasks` enumeration value to replace the task value
|
||||
read out of config in the `model_name_or_path`. This is useful when the model to be loaded is not
|
||||
equal to the model saved.
|
||||
For example, load a `backbone` into a `text-classification` model.
|
||||
Other kwargs will be directly fed into the `model` key, to replace the default configs.
|
||||
Returns:
|
||||
A model instance.
|
||||
|
||||
Examples:
|
||||
>>> from modelscope.models import Model
|
||||
>>> Model.from_pretrained('damo/nlp_structbert_backbone_base_std', task='text-classification')
|
||||
"""
|
||||
prefetched = kwargs.get('model_prefetched')
|
||||
if prefetched is not None:
|
||||
kwargs.pop('model_prefetched')
|
||||
invoked_by = kwargs.get(Invoke.KEY)
|
||||
if invoked_by is not None:
|
||||
kwargs.pop(Invoke.KEY)
|
||||
else:
|
||||
invoked_by = Invoke.PRETRAINED
|
||||
|
||||
if osp.exists(model_name_or_path):
|
||||
local_model_dir = model_name_or_path
|
||||
else:
|
||||
if prefetched is True:
|
||||
raise RuntimeError(
|
||||
'Expecting model is pre-fetched locally, but is not found.'
|
||||
)
|
||||
|
||||
invoked_by = '%s/%s' % (Invoke.KEY, invoked_by)
|
||||
local_model_dir = snapshot_download(model_name_or_path,
|
||||
revision,
|
||||
user_agent=invoked_by)
|
||||
logger.info(f'initialize model from {local_model_dir}')
|
||||
if cfg_dict is not None:
|
||||
cfg = cfg_dict
|
||||
else:
|
||||
cfg = Config.from_file(
|
||||
osp.join(local_model_dir, ModelFile.CONFIGURATION))
|
||||
task_name = cfg.task
|
||||
if 'task' in kwargs:
|
||||
task_name = kwargs.pop('task')
|
||||
model_cfg = cfg.model
|
||||
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
|
||||
model_cfg.type = model_cfg.model_type
|
||||
model_cfg.model_dir = local_model_dir
|
||||
for k, v in kwargs.items():
|
||||
model_cfg[k] = v
|
||||
if device is not None:
|
||||
model_cfg.device = device
|
||||
model = build_model(model_cfg, task_name=task_name)
|
||||
else:
|
||||
model = build_model(model_cfg, task_name=task_name)
|
||||
|
||||
# dynamically add pipeline info to model for pipeline inference
|
||||
if hasattr(cfg, 'pipeline'):
|
||||
model.pipeline = cfg.pipeline
|
||||
|
||||
if not hasattr(model, 'cfg'):
|
||||
model.cfg = cfg
|
||||
|
||||
model_cfg.pop('model_dir', None)
|
||||
model.name = model_name_or_path
|
||||
model.model_dir = local_model_dir
|
||||
return model
|
||||
|
||||
def save_pretrained(self,
|
||||
target_folder: Union[str, os.PathLike],
|
||||
save_checkpoint_names: Union[str, List[str]] = None,
|
||||
config: Optional[dict] = None,
|
||||
**kwargs):
|
||||
"""save the pretrained model, its configuration and other related files to a directory,
|
||||
so that it can be re-loaded
|
||||
|
||||
Args:
|
||||
target_folder (Union[str, os.PathLike]):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
|
||||
save_checkpoint_names (Union[str, List[str]]):
|
||||
The checkpoint names to be saved in the target_folder
|
||||
|
||||
config (Optional[dict], optional):
|
||||
The config for the configuration.json, might not be identical with model.config
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
'save_pretrained method need to be implemented by the subclass.')
|
||||
24
modelscope/models/base/base_torch_head.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.models.base.base_head import Head
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class TorchHead(Head, torch.nn.Module):
|
||||
""" Base head interface for pytorch
|
||||
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
torch.nn.Module.__init__(self)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_loss(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
128
modelscope/models/base/base_torch_model.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
|
||||
from modelscope.utils.checkpoint import (save_checkpoint, save_configuration,
|
||||
save_pretrained)
|
||||
from modelscope.utils.file_utils import func_receive_dict_inputs
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
from .base_model import Model
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class TorchModel(Model, torch.nn.Module):
|
||||
""" Base model interface for pytorch
|
||||
|
||||
"""
|
||||
def __init__(self, model_dir=None, *args, **kwargs):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
torch.nn.Module.__init__(self)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
# Adapting a model with only one dict arg, and the arg name must be input or inputs
|
||||
if func_receive_dict_inputs(self.forward):
|
||||
return self.postprocess(self.forward(args[0], **kwargs))
|
||||
else:
|
||||
return self.postprocess(self.forward(*args, **kwargs))
|
||||
|
||||
def _load_pretrained(self,
|
||||
net,
|
||||
load_path,
|
||||
strict=True,
|
||||
param_key='params'):
|
||||
if isinstance(net, (DataParallel, DistributedDataParallel)):
|
||||
net = net.module
|
||||
load_net = torch.load(load_path,
|
||||
map_location=lambda storage, loc: storage)
|
||||
if param_key is not None:
|
||||
if param_key not in load_net and 'params' in load_net:
|
||||
param_key = 'params'
|
||||
logger.info(
|
||||
f'Loading: {param_key} does not exist, use params.')
|
||||
if param_key in load_net:
|
||||
load_net = load_net[param_key]
|
||||
logger.info(
|
||||
f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].'
|
||||
)
|
||||
# remove unnecessary 'module.'
|
||||
for k, v in deepcopy(load_net).items():
|
||||
if k.startswith('module.'):
|
||||
load_net[k[7:]] = v
|
||||
load_net.pop(k)
|
||||
net.load_state_dict(load_net, strict=strict)
|
||||
logger.info('load model done.')
|
||||
return net
|
||||
|
||||
def forward(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def post_init(self):
|
||||
"""
|
||||
A method executed at the end of each model initialization, to execute code that needs the model's
|
||||
modules properly initialized (such as weight initialization).
|
||||
"""
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
# Initialize weights
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def save_pretrained(self,
|
||||
target_folder: Union[str, os.PathLike],
|
||||
save_checkpoint_names: Union[str, List[str]] = None,
|
||||
save_function: Callable = save_checkpoint,
|
||||
config: Optional[dict] = None,
|
||||
save_config_function: Callable = save_configuration,
|
||||
**kwargs):
|
||||
"""save the pretrained model, its configuration and other related files to a directory,
|
||||
so that it can be re-loaded
|
||||
|
||||
Args:
|
||||
target_folder (Union[str, os.PathLike]):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
|
||||
save_checkpoint_names (Union[str, List[str]]):
|
||||
The checkpoint names to be saved in the target_folder
|
||||
|
||||
save_function (Callable, optional):
|
||||
The function to use to save the state dictionary.
|
||||
|
||||
config (Optional[dict], optional):
|
||||
The config for the configuration.json, might not be identical with model.config
|
||||
|
||||
save_config_function (Callble, optional):
|
||||
The function to use to save the configuration.
|
||||
|
||||
"""
|
||||
if config is None and hasattr(self, 'cfg'):
|
||||
config = self.cfg
|
||||
|
||||
save_pretrained(self, target_folder, save_checkpoint_names,
|
||||
save_function, **kwargs)
|
||||
|
||||
if config is not None:
|
||||
save_config_function(target_folder, config)
|
||||
98
modelscope/models/builder.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.utils.config import ConfigDict
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.import_utils import INDEX_KEY, LazyImportModule
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.registry import Registry, build_from_cfg
|
||||
from modelscope.utils.task_utils import get_task_by_subtask_name
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
MODELS = Registry('models')
|
||||
BACKBONES = MODELS
|
||||
HEADS = Registry('heads')
|
||||
|
||||
modules = LazyImportModule.AST_INDEX[INDEX_KEY]
|
||||
for module_index in list(modules.keys()):
|
||||
if module_index[1] == Tasks.backbone and module_index[0] == 'BACKBONES':
|
||||
modules[(MODELS.name.upper(), module_index[1],
|
||||
module_index[2])] = modules[module_index]
|
||||
|
||||
|
||||
def build_model(cfg: ConfigDict,
|
||||
task_name: str = None,
|
||||
default_args: dict = None):
|
||||
""" build model given model config dict
|
||||
|
||||
Args:
|
||||
cfg (:obj:`ConfigDict`): config dict for model object.
|
||||
task_name (str, optional): task name, refer to
|
||||
:obj:`Tasks` for more details
|
||||
default_args (dict, optional): Default initialization arguments.
|
||||
"""
|
||||
try:
|
||||
model = build_from_cfg(cfg,
|
||||
MODELS,
|
||||
group_key=task_name,
|
||||
default_args=default_args)
|
||||
except KeyError as e:
|
||||
# Handle subtask with a backbone model that hasn't been registered
|
||||
# All the subtask with a parent task should have a task model, otherwise it is not a
|
||||
# valid subtask
|
||||
parent_task, task_model_type = get_task_by_subtask_name(task_name)
|
||||
if task_model_type is None:
|
||||
raise KeyError(e)
|
||||
cfg['type'] = task_model_type
|
||||
model = build_from_cfg(cfg,
|
||||
MODELS,
|
||||
group_key=parent_task,
|
||||
default_args=default_args)
|
||||
return model
|
||||
|
||||
|
||||
def build_backbone(cfg: ConfigDict, default_args: dict = None):
|
||||
""" build backbone given backbone config dict
|
||||
|
||||
Args:
|
||||
cfg (:obj:`ConfigDict`): config dict for backbone object.
|
||||
default_args (dict, optional): Default initialization arguments.
|
||||
"""
|
||||
try:
|
||||
model_dir = cfg.pop('model_dir', None)
|
||||
model = build_from_cfg(cfg,
|
||||
BACKBONES,
|
||||
group_key=Tasks.backbone,
|
||||
default_args=default_args)
|
||||
except KeyError:
|
||||
# Handle backbone that is not in the register group by using transformers AutoModel.
|
||||
# AutoModel are mostly using in NLP and part of Multi-Modal, while the number of backbone in CV、Audio and MM
|
||||
# is limited, thus could be added and registered in Modelscope directly
|
||||
logger.WARNING(
|
||||
f'The backbone {cfg.type} is not registered in modelscope, try to import the backbone from hf transformers.'
|
||||
)
|
||||
cfg['type'] = Models.transformers
|
||||
if model_dir is not None:
|
||||
cfg['model_dir'] = model_dir
|
||||
model = build_from_cfg(cfg,
|
||||
BACKBONES,
|
||||
group_key=Tasks.backbone,
|
||||
default_args=default_args)
|
||||
return model
|
||||
|
||||
|
||||
def build_head(cfg: ConfigDict,
|
||||
task_name: str = None,
|
||||
default_args: dict = None):
|
||||
""" build head given config dict
|
||||
|
||||
Args:
|
||||
cfg (:obj:`ConfigDict`): config dict for head object.
|
||||
task_name (str, optional): task name, refer to
|
||||
:obj:`Tasks` for more details
|
||||
default_args (dict, optional): Default initialization arguments.
|
||||
"""
|
||||
return build_from_cfg(cfg,
|
||||
HEADS,
|
||||
group_key=task_name,
|
||||
default_args=default_args)
|
||||
20
modelscope/models/cv/face_2d_keypoints/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .face_2d_keypoints_align import Face2DKeypoints
|
||||
|
||||
else:
|
||||
_import_structure = {'face_2d_keypoints_align': ['Face2DKeypoints']}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from easycv.models.face.face_keypoint import FaceKeypoint
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.cv.easycv_base import EasyCVBaseModel
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@MODELS.register_module(group_key=Tasks.face_2d_keypoints,
|
||||
module_name=Models.face_2d_keypoints)
|
||||
class Face2DKeypoints(EasyCVBaseModel, FaceKeypoint):
|
||||
def __init__(self, model_dir=None, *args, **kwargs):
|
||||
EasyCVBaseModel.__init__(self, model_dir, args, kwargs)
|
||||
FaceKeypoint.__init__(self, *args, **kwargs)
|
||||
20
modelscope/models/cv/face_attribute_recognition/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .fair_face import FaceAttributeRecognition
|
||||
|
||||
else:
|
||||
_import_structure = {'fair_face': ['FaceAttributeRecognition']}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .face_attribute_recognition import FaceAttributeRecognition
|
||||
@@ -0,0 +1,78 @@
|
||||
# The implementation is based on FairFace, available at
|
||||
# https://github.com/dchen236/FairFace
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from torch.autograd import Variable
|
||||
from torchvision import datasets, models, transforms
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Tensor, TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.face_attribute_recognition,
|
||||
module_name=Models.fairface)
|
||||
class FaceAttributeRecognition(TorchModel):
|
||||
def __init__(self, model_path, device='cuda'):
|
||||
super().__init__(model_path)
|
||||
cudnn.benchmark = True
|
||||
self.model_path = model_path
|
||||
self.device = device
|
||||
self.cfg_path = model_path.replace(ModelFile.TORCH_MODEL_FILE,
|
||||
ModelFile.CONFIGURATION)
|
||||
fair_face = torchvision.models.resnet34(pretrained=False)
|
||||
fair_face.fc = nn.Linear(fair_face.fc.in_features, 18)
|
||||
self.net = fair_face
|
||||
self.load_model()
|
||||
self.net = self.net.to(device)
|
||||
self.trans = transforms.Compose([
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
def load_model(self, load_to_cpu=False):
|
||||
pretrained_dict = torch.load(self.model_path,
|
||||
map_location=torch.device('cpu'))
|
||||
self.net.load_state_dict(pretrained_dict, strict=True)
|
||||
self.net.eval()
|
||||
|
||||
def forward(self, img):
|
||||
""" FariFace model forward process.
|
||||
|
||||
Args:
|
||||
img: [h, w, c]
|
||||
|
||||
Return:
|
||||
list of attribute result: [gender_score, age_score]
|
||||
"""
|
||||
img = cv2.cvtColor(img.cpu().numpy(), cv2.COLOR_BGR2RGB)
|
||||
img = img.astype(np.uint8)
|
||||
|
||||
inputs = self.trans(img)
|
||||
|
||||
c, h, w = inputs.shape
|
||||
|
||||
inputs = inputs.view(-1, c, h, w)
|
||||
inputs = inputs.to(self.device)
|
||||
inputs = Variable(inputs, volatile=True)
|
||||
outputs = self.net(inputs)[0]
|
||||
|
||||
gender_outputs = outputs[7:9]
|
||||
age_outputs = outputs[9:18]
|
||||
|
||||
gender_score = F.softmax(gender_outputs).detach().cpu().tolist()
|
||||
age_score = F.softmax(age_outputs).detach().cpu().tolist()
|
||||
|
||||
return [gender_score, age_score]
|
||||
32
modelscope/models/cv/face_detection/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .mogface import MogFaceDetector
|
||||
from .mtcnn import MtcnnFaceDetector
|
||||
from .retinaface import RetinaFaceDetection
|
||||
from .ulfd_slim import UlfdFaceDetector
|
||||
from .scrfd import ScrfdDetect
|
||||
from .scrfd import TinyMogDetect
|
||||
from .scrfd import SCRFDPreprocessor
|
||||
from .scrfd import DamoFdDetect
|
||||
else:
|
||||
_import_structure = {
|
||||
'ulfd_slim': ['UlfdFaceDetector'],
|
||||
'retinaface': ['RetinaFaceDetection'],
|
||||
'mtcnn': ['MtcnnFaceDetector'],
|
||||
'mogface': ['MogFaceDetector'],
|
||||
'scrfd': ['TinyMogDetect', 'ScrfdDetect', 'SCRFDPreprocessor', 'DamoFdDetect'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
2
modelscope/models/cv/face_detection/mogface/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .models.detectors import MogFaceDetector
|
||||
@@ -0,0 +1,98 @@
|
||||
# The implementation is based on MogFace, available at
|
||||
# https://github.com/damo-cv/MogFace
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
from .mogface import MogFace
|
||||
from .utils import MogPriorBox, mogdecode, py_cpu_nms
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.face_detection, module_name=Models.mogface)
|
||||
class MogFaceDetector(TorchModel):
|
||||
def __init__(self, model_path, device='cuda', **kwargs):
|
||||
super().__init__(model_path)
|
||||
cudnn.benchmark = True
|
||||
self.model_path = model_path
|
||||
self.device = device
|
||||
self.net = MogFace()
|
||||
self.load_model()
|
||||
self.net = self.net.to(device)
|
||||
self.conf_th = kwargs.get('conf_th', -1.82)
|
||||
self.nms_th = kwargs.get('nms_th', 0.4)
|
||||
|
||||
self.mean = np.array([[104, 117, 123]])
|
||||
|
||||
def load_model(self, load_to_cpu=False):
|
||||
pretrained_dict = torch.load(self.model_path,
|
||||
map_location=torch.device('cpu'))
|
||||
self.net.load_state_dict(pretrained_dict, strict=False)
|
||||
self.net.eval()
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
img_raw = input['img']
|
||||
img = np.array(img_raw.cpu().detach())
|
||||
img = img[:, :, ::-1]
|
||||
|
||||
im_height, im_width = img.shape[:2]
|
||||
ss = 1.0
|
||||
# tricky
|
||||
if max(im_height, im_width) > 1500:
|
||||
ss = 1000.0 / max(im_height, im_width)
|
||||
img = cv2.resize(img, (0, 0), fx=ss, fy=ss)
|
||||
im_height, im_width = img.shape[:2]
|
||||
|
||||
scale = torch.Tensor(
|
||||
[img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
|
||||
img -= np.array([[103.53, 116.28, 123.675]])
|
||||
img /= np.array([[57.375, 57.120003, 58.395]])
|
||||
img /= 255
|
||||
img = img[:, :, ::-1].copy()
|
||||
img = img.transpose(2, 0, 1)
|
||||
img = torch.from_numpy(img).unsqueeze(0)
|
||||
img = img.to(self.device)
|
||||
scale = scale.to(self.device)
|
||||
|
||||
conf, loc = self.net(img) # forward pass
|
||||
|
||||
top_k = 5000
|
||||
keep_top_k = 750
|
||||
|
||||
priorbox = MogPriorBox(scale_list=[0.68])
|
||||
priors = priorbox(im_height, im_width)
|
||||
priors = torch.tensor(priors).to(self.device)
|
||||
prior_data = priors.data
|
||||
|
||||
boxes = mogdecode(loc.data.squeeze(0), prior_data)
|
||||
boxes = boxes.cpu().numpy()
|
||||
scores = conf.squeeze(0).data.cpu().numpy()[:, 0]
|
||||
|
||||
# ignore low scores
|
||||
inds = np.where(scores > self.conf_th)[0]
|
||||
boxes = boxes[inds]
|
||||
scores = scores[inds]
|
||||
|
||||
# keep top-K before NMS
|
||||
order = scores.argsort()[::-1][:top_k]
|
||||
boxes = boxes[order]
|
||||
scores = scores[order]
|
||||
|
||||
# do NMS
|
||||
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32,
|
||||
copy=False)
|
||||
keep = py_cpu_nms(dets, self.nms_th)
|
||||
dets = dets[keep, :]
|
||||
|
||||
# keep top-K faster NMS
|
||||
dets = dets[:keep_top_k, :]
|
||||
|
||||
return dets / ss
|
||||
132
modelscope/models/cv/face_detection/mogface/models/mogface.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# --------------------------------------------------------
|
||||
# The implementation is also open-sourced by the authors as Yang Liu, and is available publicly on
|
||||
# https://github.com/damo-cv/MogFace
|
||||
# --------------------------------------------------------
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .mogprednet import MogPredNet
|
||||
from .resnet import ResNet
|
||||
|
||||
|
||||
class MogFace(nn.Module):
|
||||
def __init__(self):
|
||||
super(MogFace, self).__init__()
|
||||
self.backbone = ResNet(depth=101)
|
||||
self.fpn = LFPN()
|
||||
self.pred_net = MogPredNet()
|
||||
|
||||
def forward(self, x):
|
||||
feature_list = self.backbone(x)
|
||||
fpn_list = self.fpn(feature_list)
|
||||
pyramid_feature_list = fpn_list[0]
|
||||
conf, loc = self.pred_net(pyramid_feature_list)
|
||||
return conf, loc
|
||||
|
||||
|
||||
class FeatureFusion(nn.Module):
|
||||
def __init__(self, lat_ch=256, **channels):
|
||||
super(FeatureFusion, self).__init__()
|
||||
self.main_conv = nn.Conv2d(channels['main'], lat_ch, kernel_size=1)
|
||||
|
||||
def forward(self, up, main):
|
||||
main = self.main_conv(main)
|
||||
_, _, H, W = main.size()
|
||||
res = F.upsample(up, scale_factor=2, mode='bilinear')
|
||||
if res.size(2) != main.size(2) or res.size(3) != main.size(3):
|
||||
res = res[:, :, 0:H, 0:W]
|
||||
res = res + main
|
||||
return res
|
||||
|
||||
|
||||
class LFPN(nn.Module):
|
||||
def __init__(self,
|
||||
c2_out_ch=256,
|
||||
c3_out_ch=512,
|
||||
c4_out_ch=1024,
|
||||
c5_out_ch=2048,
|
||||
c6_mid_ch=512,
|
||||
c6_out_ch=512,
|
||||
c7_mid_ch=128,
|
||||
c7_out_ch=256,
|
||||
out_dsfd_ft=True):
|
||||
super(LFPN, self).__init__()
|
||||
self.out_dsfd_ft = out_dsfd_ft
|
||||
if self.out_dsfd_ft:
|
||||
dsfd_module = []
|
||||
dsfd_module.append(nn.Conv2d(256, 256, kernel_size=3, padding=1))
|
||||
dsfd_module.append(nn.Conv2d(512, 256, kernel_size=3, padding=1))
|
||||
dsfd_module.append(nn.Conv2d(1024, 256, kernel_size=3, padding=1))
|
||||
dsfd_module.append(nn.Conv2d(2048, 256, kernel_size=3, padding=1))
|
||||
dsfd_module.append(nn.Conv2d(256, 256, kernel_size=3, padding=1))
|
||||
dsfd_module.append(nn.Conv2d(256, 256, kernel_size=3, padding=1))
|
||||
self.dsfd_modules = nn.ModuleList(dsfd_module)
|
||||
|
||||
c6_input_ch = c5_out_ch
|
||||
self.c6 = nn.Sequential(*[
|
||||
nn.Conv2d(
|
||||
c6_input_ch,
|
||||
c6_mid_ch,
|
||||
kernel_size=1,
|
||||
),
|
||||
nn.BatchNorm2d(c6_mid_ch),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(c6_mid_ch, c6_out_ch, kernel_size=3, padding=1,
|
||||
stride=2),
|
||||
nn.BatchNorm2d(c6_out_ch),
|
||||
nn.ReLU(inplace=True)
|
||||
])
|
||||
self.c7 = nn.Sequential(*[
|
||||
nn.Conv2d(
|
||||
c6_out_ch,
|
||||
c7_mid_ch,
|
||||
kernel_size=1,
|
||||
),
|
||||
nn.BatchNorm2d(c7_mid_ch),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(c7_mid_ch, c7_out_ch, kernel_size=3, padding=1,
|
||||
stride=2),
|
||||
nn.BatchNorm2d(c7_out_ch),
|
||||
nn.ReLU(inplace=True)
|
||||
])
|
||||
|
||||
self.p2_lat = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
||||
self.p3_lat = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
||||
self.p4_lat = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
||||
|
||||
self.c5_lat = nn.Conv2d(c6_input_ch, 256, kernel_size=3, padding=1)
|
||||
self.c6_lat = nn.Conv2d(c6_out_ch, 256, kernel_size=3, padding=1)
|
||||
self.c7_lat = nn.Conv2d(c7_out_ch, 256, kernel_size=3, padding=1)
|
||||
|
||||
self.ff_c5_c4 = FeatureFusion(main=c4_out_ch)
|
||||
self.ff_c4_c3 = FeatureFusion(main=c3_out_ch)
|
||||
self.ff_c3_c2 = FeatureFusion(main=c2_out_ch)
|
||||
|
||||
def forward(self, feature_list):
|
||||
c2, c3, c4, c5 = feature_list
|
||||
c6 = self.c6(c5)
|
||||
c7 = self.c7(c6)
|
||||
|
||||
c5 = self.c5_lat(c5)
|
||||
c6 = self.c6_lat(c6)
|
||||
c7 = self.c7_lat(c7)
|
||||
|
||||
if self.out_dsfd_ft:
|
||||
dsfd_fts = []
|
||||
dsfd_fts.append(self.dsfd_modules[0](c2))
|
||||
dsfd_fts.append(self.dsfd_modules[1](c3))
|
||||
dsfd_fts.append(self.dsfd_modules[2](c4))
|
||||
dsfd_fts.append(self.dsfd_modules[3](feature_list[-1]))
|
||||
dsfd_fts.append(self.dsfd_modules[4](c6))
|
||||
dsfd_fts.append(self.dsfd_modules[5](c7))
|
||||
|
||||
p4 = self.ff_c5_c4(c5, c4)
|
||||
p3 = self.ff_c4_c3(p4, c3)
|
||||
p2 = self.ff_c3_c2(p3, c2)
|
||||
|
||||
p2 = self.p2_lat(p2)
|
||||
p3 = self.p3_lat(p3)
|
||||
p4 = self.p4_lat(p4)
|
||||
|
||||
if self.out_dsfd_ft:
|
||||
return ([p2, p3, p4, c5, c6, c7], dsfd_fts)
|
||||
168
modelscope/models/cv/face_detection/mogface/models/mogprednet.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# --------------------------------------------------------
|
||||
# The implementation is also open-sourced by the authors as Yang Liu, and is available publicly on
|
||||
# https://github.com/damo-cv/MogFace
|
||||
# --------------------------------------------------------
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class conv_bn(nn.Module):
|
||||
"""docstring for conv"""
|
||||
def __init__(self, in_plane, out_plane, kernel_size, stride, padding):
|
||||
super(conv_bn, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_plane,
|
||||
out_plane,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding)
|
||||
self.bn1 = nn.BatchNorm2d(out_plane)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
return self.bn1(x)
|
||||
|
||||
|
||||
class SSHContext(nn.Module):
|
||||
def __init__(self, channels, Xchannels=256):
|
||||
super(SSHContext, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(channels,
|
||||
Xchannels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv2 = nn.Conv2d(channels,
|
||||
Xchannels // 2,
|
||||
kernel_size=3,
|
||||
dilation=2,
|
||||
stride=1,
|
||||
padding=2)
|
||||
self.conv2_1 = nn.Conv2d(Xchannels // 2,
|
||||
Xchannels // 2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv2_2 = nn.Conv2d(Xchannels // 2,
|
||||
Xchannels // 2,
|
||||
kernel_size=3,
|
||||
dilation=2,
|
||||
stride=1,
|
||||
padding=2)
|
||||
self.conv2_2_1 = nn.Conv2d(Xchannels // 2,
|
||||
Xchannels // 2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = F.relu(self.conv1(x), inplace=True)
|
||||
x2 = F.relu(self.conv2(x), inplace=True)
|
||||
x2_1 = F.relu(self.conv2_1(x2), inplace=True)
|
||||
x2_2 = F.relu(self.conv2_2(x2), inplace=True)
|
||||
x2_2 = F.relu(self.conv2_2_1(x2_2), inplace=True)
|
||||
|
||||
return torch.cat([x1, x2_1, x2_2], 1)
|
||||
|
||||
|
||||
class DeepHead(nn.Module):
|
||||
def __init__(self,
|
||||
in_channel=256,
|
||||
out_channel=256,
|
||||
use_gn=False,
|
||||
num_conv=4):
|
||||
super(DeepHead, self).__init__()
|
||||
self.use_gn = use_gn
|
||||
self.num_conv = num_conv
|
||||
self.conv1 = nn.Conv2d(in_channel, out_channel, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(out_channel, out_channel, 3, 1, 1)
|
||||
self.conv3 = nn.Conv2d(out_channel, out_channel, 3, 1, 1)
|
||||
self.conv4 = nn.Conv2d(out_channel, out_channel, 3, 1, 1)
|
||||
if self.use_gn:
|
||||
self.gn1 = nn.GroupNorm(16, out_channel)
|
||||
self.gn2 = nn.GroupNorm(16, out_channel)
|
||||
self.gn3 = nn.GroupNorm(16, out_channel)
|
||||
self.gn4 = nn.GroupNorm(16, out_channel)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_gn:
|
||||
x1 = F.relu(self.gn1(self.conv1(x)), inplace=True)
|
||||
x2 = F.relu(self.gn2(self.conv1(x1)), inplace=True)
|
||||
x3 = F.relu(self.gn3(self.conv1(x2)), inplace=True)
|
||||
x4 = F.relu(self.gn4(self.conv1(x3)), inplace=True)
|
||||
else:
|
||||
x1 = F.relu(self.conv1(x), inplace=True)
|
||||
x2 = F.relu(self.conv1(x1), inplace=True)
|
||||
if self.num_conv == 2:
|
||||
return x2
|
||||
x3 = F.relu(self.conv1(x2), inplace=True)
|
||||
x4 = F.relu(self.conv1(x3), inplace=True)
|
||||
|
||||
return x4
|
||||
|
||||
|
||||
class MogPredNet(nn.Module):
|
||||
def __init__(self,
|
||||
num_anchor_per_pixel=1,
|
||||
num_classes=1,
|
||||
input_ch_list=[256, 256, 256, 256, 256, 256],
|
||||
use_deep_head=True,
|
||||
deep_head_with_gn=True,
|
||||
use_ssh=True,
|
||||
deep_head_ch=512):
|
||||
super(MogPredNet, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.use_deep_head = use_deep_head
|
||||
self.deep_head_with_gn = deep_head_with_gn
|
||||
|
||||
self.use_ssh = use_ssh
|
||||
|
||||
self.deep_head_ch = deep_head_ch
|
||||
|
||||
if self.use_ssh:
|
||||
self.conv_SSH = SSHContext(input_ch_list[0],
|
||||
self.deep_head_ch // 2)
|
||||
|
||||
if self.use_deep_head:
|
||||
if self.deep_head_with_gn:
|
||||
self.deep_loc_head = DeepHead(self.deep_head_ch,
|
||||
self.deep_head_ch,
|
||||
use_gn=True)
|
||||
self.deep_cls_head = DeepHead(self.deep_head_ch,
|
||||
self.deep_head_ch,
|
||||
use_gn=True)
|
||||
|
||||
self.pred_cls = nn.Conv2d(self.deep_head_ch,
|
||||
1 * num_anchor_per_pixel, 3, 1, 1)
|
||||
self.pred_loc = nn.Conv2d(self.deep_head_ch,
|
||||
4 * num_anchor_per_pixel, 3, 1, 1)
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, pyramid_feature_list, dsfd_ft_list=None):
|
||||
loc = []
|
||||
conf = []
|
||||
|
||||
if self.use_deep_head:
|
||||
for x in pyramid_feature_list:
|
||||
if self.use_ssh:
|
||||
x = self.conv_SSH(x)
|
||||
x_cls = self.deep_cls_head(x)
|
||||
x_loc = self.deep_loc_head(x)
|
||||
|
||||
conf.append(
|
||||
self.pred_cls(x_cls).permute(0, 2, 3, 1).contiguous())
|
||||
loc.append(
|
||||
self.pred_loc(x_loc).permute(0, 2, 3, 1).contiguous())
|
||||
|
||||
loc = torch.cat([o.view(o.size(0), -1, 4) for o in loc], 1)
|
||||
conf = torch.cat(
|
||||
[o.view(o.size(0), -1, self.num_classes) for o in conf], 1)
|
||||
output = (
|
||||
self.sigmoid(conf.view(conf.size(0), -1, self.num_classes)),
|
||||
loc.view(loc.size(0), -1, 4),
|
||||
)
|
||||
|
||||
return output
|
||||
194
modelscope/models/cv/face_detection/mogface/models/resnet.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# The implementation is modified from original resent implementaiton, which is
|
||||
# also open-sourced by the authors as Yang Liu,
|
||||
# and is available publicly on https://github.com/damo-cv/MogFace
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes,
|
||||
out_planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False)
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
downsample=None,
|
||||
groups=1,
|
||||
base_width=64,
|
||||
dilation=1,
|
||||
norm_layer=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self,
|
||||
depth=50,
|
||||
groups=1,
|
||||
width_per_group=64,
|
||||
replace_stride_with_dilation=None,
|
||||
norm_layer=None,
|
||||
inplanes=64,
|
||||
shrink_ch_ratio=1):
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
if depth == 50:
|
||||
block = Bottleneck
|
||||
layers = [3, 4, 6, 3]
|
||||
elif depth == 101:
|
||||
block = Bottleneck
|
||||
layers = [3, 4, 23, 3]
|
||||
elif depth == 152:
|
||||
block = Bottleneck
|
||||
layers = [3, 4, 36, 3]
|
||||
elif depth == 18:
|
||||
block = BasicBlock
|
||||
layers = [2, 2, 2, 2]
|
||||
else:
|
||||
raise ValueError('only support depth in [18, 50, 101, 152]')
|
||||
|
||||
shrink_input_ch = int(inplanes * shrink_ch_ratio)
|
||||
self.inplanes = int(inplanes * shrink_ch_ratio)
|
||||
if shrink_ch_ratio == 0.125:
|
||||
layers = [2, 3, 3, 3]
|
||||
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError('replace_stride_with_dilation should be None '
|
||||
'or a 3-element tuple, got {}'.format(
|
||||
replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3,
|
||||
self.inplanes,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, shrink_input_ch, layers[0])
|
||||
self.layer2 = self._make_layer(block,
|
||||
shrink_input_ch * 2,
|
||||
layers[1],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block,
|
||||
shrink_input_ch * 4,
|
||||
layers[2],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block,
|
||||
shrink_input_ch * 8,
|
||||
layers[3],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(
|
||||
block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(
|
||||
block(self.inplanes,
|
||||
planes,
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
four_conv_layer = []
|
||||
x = self.layer1(x)
|
||||
four_conv_layer.append(x)
|
||||
x = self.layer2(x)
|
||||
four_conv_layer.append(x)
|
||||
x = self.layer3(x)
|
||||
four_conv_layer.append(x)
|
||||
x = self.layer4(x)
|
||||
four_conv_layer.append(x)
|
||||
|
||||
return four_conv_layer
|
||||
210
modelscope/models/cv/face_detection/mogface/models/utils.py
Executable file
@@ -0,0 +1,210 @@
|
||||
# Modified from https://github.com/biubug6/Pytorch_Retinaface
|
||||
|
||||
import math
|
||||
from itertools import product as product
|
||||
from math import ceil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def transform_anchor(anchors):
|
||||
"""
|
||||
from [x0, x1, y0, y1] to [c_x, cy, w, h]
|
||||
x1 = x0 + w - 1
|
||||
c_x = (x0 + x1) / 2 = (2x0 + w - 1) / 2 = x0 + (w - 1) / 2
|
||||
"""
|
||||
return np.concatenate(((anchors[:, :2] + anchors[:, 2:]) / 2,
|
||||
anchors[:, 2:] - anchors[:, :2] + 1),
|
||||
axis=1)
|
||||
|
||||
|
||||
def normalize_anchor(anchors):
|
||||
"""
|
||||
from [c_x, cy, w, h] to [x0, x1, y0, y1]
|
||||
"""
|
||||
item_1 = anchors[:, :2] - (anchors[:, 2:] - 1) / 2
|
||||
item_2 = anchors[:, :2] + (anchors[:, 2:] - 1) / 2
|
||||
return np.concatenate((item_1, item_2), axis=1)
|
||||
|
||||
|
||||
class MogPriorBox(object):
|
||||
"""
|
||||
both for fpn and single layer, single layer need to test
|
||||
return (np.array) [num_anchros, 4] [x0, y0, x1, y1]
|
||||
"""
|
||||
def __init__(self,
|
||||
scale_list=[1.],
|
||||
aspect_ratio_list=[1.0],
|
||||
stride_list=[4, 8, 16, 32, 64, 128],
|
||||
anchor_size_list=[16, 32, 64, 128, 256, 512]):
|
||||
self.scale_list = scale_list
|
||||
self.aspect_ratio_list = aspect_ratio_list
|
||||
self.stride_list = stride_list
|
||||
self.anchor_size_list = anchor_size_list
|
||||
|
||||
def __call__(self, img_height, img_width):
|
||||
final_anchor_list = []
|
||||
|
||||
for idx, stride in enumerate(self.stride_list):
|
||||
anchor_list = []
|
||||
cur_img_height = img_height
|
||||
cur_img_width = img_width
|
||||
tmp_stride = stride
|
||||
|
||||
while tmp_stride != 1:
|
||||
tmp_stride = tmp_stride // 2
|
||||
cur_img_height = (cur_img_height + 1) // 2
|
||||
cur_img_width = (cur_img_width + 1) // 2
|
||||
|
||||
for i in range(cur_img_height):
|
||||
for j in range(cur_img_width):
|
||||
for scale in self.scale_list:
|
||||
cx = (j + 0.5) * stride
|
||||
cy = (i + 0.5) * stride
|
||||
side_x = self.anchor_size_list[idx] * scale
|
||||
side_y = self.anchor_size_list[idx] * scale
|
||||
for ratio in self.aspect_ratio_list:
|
||||
anchor_list.append([
|
||||
cx, cy, side_x / math.sqrt(ratio),
|
||||
side_y * math.sqrt(ratio)
|
||||
])
|
||||
|
||||
final_anchor_list.append(anchor_list)
|
||||
final_anchor_arr = np.concatenate(final_anchor_list, axis=0)
|
||||
normalized_anchor_arr = normalize_anchor(final_anchor_arr).astype(
|
||||
'float32')
|
||||
transformed_anchor = transform_anchor(normalized_anchor_arr)
|
||||
|
||||
return transformed_anchor
|
||||
|
||||
|
||||
class PriorBox(object):
|
||||
def __init__(self, cfg, image_size=None, phase='train'):
|
||||
super(PriorBox, self).__init__()
|
||||
self.min_sizes = cfg['min_sizes']
|
||||
self.steps = cfg['steps']
|
||||
self.clip = cfg['clip']
|
||||
self.image_size = image_size
|
||||
self.feature_maps = [[
|
||||
ceil(self.image_size[0] / step),
|
||||
ceil(self.image_size[1] / step)
|
||||
] for step in self.steps]
|
||||
self.name = 's'
|
||||
|
||||
def forward(self):
|
||||
anchors = []
|
||||
for k, f in enumerate(self.feature_maps):
|
||||
min_sizes = self.min_sizes[k]
|
||||
for i, j in product(range(f[0]), range(f[1])):
|
||||
for min_size in min_sizes:
|
||||
s_kx = min_size / self.image_size[1]
|
||||
s_ky = min_size / self.image_size[0]
|
||||
dense_cx = [
|
||||
x * self.steps[k] / self.image_size[1]
|
||||
for x in [j + 0.5]
|
||||
]
|
||||
dense_cy = [
|
||||
y * self.steps[k] / self.image_size[0]
|
||||
for y in [i + 0.5]
|
||||
]
|
||||
for cy, cx in product(dense_cy, dense_cx):
|
||||
anchors += [cx, cy, s_kx, s_ky]
|
||||
|
||||
# back to torch land
|
||||
output = torch.Tensor(anchors).view(-1, 4)
|
||||
if self.clip:
|
||||
output.clamp_(max=1, min=0)
|
||||
return output
|
||||
|
||||
|
||||
def py_cpu_nms(dets, thresh):
|
||||
"""Pure Python NMS baseline."""
|
||||
x1 = dets[:, 0]
|
||||
y1 = dets[:, 1]
|
||||
x2 = dets[:, 2]
|
||||
y2 = dets[:, 3]
|
||||
scores = dets[:, 4]
|
||||
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
|
||||
inds = np.where(ovr <= thresh)[0]
|
||||
order = order[inds + 1]
|
||||
|
||||
return keep
|
||||
|
||||
|
||||
def mogdecode(loc, anchors):
|
||||
"""
|
||||
loc: torch.Tensor
|
||||
anchors: 2-d, torch.Tensor (cx, cy, w, h)
|
||||
boxes: 2-d, torch.Tensor (x0, y0, x1, y1)
|
||||
"""
|
||||
|
||||
boxes = torch.cat((anchors[:, :2] + loc[:, :2] * anchors[:, 2:],
|
||||
anchors[:, 2:] * torch.exp(loc[:, 2:])), 1)
|
||||
|
||||
boxes[:, 0] -= (boxes[:, 2] - 1) / 2
|
||||
boxes[:, 1] -= (boxes[:, 3] - 1) / 2
|
||||
boxes[:, 2] += boxes[:, 0] - 1
|
||||
boxes[:, 3] += boxes[:, 1] - 1
|
||||
|
||||
return boxes
|
||||
|
||||
|
||||
# Adapted from https://github.com/Hakuyume/chainer-ssd
|
||||
def decode(loc, priors, variances):
|
||||
"""Decode locations from predictions using priors to undo
|
||||
the encoding we did for offset regression at train time.
|
||||
Args:
|
||||
loc (tensor): location predictions for loc layers,
|
||||
Shape: [num_priors,4]
|
||||
priors (tensor): Prior boxes in center-offset form.
|
||||
Shape: [num_priors,4].
|
||||
variances: (list[float]) Variances of priorboxes
|
||||
Return:
|
||||
decoded bounding box predictions
|
||||
"""
|
||||
|
||||
boxes = torch.cat(
|
||||
(priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
||||
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
|
||||
boxes[:, :2] -= boxes[:, 2:] / 2
|
||||
boxes[:, 2:] += boxes[:, :2]
|
||||
return boxes
|
||||
|
||||
|
||||
def decode_landm(pre, priors, variances):
|
||||
"""Decode landm from predictions using priors to undo
|
||||
the encoding we did for offset regression at train time.
|
||||
Args:
|
||||
pre (tensor): landm predictions for loc layers,
|
||||
Shape: [num_priors,10]
|
||||
priors (tensor): Prior boxes in center-offset form.
|
||||
Shape: [num_priors,4].
|
||||
variances: (list[float]) Variances of priorboxes
|
||||
Return:
|
||||
decoded landm predictions
|
||||
"""
|
||||
a = priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:]
|
||||
b = priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:]
|
||||
c = priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:]
|
||||
d = priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:]
|
||||
e = priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:]
|
||||
landms = torch.cat((a, b, c, d, e), dim=1)
|
||||
return landms
|
||||
2
modelscope/models/cv/face_detection/mtcnn/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .models.detector import MtcnnFaceDetector
|
||||
240
modelscope/models/cv/face_detection/mtcnn/models/box_utils.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def nms(boxes, overlap_threshold=0.5, mode='union'):
|
||||
"""Non-maximum suppression.
|
||||
|
||||
Arguments:
|
||||
boxes: a float numpy array of shape [n, 5],
|
||||
where each row is (xmin, ymin, xmax, ymax, score).
|
||||
overlap_threshold: a float number.
|
||||
mode: 'union' or 'min'.
|
||||
|
||||
Returns:
|
||||
list with indices of the selected boxes
|
||||
"""
|
||||
|
||||
# if there are no boxes, return the empty list
|
||||
if len(boxes) == 0:
|
||||
return []
|
||||
|
||||
# list of picked indices
|
||||
pick = []
|
||||
|
||||
# grab the coordinates of the bounding boxes
|
||||
x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)]
|
||||
|
||||
area = (x2 - x1 + 1.0) * (y2 - y1 + 1.0)
|
||||
ids = np.argsort(score) # in increasing order
|
||||
|
||||
while len(ids) > 0:
|
||||
|
||||
# grab index of the largest value
|
||||
last = len(ids) - 1
|
||||
i = ids[last]
|
||||
pick.append(i)
|
||||
|
||||
# compute intersections
|
||||
# of the box with the largest score
|
||||
# with the rest of boxes
|
||||
|
||||
# left top corner of intersection boxes
|
||||
ix1 = np.maximum(x1[i], x1[ids[:last]])
|
||||
iy1 = np.maximum(y1[i], y1[ids[:last]])
|
||||
|
||||
# right bottom corner of intersection boxes
|
||||
ix2 = np.minimum(x2[i], x2[ids[:last]])
|
||||
iy2 = np.minimum(y2[i], y2[ids[:last]])
|
||||
|
||||
# width and height of intersection boxes
|
||||
w = np.maximum(0.0, ix2 - ix1 + 1.0)
|
||||
h = np.maximum(0.0, iy2 - iy1 + 1.0)
|
||||
|
||||
# intersections' areas
|
||||
inter = w * h
|
||||
if mode == 'min':
|
||||
overlap = inter / np.minimum(area[i], area[ids[:last]])
|
||||
elif mode == 'union':
|
||||
# intersection over union (IoU)
|
||||
overlap = inter / (area[i] + area[ids[:last]] - inter)
|
||||
|
||||
# delete all boxes where overlap is too big
|
||||
ids = np.delete(
|
||||
ids,
|
||||
np.concatenate([[last],
|
||||
np.where(overlap > overlap_threshold)[0]]))
|
||||
|
||||
return pick
|
||||
|
||||
|
||||
def convert_to_square(bboxes):
|
||||
"""Convert bounding boxes to a square form.
|
||||
|
||||
Arguments:
|
||||
bboxes: a float numpy array of shape [n, 5].
|
||||
|
||||
Returns:
|
||||
a float numpy array of shape [n, 5],
|
||||
squared bounding boxes.
|
||||
"""
|
||||
|
||||
square_bboxes = np.zeros_like(bboxes)
|
||||
x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
|
||||
h = y2 - y1 + 1.0
|
||||
w = x2 - x1 + 1.0
|
||||
max_side = np.maximum(h, w)
|
||||
square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5
|
||||
square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5
|
||||
square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0
|
||||
square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0
|
||||
return square_bboxes
|
||||
|
||||
|
||||
def calibrate_box(bboxes, offsets):
|
||||
"""Transform bounding boxes to be more like true bounding boxes.
|
||||
'offsets' is one of the outputs of the nets.
|
||||
|
||||
Arguments:
|
||||
bboxes: a float numpy array of shape [n, 5].
|
||||
offsets: a float numpy array of shape [n, 4].
|
||||
|
||||
Returns:
|
||||
a float numpy array of shape [n, 5].
|
||||
"""
|
||||
x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
|
||||
w = x2 - x1 + 1.0
|
||||
h = y2 - y1 + 1.0
|
||||
w = np.expand_dims(w, 1)
|
||||
h = np.expand_dims(h, 1)
|
||||
|
||||
# this is what happening here:
|
||||
# tx1, ty1, tx2, ty2 = [offsets[:, i] for i in range(4)]
|
||||
# x1_true = x1 + tx1*w
|
||||
# y1_true = y1 + ty1*h
|
||||
# x2_true = x2 + tx2*w
|
||||
# y2_true = y2 + ty2*h
|
||||
# below is just more compact form of this
|
||||
|
||||
# are offsets always such that
|
||||
# x1 < x2 and y1 < y2 ?
|
||||
|
||||
translation = np.hstack([w, h, w, h]) * offsets
|
||||
bboxes[:, 0:4] = bboxes[:, 0:4] + translation
|
||||
return bboxes
|
||||
|
||||
|
||||
def get_image_boxes(bounding_boxes, img, size=24):
|
||||
"""Cut out boxes from the image.
|
||||
|
||||
Arguments:
|
||||
bounding_boxes: a float numpy array of shape [n, 5].
|
||||
img: an instance of PIL.Image.
|
||||
size: an integer, size of cutouts.
|
||||
|
||||
Returns:
|
||||
a float numpy array of shape [n, 3, size, size].
|
||||
"""
|
||||
|
||||
num_boxes = len(bounding_boxes)
|
||||
width, height = img.size
|
||||
|
||||
[dy, edy, dx, edx, y, ey, x, ex, w,
|
||||
h] = correct_bboxes(bounding_boxes, width, height)
|
||||
img_boxes = np.zeros((num_boxes, 3, size, size), 'float32')
|
||||
|
||||
for i in range(num_boxes):
|
||||
img_box = np.zeros((h[i], w[i], 3), 'uint8')
|
||||
|
||||
img_array = np.asarray(img, 'uint8')
|
||||
img_box[dy[i]:(edy[i] + 1), dx[i]:(edx[i] + 1), :] =\
|
||||
img_array[y[i]:(ey[i] + 1), x[i]:(ex[i] + 1), :]
|
||||
|
||||
# resize
|
||||
img_box = Image.fromarray(img_box)
|
||||
img_box = img_box.resize((size, size), Image.BILINEAR)
|
||||
img_box = np.asarray(img_box, 'float32')
|
||||
|
||||
img_boxes[i, :, :, :] = _preprocess(img_box)
|
||||
|
||||
return img_boxes
|
||||
|
||||
|
||||
def correct_bboxes(bboxes, width, height):
|
||||
"""Crop boxes that are too big and get coordinates
|
||||
with respect to cutouts.
|
||||
|
||||
Arguments:
|
||||
bboxes: a float numpy array of shape [n, 5],
|
||||
where each row is (xmin, ymin, xmax, ymax, score).
|
||||
width: a float number.
|
||||
height: a float number.
|
||||
|
||||
Returns:
|
||||
dy, dx, edy, edx: a int numpy arrays of shape [n],
|
||||
coordinates of the boxes with respect to the cutouts.
|
||||
y, x, ey, ex: a int numpy arrays of shape [n],
|
||||
corrected ymin, xmin, ymax, xmax.
|
||||
h, w: a int numpy arrays of shape [n],
|
||||
just heights and widths of boxes.
|
||||
|
||||
in the following order:
|
||||
[dy, edy, dx, edx, y, ey, x, ex, w, h].
|
||||
"""
|
||||
|
||||
x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
|
||||
w, h = x2 - x1 + 1.0, y2 - y1 + 1.0
|
||||
num_boxes = bboxes.shape[0]
|
||||
|
||||
# 'e' stands for end
|
||||
# (x, y) -> (ex, ey)
|
||||
x, y, ex, ey = x1, y1, x2, y2
|
||||
|
||||
# we need to cut out a box from the image.
|
||||
# (x, y, ex, ey) are corrected coordinates of the box
|
||||
# in the image.
|
||||
# (dx, dy, edx, edy) are coordinates of the box in the cutout
|
||||
# from the image.
|
||||
dx, dy = np.zeros((num_boxes, )), np.zeros((num_boxes, ))
|
||||
edx, edy = w.copy() - 1.0, h.copy() - 1.0
|
||||
|
||||
# if box's bottom right corner is too far right
|
||||
ind = np.where(ex > width - 1.0)[0]
|
||||
edx[ind] = w[ind] + width - 2.0 - ex[ind]
|
||||
ex[ind] = width - 1.0
|
||||
|
||||
# if box's bottom right corner is too low
|
||||
ind = np.where(ey > height - 1.0)[0]
|
||||
edy[ind] = h[ind] + height - 2.0 - ey[ind]
|
||||
ey[ind] = height - 1.0
|
||||
|
||||
# if box's top left corner is too far left
|
||||
ind = np.where(x < 0.0)[0]
|
||||
dx[ind] = 0.0 - x[ind]
|
||||
x[ind] = 0.0
|
||||
|
||||
# if box's top left corner is too high
|
||||
ind = np.where(y < 0.0)[0]
|
||||
dy[ind] = 0.0 - y[ind]
|
||||
y[ind] = 0.0
|
||||
|
||||
return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h]
|
||||
return_list = [i.astype('int32') for i in return_list]
|
||||
|
||||
return return_list
|
||||
|
||||
|
||||
def _preprocess(img):
|
||||
"""Preprocessing step before feeding the network.
|
||||
|
||||
Arguments:
|
||||
img: a float numpy array of shape [h, w, c].
|
||||
|
||||
Returns:
|
||||
a float numpy array of shape [1, c, h, w].
|
||||
"""
|
||||
img = img.transpose((2, 0, 1))
|
||||
img = np.expand_dims(img, 0)
|
||||
img = (img - 127.5) * 0.0078125
|
||||
return img
|
||||
153
modelscope/models/cv/face_detection/mtcnn/models/detector.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
from PIL import Image
|
||||
from torch.autograd import Variable
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
from .box_utils import calibrate_box, convert_to_square, get_image_boxes, nms
|
||||
from .first_stage import run_first_stage
|
||||
from .get_nets import ONet, PNet, RNet
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.face_detection, module_name=Models.mtcnn)
|
||||
class MtcnnFaceDetector(TorchModel):
|
||||
def __init__(self, model_path, device='cuda', **kwargs):
|
||||
super().__init__(model_path)
|
||||
cudnn.benchmark = True
|
||||
self.model_path = model_path
|
||||
self.device = device
|
||||
|
||||
self.pnet = PNet(model_path=os.path.join(self.model_path, 'pnet.npy'))
|
||||
self.rnet = RNet(model_path=os.path.join(self.model_path, 'rnet.npy'))
|
||||
self.onet = ONet(model_path=os.path.join(self.model_path, 'onet.npy'))
|
||||
|
||||
self.pnet = self.pnet.to(device)
|
||||
self.rnet = self.rnet.to(device)
|
||||
self.onet = self.onet.to(device)
|
||||
|
||||
conf_th = kwargs.get('conf_th')
|
||||
if conf_th is not None:
|
||||
self.threshods = [conf_th] * 3
|
||||
else:
|
||||
self.threshods = [0.7, 0.8, 0.9]
|
||||
|
||||
def forward(self, input):
|
||||
image = Image.fromarray(np.uint8(input['img'].cpu().numpy()))
|
||||
pnet = self.pnet
|
||||
rnet = self.rnet
|
||||
onet = self.onet
|
||||
onet.eval()
|
||||
|
||||
min_face_size = 20.0
|
||||
thresholds = self.threshods
|
||||
nms_thresholds = [0.7, 0.7, 0.7]
|
||||
|
||||
# BUILD AN IMAGE PYRAMID
|
||||
width, height = image.size
|
||||
min_length = min(height, width)
|
||||
|
||||
min_detection_size = 12
|
||||
factor = 0.707 # sqrt(0.5)
|
||||
|
||||
# scales for scaling the image
|
||||
scales = []
|
||||
|
||||
m = min_detection_size / min_face_size
|
||||
min_length *= m
|
||||
|
||||
factor_count = 0
|
||||
while min_length > min_detection_size:
|
||||
scales.append(m * factor**factor_count)
|
||||
min_length *= factor
|
||||
factor_count += 1
|
||||
|
||||
# STAGE 1
|
||||
|
||||
# it will be returned
|
||||
bounding_boxes = []
|
||||
|
||||
# run P-Net on different scales
|
||||
for s in scales:
|
||||
boxes = run_first_stage(image,
|
||||
pnet,
|
||||
scale=s,
|
||||
threshold=thresholds[0],
|
||||
device=self.device)
|
||||
bounding_boxes.append(boxes)
|
||||
|
||||
# collect boxes (and offsets, and scores) from different scales
|
||||
bounding_boxes = [i for i in bounding_boxes if i is not None]
|
||||
bounding_boxes = np.vstack(bounding_boxes)
|
||||
|
||||
keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0])
|
||||
bounding_boxes = bounding_boxes[keep]
|
||||
|
||||
# use offsets predicted by pnet to transform bounding boxes
|
||||
bounding_boxes = calibrate_box(bounding_boxes[:, 0:5],
|
||||
bounding_boxes[:, 5:])
|
||||
# shape [n_boxes, 5]
|
||||
|
||||
bounding_boxes = convert_to_square(bounding_boxes)
|
||||
bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
|
||||
|
||||
# STAGE 2
|
||||
|
||||
img_boxes = get_image_boxes(bounding_boxes, image, size=24)
|
||||
img_boxes = Variable(torch.FloatTensor(img_boxes), volatile=True)
|
||||
output = rnet(img_boxes.to(self.device))
|
||||
offsets = output[0].cpu().data.numpy() # shape [n_boxes, 4]
|
||||
probs = output[1].cpu().data.numpy() # shape [n_boxes, 2]
|
||||
|
||||
keep = np.where(probs[:, 1] > thresholds[1])[0]
|
||||
bounding_boxes = bounding_boxes[keep]
|
||||
bounding_boxes[:, 4] = probs[keep, 1].reshape((-1, ))
|
||||
offsets = offsets[keep]
|
||||
|
||||
keep = nms(bounding_boxes, nms_thresholds[1])
|
||||
bounding_boxes = bounding_boxes[keep]
|
||||
bounding_boxes = calibrate_box(bounding_boxes, offsets[keep])
|
||||
bounding_boxes = convert_to_square(bounding_boxes)
|
||||
bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
|
||||
|
||||
# STAGE 3
|
||||
|
||||
img_boxes = get_image_boxes(bounding_boxes, image, size=48)
|
||||
if len(img_boxes) == 0:
|
||||
return [], []
|
||||
img_boxes = Variable(torch.FloatTensor(img_boxes), volatile=True)
|
||||
output = onet(img_boxes.to(self.device))
|
||||
landmarks = output[0].cpu().data.numpy() # shape [n_boxes, 10]
|
||||
offsets = output[1].cpu().data.numpy() # shape [n_boxes, 4]
|
||||
probs = output[2].cpu().data.numpy() # shape [n_boxes, 2]
|
||||
|
||||
keep = np.where(probs[:, 1] > thresholds[2])[0]
|
||||
bounding_boxes = bounding_boxes[keep]
|
||||
bounding_boxes[:, 4] = probs[keep, 1].reshape((-1, ))
|
||||
offsets = offsets[keep]
|
||||
landmarks = landmarks[keep]
|
||||
|
||||
# compute landmark points
|
||||
width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0
|
||||
height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0
|
||||
xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1]
|
||||
landmarks[:, 0:5] = np.expand_dims(
|
||||
xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5]
|
||||
landmarks[:, 5:10] = np.expand_dims(
|
||||
ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10]
|
||||
|
||||
bounding_boxes = calibrate_box(bounding_boxes, offsets)
|
||||
keep = nms(bounding_boxes, nms_thresholds[2], mode='min')
|
||||
bounding_boxes = bounding_boxes[keep]
|
||||
landmarks = landmarks[keep]
|
||||
landmarks = landmarks.reshape(-1, 2, 5).transpose(
|
||||
(0, 2, 1)).reshape(-1, 10)
|
||||
|
||||
return bounding_boxes, landmarks
|
||||
100
modelscope/models/cv/face_detection/mtcnn/models/first_stage.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.autograd import Variable
|
||||
|
||||
from .box_utils import _preprocess, nms
|
||||
|
||||
|
||||
def run_first_stage(image, net, scale, threshold, device='cuda'):
|
||||
"""Run P-Net, generate bounding boxes, and do NMS.
|
||||
|
||||
Arguments:
|
||||
image: an instance of PIL.Image.
|
||||
net: an instance of pytorch's nn.Module, P-Net.
|
||||
scale: a float number,
|
||||
scale width and height of the image by this number.
|
||||
threshold: a float number,
|
||||
threshold on the probability of a face when generating
|
||||
bounding boxes from predictions of the net.
|
||||
|
||||
Returns:
|
||||
a float numpy array of shape [n_boxes, 9],
|
||||
bounding boxes with scores and offsets (4 + 1 + 4).
|
||||
"""
|
||||
|
||||
# scale the image and convert it to a float array
|
||||
width, height = image.size
|
||||
sw, sh = math.ceil(width * scale), math.ceil(height * scale)
|
||||
img = image.resize((sw, sh), Image.BILINEAR)
|
||||
img = np.asarray(img, 'float32')
|
||||
|
||||
img = Variable(torch.FloatTensor(_preprocess(img)),
|
||||
volatile=True).to(device)
|
||||
output = net(img)
|
||||
probs = output[1].cpu().data.numpy()[0, 1, :, :]
|
||||
offsets = output[0].cpu().data.numpy()
|
||||
# probs: probability of a face at each sliding window
|
||||
# offsets: transformations to true bounding boxes
|
||||
|
||||
boxes = _generate_bboxes(probs, offsets, scale, threshold)
|
||||
if len(boxes) == 0:
|
||||
return None
|
||||
|
||||
keep = nms(boxes[:, 0:5], overlap_threshold=0.5)
|
||||
return boxes[keep]
|
||||
|
||||
|
||||
def _generate_bboxes(probs, offsets, scale, threshold):
|
||||
"""Generate bounding boxes at places
|
||||
where there is probably a face.
|
||||
|
||||
Arguments:
|
||||
probs: a float numpy array of shape [n, m].
|
||||
offsets: a float numpy array of shape [1, 4, n, m].
|
||||
scale: a float number,
|
||||
width and height of the image were scaled by this number.
|
||||
threshold: a float number.
|
||||
|
||||
Returns:
|
||||
a float numpy array of shape [n_boxes, 9]
|
||||
"""
|
||||
|
||||
# applying P-Net is equivalent, in some sense, to
|
||||
# moving 12x12 window with stride 2
|
||||
stride = 2
|
||||
cell_size = 12
|
||||
|
||||
# indices of boxes where there is probably a face
|
||||
inds = np.where(probs > threshold)
|
||||
|
||||
if inds[0].size == 0:
|
||||
return np.array([])
|
||||
|
||||
# transformations of bounding boxes
|
||||
tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)]
|
||||
# they are defined as:
|
||||
# w = x2 - x1 + 1
|
||||
# h = y2 - y1 + 1
|
||||
# x1_true = x1 + tx1*w
|
||||
# x2_true = x2 + tx2*w
|
||||
# y1_true = y1 + ty1*h
|
||||
# y2_true = y2 + ty2*h
|
||||
|
||||
offsets = np.array([tx1, ty1, tx2, ty2])
|
||||
score = probs[inds[0], inds[1]]
|
||||
|
||||
# P-Net is applied to scaled images
|
||||
# so we need to rescale bounding boxes back
|
||||
bounding_boxes = np.vstack([
|
||||
np.round((stride * inds[1] + 1.0) / scale),
|
||||
np.round((stride * inds[0] + 1.0) / scale),
|
||||
np.round((stride * inds[1] + 1.0 + cell_size) / scale),
|
||||
np.round((stride * inds[0] + 1.0 + cell_size) / scale), score, offsets
|
||||
])
|
||||
# why one is added?
|
||||
|
||||
return bounding_boxes.T
|
||||
156
modelscope/models/cv/face_detection/mtcnn/models/get_nets.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def __init__(self):
|
||||
super(Flatten, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Arguments:
|
||||
x: a float tensor with shape [batch_size, c, h, w].
|
||||
Returns:
|
||||
a float tensor with shape [batch_size, c*h*w].
|
||||
"""
|
||||
|
||||
# without this pretrained model isn't working
|
||||
x = x.transpose(3, 2).contiguous()
|
||||
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
class PNet(nn.Module):
|
||||
def __init__(self, model_path=None):
|
||||
|
||||
super(PNet, self).__init__()
|
||||
|
||||
# suppose we have input with size HxW, then
|
||||
# after first layer: H - 2,
|
||||
# after pool: ceil((H - 2)/2),
|
||||
# after second conv: ceil((H - 2)/2) - 2,
|
||||
# after last conv: ceil((H - 2)/2) - 4,
|
||||
# and the same for W
|
||||
|
||||
self.features = nn.Sequential(
|
||||
OrderedDict([('conv1', nn.Conv2d(3, 10, 3, 1)),
|
||||
('prelu1', nn.PReLU(10)),
|
||||
('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)),
|
||||
('conv2', nn.Conv2d(10, 16, 3, 1)),
|
||||
('prelu2', nn.PReLU(16)),
|
||||
('conv3', nn.Conv2d(16, 32, 3, 1)),
|
||||
('prelu3', nn.PReLU(32))]))
|
||||
|
||||
self.conv4_1 = nn.Conv2d(32, 2, 1, 1)
|
||||
self.conv4_2 = nn.Conv2d(32, 4, 1, 1)
|
||||
|
||||
weights = np.load(model_path, allow_pickle=True)[()]
|
||||
for n, p in self.named_parameters():
|
||||
p.data = torch.FloatTensor(weights[n])
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Arguments:
|
||||
x: a float tensor with shape [batch_size, 3, h, w].
|
||||
Returns:
|
||||
b: a float tensor with shape [batch_size, 4, h', w'].
|
||||
a: a float tensor with shape [batch_size, 2, h', w'].
|
||||
"""
|
||||
x = self.features(x)
|
||||
a = self.conv4_1(x)
|
||||
b = self.conv4_2(x)
|
||||
a = F.softmax(a)
|
||||
return b, a
|
||||
|
||||
|
||||
class RNet(nn.Module):
|
||||
def __init__(self, model_path=None):
|
||||
|
||||
super(RNet, self).__init__()
|
||||
|
||||
self.features = nn.Sequential(
|
||||
OrderedDict([('conv1', nn.Conv2d(3, 28, 3, 1)),
|
||||
('prelu1', nn.PReLU(28)),
|
||||
('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)),
|
||||
('conv2', nn.Conv2d(28, 48, 3, 1)),
|
||||
('prelu2', nn.PReLU(48)),
|
||||
('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)),
|
||||
('conv3', nn.Conv2d(48, 64, 2, 1)),
|
||||
('prelu3', nn.PReLU(64)), ('flatten', Flatten()),
|
||||
('conv4', nn.Linear(576, 128)),
|
||||
('prelu4', nn.PReLU(128))]))
|
||||
|
||||
self.conv5_1 = nn.Linear(128, 2)
|
||||
self.conv5_2 = nn.Linear(128, 4)
|
||||
|
||||
weights = np.load(model_path, allow_pickle=True)[()]
|
||||
for n, p in self.named_parameters():
|
||||
p.data = torch.FloatTensor(weights[n])
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Arguments:
|
||||
x: a float tensor with shape [batch_size, 3, h, w].
|
||||
Returns:
|
||||
b: a float tensor with shape [batch_size, 4].
|
||||
a: a float tensor with shape [batch_size, 2].
|
||||
"""
|
||||
x = self.features(x)
|
||||
a = self.conv5_1(x)
|
||||
b = self.conv5_2(x)
|
||||
a = F.softmax(a)
|
||||
return b, a
|
||||
|
||||
|
||||
class ONet(nn.Module):
|
||||
def __init__(self, model_path=None):
|
||||
|
||||
super(ONet, self).__init__()
|
||||
|
||||
self.features = nn.Sequential(
|
||||
OrderedDict([
|
||||
('conv1', nn.Conv2d(3, 32, 3, 1)),
|
||||
('prelu1', nn.PReLU(32)),
|
||||
('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)),
|
||||
('conv2', nn.Conv2d(32, 64, 3, 1)),
|
||||
('prelu2', nn.PReLU(64)),
|
||||
('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)),
|
||||
('conv3', nn.Conv2d(64, 64, 3, 1)),
|
||||
('prelu3', nn.PReLU(64)),
|
||||
('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)),
|
||||
('conv4', nn.Conv2d(64, 128, 2, 1)),
|
||||
('prelu4', nn.PReLU(128)),
|
||||
('flatten', Flatten()),
|
||||
('conv5', nn.Linear(1152, 256)),
|
||||
('drop5', nn.Dropout(0.25)),
|
||||
('prelu5', nn.PReLU(256)),
|
||||
]))
|
||||
|
||||
self.conv6_1 = nn.Linear(256, 2)
|
||||
self.conv6_2 = nn.Linear(256, 4)
|
||||
self.conv6_3 = nn.Linear(256, 10)
|
||||
|
||||
weights = np.load(model_path, allow_pickle=True)[()]
|
||||
for n, p in self.named_parameters():
|
||||
p.data = torch.FloatTensor(weights[n])
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Arguments:
|
||||
x: a float tensor with shape [batch_size, 3, h, w].
|
||||
Returns:
|
||||
c: a float tensor with shape [batch_size, 10].
|
||||
b: a float tensor with shape [batch_size, 4].
|
||||
a: a float tensor with shape [batch_size, 2].
|
||||
"""
|
||||
x = self.features(x)
|
||||
a = self.conv6_1(x)
|
||||
b = self.conv6_2(x)
|
||||
c = self.conv6_3(x)
|
||||
a = F.softmax(a)
|
||||
return c, b, a
|
||||
96
modelscope/models/cv/face_detection/peppa_pig_face/LK/lk.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# The implementation here is modified based on InsightFace_Pytorch, originally Apache License and publicly available
|
||||
# at https://github.com/610265158/Peppa_Pig_Face_Engine
|
||||
import numpy as np
|
||||
|
||||
|
||||
class GroupTrack():
|
||||
def __init__(self):
|
||||
self.old_frame = None
|
||||
self.previous_landmarks_set = None
|
||||
self.with_landmark = True
|
||||
self.thres = 1
|
||||
self.alpha = 0.95
|
||||
self.iou_thres = 0.5
|
||||
|
||||
def calculate(self, img, current_landmarks_set):
|
||||
if self.previous_landmarks_set is None:
|
||||
self.previous_landmarks_set = current_landmarks_set
|
||||
result = current_landmarks_set
|
||||
else:
|
||||
previous_lm_num = self.previous_landmarks_set.shape[0]
|
||||
if previous_lm_num == 0:
|
||||
self.previous_landmarks_set = current_landmarks_set
|
||||
result = current_landmarks_set
|
||||
return result
|
||||
else:
|
||||
result = []
|
||||
for i in range(current_landmarks_set.shape[0]):
|
||||
not_in_flag = True
|
||||
for j in range(previous_lm_num):
|
||||
if self.iou(current_landmarks_set[i],
|
||||
self.previous_landmarks_set[j]
|
||||
) > self.iou_thres:
|
||||
result.append(
|
||||
self.smooth(current_landmarks_set[i],
|
||||
self.previous_landmarks_set[j]))
|
||||
not_in_flag = False
|
||||
break
|
||||
if not_in_flag:
|
||||
result.append(current_landmarks_set[i])
|
||||
|
||||
result = np.array(result)
|
||||
self.previous_landmarks_set = result
|
||||
|
||||
return result
|
||||
|
||||
def iou(self, p_set0, p_set1):
|
||||
rec1 = [
|
||||
np.min(p_set0[:, 0]),
|
||||
np.min(p_set0[:, 1]),
|
||||
np.max(p_set0[:, 0]),
|
||||
np.max(p_set0[:, 1])
|
||||
]
|
||||
rec2 = [
|
||||
np.min(p_set1[:, 0]),
|
||||
np.min(p_set1[:, 1]),
|
||||
np.max(p_set1[:, 0]),
|
||||
np.max(p_set1[:, 1])
|
||||
]
|
||||
|
||||
# computing area of each rectangles
|
||||
S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
|
||||
S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
|
||||
|
||||
# computing the sum_area
|
||||
sum_area = S_rec1 + S_rec2
|
||||
|
||||
# find the each edge of intersect rectangle
|
||||
x1 = max(rec1[0], rec2[0])
|
||||
y1 = max(rec1[1], rec2[1])
|
||||
x2 = min(rec1[2], rec2[2])
|
||||
y2 = min(rec1[3], rec2[3])
|
||||
|
||||
# judge if there is an intersect
|
||||
intersect = max(0, x2 - x1) * max(0, y2 - y1)
|
||||
|
||||
iou = intersect / (sum_area - intersect)
|
||||
return iou
|
||||
|
||||
def smooth(self, now_landmarks, previous_landmarks):
|
||||
result = []
|
||||
for i in range(now_landmarks.shape[0]):
|
||||
x = now_landmarks[i][0] - previous_landmarks[i][0]
|
||||
y = now_landmarks[i][1] - previous_landmarks[i][1]
|
||||
dis = np.sqrt(np.square(x) + np.square(y))
|
||||
if dis < self.thres:
|
||||
result.append(previous_landmarks[i])
|
||||
else:
|
||||
result.append(
|
||||
self.do_moving_average(now_landmarks[i],
|
||||
previous_landmarks[i]))
|
||||
|
||||
return np.array(result)
|
||||
|
||||
def do_moving_average(self, p_now, p_previous):
|
||||
p = self.alpha * p_now + (1 - self.alpha) * p_previous
|
||||
return p
|
||||
@@ -0,0 +1,113 @@
|
||||
# The implementation here is modified based on InsightFace_Pytorch, originally Apache License and publicly available
|
||||
# at https://github.com/610265158/Peppa_Pig_Face_Engine
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
if tf.__version__ >= '2.0':
|
||||
tf = tf.compat.v1
|
||||
|
||||
|
||||
class FaceDetector:
|
||||
def __init__(self, dir):
|
||||
|
||||
self.model_path = dir + '/detector.pb'
|
||||
self.thres = 0.8
|
||||
self.input_shape = (512, 512, 3)
|
||||
self.pixel_means = np.array([123., 116., 103.])
|
||||
|
||||
self._graph = tf.Graph()
|
||||
|
||||
with self._graph.as_default():
|
||||
self._graph, self._sess = self.init_model(self.model_path)
|
||||
|
||||
self.input_image = tf.get_default_graph().get_tensor_by_name(
|
||||
'tower_0/images:0')
|
||||
self.training = tf.get_default_graph().get_tensor_by_name(
|
||||
'training_flag:0')
|
||||
self.output_ops = [
|
||||
tf.get_default_graph().get_tensor_by_name('tower_0/boxes:0'),
|
||||
tf.get_default_graph().get_tensor_by_name('tower_0/scores:0'),
|
||||
tf.get_default_graph().get_tensor_by_name(
|
||||
'tower_0/num_detections:0'),
|
||||
]
|
||||
|
||||
def __call__(self, image):
|
||||
|
||||
image, scale_x, scale_y = self.preprocess(
|
||||
image,
|
||||
target_width=self.input_shape[1],
|
||||
target_height=self.input_shape[0])
|
||||
|
||||
image = np.expand_dims(image, 0)
|
||||
|
||||
boxes, scores, num_boxes = self._sess.run(self.output_ops,
|
||||
feed_dict={
|
||||
self.input_image: image,
|
||||
self.training: False
|
||||
})
|
||||
|
||||
num_boxes = num_boxes[0]
|
||||
boxes = boxes[0][:num_boxes]
|
||||
|
||||
scores = scores[0][:num_boxes]
|
||||
|
||||
to_keep = scores > self.thres
|
||||
boxes = boxes[to_keep]
|
||||
scores = scores[to_keep]
|
||||
|
||||
y1 = self.input_shape[0] / scale_y
|
||||
x1 = self.input_shape[1] / scale_x
|
||||
y2 = self.input_shape[0] / scale_y
|
||||
x2 = self.input_shape[1] / scale_x
|
||||
scaler = np.array([y1, x1, y2, x2], dtype='float32')
|
||||
boxes = boxes * scaler
|
||||
|
||||
scores = np.expand_dims(scores, 0).reshape([-1, 1])
|
||||
|
||||
for i in range(boxes.shape[0]):
|
||||
boxes[i] = np.array(
|
||||
[boxes[i][1], boxes[i][0], boxes[i][3], boxes[i][2]])
|
||||
return np.concatenate([boxes, scores], axis=1)
|
||||
|
||||
def preprocess(self, image, target_height, target_width, label=None):
|
||||
|
||||
h, w, c = image.shape
|
||||
|
||||
bimage = np.zeros(shape=[target_height, target_width, c],
|
||||
dtype=image.dtype) + np.array(self.pixel_means,
|
||||
dtype=image.dtype)
|
||||
long_side = max(h, w)
|
||||
|
||||
scale_x = scale_y = target_height / long_side
|
||||
|
||||
image = cv2.resize(image, None, fx=scale_x, fy=scale_y)
|
||||
|
||||
h_, w_, _ = image.shape
|
||||
bimage[:h_, :w_, :] = image
|
||||
|
||||
return bimage, scale_x, scale_y
|
||||
|
||||
def init_model(self, *args):
|
||||
pb_path = args[0]
|
||||
|
||||
def init_pb(model_path):
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.per_process_gpu_memory_fraction = 0.2
|
||||
compute_graph = tf.Graph()
|
||||
compute_graph.as_default()
|
||||
sess = tf.Session(config=config)
|
||||
with tf.gfile.GFile(model_path, 'rb') as fid:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(fid.read())
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
|
||||
return (compute_graph, sess)
|
||||
|
||||
model = init_pb(pb_path)
|
||||
|
||||
graph = model[0]
|
||||
sess = model[1]
|
||||
|
||||
return graph, sess
|
||||
@@ -0,0 +1,153 @@
|
||||
# The implementation here is modified based on InsightFace_Pytorch, originally Apache License and publicly available
|
||||
# at https://github.com/610265158/Peppa_Pig_Face_Engine
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
if tf.__version__ >= '2.0':
|
||||
tf = tf.compat.v1
|
||||
|
||||
|
||||
class FaceLandmark:
|
||||
def __init__(self, dir):
|
||||
self.model_path = dir + '/keypoints.pb'
|
||||
self.min_face = 60
|
||||
self.keypoint_num = 136
|
||||
self.pixel_means = np.array([123., 116., 103.])
|
||||
self.kp_extend_range = [0.2, 0.3]
|
||||
self.kp_shape = (160, 160, 3)
|
||||
|
||||
self._graph = tf.Graph()
|
||||
|
||||
with self._graph.as_default():
|
||||
|
||||
self._graph, self._sess = self.init_model(self.model_path)
|
||||
self.img_input = tf.get_default_graph().get_tensor_by_name(
|
||||
'tower_0/images:0')
|
||||
self.embeddings = tf.get_default_graph().get_tensor_by_name(
|
||||
'tower_0/prediction:0')
|
||||
self.training = tf.get_default_graph().get_tensor_by_name(
|
||||
'training_flag:0')
|
||||
|
||||
self.landmark = self.embeddings[:, :self.keypoint_num]
|
||||
self.headpose = self.embeddings[:, -7:-4] * 90.
|
||||
self.state = tf.nn.sigmoid(self.embeddings[:, -4:])
|
||||
|
||||
def __call__(self, img, bboxes):
|
||||
landmark_result = []
|
||||
state_result = []
|
||||
for i, bbox in enumerate(bboxes):
|
||||
landmark, state = self._one_shot_run(img, bbox, i)
|
||||
if landmark is not None:
|
||||
landmark_result.append(landmark)
|
||||
state_result.append(state)
|
||||
return np.array(landmark_result), np.array(state_result)
|
||||
|
||||
def simple_run(self, cropped_img):
|
||||
with self._graph.as_default():
|
||||
|
||||
cropped_img = np.expand_dims(cropped_img, axis=0)
|
||||
landmark, p, states = self._sess.run(
|
||||
[self.landmark, self.headpose, self.state],
|
||||
feed_dict={
|
||||
self.img_input: cropped_img,
|
||||
self.training: False
|
||||
})
|
||||
|
||||
return landmark, states
|
||||
|
||||
def _one_shot_run(self, image, bbox, i):
|
||||
|
||||
bbox_width = bbox[2] - bbox[0]
|
||||
bbox_height = bbox[3] - bbox[1]
|
||||
if (bbox_width <= self.min_face and bbox_height <= self.min_face):
|
||||
return None, None
|
||||
add = int(max(bbox_width, bbox_height))
|
||||
bimg = cv2.copyMakeBorder(image,
|
||||
add,
|
||||
add,
|
||||
add,
|
||||
add,
|
||||
borderType=cv2.BORDER_CONSTANT,
|
||||
value=self.pixel_means)
|
||||
bbox += add
|
||||
|
||||
one_edge = (1 + 2 * self.kp_extend_range[0]) * bbox_width
|
||||
center = [(bbox[0] + bbox[2]) // 2, (bbox[1] + bbox[3]) // 2]
|
||||
|
||||
bbox[0] = center[0] - one_edge // 2
|
||||
bbox[1] = center[1] - one_edge // 2
|
||||
bbox[2] = center[0] + one_edge // 2
|
||||
bbox[3] = center[1] + one_edge // 2
|
||||
|
||||
bbox = bbox.astype(np.int)
|
||||
crop_image = bimg[bbox[1]:bbox[3], bbox[0]:bbox[2], :]
|
||||
h, w, _ = crop_image.shape
|
||||
crop_image = cv2.resize(crop_image,
|
||||
(self.kp_shape[1], self.kp_shape[0]))
|
||||
crop_image = crop_image.astype(np.float32)
|
||||
|
||||
keypoints, state = self.simple_run(crop_image)
|
||||
|
||||
res = keypoints[0][:self.keypoint_num].reshape((-1, 2))
|
||||
res[:, 0] = res[:, 0] * w / self.kp_shape[1]
|
||||
res[:, 1] = res[:, 1] * h / self.kp_shape[0]
|
||||
|
||||
landmark = []
|
||||
for _index in range(res.shape[0]):
|
||||
x_y = res[_index]
|
||||
landmark.append([
|
||||
int(x_y[0] * self.kp_shape[0] + bbox[0] - add),
|
||||
int(x_y[1] * self.kp_shape[1] + bbox[1] - add)
|
||||
])
|
||||
|
||||
landmark = np.array(landmark, np.float32)
|
||||
|
||||
return landmark, state
|
||||
|
||||
def init_model(self, *args):
|
||||
|
||||
if len(args) == 1:
|
||||
use_pb = True
|
||||
pb_path = args[0]
|
||||
else:
|
||||
use_pb = False
|
||||
meta_path = args[0]
|
||||
restore_model_path = args[1]
|
||||
|
||||
def ini_ckpt():
|
||||
graph = tf.Graph()
|
||||
graph.as_default()
|
||||
configProto = tf.ConfigProto()
|
||||
configProto.gpu_options.allow_growth = True
|
||||
sess = tf.Session(config=configProto)
|
||||
# load_model(model_path, sess)
|
||||
saver = tf.train.import_meta_graph(meta_path)
|
||||
saver.restore(sess, restore_model_path)
|
||||
|
||||
print('Model restored!')
|
||||
return (graph, sess)
|
||||
|
||||
def init_pb(model_path):
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.per_process_gpu_memory_fraction = 0.2
|
||||
compute_graph = tf.Graph()
|
||||
compute_graph.as_default()
|
||||
sess = tf.Session(config=config)
|
||||
with tf.gfile.GFile(model_path, 'rb') as fid:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(fid.read())
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
|
||||
return (compute_graph, sess)
|
||||
|
||||
if use_pb:
|
||||
model = init_pb(pb_path)
|
||||
else:
|
||||
model = ini_ckpt()
|
||||
|
||||
graph = model[0]
|
||||
sess = model[1]
|
||||
|
||||
return graph, sess
|
||||
136
modelscope/models/cv/face_detection/peppa_pig_face/facer.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# The implementation here is modified based on InsightFace_Pytorch, originally Apache License and publicly available
|
||||
# at https://github.com/610265158/Peppa_Pig_Face_Engine
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from .face_detector import FaceDetector
|
||||
from .face_landmark import FaceLandmark
|
||||
from .LK.lk import GroupTrack
|
||||
|
||||
|
||||
class FaceAna():
|
||||
def __init__(self, model_dir):
|
||||
self.face_detector = FaceDetector(model_dir)
|
||||
self.face_landmark = FaceLandmark(model_dir)
|
||||
self.trace = GroupTrack()
|
||||
|
||||
self.track_box = None
|
||||
self.previous_image = None
|
||||
self.previous_box = None
|
||||
|
||||
self.diff_thres = 5
|
||||
self.top_k = 10
|
||||
self.iou_thres = 0.5
|
||||
self.alpha = 0.3
|
||||
|
||||
def run(self, image):
|
||||
|
||||
boxes = self.face_detector(image)
|
||||
|
||||
if boxes.shape[0] > self.top_k:
|
||||
boxes = self.sort(boxes)
|
||||
|
||||
boxes_return = np.array(boxes)
|
||||
landmarks, states = self.face_landmark(image, boxes)
|
||||
|
||||
if 1:
|
||||
track = []
|
||||
for i in range(landmarks.shape[0]):
|
||||
track.append([
|
||||
np.min(landmarks[i][:, 0]),
|
||||
np.min(landmarks[i][:, 1]),
|
||||
np.max(landmarks[i][:, 0]),
|
||||
np.max(landmarks[i][:, 1])
|
||||
])
|
||||
tmp_box = np.array(track)
|
||||
|
||||
self.track_box = self.judge_boxs(boxes_return, tmp_box)
|
||||
|
||||
self.track_box, landmarks = self.sort_res(self.track_box, landmarks)
|
||||
return self.track_box, landmarks, states
|
||||
|
||||
def sort_res(self, bboxes, points):
|
||||
area = []
|
||||
for bbox in bboxes:
|
||||
bbox_width = bbox[2] - bbox[0]
|
||||
bbox_height = bbox[3] - bbox[1]
|
||||
area.append(bbox_height * bbox_width)
|
||||
|
||||
area = np.array(area)
|
||||
picked = area.argsort()[::-1]
|
||||
sorted_bboxes = [bboxes[x] for x in picked]
|
||||
sorted_points = [points[x] for x in picked]
|
||||
return np.array(sorted_bboxes), np.array(sorted_points)
|
||||
|
||||
def diff_frames(self, previous_frame, image):
|
||||
if previous_frame is None:
|
||||
return True
|
||||
else:
|
||||
_diff = cv2.absdiff(previous_frame, image)
|
||||
diff = np.sum(
|
||||
_diff) / previous_frame.shape[0] / previous_frame.shape[1] / 3.
|
||||
return diff > self.diff_thres
|
||||
|
||||
def sort(self, bboxes):
|
||||
if self.top_k > 100:
|
||||
return bboxes
|
||||
area = []
|
||||
for bbox in bboxes:
|
||||
|
||||
bbox_width = bbox[2] - bbox[0]
|
||||
bbox_height = bbox[3] - bbox[1]
|
||||
area.append(bbox_height * bbox_width)
|
||||
|
||||
area = np.array(area)
|
||||
|
||||
picked = area.argsort()[-self.top_k:][::-1]
|
||||
sorted_bboxes = [bboxes[x] for x in picked]
|
||||
return np.array(sorted_bboxes)
|
||||
|
||||
def judge_boxs(self, previuous_bboxs, now_bboxs):
|
||||
def iou(rec1, rec2):
|
||||
|
||||
S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
|
||||
S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
|
||||
|
||||
sum_area = S_rec1 + S_rec2
|
||||
|
||||
x1 = max(rec1[0], rec2[0])
|
||||
y1 = max(rec1[1], rec2[1])
|
||||
x2 = min(rec1[2], rec2[2])
|
||||
y2 = min(rec1[3], rec2[3])
|
||||
|
||||
intersect = max(0, x2 - x1) * max(0, y2 - y1)
|
||||
|
||||
return intersect / (sum_area - intersect)
|
||||
|
||||
if previuous_bboxs is None:
|
||||
return now_bboxs
|
||||
|
||||
result = []
|
||||
|
||||
for i in range(now_bboxs.shape[0]):
|
||||
contain = False
|
||||
for j in range(previuous_bboxs.shape[0]):
|
||||
if iou(now_bboxs[i], previuous_bboxs[j]) > self.iou_thres:
|
||||
result.append(self.smooth(now_bboxs[i],
|
||||
previuous_bboxs[j]))
|
||||
contain = True
|
||||
break
|
||||
if not contain:
|
||||
result.append(now_bboxs[i])
|
||||
|
||||
return np.array(result)
|
||||
|
||||
def smooth(self, now_box, previous_box):
|
||||
|
||||
return self.do_moving_average(now_box[:4], previous_box[:4])
|
||||
|
||||
def do_moving_average(self, p_now, p_previous):
|
||||
p = self.alpha * p_now + (1 - self.alpha) * p_previous
|
||||
return p
|
||||
|
||||
def reset(self):
|
||||
self.track_box = None
|
||||
self.previous_image = None
|
||||
self.previous_box = None
|
||||
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .detection import RetinaFaceDetection
|
||||
136
modelscope/models/cv/face_detection/retinaface/detection.py
Executable file
@@ -0,0 +1,136 @@
|
||||
# The implementation is based on resnet, available at https://github.com/biubug6/Pytorch_Retinaface
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Tensor, TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
|
||||
from .models.retinaface import RetinaFace
|
||||
from .utils import PriorBox, decode, decode_landm, py_cpu_nms
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.face_detection, module_name=Models.retinaface)
|
||||
class RetinaFaceDetection(TorchModel):
|
||||
def __init__(self, model_path, device='cuda', **kwargs):
|
||||
super().__init__(model_path)
|
||||
cudnn.benchmark = True
|
||||
self.model_path = model_path
|
||||
self.cfg = Config.from_file(
|
||||
model_path.replace(ModelFile.TORCH_MODEL_FILE,
|
||||
ModelFile.CONFIGURATION))['models']
|
||||
self.net = RetinaFace(cfg=self.cfg)
|
||||
self.load_model()
|
||||
self.device = device
|
||||
self.net = self.net.to(self.device)
|
||||
|
||||
self.conf_th = kwargs.get('conf_th', 0.82)
|
||||
self.nms_th = kwargs.get('nms_th', 0.4)
|
||||
self.mean = torch.tensor([[[[104]], [[117]], [[123]]]]).to(device)
|
||||
|
||||
def check_keys(self, pretrained_state_dict):
|
||||
ckpt_keys = set(pretrained_state_dict.keys())
|
||||
model_keys = set(self.net.state_dict().keys())
|
||||
used_pretrained_keys = model_keys & ckpt_keys
|
||||
assert len(
|
||||
used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
|
||||
return True
|
||||
|
||||
def remove_prefix(self, state_dict, prefix):
|
||||
new_state_dict = dict()
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith(prefix):
|
||||
new_state_dict[k[len(prefix):]] = v
|
||||
else:
|
||||
new_state_dict[k] = v
|
||||
return new_state_dict
|
||||
|
||||
def load_model(self, load_to_cpu=False):
|
||||
pretrained_dict = torch.load(self.model_path,
|
||||
map_location=torch.device('cpu'))
|
||||
if 'state_dict' in pretrained_dict.keys():
|
||||
pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'],
|
||||
'module.')
|
||||
else:
|
||||
pretrained_dict = self.remove_prefix(pretrained_dict, 'module.')
|
||||
self.check_keys(pretrained_dict)
|
||||
self.net.load_state_dict(pretrained_dict, strict=False)
|
||||
self.net.eval()
|
||||
|
||||
def forward(self, input):
|
||||
img_raw = input['img'].cpu().numpy()
|
||||
img = np.float32(img_raw)
|
||||
|
||||
im_height, im_width = img.shape[:2]
|
||||
ss = 1.0
|
||||
# tricky
|
||||
if max(im_height, im_width) > 1500:
|
||||
ss = 1000.0 / max(im_height, im_width)
|
||||
img = cv2.resize(img, (0, 0), fx=ss, fy=ss)
|
||||
im_height, im_width = img.shape[:2]
|
||||
|
||||
scale = torch.Tensor(
|
||||
[img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
|
||||
img -= (104, 117, 123)
|
||||
img = img.transpose(2, 0, 1)
|
||||
img = torch.from_numpy(img).unsqueeze(0)
|
||||
img = img.to(self.device)
|
||||
scale = scale.to(self.device)
|
||||
|
||||
loc, conf, landms = self.net(img) # forward pass
|
||||
del img
|
||||
|
||||
top_k = 5000
|
||||
keep_top_k = 750
|
||||
|
||||
priorbox = PriorBox(self.cfg, image_size=(im_height, im_width))
|
||||
priors = priorbox.forward()
|
||||
priors = priors.to(self.device)
|
||||
prior_data = priors.data
|
||||
boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance'])
|
||||
boxes = boxes * scale
|
||||
boxes = boxes.cpu().numpy()
|
||||
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
|
||||
landms = decode_landm(landms.data.squeeze(0), prior_data,
|
||||
self.cfg['variance'])
|
||||
scale1 = torch.Tensor([
|
||||
im_width, im_height, im_width, im_height, im_width, im_height,
|
||||
im_width, im_height, im_width, im_height
|
||||
])
|
||||
scale1 = scale1.to(self.device)
|
||||
landms = landms * scale1
|
||||
landms = landms.cpu().numpy()
|
||||
|
||||
# ignore low scores
|
||||
inds = np.where(scores > self.conf_th)[0]
|
||||
boxes = boxes[inds]
|
||||
landms = landms[inds]
|
||||
scores = scores[inds]
|
||||
|
||||
# keep top-K before NMS
|
||||
order = scores.argsort()[::-1][:top_k]
|
||||
boxes = boxes[order]
|
||||
landms = landms[order]
|
||||
scores = scores[order]
|
||||
|
||||
# do NMS
|
||||
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32,
|
||||
copy=False)
|
||||
keep = py_cpu_nms(dets, self.nms_th)
|
||||
dets = dets[keep, :]
|
||||
landms = landms[keep]
|
||||
|
||||
# keep top-K faster NMS
|
||||
dets = dets[:keep_top_k, :]
|
||||
landms = landms[:keep_top_k, :]
|
||||
|
||||
landms = landms.reshape((-1, 5, 2))
|
||||
landms = landms.reshape(
|
||||
-1,
|
||||
10,
|
||||
)
|
||||
return dets / ss, landms / ss
|
||||
0
modelscope/models/cv/face_detection/retinaface/models/__init__.py
Executable file
162
modelscope/models/cv/face_detection/retinaface/models/net.py
Executable file
@@ -0,0 +1,162 @@
|
||||
# The implementation is based on resnet, available at https://github.com/biubug6/Pytorch_Retinaface
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models as models
|
||||
import torchvision.models._utils as _utils
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
def conv_bn(inp, oup, stride=1, leaky=0):
|
||||
return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.LeakyReLU(negative_slope=leaky, inplace=True))
|
||||
|
||||
|
||||
def conv_bn_no_relu(inp, oup, stride):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
|
||||
def conv_bn1X1(inp, oup, stride, leaky=0):
|
||||
return nn.Sequential(nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.LeakyReLU(negative_slope=leaky, inplace=True))
|
||||
|
||||
|
||||
def conv_dw(inp, oup, stride, leaky=0.1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
|
||||
nn.BatchNorm2d(inp),
|
||||
nn.LeakyReLU(negative_slope=leaky, inplace=True),
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.LeakyReLU(negative_slope=leaky, inplace=True),
|
||||
)
|
||||
|
||||
|
||||
class SSH(nn.Module):
|
||||
def __init__(self, in_channel, out_channel):
|
||||
super(SSH, self).__init__()
|
||||
assert out_channel % 4 == 0
|
||||
leaky = 0
|
||||
if (out_channel <= 64):
|
||||
leaky = 0.1
|
||||
self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
|
||||
|
||||
self.conv5X5_1 = conv_bn(in_channel,
|
||||
out_channel // 4,
|
||||
stride=1,
|
||||
leaky=leaky)
|
||||
self.conv5X5_2 = conv_bn_no_relu(out_channel // 4,
|
||||
out_channel // 4,
|
||||
stride=1)
|
||||
|
||||
self.conv7X7_2 = conv_bn(out_channel // 4,
|
||||
out_channel // 4,
|
||||
stride=1,
|
||||
leaky=leaky)
|
||||
self.conv7x7_3 = conv_bn_no_relu(out_channel // 4,
|
||||
out_channel // 4,
|
||||
stride=1)
|
||||
|
||||
def forward(self, input):
|
||||
conv3X3 = self.conv3X3(input)
|
||||
|
||||
conv5X5_1 = self.conv5X5_1(input)
|
||||
conv5X5 = self.conv5X5_2(conv5X5_1)
|
||||
|
||||
conv7X7_2 = self.conv7X7_2(conv5X5_1)
|
||||
conv7X7 = self.conv7x7_3(conv7X7_2)
|
||||
|
||||
out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class FPN(nn.Module):
|
||||
def __init__(self, in_channels_list, out_channels):
|
||||
super(FPN, self).__init__()
|
||||
leaky = 0
|
||||
if (out_channels <= 64):
|
||||
leaky = 0.1
|
||||
self.output1 = conv_bn1X1(in_channels_list[0],
|
||||
out_channels,
|
||||
stride=1,
|
||||
leaky=leaky)
|
||||
self.output2 = conv_bn1X1(in_channels_list[1],
|
||||
out_channels,
|
||||
stride=1,
|
||||
leaky=leaky)
|
||||
self.output3 = conv_bn1X1(in_channels_list[2],
|
||||
out_channels,
|
||||
stride=1,
|
||||
leaky=leaky)
|
||||
|
||||
self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
|
||||
self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
|
||||
|
||||
def forward(self, input):
|
||||
# names = list(input.keys())
|
||||
input = list(input.values())
|
||||
|
||||
output1 = self.output1(input[0])
|
||||
output2 = self.output2(input[1])
|
||||
output3 = self.output3(input[2])
|
||||
|
||||
up3 = F.interpolate(output3,
|
||||
size=[output2.size(2),
|
||||
output2.size(3)],
|
||||
mode='nearest')
|
||||
output2 = output2 + up3
|
||||
output2 = self.merge2(output2)
|
||||
|
||||
up2 = F.interpolate(output2,
|
||||
size=[output1.size(2),
|
||||
output1.size(3)],
|
||||
mode='nearest')
|
||||
output1 = output1 + up2
|
||||
output1 = self.merge1(output1)
|
||||
|
||||
out = [output1, output2, output3]
|
||||
return out
|
||||
|
||||
|
||||
class MobileNetV1(nn.Module):
|
||||
def __init__(self):
|
||||
super(MobileNetV1, self).__init__()
|
||||
self.stage1 = nn.Sequential(
|
||||
conv_bn(3, 8, 2, leaky=0.1), # 3
|
||||
conv_dw(8, 16, 1), # 7
|
||||
conv_dw(16, 32, 2), # 11
|
||||
conv_dw(32, 32, 1), # 19
|
||||
conv_dw(32, 64, 2), # 27
|
||||
conv_dw(64, 64, 1), # 43
|
||||
)
|
||||
self.stage2 = nn.Sequential(
|
||||
conv_dw(64, 128, 2), # 43 + 16 = 59
|
||||
conv_dw(128, 128, 1), # 59 + 32 = 91
|
||||
conv_dw(128, 128, 1), # 91 + 32 = 123
|
||||
conv_dw(128, 128, 1), # 123 + 32 = 155
|
||||
conv_dw(128, 128, 1), # 155 + 32 = 187
|
||||
conv_dw(128, 128, 1), # 187 + 32 = 219
|
||||
)
|
||||
self.stage3 = nn.Sequential(
|
||||
conv_dw(128, 256, 2), # 219 +3 2 = 241
|
||||
conv_dw(256, 256, 1), # 241 + 64 = 301
|
||||
)
|
||||
self.avg = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(256, 1000)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stage1(x)
|
||||
x = self.stage2(x)
|
||||
x = self.stage3(x)
|
||||
x = self.avg(x)
|
||||
x = x.view(-1, 256)
|
||||
x = self.fc(x)
|
||||
return x
|
||||