Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10506320master
| @@ -74,7 +74,8 @@ class ZeroShotClassificationPipeline(Pipeline): | |||||
| preprocess_params = {} | preprocess_params = {} | ||||
| postprocess_params = {} | postprocess_params = {} | ||||
| if 'candidate_labels' in kwargs: | if 'candidate_labels' in kwargs: | ||||
| candidate_labels = kwargs.pop('candidate_labels') | |||||
| candidate_labels = self._parse_labels( | |||||
| kwargs.pop('candidate_labels')) | |||||
| preprocess_params['candidate_labels'] = candidate_labels | preprocess_params['candidate_labels'] = candidate_labels | ||||
| postprocess_params['candidate_labels'] = candidate_labels | postprocess_params['candidate_labels'] = candidate_labels | ||||
| else: | else: | ||||
| @@ -84,6 +85,14 @@ class ZeroShotClassificationPipeline(Pipeline): | |||||
| postprocess_params['multi_label'] = kwargs.pop('multi_label', False) | postprocess_params['multi_label'] = kwargs.pop('multi_label', False) | ||||
| return preprocess_params, {}, postprocess_params | return preprocess_params, {}, postprocess_params | ||||
| def _parse_labels(self, labels): | |||||
| if isinstance(labels, str): | |||||
| labels = labels.replace(',', ',') # replace cn comma to en comma | |||||
| labels = [ | |||||
| label.strip() for label in labels.split(',') if label.strip() | |||||
| ] | |||||
| return labels | |||||
| def forward(self, inputs: Dict[str, Any], | def forward(self, inputs: Dict[str, Any], | ||||
| **forward_params) -> Dict[str, Any]: | **forward_params) -> Dict[str, Any]: | ||||
| return self.model(**inputs, **forward_params) | return self.model(**inputs, **forward_params) | ||||
| @@ -21,6 +21,7 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| sentence = '全新突破 解放军运20版空中加油机曝光' | sentence = '全新突破 解放军运20版空中加油机曝光' | ||||
| labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事'] | labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事'] | ||||
| labels_str = '文化, 体育, 娱乐, 财经, 家居, 汽车, 教育, 科技, 军事' | |||||
| template = '这篇文章的标题是{}' | template = '这篇文章的标题是{}' | ||||
| regress_tool = MsRegressTool(baseline=False) | regress_tool = MsRegressTool(baseline=False) | ||||
| @@ -40,6 +41,10 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| f'sentence: {self.sentence}\n' | f'sentence: {self.sentence}\n' | ||||
| f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}' | f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}' | ||||
| ) | ) | ||||
| print( | |||||
| f'sentence: {self.sentence}\n' | |||||
| f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels_str,hypothesis_template=self.template)}' | |||||
| ) | |||||
| print( | print( | ||||
| f'sentence: {self.sentence}\n' | f'sentence: {self.sentence}\n' | ||||
| f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}' | f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}' | ||||