From e1ab73b7d848a558f9498ac4f5d2277eb2752e5a Mon Sep 17 00:00:00 2001 From: "zhangzhicheng.zzc" Date: Tue, 25 Oct 2022 13:55:09 +0800 Subject: [PATCH] [to #42322933]support type str for for zero-shot labels' input Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10506320 --- .../nlp/zero_shot_classification_pipeline.py | 11 ++++++++++- tests/pipelines/test_zero_shot_classification.py | 5 +++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py index 88792b45..ecd538b9 100644 --- a/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py +++ b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py @@ -74,7 +74,8 @@ class ZeroShotClassificationPipeline(Pipeline): preprocess_params = {} postprocess_params = {} if 'candidate_labels' in kwargs: - candidate_labels = kwargs.pop('candidate_labels') + candidate_labels = self._parse_labels( + kwargs.pop('candidate_labels')) preprocess_params['candidate_labels'] = candidate_labels postprocess_params['candidate_labels'] = candidate_labels else: @@ -84,6 +85,14 @@ class ZeroShotClassificationPipeline(Pipeline): postprocess_params['multi_label'] = kwargs.pop('multi_label', False) return preprocess_params, {}, postprocess_params + def _parse_labels(self, labels): + if isinstance(labels, str): + labels = labels.replace(',', ',') # replace cn comma to en comma + labels = [ + label.strip() for label in labels.split(',') if label.strip() + ] + return labels + def forward(self, inputs: Dict[str, Any], **forward_params) -> Dict[str, Any]: return self.model(**inputs, **forward_params) diff --git a/tests/pipelines/test_zero_shot_classification.py b/tests/pipelines/test_zero_shot_classification.py index da1854c9..6a98132a 100644 --- a/tests/pipelines/test_zero_shot_classification.py +++ b/tests/pipelines/test_zero_shot_classification.py @@ -21,6 +21,7 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck): sentence = '全新突破 解放军运20版空中加油机曝光' labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事'] + labels_str = '文化, 体育, 娱乐, 财经, 家居, 汽车, 教育, 科技, 军事' template = '这篇文章的标题是{}' regress_tool = MsRegressTool(baseline=False) @@ -40,6 +41,10 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck): f'sentence: {self.sentence}\n' f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}' ) + print( + f'sentence: {self.sentence}\n' + f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels_str,hypothesis_template=self.template)}' + ) print( f'sentence: {self.sentence}\n' f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}'