| @@ -14,7 +14,7 @@ from .outputs import TASK_OUTPUTS | |||||
| from .util import is_model, is_official_hub_path | from .util import is_model, is_official_hub_path | ||||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | Tensor = Union['torch.Tensor', 'tf.Tensor'] | ||||
| Input = Union[str, tuple, dict, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] | |||||
| Input = Union[str, tuple, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] | |||||
| InputModel = Union[str, Model] | InputModel = Union[str, Model] | ||||
| output_keys = [ | output_keys = [ | ||||
| @@ -74,7 +74,7 @@ class Pipeline(ABC): | |||||
| self.preprocessor = preprocessor | self.preprocessor = preprocessor | ||||
| def __call__(self, input: Union[Input, List[Input]], *args, | def __call__(self, input: Union[Input, List[Input]], *args, | ||||
| **kwargs) -> Union[Dict[str, Any], Generator]: | |||||
| **post_kwargs) -> Union[Dict[str, Any], Generator]: | |||||
| # model provider should leave it as it is | # model provider should leave it as it is | ||||
| # modelscope library developer will handle this function | # modelscope library developer will handle this function | ||||
| @@ -83,42 +83,24 @@ class Pipeline(ABC): | |||||
| if isinstance(input, list): | if isinstance(input, list): | ||||
| output = [] | output = [] | ||||
| for ele in input: | for ele in input: | ||||
| output.append(self._process_single(ele, *args, **kwargs)) | |||||
| output.append(self._process_single(ele, *args, **post_kwargs)) | |||||
| elif isinstance(input, PyDataset): | elif isinstance(input, PyDataset): | ||||
| return self._process_iterator(input, *args, **kwargs) | |||||
| return self._process_iterator(input, *args, **post_kwargs) | |||||
| else: | else: | ||||
| output = self._process_single(input, *args, **kwargs) | |||||
| output = self._process_single(input, *args, **post_kwargs) | |||||
| return output | return output | ||||
| def _process_iterator(self, input: Input, *args, **kwargs): | |||||
| def _process_iterator(self, input: Input, *args, **post_kwargs): | |||||
| for ele in input: | for ele in input: | ||||
| yield self._process_single(ele, *args, **kwargs) | |||||
| def _sanitize_parameters(self, **pipeline_parameters): | |||||
| """ | |||||
| this method should sanitize the keyword args to preprocessor params, | |||||
| forward params and postprocess params on '__call__' or '_process_single' method | |||||
| considering to be a normal classmethod with default implementation / output | |||||
| Returns: | |||||
| Dict[str, str]: preprocess_params = {} | |||||
| Dict[str, str]: forward_params = {} | |||||
| Dict[str, str]: postprocess_params = pipeline_parameters | |||||
| """ | |||||
| # raise NotImplementedError("_sanitize_parameters not implemented") | |||||
| return {}, {}, pipeline_parameters | |||||
| def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]: | |||||
| # sanitize the parameters | |||||
| preprocess_params, forward_params, postprocess_params = self._sanitize_parameters( | |||||
| **kwargs) | |||||
| out = self.preprocess(input, **preprocess_params) | |||||
| out = self.forward(out, **forward_params) | |||||
| out = self.postprocess(out, **postprocess_params) | |||||
| yield self._process_single(ele, *args, **post_kwargs) | |||||
| def _process_single(self, input: Input, *args, | |||||
| **post_kwargs) -> Dict[str, Any]: | |||||
| out = self.preprocess(input) | |||||
| out = self.forward(out) | |||||
| out = self.postprocess(out, **post_kwargs) | |||||
| self._check_output(out) | self._check_output(out) | ||||
| return out | return out | ||||
| @@ -138,25 +120,23 @@ class Pipeline(ABC): | |||||
| raise ValueError(f'expected output keys are {output_keys}, ' | raise ValueError(f'expected output keys are {output_keys}, ' | ||||
| f'those {missing_keys} are missing') | f'those {missing_keys} are missing') | ||||
| def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: | |||||
| def preprocess(self, inputs: Input) -> Dict[str, Any]: | |||||
| """ Provide default implementation based on preprocess_cfg and user can reimplement it | """ Provide default implementation based on preprocess_cfg and user can reimplement it | ||||
| """ | """ | ||||
| assert self.preprocessor is not None, 'preprocess method should be implemented' | assert self.preprocessor is not None, 'preprocess method should be implemented' | ||||
| assert not isinstance(self.preprocessor, List),\ | assert not isinstance(self.preprocessor, List),\ | ||||
| 'default implementation does not support using multiple preprocessors.' | 'default implementation does not support using multiple preprocessors.' | ||||
| return self.preprocessor(inputs, **preprocess_params) | |||||
| return self.preprocessor(inputs) | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """ Provide default implementation using self.model and user can reimplement it | """ Provide default implementation using self.model and user can reimplement it | ||||
| """ | """ | ||||
| assert self.model is not None, 'forward method should be implemented' | assert self.model is not None, 'forward method should be implemented' | ||||
| assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.' | assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.' | ||||
| return self.model(inputs, **forward_params) | |||||
| return self.model(inputs) | |||||
| @abstractmethod | @abstractmethod | ||||
| def postprocess(self, inputs: Dict[str, Any], | |||||
| **postprocess_params) -> Dict[str, Any]: | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """ If current pipeline support model reuse, common postprocess | """ If current pipeline support model reuse, common postprocess | ||||
| code should be write here. | code should be write here. | ||||