Files
EasyFace/tests/trainers/test_dialog_modeling_trainer.py
2023-03-02 11:17:26 +08:00

77 lines
2.5 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
import torch
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Preprocessors, Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode, ModelFile
from modelscope.utils.test_utils import test_level
class TestDialogModelingTrainer(unittest.TestCase):
model_id = 'damo/nlp_space_pretrained-dialog-model'
output_dir = './dialog_fintune_result'
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_trainer_with_model_and_args(self):
# download data set
data_multiwoz = MsDataset.load(
'MultiWoz2.0', download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)
data_dir = os.path.join(
data_multiwoz._hf_ds.config_kwargs['split_config']['train'],
'data')
# download model
model_dir = snapshot_download(self.model_id)
# dialog finetune config
def cfg_modify_fn(cfg):
config = {
'seed': 10,
'gpu': 1,
'use_data_distributed': False,
'valid_metric_name': '-loss',
'num_epochs': 60,
'save_dir': self.output_dir,
'token_loss': True,
'batch_size': 4,
'log_steps': 10,
'valid_steps': 0,
'save_checkpoint': True,
'save_summary': False,
'shuffle': True,
'sort_pool_size': 0
}
cfg.Trainer = config
cfg.use_gpu = torch.cuda.is_available() and config['gpu'] >= 1
return cfg
# trainer config
kwargs = dict(model_dir=model_dir,
cfg_name='gen_train_config.json',
data_dir=data_dir,
cfg_modify_fn=cfg_modify_fn)
trainer = build_trainer(name=Trainers.dialog_modeling_trainer,
default_args=kwargs)
assert trainer is not None
# todo: it takes too long time to train and evaluate. It will be optimized later.
"""
trainer.train()
checkpoint_path = os.path.join(self.output_dir,
ModelFile.TORCH_MODEL_BIN_FILE)
assert os.path.exists(checkpoint_path)
trainer.evaluate(checkpoint_path=checkpoint_path)
"""
if __name__ == '__main__':
unittest.main()