|
|
|
@@ -20,6 +20,10 @@ class ImageInpaintingTest(unittest.TestCase): |
|
|
|
self.input_location = 'data/test/images/image_inpainting/image_inpainting.png' |
|
|
|
self.input_mask_location = 'data/test/images/image_inpainting/image_inpainting_mask.png' |
|
|
|
self.model_id = 'damo/cv_fft_inpainting_lama' |
|
|
|
self.input = { |
|
|
|
'img': self.input_location, |
|
|
|
'mask': self.input_mask_location |
|
|
|
} |
|
|
|
|
|
|
|
def save_result(self, result): |
|
|
|
vis_img = result[OutputKeys.OUTPUT_IMG] |
|
|
|
@@ -28,8 +32,7 @@ class ImageInpaintingTest(unittest.TestCase): |
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
def test_inpainting(self): |
|
|
|
inpainting = pipeline(Tasks.image_inpainting, model=self.model_id) |
|
|
|
result = inpainting(self.input_location + '+' |
|
|
|
+ self.input_mask_location) |
|
|
|
result = inpainting(self.input) |
|
|
|
if result: |
|
|
|
self.save_result(result) |
|
|
|
else: |
|
|
|
@@ -41,8 +44,7 @@ class ImageInpaintingTest(unittest.TestCase): |
|
|
|
# if input image is HR, set refine=True is more better |
|
|
|
inpainting = pipeline( |
|
|
|
Tasks.image_inpainting, model=self.model_id, refine=True) |
|
|
|
result = inpainting(self.input_location + '+' |
|
|
|
+ self.input_mask_location) |
|
|
|
result = inpainting(self.input) |
|
|
|
if result: |
|
|
|
self.save_result(result) |
|
|
|
else: |
|
|
|
@@ -53,10 +55,7 @@ class ImageInpaintingTest(unittest.TestCase): |
|
|
|
inpainting = pipeline(Tasks.image_inpainting, model=self.model_id) |
|
|
|
img = Image.open(self.input_location).convert('RGB') |
|
|
|
mask = Image.open(self.input_mask_location).convert('RGB') |
|
|
|
img_new = Image.new('RGB', (img.width + mask.width, img.height)) |
|
|
|
img_new.paste(img, (0, 0)) |
|
|
|
img_new.paste(mask, (img.width, 0)) |
|
|
|
result = inpainting(img_new) |
|
|
|
result = inpainting({'img': img, 'mask': mask}) |
|
|
|
if result: |
|
|
|
self.save_result(result) |
|
|
|
else: |
|
|
|
@@ -65,8 +64,7 @@ class ImageInpaintingTest(unittest.TestCase): |
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_inpainting_with_default_task(self): |
|
|
|
inpainting = pipeline(Tasks.image_inpainting) |
|
|
|
result = inpainting(self.input_location + '+' |
|
|
|
+ self.input_mask_location) |
|
|
|
result = inpainting(self.input) |
|
|
|
if result: |
|
|
|
self.save_result(result) |
|
|
|
else: |
|
|
|
|