diff --git a/modelscope/pipeline_inputs.py b/modelscope/pipeline_inputs.py index 77940c3c..13560229 100644 --- a/modelscope/pipeline_inputs.py +++ b/modelscope/pipeline_inputs.py @@ -91,6 +91,10 @@ TASK_INPUTS = { InputType.IMAGE, Tasks.crowd_counting: InputType.IMAGE, + Tasks.image_inpainting: { + 'img': InputType.IMAGE, + 'mask': InputType.IMAGE, + }, # image generation task result for a single image Tasks.image_to_image_generation: diff --git a/modelscope/pipelines/cv/image_inpainting_pipeline.py b/modelscope/pipelines/cv/image_inpainting_pipeline.py index 6ae0d63e..aff9788d 100644 --- a/modelscope/pipelines/cv/image_inpainting_pipeline.py +++ b/modelscope/pipelines/cv/image_inpainting_pipeline.py @@ -77,21 +77,22 @@ class ImageInpaintingPipeline(Pipeline): img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric') - def preprocess(self, input: Input) -> Dict[str, Any]: - if isinstance(input, str): - image_name, mask_name = input.split('+') + def preprocess(self, input: Dict[str, Any]) -> Dict[str, Any]: + if isinstance(input['img'], str): + image_name, mask_name = input['img'], input['mask'] img = LoadImage.convert_to_ndarray(image_name) img = self.transforms(img) mask = np.array(LoadImage(mode='L')(mask_name)['img']) mask = self.transforms(mask) - elif isinstance(input, PIL.Image.Image): - img = input.crop((0, 0, int(input.width / 2), input.height)) + elif isinstance(input['img'], PIL.Image.Image): + img = input['img'] img = self.transforms(np.array(img)) - mask = input.crop((int(input.width / 2), 0, input.width, - input.height)).convert('L') + mask = input['mask'].convert('L') mask = self.transforms(np.array(mask)) else: - raise TypeError('input should be either str or PIL.Image') + raise TypeError( + 'input should be either str or PIL.Image, and both inputs should have the same type' + ) result = dict(image=img, mask=mask[None, ...]) if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: diff --git a/tests/pipelines/test_image_inpainting.py b/tests/pipelines/test_image_inpainting.py index b89ce399..a8b704b7 100644 --- a/tests/pipelines/test_image_inpainting.py +++ b/tests/pipelines/test_image_inpainting.py @@ -20,6 +20,10 @@ class ImageInpaintingTest(unittest.TestCase): self.input_location = 'data/test/images/image_inpainting/image_inpainting.png' self.input_mask_location = 'data/test/images/image_inpainting/image_inpainting_mask.png' self.model_id = 'damo/cv_fft_inpainting_lama' + self.input = { + 'img': self.input_location, + 'mask': self.input_mask_location + } def save_result(self, result): vis_img = result[OutputKeys.OUTPUT_IMG] @@ -28,8 +32,7 @@ class ImageInpaintingTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_inpainting(self): inpainting = pipeline(Tasks.image_inpainting, model=self.model_id) - result = inpainting(self.input_location + '+' - + self.input_mask_location) + result = inpainting(self.input) if result: self.save_result(result) else: @@ -41,8 +44,7 @@ class ImageInpaintingTest(unittest.TestCase): # if input image is HR, set refine=True is more better inpainting = pipeline( Tasks.image_inpainting, model=self.model_id, refine=True) - result = inpainting(self.input_location + '+' - + self.input_mask_location) + result = inpainting(self.input) if result: self.save_result(result) else: @@ -53,10 +55,7 @@ class ImageInpaintingTest(unittest.TestCase): inpainting = pipeline(Tasks.image_inpainting, model=self.model_id) img = Image.open(self.input_location).convert('RGB') mask = Image.open(self.input_mask_location).convert('RGB') - img_new = Image.new('RGB', (img.width + mask.width, img.height)) - img_new.paste(img, (0, 0)) - img_new.paste(mask, (img.width, 0)) - result = inpainting(img_new) + result = inpainting({'img': img, 'mask': mask}) if result: self.save_result(result) else: @@ -65,8 +64,7 @@ class ImageInpaintingTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_inpainting_with_default_task(self): inpainting = pipeline(Tasks.image_inpainting) - result = inpainting(self.input_location + '+' - + self.input_mask_location) + result = inpainting(self.input) if result: self.save_result(result) else: