mirror of
https://gitcode.com/gh_mirrors/eas/EasyFace.git
synced 2025-12-30 04:52:28 +00:00
77 lines
2.5 KiB
Python
77 lines
2.5 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import os
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from modelscope.hub.snapshot_download import snapshot_download
|
|
from modelscope.metainfo import Preprocessors, Trainers
|
|
from modelscope.msdatasets import MsDataset
|
|
from modelscope.trainers import build_trainer
|
|
from modelscope.utils.constant import DownloadMode, ModelFile
|
|
from modelscope.utils.test_utils import test_level
|
|
|
|
|
|
class TestDialogModelingTrainer(unittest.TestCase):
|
|
|
|
model_id = 'damo/nlp_space_pretrained-dialog-model'
|
|
output_dir = './dialog_fintune_result'
|
|
|
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
|
def test_trainer_with_model_and_args(self):
|
|
# download data set
|
|
data_multiwoz = MsDataset.load(
|
|
'MultiWoz2.0', download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)
|
|
data_dir = os.path.join(
|
|
data_multiwoz._hf_ds.config_kwargs['split_config']['train'],
|
|
'data')
|
|
|
|
# download model
|
|
model_dir = snapshot_download(self.model_id)
|
|
|
|
# dialog finetune config
|
|
def cfg_modify_fn(cfg):
|
|
config = {
|
|
'seed': 10,
|
|
'gpu': 1,
|
|
'use_data_distributed': False,
|
|
'valid_metric_name': '-loss',
|
|
'num_epochs': 60,
|
|
'save_dir': self.output_dir,
|
|
'token_loss': True,
|
|
'batch_size': 4,
|
|
'log_steps': 10,
|
|
'valid_steps': 0,
|
|
'save_checkpoint': True,
|
|
'save_summary': False,
|
|
'shuffle': True,
|
|
'sort_pool_size': 0
|
|
}
|
|
|
|
cfg.Trainer = config
|
|
cfg.use_gpu = torch.cuda.is_available() and config['gpu'] >= 1
|
|
return cfg
|
|
|
|
# trainer config
|
|
kwargs = dict(model_dir=model_dir,
|
|
cfg_name='gen_train_config.json',
|
|
data_dir=data_dir,
|
|
cfg_modify_fn=cfg_modify_fn)
|
|
|
|
trainer = build_trainer(name=Trainers.dialog_modeling_trainer,
|
|
default_args=kwargs)
|
|
assert trainer is not None
|
|
|
|
# todo: it takes too long time to train and evaluate. It will be optimized later.
|
|
"""
|
|
trainer.train()
|
|
checkpoint_path = os.path.join(self.output_dir,
|
|
ModelFile.TORCH_MODEL_BIN_FILE)
|
|
assert os.path.exists(checkpoint_path)
|
|
trainer.evaluate(checkpoint_path=checkpoint_path)
|
|
"""
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|