Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8993758master
| @@ -23,7 +23,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), | |||
| Tasks.image_captioning: ('ofa', None), | |||
| Tasks.image_generation: | |||
| ('cv_unet_person-image-cartoon', 'damo/cv_unet_image-matting_damo'), | |||
| ('person-image-cartoon', | |||
| 'damo/cv_unet_person-image-cartoon_compound-models'), | |||
| } | |||
| @@ -25,20 +25,19 @@ logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_generation, module_name='cv_unet_person-image-cartoon') | |||
| Tasks.image_generation, module_name='person-image-cartoon') | |||
| class ImageCartoonPipeline(Pipeline): | |||
| def __init__(self, model: str): | |||
| super().__init__(model=model) | |||
| self.facer = FaceAna(model) | |||
| self.facer = FaceAna(self.model) | |||
| self.sess_anime_head = self.load_sess( | |||
| os.path.join(model, 'cartoon_anime_h.pb'), 'model_anime_head') | |||
| os.path.join(self.model, 'cartoon_anime_h.pb'), 'model_anime_head') | |||
| self.sess_anime_bg = self.load_sess( | |||
| os.path.join(model, 'cartoon_anime_bg.pb'), 'model_anime_bg') | |||
| os.path.join(self.model, 'cartoon_anime_bg.pb'), 'model_anime_bg') | |||
| self.box_width = 288 | |||
| global_mask = cv2.imread(os.path.join(model, 'alpha.jpg')) | |||
| global_mask = cv2.imread(os.path.join(self.model, 'alpha.jpg')) | |||
| global_mask = cv2.resize( | |||
| global_mask, (self.box_width, self.box_width), | |||
| interpolation=cv2.INTER_AREA) | |||
| @@ -1,26 +1,31 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| import unittest | |||
| import cv2 | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| def all_file(file_dir): | |||
| L = [] | |||
| for root, dirs, files in os.walk(file_dir): | |||
| for file in files: | |||
| extend = os.path.splitext(file)[1] | |||
| if extend == '.png' or extend == '.jpg' or extend == '.jpeg' or extend == '.JPG' or extend == '.HEIC': | |||
| L.append(os.path.join(root, file)) | |||
| return L | |||
| class ImageCartoonTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/cv_unet_person-image-cartoon_compound-models' | |||
| self.test_image = \ | |||
| 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com' \ | |||
| '/data/test/maas/image_carton/test.png' | |||
| class ImageCartoonTest(unittest.TestCase): | |||
| def pipeline_inference(self, pipeline: Pipeline, input_location: str): | |||
| result = pipeline(input_location) | |||
| if result is not None: | |||
| cv2.imwrite('result.png', result['output_png']) | |||
| print(f'Output written to {osp.abspath("result.png")}') | |||
| def test_run(self): | |||
| @unittest.skip('deprecated, download model from model hub instead') | |||
| def test_run_by_direct_model_download(self): | |||
| model_dir = './assets' | |||
| if not os.path.exists(model_dir): | |||
| os.system( | |||
| @@ -29,9 +34,15 @@ class ImageCartoonTest(unittest.TestCase): | |||
| os.system('unzip assets.zip') | |||
| img_cartoon = pipeline(Tasks.image_generation, model=model_dir) | |||
| result = img_cartoon(os.path.join(model_dir, 'test.png')) | |||
| if result is not None: | |||
| cv2.imwrite('result.png', result['output_png']) | |||
| self.pipeline_inference(img_cartoon, self.test_image) | |||
| def test_run_modelhub(self): | |||
| img_cartoon = pipeline(Tasks.image_generation, model=self.model_id) | |||
| self.pipeline_inference(img_cartoon, self.test_image) | |||
| def test_run_modelhub_default_model(self): | |||
| img_cartoon = pipeline(Tasks.image_generation) | |||
| self.pipeline_inference(img_cartoon, self.test_image) | |||
| if __name__ == '__main__': | |||
| @@ -0,0 +1,50 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| import unittest | |||
| from maas_hub.maas_api import MaasApi | |||
| from maas_hub.repository import Repository | |||
| USER_NAME = 'maasadmin' | |||
| PASSWORD = '12345678' | |||
| class HubOperationTest(unittest.TestCase): | |||
| def setUp(self): | |||
| self.api = MaasApi() | |||
| # note this is temporary before official account management is ready | |||
| self.api.login(USER_NAME, PASSWORD) | |||
| @unittest.skip('to be used for local test only') | |||
| def test_model_repo_creation(self): | |||
| # change to proper model names before use | |||
| model_name = 'cv_unet_person-image-cartoon_compound-models' | |||
| model_chinese_name = '达摩卡通化模型' | |||
| model_org = 'damo' | |||
| try: | |||
| self.api.create_model( | |||
| owner=model_org, | |||
| name=model_name, | |||
| chinese_name=model_chinese_name, | |||
| visibility=5, # 1-private, 5-public | |||
| license='apache-2.0') | |||
| # TODO: support proper name duplication checking | |||
| except KeyError as ke: | |||
| if ke.args[0] == 'name': | |||
| print(f'model {self.model_name} already exists, ignore') | |||
| else: | |||
| raise | |||
| # Note that this can be done via git operation once model repo | |||
| # has been created. Git-Op is the RECOMMENDED model upload approach | |||
| @unittest.skip('to be used for local test only') | |||
| def test_model_upload(self): | |||
| local_path = '/path/to/local/model/directory' | |||
| assert osp.exists(local_path), 'Local model directory not exist.' | |||
| repo = Repository(local_dir=local_path) | |||
| repo.push_to_hub(commit_message='Upload model files') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||