| @@ -24,12 +24,13 @@ from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, | |||
| @TRAINERS.register_module(module_name=Trainers.ofa_tasks) | |||
| class OFATrainer(EpochBasedTrainer): | |||
| def __init__(self, model: str, *args, **kwargs): | |||
| def __init__(self, model: str, cfg_file, work_dir, train_dataset, | |||
| eval_dataset, *args, **kwargs): | |||
| model = Model.from_pretrained(model) | |||
| model_dir = model.model_dir | |||
| cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||
| # cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||
| cfg = Config.from_file(cfg_file) | |||
| dataset = self._build_dataset_with_config(cfg) | |||
| # dataset = self._build_dataset_with_config(cfg) | |||
| preprocessor = { | |||
| ConfigKeys.train: | |||
| OfaPreprocessor( | |||
| @@ -41,7 +42,7 @@ class OFATrainer(EpochBasedTrainer): | |||
| # use torchrun launch | |||
| world_size = int(os.environ.get('WORLD_SIZE', 1)) | |||
| epoch_steps = math.ceil( | |||
| len(dataset['train']) / # noqa | |||
| len(train_dataset) / # noqa | |||
| (cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa | |||
| cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs | |||
| cfg.train.criterion.tokenizer = model.tokenizer | |||
| @@ -68,11 +69,11 @@ class OFATrainer(EpochBasedTrainer): | |||
| cfg_file=cfg_file, | |||
| model=model, | |||
| data_collator=collator, | |||
| train_dataset=dataset['train'], | |||
| eval_dataset=dataset['valid'], | |||
| train_dataset=train_dataset, | |||
| eval_dataset=eval_dataset, | |||
| preprocessor=preprocessor, | |||
| optimizers=(optimizer, lr_scheduler), | |||
| work_dir=cfg.train.work_dir, | |||
| work_dir=work_dir, | |||
| *args, | |||
| **kwargs, | |||
| ) | |||
| @@ -3,22 +3,51 @@ import glob | |||
| import os | |||
| import os.path as osp | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import DownloadMode | |||
| from modelscope.utils.test_utils import test_level | |||
| class TestOfaTrainer(unittest.TestCase): | |||
| def setUp(self): | |||
| column_map = {'premise': 'text', 'hypothesis': 'text2'} | |||
| data_train = MsDataset.load( | |||
| dataset_name='glue', | |||
| subset_name='mnli', | |||
| namespace='modelscope', | |||
| split='train[:100]', | |||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) | |||
| self.train_dataset = MsDataset.from_hf_dataset( | |||
| data_train._hf_ds.rename_columns(column_map)) | |||
| data_eval = MsDataset.load( | |||
| dataset_name='glue', | |||
| subset_name='mnli', | |||
| namespace='modelscope', | |||
| split='validation_matched[:8]', | |||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) | |||
| self.test_dataset = MsDataset.from_hf_dataset( | |||
| data_eval._hf_ds.rename_columns(column_map)) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_trainer(self): | |||
| os.environ['LOCAL_RANK'] = '0' | |||
| model_id = 'damo/ofa_text-classification_mnli_large_en' | |||
| default_args = {'model': model_id} | |||
| trainer = build_trainer( | |||
| name=Trainers.ofa_tasks, default_args=default_args) | |||
| kwargs = dict( | |||
| model=model_id, | |||
| cfg_file= | |||
| '/Users/running_you/.cache/modelscope/hub/damo/ofa_text-classification_mnli_large_en//configuration.json', | |||
| train_dataset=self.train_dataset, | |||
| eval_dataset=self.test_dataset, | |||
| work_dir='/Users/running_you/.cache/modelscope/hub/work/mnli') | |||
| trainer = build_trainer(name=Trainers.ofa_tasks, default_args=kwargs) | |||
| os.makedirs(trainer.work_dir, exist_ok=True) | |||
| trainer.train() | |||
| assert len( | |||