Browse Source

allow params pass to pipeline's __call__ method

master
智丞 3 years ago
parent
commit
8ae2e46ad3
4 changed files with 82 additions and 57 deletions
  1. +35
    -16
      modelscope/pipelines/base.py
  2. +23
    -10
      modelscope/pipelines/nlp/zero_shot_classification_pipeline.py
  3. +4
    -5
      modelscope/preprocessors/nlp.py
  4. +20
    -26
      tests/pipelines/test_zero_shot_classification.py

+ 35
- 16
modelscope/pipelines/base.py View File

@@ -80,7 +80,7 @@ class Pipeline(ABC):
self.preprocessor = preprocessor self.preprocessor = preprocessor


def __call__(self, input: Union[Input, List[Input]], *args, def __call__(self, input: Union[Input, List[Input]], *args,
**post_kwargs) -> Union[Dict[str, Any], Generator]:
**kwargs) -> Union[Dict[str, Any], Generator]:
# model provider should leave it as it is # model provider should leave it as it is
# modelscope library developer will handle this function # modelscope library developer will handle this function


@@ -89,24 +89,41 @@ class Pipeline(ABC):
if isinstance(input, list): if isinstance(input, list):
output = [] output = []
for ele in input: for ele in input:
output.append(self._process_single(ele, *args, **post_kwargs))
output.append(self._process_single(ele, *args, **kwargs))


elif isinstance(input, PyDataset): elif isinstance(input, PyDataset):
return self._process_iterator(input, *args, **post_kwargs)
return self._process_iterator(input, *args, **kwargs)


else: else:
output = self._process_single(input, *args, **post_kwargs)
output = self._process_single(input, *args, **kwargs)
return output return output


def _process_iterator(self, input: Input, *args, **post_kwargs):
def _process_iterator(self, input: Input, *args, **kwargs):
for ele in input: for ele in input:
yield self._process_single(ele, *args, **post_kwargs)
yield self._process_single(ele, *args, **kwargs)


def _process_single(self, input: Input, *args,
**post_kwargs) -> Dict[str, Any]:
out = self.preprocess(input)
out = self.forward(out)
out = self.postprocess(out, **post_kwargs)
def _sanitize_parameters(self, **pipeline_parameters):
"""
this method should sanitize the keyword args to preprocessor params,
forward params and postprocess params on '__call__' or '_process_single' method
considering to be a normal classmethod with default implementation / output

Returns:
Dict[str, str]: preprocess_params = {}
Dict[str, str]: forward_params = {}
Dict[str, str]: postprocess_params = pipeline_parameters
"""
# raise NotImplementedError("_sanitize_parameters not implemented")
return {}, {}, pipeline_parameters

def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]:

# sanitize the parameters
preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(
**kwargs)
out = self.preprocess(input, **preprocess_params)
out = self.forward(out, **forward_params)
out = self.postprocess(out, **postprocess_params)
self._check_output(out) self._check_output(out)
return out return out


@@ -126,23 +143,25 @@ class Pipeline(ABC):
raise ValueError(f'expected output keys are {output_keys}, ' raise ValueError(f'expected output keys are {output_keys}, '
f'those {missing_keys} are missing') f'those {missing_keys} are missing')


def preprocess(self, inputs: Input) -> Dict[str, Any]:
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
""" Provide default implementation based on preprocess_cfg and user can reimplement it """ Provide default implementation based on preprocess_cfg and user can reimplement it
""" """
assert self.preprocessor is not None, 'preprocess method should be implemented' assert self.preprocessor is not None, 'preprocess method should be implemented'
assert not isinstance(self.preprocessor, List),\ assert not isinstance(self.preprocessor, List),\
'default implementation does not support using multiple preprocessors.' 'default implementation does not support using multiple preprocessors.'
return self.preprocessor(inputs)
return self.preprocessor(inputs, **preprocess_params)


