|
|
|
@@ -74,7 +74,8 @@ class ZeroShotClassificationPipeline(Pipeline): |
|
|
|
preprocess_params = {} |
|
|
|
postprocess_params = {} |
|
|
|
if 'candidate_labels' in kwargs: |
|
|
|
candidate_labels = kwargs.pop('candidate_labels') |
|
|
|
candidate_labels = self._parse_labels( |
|
|
|
kwargs.pop('candidate_labels')) |
|
|
|
preprocess_params['candidate_labels'] = candidate_labels |
|
|
|
postprocess_params['candidate_labels'] = candidate_labels |
|
|
|
else: |
|
|
|
@@ -84,6 +85,14 @@ class ZeroShotClassificationPipeline(Pipeline): |
|
|
|
postprocess_params['multi_label'] = kwargs.pop('multi_label', False) |
|
|
|
return preprocess_params, {}, postprocess_params |
|
|
|
|
|
|
|
def _parse_labels(self, labels): |
|
|
|
if isinstance(labels, str): |
|
|
|
labels = labels.replace(',', ',') # replace cn comma to en comma |
|
|
|
labels = [ |
|
|
|
label.strip() for label in labels.split(',') if label.strip() |
|
|
|
] |
|
|
|
return labels |
|
|
|
|
|
|
|
def forward(self, inputs: Dict[str, Any], |
|
|
|
**forward_params) -> Dict[str, Any]: |
|
|
|
return self.model(**inputs, **forward_params) |
|
|
|
|