Browse Source

[to #42322933]support type str for for zero-shot labels' input

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10506320
master
zhangzhicheng.zzc yingda.chen 3 years ago
parent
commit
e1ab73b7d8
2 changed files with 15 additions and 1 deletions
  1. +10
    -1
      modelscope/pipelines/nlp/zero_shot_classification_pipeline.py
  2. +5
    -0
      tests/pipelines/test_zero_shot_classification.py

+ 10
- 1
modelscope/pipelines/nlp/zero_shot_classification_pipeline.py View File

@@ -74,7 +74,8 @@ class ZeroShotClassificationPipeline(Pipeline):
preprocess_params = {}
postprocess_params = {}
if 'candidate_labels' in kwargs:
candidate_labels = kwargs.pop('candidate_labels')
candidate_labels = self._parse_labels(
kwargs.pop('candidate_labels'))
preprocess_params['candidate_labels'] = candidate_labels
postprocess_params['candidate_labels'] = candidate_labels
else:
@@ -84,6 +85,14 @@ class ZeroShotClassificationPipeline(Pipeline):
postprocess_params['multi_label'] = kwargs.pop('multi_label', False)
return preprocess_params, {}, postprocess_params

def _parse_labels(self, labels):
if isinstance(labels, str):
labels = labels.replace(',', ',') # replace cn comma to en comma
labels = [
label.strip() for label in labels.split(',') if label.strip()
]
return labels

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
return self.model(**inputs, **forward_params)


+ 5
- 0
tests/pipelines/test_zero_shot_classification.py View File

@@ -21,6 +21,7 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck):

sentence = '全新突破 解放军运20版空中加油机曝光'
labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事']
labels_str = '文化, 体育, 娱乐, 财经, 家居, 汽车, 教育, 科技, 军事'
template = '这篇文章的标题是{}'
regress_tool = MsRegressTool(baseline=False)

@@ -40,6 +41,10 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck):
f'sentence: {self.sentence}\n'
f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}'
)
print(
f'sentence: {self.sentence}\n'
f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels_str,hypothesis_template=self.template)}'
)
print(
f'sentence: {self.sentence}\n'
f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}'


Loading…
Cancel
Save