From c2da44b371d88f81e39fac4a0bbdfbd1573c6e21 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Tue, 25 Oct 2022 22:38:49 +0800 Subject: [PATCH] [to #42322933] remove dev model inference and fix some bugs 1. Change structbert dev revision to master revision 2. Fix bug: Sample code failed because the updating of model configuration 3. Fix bug: Continue training regression failed Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10519992 --- modelscope/models/builder.py | 10 +++++++-- modelscope/models/nlp/__init__.py | 2 ++ modelscope/models/nlp/structbert/adv_utils.py | 2 +- modelscope/trainers/hooks/checkpoint_hook.py | 15 +++++++------ modelscope/trainers/trainer.py | 21 ++++++++++++------- tests/pipelines/test_fill_mask.py | 7 ++----- .../test_sentiment_classification.py | 13 +++++------- tests/trainers/test_trainer_with_nlp.py | 15 ++++++------- 8 files changed, 45 insertions(+), 40 deletions(-) diff --git a/modelscope/models/builder.py b/modelscope/models/builder.py index a35358c1..2804c6c7 100644 --- a/modelscope/models/builder.py +++ b/modelscope/models/builder.py @@ -2,13 +2,19 @@ from modelscope.utils.config import ConfigDict from modelscope.utils.constant import Tasks +from modelscope.utils.import_utils import INDEX_KEY, LazyImportModule from modelscope.utils.registry import TYPE_NAME, Registry, build_from_cfg MODELS = Registry('models') -BACKBONES = Registry('backbones') -BACKBONES._modules = MODELS._modules +BACKBONES = MODELS HEADS = Registry('heads') +modules = LazyImportModule.AST_INDEX[INDEX_KEY] +for module_index in list(modules.keys()): + if module_index[1] == Tasks.backbone and module_index[0] == 'BACKBONES': + modules[(MODELS.name.upper(), module_index[1], + module_index[2])] = modules[module_index] + def build_model(cfg: ConfigDict, task_name: str = None, diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index dff42d1c..5ae93caa 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: SbertForSequenceClassification, SbertForTokenClassification, SbertTokenizer, + SbertModel, SbertTokenizerFast, ) from .bert import ( @@ -61,6 +62,7 @@ else: 'SbertForTokenClassification', 'SbertTokenizer', 'SbertTokenizerFast', + 'SbertModel', ], 'veco': [ 'VecoModel', 'VecoConfig', 'VecoForTokenClassification', diff --git a/modelscope/models/nlp/structbert/adv_utils.py b/modelscope/models/nlp/structbert/adv_utils.py index 44aae85c..91a4cb82 100644 --- a/modelscope/models/nlp/structbert/adv_utils.py +++ b/modelscope/models/nlp/structbert/adv_utils.py @@ -98,7 +98,7 @@ def compute_adv_loss(embedding, if is_nan: logger.warning('Nan occured when calculating adv loss.') return ori_loss - emb_grad = emb_grad / emb_grad_norm + emb_grad = emb_grad / (emb_grad_norm + 1e-6) embedding_2 = embedding_1 + adv_grad_factor * emb_grad embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2) embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2) diff --git a/modelscope/trainers/hooks/checkpoint_hook.py b/modelscope/trainers/hooks/checkpoint_hook.py index c9f51a88..9b86d5b5 100644 --- a/modelscope/trainers/hooks/checkpoint_hook.py +++ b/modelscope/trainers/hooks/checkpoint_hook.py @@ -69,7 +69,7 @@ class CheckpointHook(Hook): self.rng_state = meta.get('rng_state') self.need_load_rng_state = True - def before_train_epoch(self, trainer): + def before_train_iter(self, trainer): if self.need_load_rng_state: if self.rng_state is not None: random.setstate(self.rng_state['random']) @@ -84,13 +84,6 @@ class CheckpointHook(Hook): 'this may cause a random data order or model initialization.' ) - self.rng_state = { - 'random': random.getstate(), - 'numpy': np.random.get_state(), - 'cpu': torch.random.get_rng_state(), - 'cuda': torch.cuda.get_rng_state_all(), - } - def after_train_epoch(self, trainer): if not self.by_epoch: return @@ -142,6 +135,12 @@ class CheckpointHook(Hook): cur_save_name = os.path.join( self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') + self.rng_state = { + 'random': random.getstate(), + 'numpy': np.random.get_state(), + 'cpu': torch.random.get_rng_state(), + 'cuda': torch.cuda.get_rng_state_all(), + } meta = { 'epoch': trainer.epoch, 'iter': trainer.iter + 1, diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index f47bff10..605136e5 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -354,6 +354,9 @@ class EpochBasedTrainer(BaseTrainer): task_dataset.trainer = self return task_dataset else: + if task_data_config is None: + # adapt to some special models + task_data_config = {} # avoid add no str value datasets, preprocessors in cfg task_data_build_config = ConfigDict( type=self.cfg.model.type, @@ -419,13 +422,17 @@ class EpochBasedTrainer(BaseTrainer): return metrics def set_checkpoint_file_to_hook(self, checkpoint_path): - if checkpoint_path is not None and os.path.isfile(checkpoint_path): - from modelscope.trainers.hooks import CheckpointHook - checkpoint_hooks = list( - filter(lambda hook: isinstance(hook, CheckpointHook), - self.hooks)) - for hook in checkpoint_hooks: - hook.checkpoint_file = checkpoint_path + if checkpoint_path is not None: + if os.path.isfile(checkpoint_path): + from modelscope.trainers.hooks import CheckpointHook + checkpoint_hooks = list( + filter(lambda hook: isinstance(hook, CheckpointHook), + self.hooks)) + for hook in checkpoint_hooks: + hook.checkpoint_file = checkpoint_path + else: + self.logger.error( + f'No {checkpoint_path} found in local file system.') def train(self, checkpoint_path=None, *args, **kwargs): self._mode = ModeKeys.TRAIN diff --git a/tests/pipelines/test_fill_mask.py b/tests/pipelines/test_fill_mask.py index 568865c6..35202b88 100644 --- a/tests/pipelines/test_fill_mask.py +++ b/tests/pipelines/test_fill_mask.py @@ -83,7 +83,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck): # bert language = 'zh' - model_dir = snapshot_download(self.model_id_bert, revision='beta') + model_dir = snapshot_download(self.model_id_bert) preprocessor = NLPPreprocessor( model_dir, first_sequence='sentence', second_sequence=None) model = Model.from_pretrained(model_dir) @@ -149,10 +149,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck): # Bert language = 'zh' - pipeline_ins = pipeline( - task=Tasks.fill_mask, - model=self.model_id_bert, - model_revision='beta') + pipeline_ins = pipeline(task=Tasks.fill_mask, model=self.model_id_bert) print( f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' f'{pipeline_ins(self.test_inputs[language])}\n') diff --git a/tests/pipelines/test_sentiment_classification.py b/tests/pipelines/test_sentiment_classification.py index b3d9b9d6..5c8d4e93 100644 --- a/tests/pipelines/test_sentiment_classification.py +++ b/tests/pipelines/test_sentiment_classification.py @@ -24,10 +24,10 @@ class SentimentClassificationTaskModelTest(unittest.TestCase, @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_direct_file_download(self): - cache_path = snapshot_download(self.model_id, revision='beta') + cache_path = snapshot_download(self.model_id) tokenizer = SequenceClassificationPreprocessor(cache_path) model = SequenceClassificationModel.from_pretrained( - self.model_id, num_labels=2, revision='beta') + self.model_id, num_labels=2) pipeline1 = TextClassificationPipeline(model, preprocessor=tokenizer) pipeline2 = pipeline( Tasks.text_classification, model=model, preprocessor=tokenizer) @@ -38,7 +38,7 @@ class SentimentClassificationTaskModelTest(unittest.TestCase, @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_model_from_modelhub(self): - model = Model.from_pretrained(self.model_id, revision='beta') + model = Model.from_pretrained(self.model_id) tokenizer = SequenceClassificationPreprocessor(model.model_dir) pipeline_ins = pipeline( task=Tasks.text_classification, @@ -51,17 +51,14 @@ class SentimentClassificationTaskModelTest(unittest.TestCase, @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_name(self): pipeline_ins = pipeline( - task=Tasks.text_classification, - model=self.model_id, - model_revision='beta') + task=Tasks.text_classification, model=self.model_id) print(pipeline_ins(input=self.sentence1)) self.assertTrue( isinstance(pipeline_ins.model, SequenceClassificationModel)) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_default_model(self): - pipeline_ins = pipeline( - task=Tasks.text_classification, model_revision='beta') + pipeline_ins = pipeline(task=Tasks.text_classification) print(pipeline_ins(input=self.sentence1)) self.assertTrue( isinstance(pipeline_ins.model, SequenceClassificationModel)) diff --git a/tests/trainers/test_trainer_with_nlp.py b/tests/trainers/test_trainer_with_nlp.py index 9380ad0f..8aaa42a3 100644 --- a/tests/trainers/test_trainer_with_nlp.py +++ b/tests/trainers/test_trainer_with_nlp.py @@ -37,13 +37,12 @@ class TestTrainerWithNlp(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer(self): - model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' kwargs = dict( model=model_id, train_dataset=self.dataset, eval_dataset=self.dataset, - work_dir=self.tmp_dir, - model_revision='beta') + work_dir=self.tmp_dir) trainer = build_trainer(default_args=kwargs) trainer.train() @@ -80,8 +79,7 @@ class TestTrainerWithNlp(unittest.TestCase): model=model_id, train_dataset=self.dataset, eval_dataset=self.dataset, - work_dir=self.tmp_dir, - model_revision='beta') + work_dir=self.tmp_dir) trainer = build_trainer(default_args=kwargs) trainer.train() @@ -97,7 +95,7 @@ class TestTrainerWithNlp(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_trainer_with_user_defined_config(self): model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' - cfg = read_config(model_id, revision='beta') + cfg = read_config(model_id) cfg.train.max_epochs = 20 cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} @@ -108,8 +106,7 @@ class TestTrainerWithNlp(unittest.TestCase): model=model_id, train_dataset=self.dataset, eval_dataset=self.dataset, - cfg_file=cfg_file, - model_revision='beta') + cfg_file=cfg_file) trainer = build_trainer(default_args=kwargs) trainer.train() @@ -233,7 +230,7 @@ class TestTrainerWithNlp(unittest.TestCase): os.makedirs(tmp_dir) model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' - cache_path = snapshot_download(model_id, revision='beta') + cache_path = snapshot_download(model_id) model = SbertForSequenceClassification.from_pretrained(cache_path) kwargs = dict( cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),