diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index 60d67786..86ea6dab 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -160,6 +160,7 @@ class Pipeline(ABC): # input_dict = self._handle_input(input) # sanitize the parameters + batch_size = kwargs.pop('batch_size', None) preprocess_params, forward_params, postprocess_params = self._sanitize_parameters( **kwargs) kwargs['preprocess_params'] = preprocess_params @@ -167,9 +168,12 @@ class Pipeline(ABC): kwargs['postprocess_params'] = postprocess_params if isinstance(input, list): - output = [] - for ele in input: - output.append(self._process_single(ele, *args, **kwargs)) + if batch_size is None: + output = [] + for ele in input: + output.append(self._process_single(ele, *args, **kwargs)) + else: + output = self._process_batch(input, batch_size, **kwargs) elif isinstance(input, MsDataset): return self._process_iterator(input, *args, **kwargs) @@ -204,6 +208,7 @@ class Pipeline(ABC): postprocess_params = kwargs.get('postprocess_params', {}) self._check_input(input) out = self.preprocess(input, **preprocess_params) + with device_placement(self.framework, self.device_name): if self.framework == Frameworks.torch: with torch.no_grad(): @@ -217,6 +222,55 @@ class Pipeline(ABC): self._check_output(out) return out + def _batch(self, data_list): + batch_data = {} + for sample_preprocessed in data_list: + for k, v in sample_preprocessed.items(): + value_list = batch_data.get(k, []) + value_list.append(v) + batch_data[k] = value_list + for k in batch_data.keys(): + if isinstance(batch_data[k][0], torch.Tensor): + batch_data[k] = torch.concat(batch_data[k]) + return batch_data + + def _process_batch(self, input: List[Input], batch_size, + **kwargs) -> Dict[str, Any]: + preprocess_params = kwargs.get('preprocess_params') + forward_params = kwargs.get('forward_params') + postprocess_params = kwargs.get('postprocess_params') + + # batch data + batched_input = {} + output_list = [] + for i in range(0, len(input), batch_size): + end = min(i + batch_size, len(input)) + real_batch_size = end - i + preprocessed_list = [ + self.preprocess(i, **preprocess_params) for i in input[i:end] + ] + + with device_placement(self.framework, self.device_name): + if self.framework == Frameworks.torch: + with torch.no_grad(): + if self._auto_collate: + out = self._batch(preprocessed_list) + batched_out = self._collate_fn(out) + batched_out = self.forward(batched_out, + **forward_params) + else: + batched_out = self.forward(batched_input, **forward_params) + for batch_idx in range(real_batch_size): + out = {} + for k, element in batched_out.items(): + if element is not None: + out[k] = element[batch_idx] + out = self.postprocess(out, **postprocess_params) + self._check_output(out) + output_list.append(out) + + return output_list + def _check_input(self, input): task_name = self.group_key if task_name in TASK_INPUTS: @@ -290,12 +344,14 @@ class Pipeline(ABC): return self.model(inputs, **forward_params) @abstractmethod - def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def postprocess(self, inputs: Dict[str, Any], + **post_params) -> Dict[str, Any]: """ If current pipeline support model reuse, common postprocess code should be write here. Args: inputs: input data + post_params: post process parameters Return: dict of results: a dict containing outputs of model, each @@ -429,7 +485,11 @@ def collate_fn(data, device): from torch.utils.data.dataloader import default_collate from modelscope.preprocessors.nlp import InputFeatures if isinstance(data, dict) or isinstance(data, Mapping): - return type(data)({k: collate_fn(v, device) for k, v in data.items()}) + # add compatibility for img_metas for mmlab models + return type(data)({ + k: collate_fn(v, device) if k != 'img_metas' else v + for k, v in data.items() + }) elif isinstance(data, (tuple, list)): if 0 == len(data): return torch.Tensor([]) diff --git a/modelscope/pipelines/multi_modal/image_captioning_pipeline.py b/modelscope/pipelines/multi_modal/image_captioning_pipeline.py index 81a5f8cd..63966ed4 100644 --- a/modelscope/pipelines/multi_modal/image_captioning_pipeline.py +++ b/modelscope/pipelines/multi_modal/image_captioning_pipeline.py @@ -46,6 +46,31 @@ class ImageCaptioningPipeline(Pipeline): preprocessor = MPlugPreprocessor(pipe_model.model_dir) super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + def _batch(self, data): + if isinstance(self.model, OfaForAllTasks): + # collate batch data due to the nested data structure + if isinstance(data, list): + batch_data = {} + batch_data['nsentences'] = len(data) + batch_data['samples'] = [d['samples'][0] for d in data] + batch_data['net_input'] = {} + for k in data[0]['net_input'].keys(): + batch_data['net_input'][k] = torch.concat( + [d['net_input'][k] for d in data]) + + return batch_data + elif isinstance(self.model, MPlugForAllTasks): + from transformers.tokenization_utils_base import BatchEncoding + batch_data = dict(train=data[0]['train']) + batch_data['image'] = torch.concat([d['image'] for d in data]) + question = {} + for k in data[0]['question'].keys(): + question[k] = torch.concat([d['question'][k] for d in data]) + batch_data['question'] = BatchEncoding(question) + return batch_data + else: + return super()._collate_batch(data) + def forward(self, inputs: Dict[str, Any], **forward_params) -> Dict[str, Any]: with torch.no_grad(): diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py index 6be70468..bd8a8d48 100644 --- a/tests/pipelines/test_ofa_tasks.py +++ b/tests/pipelines/test_ofa_tasks.py @@ -45,6 +45,19 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): result = img_captioning('data/test/images/image_captioning.png') print(result[OutputKeys.CAPTION]) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_image_captioning_batch(self): + img_captioning = pipeline( + Tasks.image_captioning, + model='damo/ofa_image-caption_coco_large_en') + results = img_captioning( + [{ + 'image': 'data/test/images/image_captioning.png' + } for _ in range(6)], + batch_size=2) + for r in results: + print(r[OutputKeys.CAPTION]) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_ocr_recognize_with_name(self): ocr_recognize = pipeline(