diff --git a/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py index 88792b45..ecd538b9 100644 --- a/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py +++ b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py @@ -74,7 +74,8 @@ class ZeroShotClassificationPipeline(Pipeline): preprocess_params = {} postprocess_params = {} if 'candidate_labels' in kwargs: - candidate_labels = kwargs.pop('candidate_labels') + candidate_labels = self._parse_labels( + kwargs.pop('candidate_labels')) preprocess_params['candidate_labels'] = candidate_labels postprocess_params['candidate_labels'] = candidate_labels else: @@ -84,6 +85,14 @@ class ZeroShotClassificationPipeline(Pipeline): postprocess_params['multi_label'] = kwargs.pop('multi_label', False) return preprocess_params, {}, postprocess_params + def _parse_labels(self, labels): + if isinstance(labels, str): + labels = labels.replace(',', ',') # replace cn comma to en comma + labels = [ + label.strip() for label in labels.split(',') if label.strip() + ] + return labels + def forward(self, inputs: Dict[str, Any], **forward_params) -> Dict[str, Any]: return self.model(**inputs, **forward_params) diff --git a/tests/pipelines/test_zero_shot_classification.py b/tests/pipelines/test_zero_shot_classification.py index da1854c9..6a98132a 100644 --- a/tests/pipelines/test_zero_shot_classification.py +++ b/tests/pipelines/test_zero_shot_classification.py @@ -21,6 +21,7 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck): sentence = '全新突破 解放军运20版空中加油机曝光' labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事'] + labels_str = '文化, 体育, 娱乐, 财经, 家居, 汽车, 教育, 科技, 军事' template = '这篇文章的标题是{}' regress_tool = MsRegressTool(baseline=False) @@ -40,6 +41,10 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck): f'sentence: {self.sentence}\n' f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}' ) + print( + f'sentence: {self.sentence}\n' + f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels_str,hypothesis_template=self.template)}' + ) print( f'sentence: {self.sentence}\n' f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}'