def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
""" Provide default implementation using self.model and user can reimplement it """ Provide default implementation using self.model and user can reimplement it
""" """
assert self.model is not None, 'forward method should be implemented' assert self.model is not None, 'forward method should be implemented'
assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.' assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.'
return self.model(inputs)
return self.model(inputs, **forward_params)


@abstractmethod @abstractmethod
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def postprocess(self, inputs: Dict[str, Any],
**postprocess_params) -> Dict[str, Any]:
""" If current pipeline support model reuse, common postprocess """ If current pipeline support model reuse, common postprocess
code should be write here. code should be write here.




+ 23
- 10
modelscope/pipelines/nlp/zero_shot_classification_pipeline.py View File

@@ -39,18 +39,32 @@ class ZeroShotClassificationPipeline(Pipeline):


self.entailment_id = 0 self.entailment_id = 0
self.contradiction_id = 2 self.contradiction_id = 2
self.candidate_labels = kwargs.pop('candidate_labels')
self.hypothesis_template = kwargs.pop('hypothesis_template', '{}')
self.multi_label = kwargs.pop('multi_label', False)


if preprocessor is None: if preprocessor is None:
preprocessor = ZeroShotClassificationPreprocessor( preprocessor = ZeroShotClassificationPreprocessor(
sc_model.model_dir,
candidate_labels=self.candidate_labels,
hypothesis_template=self.hypothesis_template)
sc_model.model_dir)
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)


def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def _sanitize_parameters(self, **kwargs):
preprocess_params = {}
postprocess_params = {}

if 'candidate_labels' in kwargs:
candidate_labels = kwargs.pop('candidate_labels')
preprocess_params['candidate_labels'] = candidate_labels
postprocess_params['candidate_labels'] = candidate_labels
else:
raise ValueError('You must include at least one label.')
preprocess_params['hypothesis_template'] = kwargs.pop(
'hypothesis_template', '{}')

postprocess_params['multi_label'] = kwargs.pop('multi_label', False)
return preprocess_params, {}, postprocess_params

def postprocess(self,
inputs: Dict[str, Any],
candidate_labels,
multi_label=False) -> Dict[str, Any]:
"""process the prediction results """process the prediction results


Args: Args:
@@ -61,8 +75,7 @@ class ZeroShotClassificationPipeline(Pipeline):
""" """


logits = inputs['logits'] logits = inputs['logits']

if self.multi_label or len(self.candidate_labels) == 1:
if multi_label or len(candidate_labels) == 1:
logits = logits[..., [self.contradiction_id, self.entailment_id]] logits = logits[..., [self.contradiction_id, self.entailment_id]]
scores = softmax(logits, axis=-1)[..., 1] scores = softmax(logits, axis=-1)[..., 1]
else: else:
@@ -71,7 +84,7 @@ class ZeroShotClassificationPipeline(Pipeline):


reversed_index = list(reversed(scores.argsort())) reversed_index = list(reversed(scores.argsort()))
result = { result = {
'labels': [self.candidate_labels[i] for i in reversed_index],
'labels': [candidate_labels[i] for i in reversed_index],
'scores': [scores[i].item() for i in reversed_index], 'scores': [scores[i].item() for i in reversed_index],
} }
return result return result

+ 4
- 5
modelscope/preprocessors/nlp.py View File

@@ -196,12 +196,11 @@ class ZeroShotClassificationPreprocessor(Preprocessor):
from sofa import SbertTokenizer from sofa import SbertTokenizer
self.model_dir: str = model_dir self.model_dir: str = model_dir
self.sequence_length = kwargs.pop('sequence_length', 512) self.sequence_length = kwargs.pop('sequence_length', 512)
self.candidate_labels = kwargs.pop('candidate_labels')
self.hypothesis_template = kwargs.pop('hypothesis_template', '{}')
self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir)


@type_assert(object, str) @type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
def __call__(self, data: str, hypothesis_template: str,
candidate_labels: list) -> Dict[str, Any]:
"""process the raw input data """process the raw input data


