Browse Source

update pipeline registry info

master
智丞 3 years ago
parent
commit
753b98f526
3 changed files with 6 additions and 3 deletions
  1. +3
    -2
      modelscope/models/nlp/nli_model.py
  2. +2
    -1
      modelscope/pipelines/builder.py
  3. +1
    -0
      tests/pipelines/test_nli.py

+ 3
- 2
modelscope/models/nlp/nli_model.py View File

@@ -16,7 +16,7 @@ from ..builder import MODELS
__all__ = ['SbertForNLI']


class TextClassifier(SbertPreTrainedModel):
class SbertTextClassifier(SbertPreTrainedModel):

def __init__(self, config):
super().__init__(config)
@@ -53,7 +53,8 @@ class SbertForNLI(Model):
super().__init__(model_dir, *args, **kwargs)
self.model_dir = model_dir

self.model = TextClassifier.from_pretrained(model_dir, num_labels=3)
self.model = SbertTextClassifier.from_pretrained(
model_dir, num_labels=3)
self.model.eval()

def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:


+ 2
- 1
modelscope/pipelines/builder.py View File

@@ -18,7 +18,8 @@ PIPELINES = Registry('pipelines')
DEFAULT_MODEL_FOR_PIPELINE = {
# TaskName: (pipeline_module_name, model_repo)
Tasks.image_matting: ('image-matting', 'damo/image-matting-person'),
Tasks.nli: ('nli', 'damo/nlp_structbert_nli_chinese-base'),
Tasks.nli: ('nlp_structbert_nli_chinese-base',
'damo/nlp_structbert_nli_chinese-base'),
Tasks.text_classification:
('bert-sentiment-analysis', 'damo/bert-base-sst2'),
Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'),


+ 1
- 0
tests/pipelines/test_nli.py View File

@@ -15,6 +15,7 @@ class NLITest(unittest.TestCase):
sentence1 = '四川商务职业学院和四川财经职业学院哪个好?'
sentence2 = '四川商务职业学院商务管理在哪个校区?'

@unittest.skip('skip temporarily to save test time')
def test_run_from_local(self):
cache_path = snapshot_download(self.model_id)
tokenizer = NLIPreprocessor(cache_path)


Loading…
Cancel
Save