Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9590940master
| @@ -203,6 +203,7 @@ class Preprocessors(object): | |||||
| # multi-modal | # multi-modal | ||||
| ofa_image_caption = 'ofa-image-caption' | ofa_image_caption = 'ofa-image-caption' | ||||
| ofa_text_to_image_synthesis = 'ofa-text-to-image-synthesis' | |||||
| mplug_visual_question_answering = 'mplug-visual-question-answering' | mplug_visual_question_answering = 'mplug-visual-question-answering' | ||||
| @@ -20,7 +20,9 @@ else: | |||||
| 'mmr': ['VideoCLIPForMultiModalEmbedding'], | 'mmr': ['VideoCLIPForMultiModalEmbedding'], | ||||
| 'mplug_for_visual_question_answering': | 'mplug_for_visual_question_answering': | ||||
| ['MPlugForVisualQuestionAnswering'], | ['MPlugForVisualQuestionAnswering'], | ||||
| 'ofa_for_all_tasks': ['OfaForAllTasks'] | |||||
| 'ofa_for_all_tasks': ['OfaForAllTasks'], | |||||
| 'ofa_for_text_to_image_synthesis_model': | |||||
| ['OfaForTextToImageSynthesis'] | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,90 @@ | |||||
| import os | |||||
| from typing import Any, Dict | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.cuda | |||||
| from PIL import Image | |||||
| from taming.models.vqgan import GumbelVQ, VQModel | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.models.multi_modal.ofa import OFAModel, OFATokenizer | |||||
| from modelscope.models.multi_modal.ofa.generate import sequence_generator as sg | |||||
| from modelscope.models.multi_modal.ofa.generate.search import Sampling | |||||
| from modelscope.models.multi_modal.ofa.generate.utils import move_to_device | |||||
| from modelscope.utils.constant import Tasks | |||||
| __all__ = ['OfaForTextToImageSynthesis'] | |||||
| def custom_to_pil(x): | |||||
| x = x.detach().cpu() | |||||
| x = torch.clamp(x, -1., 1.) | |||||
| x = (x + 1.) / 2. | |||||
| x = x.permute(1, 2, 0).numpy() | |||||
| x = (255 * x).astype(np.uint8) | |||||
| x = Image.fromarray(x) | |||||
| if not x.mode == 'RGB': | |||||
| x = x.convert('RGB') | |||||
| return x | |||||
| def load_vqgan(config, ckpt_path=None, is_gumbel=False): | |||||
| if is_gumbel: | |||||
| model = GumbelVQ(**config['model']['params']) | |||||
| else: | |||||
| model = VQModel(**config['model']['params']) | |||||
| if ckpt_path is not None: | |||||
| sd = torch.load(ckpt_path, map_location='cpu')['state_dict'] | |||||
| missing, unexpected = model.load_state_dict(sd, strict=False) | |||||
| return model.eval() | |||||
| @MODELS.register_module(Tasks.text_to_image_synthesis, module_name=Models.ofa) | |||||
| class OfaForTextToImageSynthesis(Model): | |||||
| def __init__(self, model_dir, *args, **kwargs): | |||||
| super().__init__(model_dir=model_dir, *args, **kwargs) | |||||
| # Initialize ofa | |||||
| model = OFAModel.from_pretrained(model_dir) | |||||
| self.model = model.module if hasattr(model, 'module') else model | |||||
| self.tokenizer = OFATokenizer.from_pretrained(model_dir) | |||||
| self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)]) | |||||
| self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)]) | |||||
| self._device = torch.device('cuda') if torch.cuda.is_available() \ | |||||
| else torch.device('cpu') | |||||
| self.model.to(self._device) | |||||
| # Initialize vqgan | |||||
| vqgan_config = json.load( | |||||
| open(os.path.join(model_dir, 'vqgan_config.json'))) | |||||
| self.vqgan_model = load_vqgan( | |||||
| vqgan_config, | |||||
| ckpt_path=os.path.join(model_dir, 'vqgan_model.ckpt'), | |||||
| is_gumbel=True).to(self._device) | |||||
| # Initialize generator | |||||
| sampling = Sampling(self.tokenizer, sampling_topp=0.9) | |||||
| sg_args = { | |||||
| 'tokenizer': self.tokenizer, | |||||
| 'beam_size': 1, | |||||
| 'max_len_b': 1024, | |||||
| 'min_len': 1024, | |||||
| 'search_strategy': sampling, | |||||
| 'gen_code': True, | |||||
| 'constraint_range': '50265,58457' | |||||
| } | |||||
| self.generator = sg.SequenceGenerator(**sg_args) | |||||
| def forward(self, input: Dict[str, Any]): | |||||
| input = move_to_device(input, self._device) | |||||
| gen_output = self.generator.generate([self.model], input) | |||||
| gen_tokens = gen_output[0][0]['tokens'][:-1] | |||||
| codes = gen_tokens.view(1, 32, 32) - 50265 | |||||
| quant_b = self.vqgan_model.quantize.get_codebook_entry( | |||||
| codes.view(-1), | |||||
| list(codes.size()) + [self.vqgan_model.quantize.embedding_dim]) | |||||
| dec = self.vqgan_model.decode(quant_b)[0] | |||||
| return custom_to_pil(dec) | |||||
| @@ -1,11 +1,13 @@ | |||||
| from typing import Any, Dict | |||||
| from typing import Any, Dict, Optional | |||||
| import torch | import torch | ||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models.multi_modal import OfaForTextToImageSynthesis | |||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.pipelines.base import Input, Model, Pipeline | from modelscope.pipelines.base import Input, Model, Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import OfaPreprocessor, Preprocessor | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -17,7 +19,10 @@ logger = get_logger() | |||||
| module_name=Pipelines.text_to_image_synthesis) | module_name=Pipelines.text_to_image_synthesis) | ||||
| class TextToImageSynthesisPipeline(Pipeline): | class TextToImageSynthesisPipeline(Pipeline): | ||||
| def __init__(self, model: str, **kwargs): | |||||
| def __init__(self, | |||||
| model: str, | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| **kwargs): | |||||
| """ | """ | ||||
| use `model` and `preprocessor` to create a kws pipeline for prediction | use `model` and `preprocessor` to create a kws pipeline for prediction | ||||
| Args: | Args: | ||||
| @@ -31,13 +36,20 @@ class TextToImageSynthesisPipeline(Pipeline): | |||||
| else: | else: | ||||
| raise NotImplementedError( | raise NotImplementedError( | ||||
| f'expecting a Model instance or str, but get {type(model)}.') | f'expecting a Model instance or str, but get {type(model)}.') | ||||
| super().__init__(model=pipe_model, **kwargs) | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| return input | |||||
| if preprocessor is None and isinstance(pipe_model, | |||||
| OfaForTextToImageSynthesis): | |||||
| preprocessor = OfaPreprocessor(pipe_model.model_dir) | |||||
| super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||||
| def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]: | |||||
| if self.preprocessor is not None: | |||||
| return self.preprocessor(input, **preprocess_params) | |||||
| else: | |||||
| return input | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | ||||
| if isinstance(self.model, OfaForTextToImageSynthesis): | |||||
| return self.model(input) | |||||
| return self.model.generate(input) | return self.model.generate(input) | ||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
| @@ -23,6 +23,8 @@ __all__ = [ | |||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| Fields.multi_modal, module_name=Preprocessors.ofa_image_caption) | Fields.multi_modal, module_name=Preprocessors.ofa_image_caption) | ||||
| @PREPROCESSORS.register_module( | |||||
| Fields.multi_modal, module_name=Preprocessors.ofa_text_to_image_synthesis) | |||||
| class OfaPreprocessor(Preprocessor): | class OfaPreprocessor(Preprocessor): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -40,7 +42,8 @@ class OfaPreprocessor(Preprocessor): | |||||
| Tasks.visual_entailment: OfaVisualEntailmentPreprocessor, | Tasks.visual_entailment: OfaVisualEntailmentPreprocessor, | ||||
| Tasks.image_classification: OfaImageClassificationPreprocessor, | Tasks.image_classification: OfaImageClassificationPreprocessor, | ||||
| Tasks.text_classification: OfaTextClassificationPreprocessor, | Tasks.text_classification: OfaTextClassificationPreprocessor, | ||||
| Tasks.summarization: OfaSummarizationPreprocessor | |||||
| Tasks.summarization: OfaSummarizationPreprocessor, | |||||
| Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor | |||||
| } | } | ||||
| input_key_mapping = { | input_key_mapping = { | ||||
| Tasks.image_captioning: ['image'], | Tasks.image_captioning: ['image'], | ||||
| @@ -50,6 +53,7 @@ class OfaPreprocessor(Preprocessor): | |||||
| Tasks.visual_grounding: ['image', 'text'], | Tasks.visual_grounding: ['image', 'text'], | ||||
| Tasks.visual_question_answering: ['image', 'text'], | Tasks.visual_question_answering: ['image', 'text'], | ||||
| Tasks.visual_entailment: ['image', 'text', 'text2'], | Tasks.visual_entailment: ['image', 'text', 'text2'], | ||||
| Tasks.text_to_image_synthesis: ['text'] | |||||
| } | } | ||||
| model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | ||||
| model_dir) | model_dir) | ||||
| @@ -3,6 +3,7 @@ from .image_captioning import OfaImageCaptioningPreprocessor | |||||
| from .image_classification import OfaImageClassificationPreprocessor | from .image_classification import OfaImageClassificationPreprocessor | ||||
| from .summarization import OfaSummarizationPreprocessor | from .summarization import OfaSummarizationPreprocessor | ||||
| from .text_classification import OfaTextClassificationPreprocessor | from .text_classification import OfaTextClassificationPreprocessor | ||||
| from .text_to_image_synthesis import OfaTextToImageSynthesisPreprocessor | |||||
| from .visual_entailment import OfaVisualEntailmentPreprocessor | from .visual_entailment import OfaVisualEntailmentPreprocessor | ||||
| from .visual_grounding import OfaVisualGroundingPreprocessor | from .visual_grounding import OfaVisualGroundingPreprocessor | ||||
| from .visual_question_answering import OfaVisualQuestionAnsweringPreprocessor | from .visual_question_answering import OfaVisualQuestionAnsweringPreprocessor | ||||
| @@ -0,0 +1,31 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Any, Dict | |||||
| import torch | |||||
| from .base import OfaBasePreprocessor | |||||
| class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor): | |||||
| def __init__(self, cfg, model_dir): | |||||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super(OfaTextToImageSynthesisPreprocessor, | |||||
| self).__init__(cfg, model_dir) | |||||
| self.max_src_length = 64 | |||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| source = data['text'].lower().strip().split()[:self.max_src_length] | |||||
| source = 'what is the complete image? caption: {}'.format(source) | |||||
| inputs = self.get_inputs(source) | |||||
| sample = { | |||||
| 'source': inputs, | |||||
| 'patch_images': None, | |||||
| 'patch_masks': torch.tensor([False]), | |||||
| 'code_masks': torch.tensor([False]) | |||||
| } | |||||
| return sample | |||||
| @@ -6,6 +6,7 @@ pycocotools>=2.0.4 | |||||
| # rough-score was just recently updated from 0.0.4 to 0.0.7 | # rough-score was just recently updated from 0.0.4 to 0.0.7 | ||||
| # which introduced compatability issues that are being investigated | # which introduced compatability issues that are being investigated | ||||
| rouge_score<=0.0.4 | rouge_score<=0.0.4 | ||||
| taming-transformers-rom1504 | |||||
| timm | timm | ||||
| tokenizers | tokenizers | ||||
| torchvision | torchvision | ||||
| @@ -244,6 +244,25 @@ class OfaTasksTest(unittest.TestCase): | |||||
| result = ofa_pipe(input) | result = ofa_pipe(input) | ||||
| print(result) | print(result) | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_text_to_image_synthesis_with_name(self): | |||||
| model = 'damo/ofa_text-to-image-synthesis_coco_large_en' | |||||
| ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model) | |||||
| example = {'text': 'a bear in the water.'} | |||||
| result = ofa_pipe(example) | |||||
| result[OutputKeys.OUTPUT_IMG].save('result.png') | |||||
| print(f'Output written to {osp.abspath("result.png")}') | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_text_to_image_synthesis_with_model(self): | |||||
| model = Model.from_pretrained( | |||||
| 'damo/ofa_text-to-image-synthesis_coco_large_en') | |||||
| ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model) | |||||
| example = {'text': 'a bear in the water.'} | |||||
| result = ofa_pipe(example) | |||||
| result[OutputKeys.OUTPUT_IMG].save('result.png') | |||||
| print(f'Output written to {osp.abspath("result.png")}') | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||