Args: Args:
@@ -212,8 +211,8 @@ class ZeroShotClassificationPreprocessor(Preprocessor):
Returns: Returns:
Dict[str, Any]: the preprocessed data Dict[str, Any]: the preprocessed data
""" """
pairs = [[data, self.hypothesis_template.format(label)]
for label in self.candidate_labels]
pairs = [[data, hypothesis_template.format(label)]
for label in candidate_labels]


features = self.tokenizer( features = self.tokenizer(
pairs, pairs,


+ 20
- 26
tests/pipelines/test_zero_shot_classification.py View File

@@ -13,53 +13,47 @@ from modelscope.utils.constant import Tasks
class ZeroShotClassificationTest(unittest.TestCase): class ZeroShotClassificationTest(unittest.TestCase):
model_id = 'damo/nlp_structbert_zero-shot-classification_chinese-base' model_id = 'damo/nlp_structbert_zero-shot-classification_chinese-base'
sentence = '全新突破 解放军运20版空中加油机曝光' sentence = '全新突破 解放军运20版空中加油机曝光'
candidate_labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事']
labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事']
template = '这篇文章的标题是{}'


def test_run_from_local(self): def test_run_from_local(self):
cache_path = snapshot_download(self.model_id) cache_path = snapshot_download(self.model_id)
tokenizer = ZeroShotClassificationPreprocessor(
cache_path, candidate_labels=self.candidate_labels)
tokenizer = ZeroShotClassificationPreprocessor(cache_path)
model = BertForZeroShotClassification(cache_path, tokenizer=tokenizer) model = BertForZeroShotClassification(cache_path, tokenizer=tokenizer)
pipeline1 = ZeroShotClassificationPipeline( pipeline1 = ZeroShotClassificationPipeline(
model,
preprocessor=tokenizer,
candidate_labels=self.candidate_labels,
)
model, preprocessor=tokenizer)
pipeline2 = pipeline( pipeline2 = pipeline(
Tasks.zero_shot_classification, Tasks.zero_shot_classification,
model=model, model=model,
preprocessor=tokenizer,
candidate_labels=self.candidate_labels)
preprocessor=tokenizer)


print(f'sentence: {self.sentence}\n'
f'pipeline1:{pipeline1(input=self.sentence)}')
print(
f'sentence: {self.sentence}\n'
f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}'
)
print() print()
print(f'sentence: {self.sentence}\n'
f'pipeline2: {pipeline2(input=self.sentence)}')
print(
f'sentence: {self.sentence}\n'
f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}'
)


def test_run_with_model_from_modelhub(self): def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id) model = Model.from_pretrained(self.model_id)
tokenizer = ZeroShotClassificationPreprocessor(
model.model_dir, candidate_labels=self.candidate_labels)
tokenizer = ZeroShotClassificationPreprocessor(model.model_dir)
pipeline_ins = pipeline( pipeline_ins = pipeline(
task=Tasks.zero_shot_classification, task=Tasks.zero_shot_classification,
model=model, model=model,
preprocessor=tokenizer,
candidate_labels=self.candidate_labels)
print(pipeline_ins(input=self.sentence))
preprocessor=tokenizer)
print(pipeline_ins(input=self.sentence, candidate_labels=self.labels))


def test_run_with_model_name(self): def test_run_with_model_name(self):
pipeline_ins = pipeline( pipeline_ins = pipeline(
task=Tasks.zero_shot_classification,
model=self.model_id,
candidate_labels=self.candidate_labels)
print(pipeline_ins(input=self.sentence))
task=Tasks.zero_shot_classification, model=self.model_id)
print(pipeline_ins(input=self.sentence, candidate_labels=self.labels))


def test_run_with_default_model(self): def test_run_with_default_model(self):
pipeline_ins = pipeline(
task=Tasks.zero_shot_classification,
candidate_labels=self.candidate_labels)
print(pipeline_ins(input=self.sentence))
pipeline_ins = pipeline(task=Tasks.zero_shot_classification)
print(pipeline_ins(input=self.sentence, candidate_labels=self.labels))




if __name__ == '__main__': if __name__ == '__main__':


Loading…
Cancel
Save