diff --git a/modelscope/models/nlp/nli_model.py b/modelscope/models/nlp/nli_model.py index 05166bd0..91972a62 100644 --- a/modelscope/models/nlp/nli_model.py +++ b/modelscope/models/nlp/nli_model.py @@ -16,7 +16,7 @@ from ..builder import MODELS __all__ = ['SbertForNLI'] -class TextClassifier(SbertPreTrainedModel): +class SbertTextClassifier(SbertPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -53,7 +53,8 @@ class SbertForNLI(Model): super().__init__(model_dir, *args, **kwargs) self.model_dir = model_dir - self.model = TextClassifier.from_pretrained(model_dir, num_labels=3) + self.model = SbertTextClassifier.from_pretrained( + model_dir, num_labels=3) self.model.eval() def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 3c97c2be..8afbd041 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -18,7 +18,8 @@ PIPELINES = Registry('pipelines') DEFAULT_MODEL_FOR_PIPELINE = { # TaskName: (pipeline_module_name, model_repo) Tasks.image_matting: ('image-matting', 'damo/image-matting-person'), - Tasks.nli: ('nli', 'damo/nlp_structbert_nli_chinese-base'), + Tasks.nli: ('nlp_structbert_nli_chinese-base', + 'damo/nlp_structbert_nli_chinese-base'), Tasks.text_classification: ('bert-sentiment-analysis', 'damo/bert-base-sst2'), Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), diff --git a/tests/pipelines/test_nli.py b/tests/pipelines/test_nli.py index 9167b897..ad94697a 100644 --- a/tests/pipelines/test_nli.py +++ b/tests/pipelines/test_nli.py @@ -15,6 +15,7 @@ class NLITest(unittest.TestCase): sentence1 = '四川商务职业学院和四川财经职业学院哪个好?' sentence2 = '四川商务职业学院商务管理在哪个校区?' + @unittest.skip('skip temporarily to save test time') def test_run_from_local(self): cache_path = snapshot_download(self.model_id) tokenizer = NLIPreprocessor(cache_path)