Browse Source

revert pipeline params update

master
智丞 3 years ago
parent
commit
f219e1a8bc
1 changed files with 17 additions and 37 deletions
  1. +17
    -37
      modelscope/pipelines/base.py

+ 17
- 37
modelscope/pipelines/base.py View File

@@ -14,7 +14,7 @@ from .outputs import TASK_OUTPUTS
from .util import is_model, is_official_hub_path

Tensor = Union['torch.Tensor', 'tf.Tensor']
Input = Union[str, tuple, dict, PyDataset, 'PIL.Image.Image', 'numpy.ndarray']
Input = Union[str, tuple, PyDataset, 'PIL.Image.Image', 'numpy.ndarray']
InputModel = Union[str, Model]

output_keys = [
@@ -74,7 +74,7 @@ class Pipeline(ABC):
self.preprocessor = preprocessor

def __call__(self, input: Union[Input, List[Input]], *args,
**kwargs) -> Union[Dict[str, Any], Generator]:
**post_kwargs) -> Union[Dict[str, Any], Generator]:
# model provider should leave it as it is
# modelscope library developer will handle this function

@@ -83,42 +83,24 @@ class Pipeline(ABC):
if isinstance(input, list):
output = []
for ele in input:
output.append(self._process_single(ele, *args, **kwargs))
output.append(self._process_single(ele, *args, **post_kwargs))

elif isinstance(input, PyDataset):
return self._process_iterator(input, *args, **kwargs)
return self._process_iterator(input, *args, **post_kwargs)

else:
output = self._process_single(input, *args, **kwargs)
output = self._process_single(input, *args, **post_kwargs)
return output

def _process_iterator(self, input: Input, *args, **kwargs):
def _process_iterator(self, input: Input, *args, **post_kwargs):
for ele in input:
yield self._process_single(ele, *args, **kwargs)

def _sanitize_parameters(self, **pipeline_parameters):
"""
this method should sanitize the keyword args to preprocessor params,
forward params and postprocess params on '__call__' or '_process_single' method
considering to be a normal classmethod with default implementation / output

Returns:
Dict[str, str]: preprocess_params = {}
Dict[str, str]: forward_params = {}
Dict[str, str]: postprocess_params = pipeline_parameters
"""
# raise NotImplementedError("_sanitize_parameters not implemented")
return {}, {}, pipeline_parameters

def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]:

# sanitize the parameters
preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(
**kwargs)
out = self.preprocess(input, **preprocess_params)
out = self.forward(out, **forward_params)
out = self.postprocess(out, **postprocess_params)
yield self._process_single(ele, *args, **post_kwargs)

def _process_single(self, input: Input, *args,
**post_kwargs) -> Dict[str, Any]:
out = self.preprocess(input)
out = self.forward(out)
out = self.postprocess(out, **post_kwargs)
self._check_output(out)
return out

@@ -138,25 +120,23 @@ class Pipeline(ABC):
raise ValueError(f'expected output keys are {output_keys}, '
f'those {missing_keys} are missing')

def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
def preprocess(self, inputs: Input) -> Dict[str, Any]:
""" Provide default implementation based on preprocess_cfg and user can reimplement it
"""
assert self.preprocessor is not None, 'preprocess method should be implemented'
assert not isinstance(self.preprocessor, List),\
'default implementation does not support using multiple preprocessors.'
return self.preprocessor(inputs, **preprocess_params)
return self.preprocessor(inputs)

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
""" Provide default implementation using self.model and user can reimplement it
"""
assert self.model is not None, 'forward method should be implemented'
assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.'
return self.model(inputs, **forward_params)
return self.model(inputs)

@abstractmethod
def postprocess(self, inputs: Dict[str, Any],
**postprocess_params) -> Dict[str, Any]:
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
""" If current pipeline support model reuse, common postprocess
code should be write here.



Loading…
Cancel
Save