Browse Source

[to #42322933]support type str for for zero-shot labels' input

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10506320
master
zhangzhicheng.zzc yingda.chen 3 years ago
parent
commit
e1ab73b7d8
2 changed files with 15 additions and 1 deletions
  1. +10
    -1
      modelscope/pipelines/nlp/zero_shot_classification_pipeline.py
  2. +5
    -0
      tests/pipelines/test_zero_shot_classification.py

+ 10
- 1
modelscope/pipelines/nlp/zero_shot_classification_pipeline.py View File

@@ -74,7 +74,8 @@ class ZeroShotClassificationPipeline(Pipeline):
preprocess_params = {} preprocess_params = {}
postprocess_params = {} postprocess_params = {}
if 'candidate_labels' in kwargs: if 'candidate_labels' in kwargs:
candidate_labels = kwargs.pop('candidate_labels')
candidate_labels = self._parse_labels(
kwargs.pop('candidate_labels'))
preprocess_params['candidate_labels'] = candidate_labels preprocess_params['candidate_labels'] = candidate_labels
postprocess_params['candidate_labels'] = candidate_labels postprocess_params['candidate_labels'] = candidate_labels
else: else:
@@ -84,6 +85,14 @@ class ZeroShotClassificationPipeline(Pipeline):
postprocess_params['multi_label'] = kwargs.pop('multi_label', False) postprocess_params['multi_label'] = kwargs.pop('multi_label', False)
return preprocess_params, {}, postprocess_params return preprocess_params, {}, postprocess_params


def _parse_labels(self, labels):
if isinstance(labels, str):
labels = labels.replace(',', ',') # replace cn comma to en comma
labels = [
label.strip() for label in labels.split(',') if label.strip()
]
return labels

def forward(self, inputs: Dict[str, Any], def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]: **forward_params) -> Dict[str, Any]:
return self.model(**inputs, **forward_params) return self.model(**inputs, **forward_params)


+ 5
- 0
tests/pipelines/test_zero_shot_classification.py View File

@@ -21,6 +21,7 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck):


sentence = '全新突破 解放军运20版空中加油机曝光' sentence = '全新突破 解放军运20版空中加油机曝光'
labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事'] labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事']
labels_str = '文化, 体育, 娱乐, 财经, 家居, 汽车, 教育, 科技, 军事'
template = '这篇文章的标题是{}' template = '这篇文章的标题是{}'
regress_tool = MsRegressTool(baseline=False) regress_tool = MsRegressTool(baseline=False)


@@ -40,6 +41,10 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck):
f'sentence: {self.sentence}\n' f'sentence: {self.sentence}\n'
f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}' f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}'
) )
print(
f'sentence: {self.sentence}\n'
f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels_str,hypothesis_template=self.template)}'
)
print( print(
f'sentence: {self.sentence}\n' f'sentence: {self.sentence}\n'
f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}' f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}'


Loading…
Cancel
Save