Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9590940master
| @@ -203,6 +203,7 @@ class Preprocessors(object): | |||
| # multi-modal | |||
| ofa_image_caption = 'ofa-image-caption' | |||
| ofa_text_to_image_synthesis = 'ofa-text-to-image-synthesis' | |||
| mplug_visual_question_answering = 'mplug-visual-question-answering' | |||
| @@ -20,7 +20,9 @@ else: | |||
| 'mmr': ['VideoCLIPForMultiModalEmbedding'], | |||
| 'mplug_for_visual_question_answering': | |||
| ['MPlugForVisualQuestionAnswering'], | |||
| 'ofa_for_all_tasks': ['OfaForAllTasks'] | |||
| 'ofa_for_all_tasks': ['OfaForAllTasks'], | |||
| 'ofa_for_text_to_image_synthesis_model': | |||
| ['OfaForTextToImageSynthesis'] | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,90 @@ | |||
| import os | |||
| from typing import Any, Dict | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| import torch.cuda | |||
| from PIL import Image | |||
| from taming.models.vqgan import GumbelVQ, VQModel | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import Model | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.multi_modal.ofa import OFAModel, OFATokenizer | |||
| from modelscope.models.multi_modal.ofa.generate import sequence_generator as sg | |||
| from modelscope.models.multi_modal.ofa.generate.search import Sampling | |||
| from modelscope.models.multi_modal.ofa.generate.utils import move_to_device | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['OfaForTextToImageSynthesis'] | |||
| def custom_to_pil(x): | |||
| x = x.detach().cpu() | |||
| x = torch.clamp(x, -1., 1.) | |||
| x = (x + 1.) / 2. | |||
| x = x.permute(1, 2, 0).numpy() | |||
| x = (255 * x).astype(np.uint8) | |||
| x = Image.fromarray(x) | |||
| if not x.mode == 'RGB': | |||
| x = x.convert('RGB') | |||
| return x | |||
| def load_vqgan(config, ckpt_path=None, is_gumbel=False): | |||
| if is_gumbel: | |||
| model = GumbelVQ(**config['model']['params']) | |||
| else: | |||
| model = VQModel(**config['model']['params']) | |||
| if ckpt_path is not None: | |||
| sd = torch.load(ckpt_path, map_location='cpu')['state_dict'] | |||
| missing, unexpected = model.load_state_dict(sd, strict=False) | |||
| return model.eval() | |||
| @MODELS.register_module(Tasks.text_to_image_synthesis, module_name=Models.ofa) | |||
| class OfaForTextToImageSynthesis(Model): | |||
| def __init__(self, model_dir, *args, **kwargs): | |||
| super().__init__(model_dir=model_dir, *args, **kwargs) | |||
| # Initialize ofa | |||
| model = OFAModel.from_pretrained(model_dir) | |||
| self.model = model.module if hasattr(model, 'module') else model | |||
| self.tokenizer = OFATokenizer.from_pretrained(model_dir) | |||
| self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)]) | |||
| self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)]) | |||
| self._device = torch.device('cuda') if torch.cuda.is_available() \ | |||
| else torch.device('cpu') | |||
| self.model.to(self._device) | |||
| # Initialize vqgan | |||
| vqgan_config = json.load( | |||
| open(os.path.join(model_dir, 'vqgan_config.json'))) | |||
| self.vqgan_model = load_vqgan( | |||
| vqgan_config, | |||
| ckpt_path=os.path.join(model_dir, 'vqgan_model.ckpt'), | |||
| is_gumbel=True).to(self._device) | |||
| # Initialize generator | |||
| sampling = Sampling(self.tokenizer, sampling_topp=0.9) | |||
| sg_args = { | |||
| 'tokenizer': self.tokenizer, | |||
| 'beam_size': 1, | |||
| 'max_len_b': 1024, | |||
| 'min_len': 1024, | |||
| 'search_strategy': sampling, | |||
| 'gen_code': True, | |||
| 'constraint_range': '50265,58457' | |||
| } | |||
| self.generator = sg.SequenceGenerator(**sg_args) | |||
| def forward(self, input: Dict[str, Any]): | |||
| input = move_to_device(input, self._device) | |||
| gen_output = self.generator.generate([self.model], input) | |||
| gen_tokens = gen_output[0][0]['tokens'][:-1] | |||
| codes = gen_tokens.view(1, 32, 32) - 50265 | |||
| quant_b = self.vqgan_model.quantize.get_codebook_entry( | |||
| codes.view(-1), | |||
| list(codes.size()) + [self.vqgan_model.quantize.embedding_dim]) | |||
| dec = self.vqgan_model.decode(quant_b)[0] | |||
| return custom_to_pil(dec) | |||
| @@ -1,11 +1,13 @@ | |||
| from typing import Any, Dict | |||
| from typing import Any, Dict, Optional | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.multi_modal import OfaForTextToImageSynthesis | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Model, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import OfaPreprocessor, Preprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| @@ -17,7 +19,10 @@ logger = get_logger() | |||
| module_name=Pipelines.text_to_image_synthesis) | |||
| class TextToImageSynthesisPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| def __init__(self, | |||
| model: str, | |||
| preprocessor: Optional[Preprocessor] = None, | |||
| **kwargs): | |||
| """ | |||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||
| Args: | |||
| @@ -31,13 +36,20 @@ class TextToImageSynthesisPipeline(Pipeline): | |||
| else: | |||
| raise NotImplementedError( | |||
| f'expecting a Model instance or str, but get {type(model)}.') | |||
| super().__init__(model=pipe_model, **kwargs) | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| return input | |||
| if preprocessor is None and isinstance(pipe_model, | |||
| OfaForTextToImageSynthesis): | |||
| preprocessor = OfaPreprocessor(pipe_model.model_dir) | |||
| super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||
| def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]: | |||
| if self.preprocessor is not None: | |||
| return self.preprocessor(input, **preprocess_params) | |||
| else: | |||
| return input | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| if isinstance(self.model, OfaForTextToImageSynthesis): | |||
| return self.model(input) | |||
| return self.model.generate(input) | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| @@ -23,6 +23,8 @@ __all__ = [ | |||
| @PREPROCESSORS.register_module( | |||
| Fields.multi_modal, module_name=Preprocessors.ofa_image_caption) | |||
| @PREPROCESSORS.register_module( | |||
| Fields.multi_modal, module_name=Preprocessors.ofa_text_to_image_synthesis) | |||
| class OfaPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -40,7 +42,8 @@ class OfaPreprocessor(Preprocessor): | |||
| Tasks.visual_entailment: OfaVisualEntailmentPreprocessor, | |||
| Tasks.image_classification: OfaImageClassificationPreprocessor, | |||
| Tasks.text_classification: OfaTextClassificationPreprocessor, | |||
| Tasks.summarization: OfaSummarizationPreprocessor | |||
| Tasks.summarization: OfaSummarizationPreprocessor, | |||
| Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor | |||
| } | |||
| input_key_mapping = { | |||
| Tasks.image_captioning: ['image'], | |||
| @@ -50,6 +53,7 @@ class OfaPreprocessor(Preprocessor): | |||
| Tasks.visual_grounding: ['image', 'text'], | |||
| Tasks.visual_question_answering: ['image', 'text'], | |||
| Tasks.visual_entailment: ['image', 'text', 'text2'], | |||
| Tasks.text_to_image_synthesis: ['text'] | |||
| } | |||
| model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | |||
| model_dir) | |||
| @@ -3,6 +3,7 @@ from .image_captioning import OfaImageCaptioningPreprocessor | |||
| from .image_classification import OfaImageClassificationPreprocessor | |||
| from .summarization import OfaSummarizationPreprocessor | |||
| from .text_classification import OfaTextClassificationPreprocessor | |||
| from .text_to_image_synthesis import OfaTextToImageSynthesisPreprocessor | |||
| from .visual_entailment import OfaVisualEntailmentPreprocessor | |||
| from .visual_grounding import OfaVisualGroundingPreprocessor | |||
| from .visual_question_answering import OfaVisualQuestionAnsweringPreprocessor | |||
| @@ -0,0 +1,31 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict | |||
| import torch | |||
| from .base import OfaBasePreprocessor | |||
| class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor): | |||
| def __init__(self, cfg, model_dir): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super(OfaTextToImageSynthesisPreprocessor, | |||
| self).__init__(cfg, model_dir) | |||
| self.max_src_length = 64 | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| source = data['text'].lower().strip().split()[:self.max_src_length] | |||
| source = 'what is the complete image? caption: {}'.format(source) | |||
| inputs = self.get_inputs(source) | |||
| sample = { | |||
| 'source': inputs, | |||
| 'patch_images': None, | |||
| 'patch_masks': torch.tensor([False]), | |||
| 'code_masks': torch.tensor([False]) | |||
| } | |||
| return sample | |||
| @@ -6,6 +6,7 @@ pycocotools>=2.0.4 | |||
| # rough-score was just recently updated from 0.0.4 to 0.0.7 | |||
| # which introduced compatability issues that are being investigated | |||
| rouge_score<=0.0.4 | |||
| taming-transformers-rom1504 | |||
| timm | |||
| tokenizers | |||
| torchvision | |||
| @@ -244,6 +244,25 @@ class OfaTasksTest(unittest.TestCase): | |||
| result = ofa_pipe(input) | |||
| print(result) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_text_to_image_synthesis_with_name(self): | |||
| model = 'damo/ofa_text-to-image-synthesis_coco_large_en' | |||
| ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model) | |||
| example = {'text': 'a bear in the water.'} | |||
| result = ofa_pipe(example) | |||
| result[OutputKeys.OUTPUT_IMG].save('result.png') | |||
| print(f'Output written to {osp.abspath("result.png")}') | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_text_to_image_synthesis_with_model(self): | |||
| model = Model.from_pretrained( | |||
| 'damo/ofa_text-to-image-synthesis_coco_large_en') | |||
| ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model) | |||
| example = {'text': 'a bear in the water.'} | |||
| result = ofa_pipe(example) | |||
| result[OutputKeys.OUTPUT_IMG].save('result.png') | |||
| print(f'Output written to {osp.abspath("result.png")}') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||