| @@ -24,12 +24,13 @@ from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, | |||||
| @TRAINERS.register_module(module_name=Trainers.ofa_tasks) | @TRAINERS.register_module(module_name=Trainers.ofa_tasks) | ||||
| class OFATrainer(EpochBasedTrainer): | class OFATrainer(EpochBasedTrainer): | ||||
| def __init__(self, model: str, *args, **kwargs): | |||||
| def __init__(self, model: str, cfg_file, work_dir, train_dataset, | |||||
| eval_dataset, *args, **kwargs): | |||||
| model = Model.from_pretrained(model) | model = Model.from_pretrained(model) | ||||
| model_dir = model.model_dir | model_dir = model.model_dir | ||||
| cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||||
| # cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||||
| cfg = Config.from_file(cfg_file) | cfg = Config.from_file(cfg_file) | ||||
| dataset = self._build_dataset_with_config(cfg) | |||||
| # dataset = self._build_dataset_with_config(cfg) | |||||
| preprocessor = { | preprocessor = { | ||||
| ConfigKeys.train: | ConfigKeys.train: | ||||
| OfaPreprocessor( | OfaPreprocessor( | ||||
| @@ -41,7 +42,7 @@ class OFATrainer(EpochBasedTrainer): | |||||
| # use torchrun launch | # use torchrun launch | ||||
| world_size = int(os.environ.get('WORLD_SIZE', 1)) | world_size = int(os.environ.get('WORLD_SIZE', 1)) | ||||
| epoch_steps = math.ceil( | epoch_steps = math.ceil( | ||||
| len(dataset['train']) / # noqa | |||||
| len(train_dataset) / # noqa | |||||
| (cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa | (cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa | ||||
| cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs | cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs | ||||
| cfg.train.criterion.tokenizer = model.tokenizer | cfg.train.criterion.tokenizer = model.tokenizer | ||||
| @@ -68,11 +69,11 @@ class OFATrainer(EpochBasedTrainer): | |||||
| cfg_file=cfg_file, | cfg_file=cfg_file, | ||||
| model=model, | model=model, | ||||
| data_collator=collator, | data_collator=collator, | ||||
| train_dataset=dataset['train'], | |||||
| eval_dataset=dataset['valid'], | |||||
| train_dataset=train_dataset, | |||||
| eval_dataset=eval_dataset, | |||||
| preprocessor=preprocessor, | preprocessor=preprocessor, | ||||
| optimizers=(optimizer, lr_scheduler), | optimizers=(optimizer, lr_scheduler), | ||||
| work_dir=cfg.train.work_dir, | |||||
| work_dir=work_dir, | |||||
| *args, | *args, | ||||
| **kwargs, | **kwargs, | ||||
| ) | ) | ||||
| @@ -3,22 +3,51 @@ import glob | |||||
| import os | import os | ||||
| import os.path as osp | import os.path as osp | ||||
| import shutil | import shutil | ||||
| import tempfile | |||||
| import unittest | import unittest | ||||
| from modelscope.metainfo import Trainers | from modelscope.metainfo import Trainers | ||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
| from modelscope.utils.constant import DownloadMode | |||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| class TestOfaTrainer(unittest.TestCase): | class TestOfaTrainer(unittest.TestCase): | ||||
| def setUp(self): | |||||
| column_map = {'premise': 'text', 'hypothesis': 'text2'} | |||||
| data_train = MsDataset.load( | |||||
| dataset_name='glue', | |||||
| subset_name='mnli', | |||||
| namespace='modelscope', | |||||
| split='train[:100]', | |||||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) | |||||
| self.train_dataset = MsDataset.from_hf_dataset( | |||||
| data_train._hf_ds.rename_columns(column_map)) | |||||
| data_eval = MsDataset.load( | |||||
| dataset_name='glue', | |||||
| subset_name='mnli', | |||||
| namespace='modelscope', | |||||
| split='validation_matched[:8]', | |||||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) | |||||
| self.test_dataset = MsDataset.from_hf_dataset( | |||||
| data_eval._hf_ds.rename_columns(column_map)) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_trainer(self): | def test_trainer(self): | ||||
| os.environ['LOCAL_RANK'] = '0' | os.environ['LOCAL_RANK'] = '0' | ||||
| model_id = 'damo/ofa_text-classification_mnli_large_en' | model_id = 'damo/ofa_text-classification_mnli_large_en' | ||||
| default_args = {'model': model_id} | |||||
| trainer = build_trainer( | |||||
| name=Trainers.ofa_tasks, default_args=default_args) | |||||
| kwargs = dict( | |||||
| model=model_id, | |||||
| cfg_file= | |||||
| '/Users/running_you/.cache/modelscope/hub/damo/ofa_text-classification_mnli_large_en//configuration.json', | |||||
| train_dataset=self.train_dataset, | |||||
| eval_dataset=self.test_dataset, | |||||
| work_dir='/Users/running_you/.cache/modelscope/hub/work/mnli') | |||||
| trainer = build_trainer(name=Trainers.ofa_tasks, default_args=kwargs) | |||||
| os.makedirs(trainer.work_dir, exist_ok=True) | os.makedirs(trainer.work_dir, exist_ok=True) | ||||
| trainer.train() | trainer.train() | ||||
| assert len( | assert len( | ||||