| @@ -16,7 +16,7 @@ from ..builder import MODELS | |||||
| __all__ = ['SbertForNLI'] | __all__ = ['SbertForNLI'] | ||||
| class TextClassifier(SbertPreTrainedModel): | |||||
| class SbertTextClassifier(SbertPreTrainedModel): | |||||
| def __init__(self, config): | def __init__(self, config): | ||||
| super().__init__(config) | super().__init__(config) | ||||
| @@ -53,7 +53,8 @@ class SbertForNLI(Model): | |||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| self.model_dir = model_dir | self.model_dir = model_dir | ||||
| self.model = TextClassifier.from_pretrained(model_dir, num_labels=3) | |||||
| self.model = SbertTextClassifier.from_pretrained( | |||||
| model_dir, num_labels=3) | |||||
| self.model.eval() | self.model.eval() | ||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | ||||
| @@ -18,7 +18,8 @@ PIPELINES = Registry('pipelines') | |||||
| DEFAULT_MODEL_FOR_PIPELINE = { | DEFAULT_MODEL_FOR_PIPELINE = { | ||||
| # TaskName: (pipeline_module_name, model_repo) | # TaskName: (pipeline_module_name, model_repo) | ||||
| Tasks.image_matting: ('image-matting', 'damo/image-matting-person'), | Tasks.image_matting: ('image-matting', 'damo/image-matting-person'), | ||||
| Tasks.nli: ('nli', 'damo/nlp_structbert_nli_chinese-base'), | |||||
| Tasks.nli: ('nlp_structbert_nli_chinese-base', | |||||
| 'damo/nlp_structbert_nli_chinese-base'), | |||||
| Tasks.text_classification: | Tasks.text_classification: | ||||
| ('bert-sentiment-analysis', 'damo/bert-base-sst2'), | ('bert-sentiment-analysis', 'damo/bert-base-sst2'), | ||||
| Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), | Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), | ||||
| @@ -15,6 +15,7 @@ class NLITest(unittest.TestCase): | |||||
| sentence1 = '四川商务职业学院和四川财经职业学院哪个好?' | sentence1 = '四川商务职业学院和四川财经职业学院哪个好?' | ||||
| sentence2 = '四川商务职业学院商务管理在哪个校区?' | sentence2 = '四川商务职业学院商务管理在哪个校区?' | ||||
| @unittest.skip('skip temporarily to save test time') | |||||
| def test_run_from_local(self): | def test_run_from_local(self): | ||||
| cache_path = snapshot_download(self.model_id) | cache_path = snapshot_download(self.model_id) | ||||
| tokenizer = NLIPreprocessor(cache_path) | tokenizer = NLIPreprocessor(cache_path) | ||||