From 6940d7720e0ad3ca68c4cb8f4c8064856dc2d7a7 Mon Sep 17 00:00:00 2001 From: myf272609 Date: Tue, 2 Aug 2022 13:09:49 +0800 Subject: [PATCH] [to #42322933] rename task of person_image_cartoon and fix a bug MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 应工程同学要求,修改原使用的一级task(image-generation)至新二级task(image-cartoon) - 修复无人脸时无图像结果返回,更新为 返回背景卡通化后的图像(无人脸卡通化处理) Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9594912 --- modelscope/pipelines/builder.py | 2 +- modelscope/pipelines/cv/image_cartoon_pipeline.py | 13 +++++++------ modelscope/utils/constant.py | 1 + tests/pipelines/test_person_image_cartoon.py | 8 +++++--- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 6e2f9c14..7ab77d98 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -60,7 +60,7 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/nlp_bart_text-error-correction_chinese'), Tasks.image_captioning: (Pipelines.image_captioning, 'damo/ofa_image-caption_coco_large_en'), - Tasks.image_generation: + Tasks.image_portrait_stylization: (Pipelines.person_image_cartoon, 'damo/cv_unet_person-image-cartoon_compound-models'), Tasks.ocr_detection: (Pipelines.ocr_detection, diff --git a/modelscope/pipelines/cv/image_cartoon_pipeline.py b/modelscope/pipelines/cv/image_cartoon_pipeline.py index 46a30ad0..9c3c418e 100644 --- a/modelscope/pipelines/cv/image_cartoon_pipeline.py +++ b/modelscope/pipelines/cv/image_cartoon_pipeline.py @@ -25,7 +25,8 @@ logger = get_logger() @PIPELINES.register_module( - Tasks.image_generation, module_name=Pipelines.person_image_cartoon) + Tasks.image_portrait_stylization, + module_name=Pipelines.person_image_cartoon) class ImageCartoonPipeline(Pipeline): def __init__(self, model: str, **kwargs): @@ -85,11 +86,6 @@ class ImageCartoonPipeline(Pipeline): img_brg = img[:, :, ::-1] - landmarks = self.detect_face(img) - if landmarks is None: - print('No face detected!') - return {OutputKeys.OUTPUT_IMG: None} - # background process pad_bg, pad_h, pad_w = padTo16x(img_brg) @@ -99,6 +95,11 @@ class ImageCartoonPipeline(Pipeline): feed_dict={'model_anime_bg/input_image:0': pad_bg}) res = bg_res[:pad_h, :pad_w, :] + landmarks = self.detect_face(img) + if landmarks is None: + print('No face detected!') + return {OutputKeys.OUTPUT_IMG: res} + for landmark in landmarks: # get facial 5 points f5p = get_f5p(landmark, img_brg) diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 0cc43e00..dd402101 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -43,6 +43,7 @@ class CVTasks(object): video_category = 'video-category' image_classification_imagenet = 'image-classification-imagenet' image_classification_dailylife = 'image-classification-dailylife' + image_portrait_stylization = 'image-portrait-stylization' image_to_image_generation = 'image-to-image-generation' diff --git a/tests/pipelines/test_person_image_cartoon.py b/tests/pipelines/test_person_image_cartoon.py index d6ef1894..660ba1df 100644 --- a/tests/pipelines/test_person_image_cartoon.py +++ b/tests/pipelines/test_person_image_cartoon.py @@ -33,17 +33,19 @@ class ImageCartoonTest(unittest.TestCase): ) os.system('unzip assets.zip') - img_cartoon = pipeline(Tasks.image_generation, model=model_dir) + img_cartoon = pipeline( + Tasks.image_portrait_stylization, model=model_dir) self.pipeline_inference(img_cartoon, self.test_image) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_modelhub(self): - img_cartoon = pipeline(Tasks.image_generation, model=self.model_id) + img_cartoon = pipeline( + Tasks.image_portrait_stylization, model=self.model_id) self.pipeline_inference(img_cartoon, self.test_image) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_modelhub_default_model(self): - img_cartoon = pipeline(Tasks.image_generation) + img_cartoon = pipeline(Tasks.image_portrait_stylization) self.pipeline_inference(img_cartoon, self.test_image)