修复了 mplug evaluation 使用了错误的 metrics 的问题,将部分中文处理代码独立到 utils 中,为 mplug 添加 trainer
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10612875
master
| @@ -299,6 +299,7 @@ class Trainers(object): | |||||
| # multi-modal trainers | # multi-modal trainers | ||||
| clip_multi_modal_embedding = 'clip-multi-modal-embedding' | clip_multi_modal_embedding = 'clip-multi-modal-embedding' | ||||
| ofa = 'ofa' | ofa = 'ofa' | ||||
| mplug = 'mplug' | |||||
| # cv trainers | # cv trainers | ||||
| image_instance_segmentation = 'image-instance-segmentation' | image_instance_segmentation = 'image-instance-segmentation' | ||||
| @@ -6,6 +6,7 @@ import numpy as np | |||||
| from modelscope.metainfo import Metrics | from modelscope.metainfo import Metrics | ||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.utils.chinese_utils import remove_space_between_chinese_chars | |||||
| from modelscope.utils.registry import default_group | from modelscope.utils.registry import default_group | ||||
| from .base import Metric | from .base import Metric | ||||
| from .builder import METRICS, MetricKeys | from .builder import METRICS, MetricKeys | ||||
| @@ -26,10 +27,10 @@ class AccuracyMetric(Metric): | |||||
| def add(self, outputs: Dict, inputs: Dict): | def add(self, outputs: Dict, inputs: Dict): | ||||
| label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS | label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS | ||||
| ground_truths = inputs[label_name] | ground_truths = inputs[label_name] | ||||
| eval_results = outputs[label_name] | |||||
| eval_results = None | |||||
| for key in [ | for key in [ | ||||
| OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, | OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, | ||||
| OutputKeys.LABELS, OutputKeys.SCORES | |||||
| OutputKeys.LABEL, OutputKeys.LABELS, OutputKeys.SCORES | |||||
| ]: | ]: | ||||
| if key in outputs and outputs[key] is not None: | if key in outputs and outputs[key] is not None: | ||||
| eval_results = outputs[key] | eval_results = outputs[key] | ||||
| @@ -39,7 +40,7 @@ class AccuracyMetric(Metric): | |||||
| self.labels.append(truth) | self.labels.append(truth) | ||||
| for result in eval_results: | for result in eval_results: | ||||
| if isinstance(truth, str): | if isinstance(truth, str): | ||||
| self.preds.append(result.strip().replace(' ', '')) | |||||
| self.preds.append(remove_space_between_chinese_chars(result)) | |||||
| else: | else: | ||||
| self.preds.append(result) | self.preds.append(result) | ||||
| @@ -41,8 +41,8 @@ task_default_metrics = { | |||||
| Tasks.image_portrait_enhancement: | Tasks.image_portrait_enhancement: | ||||
| [Metrics.image_portrait_enhancement_metric], | [Metrics.image_portrait_enhancement_metric], | ||||
| Tasks.video_summarization: [Metrics.video_summarization_metric], | Tasks.video_summarization: [Metrics.video_summarization_metric], | ||||
| Tasks.image_captioning: [Metrics.text_gen_metric], | |||||
| Tasks.visual_question_answering: [Metrics.text_gen_metric], | |||||
| Tasks.image_captioning: [Metrics.accuracy], | |||||
| Tasks.visual_question_answering: [Metrics.accuracy], | |||||
| Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], | Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], | ||||
| Tasks.image_inpainting: [Metrics.image_inpainting_metric], | Tasks.image_inpainting: [Metrics.image_inpainting_metric], | ||||
| Tasks.referring_video_object_segmentation: | Tasks.referring_video_object_segmentation: | ||||
| @@ -8,6 +8,7 @@ from rouge import Rouge | |||||
| from modelscope.metainfo import Metrics | from modelscope.metainfo import Metrics | ||||
| from modelscope.metrics.base import Metric | from modelscope.metrics.base import Metric | ||||
| from modelscope.metrics.builder import METRICS, MetricKeys | from modelscope.metrics.builder import METRICS, MetricKeys | ||||
| from modelscope.utils.chinese_utils import rebuild_chinese_str | |||||
| from modelscope.utils.registry import default_group | from modelscope.utils.registry import default_group | ||||
| @@ -24,25 +25,13 @@ class TextGenerationMetric(Metric): | |||||
| self.tgts: List[str] = [] | self.tgts: List[str] = [] | ||||
| self.rouge = Rouge() | self.rouge = Rouge() | ||||
| @staticmethod | |||||
| def is_chinese_char(char: str): | |||||
| # the length of char must be 1 | |||||
| return '\u4e00' <= char <= '\u9fa5' | |||||
| # add space for each chinese char | |||||
| def rebuild_str(self, string: str): | |||||
| return ' '.join(''.join([ | |||||
| f' {char} ' if self.is_chinese_char(char) else char | |||||
| for char in string | |||||
| ]).split()) | |||||
| def add(self, outputs: Dict[str, List[str]], inputs: Dict[str, List[str]]): | def add(self, outputs: Dict[str, List[str]], inputs: Dict[str, List[str]]): | ||||
| ground_truths = inputs['tgts'] | ground_truths = inputs['tgts'] | ||||
| eval_results = outputs['preds'] | eval_results = outputs['preds'] | ||||
| for truth in ground_truths: | for truth in ground_truths: | ||||
| self.tgts.append(self.rebuild_str(truth)) | |||||
| self.tgts.append(rebuild_chinese_str(truth)) | |||||
| for result in eval_results: | for result in eval_results: | ||||
| self.preds.append(self.rebuild_str(result)) | |||||
| self.preds.append(rebuild_chinese_str(result)) | |||||
| def _check(self, pred: str, tgt: str) -> bool: | def _check(self, pred: str, tgt: str) -> bool: | ||||
| @@ -45,10 +45,6 @@ class MPlugForAllTasks(TorchModel): | |||||
| } | } | ||||
| """ | """ | ||||
| replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), | |||||
| ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), | |||||
| ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) | |||||
| # get task from config file | # get task from config file | ||||
| task = Config.from_file( | task = Config.from_file( | ||||
| osp.join(self.model_dir, ModelFile.CONFIGURATION)).task | osp.join(self.model_dir, ModelFile.CONFIGURATION)).task | ||||
| @@ -60,10 +56,7 @@ class MPlugForAllTasks(TorchModel): | |||||
| return {OutputKeys.SCORES: output[0].tolist()} | return {OutputKeys.SCORES: output[0].tolist()} | ||||
| topk_ids, _ = output | topk_ids, _ = output | ||||
| pred_string: List[str] = \ | pred_string: List[str] = \ | ||||
| self.tokenizer.decode(topk_ids[0][0]) | |||||
| for _old, _new in replace_tokens_bert: | |||||
| pred_string = pred_string.replace(_old, _new) | |||||
| pred_string = pred_string.strip() | |||||
| self.tokenizer.decode(topk_ids[0][0], skip_special_tokens=True) | |||||
| output_key = OutputKeys.CAPTION \ | output_key = OutputKeys.CAPTION \ | ||||
| if task == Tasks.image_captioning else OutputKeys.TEXT | if task == Tasks.image_captioning else OutputKeys.TEXT | ||||
| return {output_key: pred_string} | return {output_key: pred_string} | ||||
| @@ -87,19 +80,4 @@ class MPlugForAllTasks(TorchModel): | |||||
| # evaluate | # evaluate | ||||
| topk_ids, _ = output | topk_ids, _ = output | ||||
| preds: List[str] = [ | |||||
| self.tokenizer.decode(batch[0]) for batch in topk_ids | |||||
| ] | |||||
| for i in range(len(preds)): | |||||
| for _old, _new in replace_tokens_bert: | |||||
| preds[i] = preds[i].replace(_old, _new) | |||||
| preds[i] = preds[i].strip() | |||||
| tgts: List[str] = [ | |||||
| self.tokenizer.decode(batch) | |||||
| for batch in input['answer_input_ids'].cpu().numpy().tolist() | |||||
| ] | |||||
| for i in range(len(tgts)): | |||||
| for _old, _new in replace_tokens_bert: | |||||
| tgts[i] = tgts[i].replace(_old, _new) | |||||
| preds[i] = preds[i].strip() | |||||
| return {'preds': preds, 'tgts': tgts} | |||||
| return {'sequences': [list_tensor[0] for list_tensor in topk_ids]} | |||||
| @@ -2,7 +2,7 @@ | |||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| import numpy as np | import numpy as np | ||||
| from transformers.modeling_utils import PreTrainedModel | |||||
| from transformers.modeling_utils import GenerationMixin | |||||
| from modelscope.metainfo import TaskModels | from modelscope.metainfo import TaskModels | ||||
| from modelscope.models.builder import MODELS | from modelscope.models.builder import MODELS | ||||
| @@ -17,7 +17,8 @@ __all__ = ['TaskModelForTextGeneration'] | |||||
| @MODELS.register_module( | @MODELS.register_module( | ||||
| Tasks.text_generation, module_name=TaskModels.text_generation) | Tasks.text_generation, module_name=TaskModels.text_generation) | ||||
| class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): | |||||
| class TaskModelForTextGeneration(SingleBackboneTaskModelBase, GenerationMixin): | |||||
| main_input_name = 'input_ids' | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| """initialize the text generation model from the `model_dir` path. | """initialize the text generation model from the `model_dir` path. | ||||
| @@ -10,6 +10,7 @@ from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Pipeline, Tensor | from modelscope.pipelines.base import Pipeline, Tensor | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import Preprocessor, build_preprocessor | from modelscope.preprocessors import Preprocessor, build_preprocessor | ||||
| from modelscope.utils.chinese_utils import remove_space_between_chinese_chars | |||||
| from modelscope.utils.constant import Fields, Tasks | from modelscope.utils.constant import Fields, Tasks | ||||
| from modelscope.utils.hub import read_config | from modelscope.utils.hub import read_config | ||||
| @@ -78,28 +79,6 @@ class TextGenerationPipeline(Pipeline): | |||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| return self.model.generate(inputs, **forward_params) | return self.model.generate(inputs, **forward_params) | ||||
| def _is_chinese_char(self, word: str): | |||||
| chinese_punctuations = (',', '。', ';', ':' '!', '?', '《', '》') | |||||
| return len(word) == 1 \ | |||||
| and ('\u4e00' <= word <= '\u9fa5' or word in chinese_punctuations) | |||||
| def _remove_space_between_chinese_chars(self, decoded: str): | |||||
| old_word_list = decoded.split(' ') | |||||
| new_word_list = [] | |||||
| start = -1 | |||||
| for i, word in enumerate(old_word_list): | |||||
| if self._is_chinese_char(word): | |||||
| if start == -1: | |||||
| start = i | |||||
| else: | |||||
| if start != -1: | |||||
| new_word_list.append(''.join(old_word_list[start:i])) | |||||
| start = -1 | |||||
| new_word_list.append(word) | |||||
| if start != -1: | |||||
| new_word_list.append(''.join(old_word_list[start:])) | |||||
| return ' '.join(new_word_list) | |||||
| def decode(self, inputs) -> str: | def decode(self, inputs) -> str: | ||||
| tokenizer = self.preprocessor.tokenizer | tokenizer = self.preprocessor.tokenizer | ||||
| return tokenizer.decode(inputs.tolist(), skip_special_tokens=True) | return tokenizer.decode(inputs.tolist(), skip_special_tokens=True) | ||||
| @@ -128,5 +107,5 @@ class TextGenerationPipeline(Pipeline): | |||||
| if isinstance(inputs, list) or len(inputs.shape) > 1: | if isinstance(inputs, list) or len(inputs.shape) > 1: | ||||
| inputs = inputs[0] | inputs = inputs[0] | ||||
| decoded = getattr(self, self.postprocessor)(inputs) | decoded = getattr(self, self.postprocessor)(inputs) | ||||
| text = self._remove_space_between_chinese_chars(decoded) | |||||
| text = remove_space_between_chinese_chars(decoded) | |||||
| return {OutputKeys.TEXT: text} | return {OutputKeys.TEXT: text} | ||||
| @@ -6,11 +6,15 @@ from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from .clip import CLIPTrainer | from .clip import CLIPTrainer | ||||
| from .team import TEAMImgClsTrainer | from .team import TEAMImgClsTrainer | ||||
| from .ofa import OFATrainer | |||||
| from .mplug import MPlugTrainer | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'clip': ['CLIPTrainer'], | 'clip': ['CLIPTrainer'], | ||||
| 'team': ['TEAMImgClsTrainer'] | |||||
| 'team': ['TEAMImgClsTrainer'], | |||||
| 'ofa': ['OFATrainer'], | |||||
| 'mplug': ['MPlugTrainer'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,3 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from .mplug_trainer import MPlugTrainer | |||||
| @@ -0,0 +1,40 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from collections.abc import Mapping | |||||
| import torch | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.trainers import NlpEpochBasedTrainer | |||||
| from modelscope.trainers.builder import TRAINERS | |||||
| from modelscope.utils.file_utils import func_receive_dict_inputs | |||||
| @TRAINERS.register_module(module_name=Trainers.mplug) | |||||
| class MPlugTrainer(NlpEpochBasedTrainer): | |||||
| def _decode(self, tokens): | |||||
| tokenizer = self.eval_preprocessor.tokenizer | |||||
| return tokenizer.decode(tokens, skip_special_tokens=True) | |||||
| def evaluation_step(self, data): | |||||
| model = self.model.module if self._dist else self.model | |||||
| model.eval() | |||||
| with torch.no_grad(): | |||||
| if isinstance( | |||||
| data, | |||||
| Mapping) and not func_receive_dict_inputs(model.forward): | |||||
| result = model.forward(**data) | |||||
| else: | |||||
| result = model.forward(data) | |||||
| result[OutputKeys.TEXT] = [ | |||||
| self._decode(seq) for seq in result['sequences'] | |||||
| ] | |||||
| data[OutputKeys.LABELS] = [ | |||||
| self._decode(seq) for seq in data['answer_input_ids'] | |||||
| ] | |||||
| return result | |||||
| @@ -0,0 +1,35 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| def is_chinese_char(word: str): | |||||
| chinese_punctuations = { | |||||
| ',', '。', ';', ':' | |||||
| '!', '?', '《', '》', '‘', '’', '“', '”', '(', ')', '【', '】' | |||||
| } | |||||
| return len(word) == 1 \ | |||||
| and ('\u4e00' <= word <= '\u9fa5' or word in chinese_punctuations) | |||||
| def remove_space_between_chinese_chars(decoded_str: str): | |||||
| old_word_list = decoded_str.split(' ') | |||||
| new_word_list = [] | |||||
| start = -1 | |||||
| for i, word in enumerate(old_word_list): | |||||
| if is_chinese_char(word): | |||||
| if start == -1: | |||||
| start = i | |||||
| else: | |||||
| if start != -1: | |||||
| new_word_list.append(''.join(old_word_list[start:i])) | |||||
| start = -1 | |||||
| new_word_list.append(word) | |||||
| if start != -1: | |||||
| new_word_list.append(''.join(old_word_list[start:])) | |||||
| return ' '.join(new_word_list).strip() | |||||
| # add space for each chinese char | |||||
| def rebuild_chinese_str(string: str): | |||||
| return ' '.join(''.join([ | |||||
| f' {char} ' if is_chinese_char(char) else char for char in string | |||||
| ]).split()) | |||||
| @@ -20,10 +20,7 @@ class TestFinetuneMPlug(unittest.TestCase): | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | self.tmp_dir = tempfile.TemporaryDirectory().name | ||||
| if not os.path.exists(self.tmp_dir): | if not os.path.exists(self.tmp_dir): | ||||
| os.makedirs(self.tmp_dir) | os.makedirs(self.tmp_dir) | ||||
| from modelscope.utils.constant import DownloadMode | |||||
| datadict = MsDataset.load( | |||||
| 'coco_captions_small_slice', | |||||
| download_mode=DownloadMode.FORCE_REDOWNLOAD) | |||||
| datadict = MsDataset.load('coco_captions_small_slice') | |||||
| self.train_dataset = MsDataset( | self.train_dataset = MsDataset( | ||||
| datadict['train'].remap_columns({ | datadict['train'].remap_columns({ | ||||
| 'image:FILE': 'image', | 'image:FILE': 'image', | ||||
| @@ -40,18 +37,6 @@ class TestFinetuneMPlug(unittest.TestCase): | |||||
| shutil.rmtree(self.tmp_dir) | shutil.rmtree(self.tmp_dir) | ||||
| super().tearDown() | super().tearDown() | ||||
| def _cfg_modify_fn(self, cfg): | |||||
| cfg.train.hooks = [{ | |||||
| 'type': 'CheckpointHook', | |||||
| 'interval': self.max_epochs | |||||
| }, { | |||||
| 'type': 'TextLoggerHook', | |||||
| 'interval': 1 | |||||
| }, { | |||||
| 'type': 'IterTimerHook' | |||||
| }] | |||||
| return cfg | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_trainer_with_caption(self): | def test_trainer_with_caption(self): | ||||
| kwargs = dict( | kwargs = dict( | ||||
| @@ -59,11 +44,10 @@ class TestFinetuneMPlug(unittest.TestCase): | |||||
| train_dataset=self.train_dataset, | train_dataset=self.train_dataset, | ||||
| eval_dataset=self.test_dataset, | eval_dataset=self.test_dataset, | ||||
| max_epochs=self.max_epochs, | max_epochs=self.max_epochs, | ||||
| work_dir=self.tmp_dir, | |||||
| cfg_modify_fn=self._cfg_modify_fn) | |||||
| work_dir=self.tmp_dir) | |||||
| trainer: EpochBasedTrainer = build_trainer( | trainer: EpochBasedTrainer = build_trainer( | ||||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||||
| name=Trainers.mplug, default_args=kwargs) | |||||
| trainer.train() | trainer.train() | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| @@ -80,7 +64,7 @@ class TestFinetuneMPlug(unittest.TestCase): | |||||
| work_dir=self.tmp_dir) | work_dir=self.tmp_dir) | ||||
| trainer: EpochBasedTrainer = build_trainer( | trainer: EpochBasedTrainer = build_trainer( | ||||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||||
| name=Trainers.mplug, default_args=kwargs) | |||||
| trainer.train() | trainer.train() | ||||
| results_files = os.listdir(self.tmp_dir) | results_files = os.listdir(self.tmp_dir) | ||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | self.assertIn(f'{trainer.timestamp}.log.json', results_files) | ||||
| @@ -94,11 +78,10 @@ class TestFinetuneMPlug(unittest.TestCase): | |||||
| train_dataset=self.train_dataset, | train_dataset=self.train_dataset, | ||||
| eval_dataset=self.test_dataset, | eval_dataset=self.test_dataset, | ||||
| max_epochs=self.max_epochs, | max_epochs=self.max_epochs, | ||||
| work_dir=self.tmp_dir, | |||||
| cfg_modify_fn=self._cfg_modify_fn) | |||||
| work_dir=self.tmp_dir) | |||||
| trainer: EpochBasedTrainer = build_trainer( | trainer: EpochBasedTrainer = build_trainer( | ||||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||||
| name=Trainers.mplug, default_args=kwargs) | |||||
| trainer.train() | trainer.train() | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| @@ -115,7 +98,7 @@ class TestFinetuneMPlug(unittest.TestCase): | |||||
| work_dir=self.tmp_dir) | work_dir=self.tmp_dir) | ||||
| trainer: EpochBasedTrainer = build_trainer( | trainer: EpochBasedTrainer = build_trainer( | ||||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||||
| name=Trainers.mplug, default_args=kwargs) | |||||
| trainer.train() | trainer.train() | ||||
| results_files = os.listdir(self.tmp_dir) | results_files = os.listdir(self.tmp_dir) | ||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | self.assertIn(f'{trainer.timestamp}.log.json', results_files) | ||||
| @@ -129,11 +112,10 @@ class TestFinetuneMPlug(unittest.TestCase): | |||||
| train_dataset=self.train_dataset, | train_dataset=self.train_dataset, | ||||
| eval_dataset=self.test_dataset, | eval_dataset=self.test_dataset, | ||||
| max_epochs=self.max_epochs, | max_epochs=self.max_epochs, | ||||
| work_dir=self.tmp_dir, | |||||
| cfg_modify_fn=self._cfg_modify_fn) | |||||
| work_dir=self.tmp_dir) | |||||
| trainer: EpochBasedTrainer = build_trainer( | trainer: EpochBasedTrainer = build_trainer( | ||||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||||
| name=Trainers.mplug, default_args=kwargs) | |||||
| trainer.train() | trainer.train() | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| @@ -150,7 +132,7 @@ class TestFinetuneMPlug(unittest.TestCase): | |||||
| work_dir=self.tmp_dir) | work_dir=self.tmp_dir) | ||||
| trainer: EpochBasedTrainer = build_trainer( | trainer: EpochBasedTrainer = build_trainer( | ||||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||||
| name=Trainers.mplug, default_args=kwargs) | |||||
| trainer.train() | trainer.train() | ||||
| results_files = os.listdir(self.tmp_dir) | results_files = os.listdir(self.tmp_dir) | ||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | self.assertIn(f'{trainer.timestamp}.log.json', results_files) | ||||
| @@ -41,7 +41,7 @@ class AstScaningTest(unittest.TestCase): | |||||
| self.assertIsInstance(from_imports, dict) | self.assertIsInstance(from_imports, dict) | ||||
| self.assertIsInstance(decorators, list) | self.assertIsInstance(decorators, list) | ||||
| self.assertListEqual(list(set(imports.keys()) - set(['torch'])), []) | self.assertListEqual(list(set(imports.keys()) - set(['torch'])), []) | ||||
| self.assertEqual(len(from_imports.keys()), 9) | |||||
| self.assertEqual(len(from_imports.keys()), 10) | |||||
| self.assertTrue(from_imports['modelscope.metainfo'] is not None) | self.assertTrue(from_imports['modelscope.metainfo'] is not None) | ||||
| self.assertEqual(from_imports['modelscope.metainfo'], ['Pipelines']) | self.assertEqual(from_imports['modelscope.metainfo'], ['Pipelines']) | ||||
| self.assertEqual(decorators, | self.assertEqual(decorators, | ||||