|
|
|
@@ -198,6 +198,7 @@ class ZeroShotClassificationPreprocessor(Preprocessor): |
|
|
|
self.sequence_length = kwargs.pop('sequence_length', 512) |
|
|
|
self.candidate_labels = kwargs.pop('candidate_labels') |
|
|
|
self.hypothesis_template = kwargs.pop('hypothesis_template', '{}') |
|
|
|
self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) |
|
|
|
|
|
|
|
@type_assert(object, str) |
|
|
|
def __call__(self, data: str) -> Dict[str, Any]: |
|
|
|
|