Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8993758master
| @@ -23,7 +23,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), | Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), | ||||
| Tasks.image_captioning: ('ofa', None), | Tasks.image_captioning: ('ofa', None), | ||||
| Tasks.image_generation: | Tasks.image_generation: | ||||
| ('cv_unet_person-image-cartoon', 'damo/cv_unet_image-matting_damo'), | |||||
| ('person-image-cartoon', | |||||
| 'damo/cv_unet_person-image-cartoon_compound-models'), | |||||
| } | } | ||||
| @@ -25,20 +25,19 @@ logger = get_logger() | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.image_generation, module_name='cv_unet_person-image-cartoon') | |||||
| Tasks.image_generation, module_name='person-image-cartoon') | |||||
| class ImageCartoonPipeline(Pipeline): | class ImageCartoonPipeline(Pipeline): | ||||
| def __init__(self, model: str): | def __init__(self, model: str): | ||||
| super().__init__(model=model) | super().__init__(model=model) | ||||
| self.facer = FaceAna(model) | |||||
| self.facer = FaceAna(self.model) | |||||
| self.sess_anime_head = self.load_sess( | self.sess_anime_head = self.load_sess( | ||||
| os.path.join(model, 'cartoon_anime_h.pb'), 'model_anime_head') | |||||
| os.path.join(self.model, 'cartoon_anime_h.pb'), 'model_anime_head') | |||||
| self.sess_anime_bg = self.load_sess( | self.sess_anime_bg = self.load_sess( | ||||
| os.path.join(model, 'cartoon_anime_bg.pb'), 'model_anime_bg') | |||||
| os.path.join(self.model, 'cartoon_anime_bg.pb'), 'model_anime_bg') | |||||
| self.box_width = 288 | self.box_width = 288 | ||||
| global_mask = cv2.imread(os.path.join(model, 'alpha.jpg')) | |||||
| global_mask = cv2.imread(os.path.join(self.model, 'alpha.jpg')) | |||||
| global_mask = cv2.resize( | global_mask = cv2.resize( | ||||
| global_mask, (self.box_width, self.box_width), | global_mask, (self.box_width, self.box_width), | ||||
| interpolation=cv2.INTER_AREA) | interpolation=cv2.INTER_AREA) | ||||
| @@ -1,26 +1,31 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | import os | ||||
| import os.path as osp | |||||
| import unittest | import unittest | ||||
| import cv2 | import cv2 | ||||
| from modelscope.pipelines import pipeline | from modelscope.pipelines import pipeline | ||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| def all_file(file_dir): | |||||
| L = [] | |||||
| for root, dirs, files in os.walk(file_dir): | |||||
| for file in files: | |||||
| extend = os.path.splitext(file)[1] | |||||
| if extend == '.png' or extend == '.jpg' or extend == '.jpeg' or extend == '.JPG' or extend == '.HEIC': | |||||
| L.append(os.path.join(root, file)) | |||||
| return L | |||||
| class ImageCartoonTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/cv_unet_person-image-cartoon_compound-models' | |||||
| self.test_image = \ | |||||
| 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com' \ | |||||
| '/data/test/maas/image_carton/test.png' | |||||
| class ImageCartoonTest(unittest.TestCase): | |||||
| def pipeline_inference(self, pipeline: Pipeline, input_location: str): | |||||
| result = pipeline(input_location) | |||||
| if result is not None: | |||||
| cv2.imwrite('result.png', result['output_png']) | |||||
| print(f'Output written to {osp.abspath("result.png")}') | |||||
| def test_run(self): | |||||
| @unittest.skip('deprecated, download model from model hub instead') | |||||
| def test_run_by_direct_model_download(self): | |||||
| model_dir = './assets' | model_dir = './assets' | ||||
| if not os.path.exists(model_dir): | if not os.path.exists(model_dir): | ||||
| os.system( | os.system( | ||||
| @@ -29,9 +34,15 @@ class ImageCartoonTest(unittest.TestCase): | |||||
| os.system('unzip assets.zip') | os.system('unzip assets.zip') | ||||
| img_cartoon = pipeline(Tasks.image_generation, model=model_dir) | img_cartoon = pipeline(Tasks.image_generation, model=model_dir) | ||||
| result = img_cartoon(os.path.join(model_dir, 'test.png')) | |||||
| if result is not None: | |||||
| cv2.imwrite('result.png', result['output_png']) | |||||
| self.pipeline_inference(img_cartoon, self.test_image) | |||||
| def test_run_modelhub(self): | |||||
| img_cartoon = pipeline(Tasks.image_generation, model=self.model_id) | |||||
| self.pipeline_inference(img_cartoon, self.test_image) | |||||
| def test_run_modelhub_default_model(self): | |||||
| img_cartoon = pipeline(Tasks.image_generation) | |||||
| self.pipeline_inference(img_cartoon, self.test_image) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -0,0 +1,50 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| import unittest | |||||
| from maas_hub.maas_api import MaasApi | |||||
| from maas_hub.repository import Repository | |||||
| USER_NAME = 'maasadmin' | |||||
| PASSWORD = '12345678' | |||||
| class HubOperationTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| self.api = MaasApi() | |||||
| # note this is temporary before official account management is ready | |||||
| self.api.login(USER_NAME, PASSWORD) | |||||
| @unittest.skip('to be used for local test only') | |||||
| def test_model_repo_creation(self): | |||||
| # change to proper model names before use | |||||
| model_name = 'cv_unet_person-image-cartoon_compound-models' | |||||
| model_chinese_name = '达摩卡通化模型' | |||||
| model_org = 'damo' | |||||
| try: | |||||
| self.api.create_model( | |||||
| owner=model_org, | |||||
| name=model_name, | |||||
| chinese_name=model_chinese_name, | |||||
| visibility=5, # 1-private, 5-public | |||||
| license='apache-2.0') | |||||
| # TODO: support proper name duplication checking | |||||
| except KeyError as ke: | |||||
| if ke.args[0] == 'name': | |||||
| print(f'model {self.model_name} already exists, ignore') | |||||
| else: | |||||
| raise | |||||
| # Note that this can be done via git operation once model repo | |||||
| # has been created. Git-Op is the RECOMMENDED model upload approach | |||||
| @unittest.skip('to be used for local test only') | |||||
| def test_model_upload(self): | |||||
| local_path = '/path/to/local/model/directory' | |||||
| assert osp.exists(local_path), 'Local model directory not exist.' | |||||
| repo = Repository(local_dir=local_path) | |||||
| repo.push_to_hub(commit_message='Upload model files') | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||