diff --git a/data/test/images/style_transfer_content.jpg b/data/test/images/style_transfer_content.jpg new file mode 100644 index 00000000..5602662d --- /dev/null +++ b/data/test/images/style_transfer_content.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f33a6ad9fcd7367cec2e81b8b0e4234d4f5f7d1be284d48085a25bb6d03782d7 +size 72130 diff --git a/data/test/images/style_transfer_style.jpg b/data/test/images/style_transfer_style.jpg new file mode 100644 index 00000000..820b093f --- /dev/null +++ b/data/test/images/style_transfer_style.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1af09b2c18a6674b7d88849cb87564dd77e1ce04d1517bb085449b614cc0c8d8 +size 376101 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 79fd3b4f..f4f100dd 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -49,6 +49,7 @@ class Pipelines(object): action_recognition = 'TAdaConv_action-recognition' animal_recognation = 'resnet101-animal_recog' cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' + style_transfer = 'AAMS-style-transfer' # nlp tasks sentence_similarity = 'sentence-similarity' diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 05a03166..c47c6744 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -64,7 +64,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_r2p1d_video_embedding'), Tasks.text_to_image_synthesis: (Pipelines.text_to_image_synthesis, - 'damo/cv_imagen_text-to-image-synthesis_tiny') + 'damo/cv_imagen_text-to-image-synthesis_tiny'), + Tasks.style_transfer: (Pipelines.style_transfer, + 'damo/cv_aams_style-transfer_damo') } diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 18dd1e3a..b4b27b4b 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -15,11 +15,12 @@ except ModuleNotFoundError as e: try: from .image_cartoon_pipeline import ImageCartoonPipeline from .image_matting_pipeline import ImageMattingPipeline + from .style_transfer_pipeline import StyleTransferPipeline from .ocr_detection_pipeline import OCRDetectionPipeline except ModuleNotFoundError as e: if str(e) == "No module named 'tensorflow'": print( TENSORFLOW_IMPORT_ERROR.format( - 'image-cartoon image-matting ocr-detection')) + 'image-cartoon image-matting ocr-detection style-transfer')) else: raise ModuleNotFoundError(e) diff --git a/modelscope/pipelines/cv/style_transfer_pipeline.py b/modelscope/pipelines/cv/style_transfer_pipeline.py new file mode 100644 index 00000000..eeb6b206 --- /dev/null +++ b/modelscope/pipelines/cv/style_transfer_pipeline.py @@ -0,0 +1,131 @@ +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import load_image +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.style_transfer, module_name=Pipelines.style_transfer) +class StyleTransferPipeline(Pipeline): + + def __init__(self, model: str): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + import tensorflow as tf + if tf.__version__ >= '2.0': + tf = tf.compat.v1 + model_path = osp.join(self.model, ModelFile.TF_GRAPH_FILE) + + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth = True + self._session = tf.Session(config=config) + self.max_length = 800 + with self._session.as_default(): + logger.info(f'loading model from {model_path}') + with tf.gfile.FastGFile(model_path, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + + self.content = tf.get_default_graph().get_tensor_by_name( + 'content:0') + self.style = tf.get_default_graph().get_tensor_by_name( + 'style:0') + self.output = tf.get_default_graph().get_tensor_by_name( + 'stylized_output:0') + self.attention = tf.get_default_graph().get_tensor_by_name( + 'attention_map:0') + self.inter_weight = tf.get_default_graph().get_tensor_by_name( + 'inter_weight:0') + self.centroids = tf.get_default_graph().get_tensor_by_name( + 'centroids:0') + logger.info('load model done') + + def _sanitize_parameters(self, **pipeline_parameters): + return pipeline_parameters, {}, {} + + def preprocess(self, content: Input, style: Input) -> Dict[str, Any]: + if isinstance(content, str): + content = np.array(load_image(content)) + elif isinstance(content, PIL.Image.Image): + content = np.array(content.convert('RGB')) + elif isinstance(content, np.ndarray): + if len(content.shape) == 2: + content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR) + content = content[:, :, ::-1] # in rgb order + else: + raise TypeError( + f'modelscope error: content should be either str, PIL.Image,' + f' np.array, but got {type(content)}') + if len(content.shape) == 2: + content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR) + content_img = content.astype(np.float) + + if isinstance(style, str): + style_img = np.array(load_image(style)) + elif isinstance(style, PIL.Image.Image): + style_img = np.array(style.convert('RGB')) + elif isinstance(style, np.ndarray): + if len(style.shape) == 2: + style_img = cv2.cvtColor(style, cv2.COLOR_GRAY2BGR) + style_img = style_img[:, :, ::-1] # in rgb order + else: + raise TypeError( + f'modelscope error: style should be either str, PIL.Image,' + f' np.array, but got {type(style)}') + + if len(style_img.shape) == 2: + style_img = cv2.cvtColor(style_img, cv2.COLOR_GRAY2BGR) + style_img = style_img.astype(np.float) + + result = {'content': content_img, 'style': style_img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + content_feed, style_feed = input['content'], input['style'] + h = np.shape(content_feed)[0] + w = np.shape(content_feed)[1] + if h > self.max_length or w > self.max_length: + if h > w: + content_feed = cv2.resize( + content_feed, + (int(self.max_length * w / h), self.max_length)) + else: + content_feed = cv2.resize( + content_feed, + (self.max_length, int(self.max_length * h / w))) + + with self._session.as_default(): + feed_dict = { + self.content: content_feed, + self.style: style_feed, + self.inter_weight: 1.0 + } + output_img = self._session.run(self.output, feed_dict=feed_dict) + + # print('out_img shape:{}'.format(output_img.shape)) + output_img = cv2.cvtColor(output_img[0], cv2.COLOR_RGB2BGR) + output_img = np.clip(output_img, 0, 255).astype(np.uint8) + + output_img = cv2.resize(output_img, (w, h)) + + return {OutputKeys.OUTPUT_IMG: output_img} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index d6afb35a..e5935c3e 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -27,6 +27,7 @@ class CVTasks(object): ocr_detection = 'ocr-detection' action_recognition = 'action-recognition' video_embedding = 'video-embedding' + style_transfer = 'style-transfer' class NLPTasks(object): diff --git a/tests/pipelines/test_style_transfer.py b/tests/pipelines/test_style_transfer.py new file mode 100644 index 00000000..7bf7f1c4 --- /dev/null +++ b/tests/pipelines/test_style_transfer.py @@ -0,0 +1,55 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import tempfile +import unittest + +import cv2 + +from modelscope.fileio import File +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.test_utils import test_level + + +class StyleTransferTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_aams_style-transfer_damo' + + @unittest.skip('deprecated, download model from model hub instead') + def test_run_by_direct_model_download(self): + snapshot_path = snapshot_download(self.model_id) + print('snapshot_path: {}'.format(snapshot_path)) + style_transfer = pipeline(Tasks.style_transfer, model=snapshot_path) + + result = style_transfer( + 'data/test/images/style_transfer_content.jpg', + style='data/test/images/style_transfer_style.jpg') + cv2.imwrite('result_styletransfer1.png', result[OutputKeys.OUTPUT_IMG]) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_modelhub(self): + style_transfer = pipeline(Tasks.style_transfer, model=self.model_id) + + result = style_transfer( + 'data/test/images/style_transfer_content.jpg', + style='data/test/images/style_transfer_style.jpg') + cv2.imwrite('result_styletransfer2.png', result[OutputKeys.OUTPUT_IMG]) + print('style_transfer.test_run_modelhub done') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + style_transfer = pipeline(Tasks.style_transfer) + + result = style_transfer( + 'data/test/images/style_transfer_content.jpg', + style='data/test/images/style_transfer_style.jpg') + cv2.imwrite('result_styletransfer3.png', result[OutputKeys.OUTPUT_IMG]) + print('style_transfer.test_run_modelhub_default_model done') + + +if __name__ == '__main__': + unittest.main()