diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index d011dd4a..fa6f8a99 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -75,6 +75,7 @@ class EpochBasedTrainer(BaseTrainer): this preprocessing action will be executed every time the dataset's __getitem__ is called. optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]`, *optional*): A tuple containing the optimizer and the scheduler to use. + seed (int): The optional random seed for torch, cuda, numpy and random. max_epochs: (int, optional): Total training epochs. """ @@ -93,8 +94,11 @@ class EpochBasedTrainer(BaseTrainer): torch.optim.lr_scheduler._LRScheduler] = (None, None), model_revision: Optional[str] = DEFAULT_MODEL_REVISION, + seed: int = 42, **kwargs): + self._seed = seed + set_random_seed(self._seed) if isinstance(model, str): if os.path.exists(model): self.model_dir = model if os.path.isdir( @@ -213,9 +217,6 @@ class EpochBasedTrainer(BaseTrainer): self.use_fp16 = kwargs.get('use_fp16', False) - # TODO @wenmeng.zwm add seed init fn - self._seed = 0 - if kwargs.get('launcher', None) is not None: init_dist(kwargs['launcher']) diff --git a/modelscope/utils/regress_test_utils.py b/modelscope/utils/regress_test_utils.py index ca50d579..82267447 100644 --- a/modelscope/utils/regress_test_utils.py +++ b/modelscope/utils/regress_test_utils.py @@ -133,6 +133,7 @@ class RegressTool: compare_fn=None, ignore_keys=None, compare_random=True, + reset_dropout=True, lazy_stop_callback=None): """Monitor a pytorch module's backward data and cfg data within a step of the optimizer. @@ -151,6 +152,7 @@ class RegressTool: @param compare_fn: A custom fn used to compare the results manually. @param ignore_keys: The keys to ignore of the named_parameters. @param compare_random: If to compare random setttings, default True. + @param reset_dropout: Reset all dropout modules to 0.0. @param lazy_stop_callback: A callback passed in, when the moniting is over, this callback will be called. >>> def compare_fn(v1, v2, key, type): @@ -202,6 +204,18 @@ class RegressTool: trainer, '_seed') else trainer.seed if hasattr(trainer, 'seed') else None + if reset_dropout: + with torch.no_grad(): + + def reinit_dropout(_module): + for name, submodule in _module.named_children(): + if isinstance(submodule, torch.nn.Dropout): + setattr(_module, name, torch.nn.Dropout(0.)) + else: + reinit_dropout(submodule) + + reinit_dropout(module) + if level == 'strict': hack_forward(module, file_name, io_json) intercept_module(module, io_json) diff --git a/modelscope/utils/torch_utils.py b/modelscope/utils/torch_utils.py index eaa285a2..6d4132f6 100644 --- a/modelscope/utils/torch_utils.py +++ b/modelscope/utils/torch_utils.py @@ -186,6 +186,7 @@ def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) else: raise ValueError( f'Random seed should be positive, current seed is {seed}') diff --git a/tests/trainers/data/test/regression/sbert-base-tnews.bin b/tests/trainers/data/test/regression/sbert-base-tnews.bin new file mode 100644 index 00000000..3a06d49c --- /dev/null +++ b/tests/trainers/data/test/regression/sbert-base-tnews.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2df2a5f3cdfc6dded52d31a8e97d9a9c41a803cb6d46dee709c51872eda37b21 +size 151830 diff --git a/tests/trainers/test_finetune_sequence_classification.py b/tests/trainers/test_finetune_sequence_classification.py index 24f1a2fd..f2adfa22 100644 --- a/tests/trainers/test_finetune_sequence_classification.py +++ b/tests/trainers/test_finetune_sequence_classification.py @@ -10,11 +10,14 @@ from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline from modelscope.trainers import build_trainer from modelscope.trainers.hooks import Hook -from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer +from modelscope.trainers.nlp_trainer import (EpochBasedTrainer, + NlpEpochBasedTrainer) from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \ calculate_fisher from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.data_utils import to_device +from modelscope.utils.regress_test_utils import MsRegressTool +from modelscope.utils.test_utils import test_level class TestFinetuneSequenceClassification(unittest.TestCase): @@ -28,11 +31,76 @@ class TestFinetuneSequenceClassification(unittest.TestCase): self.tmp_dir = tempfile.TemporaryDirectory().name if not os.path.exists(self.tmp_dir): os.makedirs(self.tmp_dir) + self.regress_tool = MsRegressTool(baseline=False) def tearDown(self): shutil.rmtree(self.tmp_dir) super().tearDown() + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer_repeatable(self): + import torch # noqa + + def cfg_modify_fn(cfg): + cfg.task = 'nli' + cfg['preprocessor'] = {'type': 'nli-tokenizer'} + cfg.train.optimizer.lr = 2e-5 + cfg['dataset'] = { + 'train': { + 'labels': [ + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', + '11', '12', '13', '14' + ], + 'first_sequence': + 'sentence', + 'label': + 'label', + } + } + cfg.train.max_epochs = 5 + cfg.train.lr_scheduler = { + 'type': 'LinearLR', + 'start_factor': 1.0, + 'end_factor': 0.0, + 'total_iters': + int(len(dataset['train']) / 32) * cfg.train.max_epochs, + 'options': { + 'by_epoch': False + } + } + cfg.train.hooks = [{ + 'type': 'CheckpointHook', + 'interval': 1 + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }, { + 'type': 'EvaluationHook', + 'by_epoch': False, + 'interval': 100 + }] + return cfg + + dataset = MsDataset.load('clue', subset_name='tnews') + + kwargs = dict( + model='damo/nlp_structbert_backbone_base_std', + train_dataset=dataset['train'], + eval_dataset=dataset['validation'], + work_dir=self.tmp_dir, + seed=42, + cfg_modify_fn=cfg_modify_fn) + + os.environ['LOCAL_RANK'] = '0' + trainer: EpochBasedTrainer = build_trainer( + name=Trainers.nlp_base_trainer, default_args=kwargs) + + with self.regress_tool.monitor_ms_train( + trainer, 'sbert-base-tnews', level='strict'): + trainer.train() + def finetune(self, model_id, train_dataset, @@ -54,7 +122,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): results_files = os.listdir(self.tmp_dir) self.assertIn(f'{trainer.timestamp}.log.json', results_files) for i in range(self.epoch_num): - self.assertIn(f'epoch_{i+1}.pth', results_files) + self.assertIn(f'epoch_{i + 1}.pth', results_files) output_files = os.listdir( os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)) @@ -118,11 +186,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): }] return cfg - from datasets import load_dataset - from datasets import DownloadConfig - dc = DownloadConfig() - dc.local_files_only = True - dataset = load_dataset('clue', 'afqmc', download_config=dc) + dataset = MsDataset.load('clue', subset_name='afqmc') self.finetune( model_id='damo/nlp_structbert_backbone_base_std', train_dataset=dataset['train'], @@ -182,11 +246,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): }] return cfg - from datasets import load_dataset - from datasets import DownloadConfig - dc = DownloadConfig() - dc.local_files_only = True - dataset = load_dataset('clue', 'tnews', download_config=dc) + dataset = MsDataset.load('clue', subset_name='tnews') self.finetune( model_id='damo/nlp_structbert_backbone_base_std', diff --git a/tests/trainers/test_finetune_text_generation.py b/tests/trainers/test_finetune_text_generation.py index a561effe..6aefa969 100644 --- a/tests/trainers/test_finetune_text_generation.py +++ b/tests/trainers/test_finetune_text_generation.py @@ -129,7 +129,7 @@ class TestFinetuneTextGeneration(unittest.TestCase): @unittest.skip def test_finetune_cnndm(self): from modelscope.msdatasets import MsDataset - dataset_dict = MsDataset.load('dureader_robust_qg') + dataset_dict = MsDataset.load('DuReader_robust-QG') train_dataset = dataset_dict['train'].to_hf_dataset() \ .rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) eval_dataset = dataset_dict['validation'].to_hf_dataset() \