diff --git a/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py b/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py index 5cdc9668..b942e3fa 100644 --- a/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py +++ b/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py @@ -6,17 +6,30 @@ import numpy as np import torch import torch.cuda from PIL import Image +from pkg_resources import packaging from taming.models.vqgan import GumbelVQ, VQModel +from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize, + ToTensor) from modelscope.metainfo import Models from modelscope.models.base import Model from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.mmr.models.module_clip import CLIP +from modelscope.models.multi_modal.mmr.models.tokenization_clip import \ + SimpleTokenizer as ClipTokenizer from modelscope.models.multi_modal.ofa import OFAModel, OFATokenizer from modelscope.models.multi_modal.ofa.generate import sequence_generator as sg from modelscope.models.multi_modal.ofa.generate.search import Sampling from modelscope.models.multi_modal.ofa.generate.utils import move_to_device from modelscope.utils.constant import Tasks +try: + from torchvision.transforms import InterpolationMode + + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + __all__ = ['OfaForTextToImageSynthesis'] @@ -43,6 +56,74 @@ def load_vqgan(config, ckpt_path=None, is_gumbel=False): return model.eval() +def build_clip_model(model_path): + state_dict = torch.load(model_path, map_location='cpu').state_dict() + vit = 'visual.proj' in state_dict + if vit: + vision_width = state_dict['visual.conv1.weight'].shape[0] + vision_layers = len([ + k for k in state_dict.keys() + if k.startswith('visual.') and k.endswith('.attn.in_proj_weight') + ]) + vision_patch_size = state_dict['visual.conv1.weight'].shape[-1] + grid_size = round( + (state_dict['visual.positional_embedding'].shape[0] - 1)**0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [ + len( + set( + k.split('.')[2] for k in state_dict + if k.startswith(f'visual.layer{b}'))) + for b in [1, 2, 3, 4] + ] + vision_layers = tuple(counts) + vision_width = state_dict['visual.layer1.0.conv1.weight'].shape[0] + output_width = round( + (state_dict['visual.attnpool.positional_embedding'].shape[0] + - 1)**0.5) + vision_patch_size = None + assert output_width**2 + 1 == state_dict[ + 'visual.attnpool.positional_embedding'].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict['text_projection'].shape[1] + context_length = state_dict['positional_embedding'].shape[0] + vocab_size = state_dict['token_embedding.weight'].shape[0] + transformer_width = state_dict['ln_final.weight'].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split('.')[2] for k in state_dict + if k.startswith('transformer.resblocks'))) + + model = CLIP(embed_dim, image_resolution, vision_layers, vision_width, + vision_patch_size, context_length, vocab_size, + transformer_width, transformer_heads, transformer_layers) + + for key in ['input_resolution', 'context_length', 'vocab_size']: + if key in state_dict: + del state_dict[key] + + model.load_state_dict(state_dict) + return model.eval() + + +def _convert_image_to_rgb(image): + return image.convert('RGB') + + +def build_clip_transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + + @MODELS.register_module(Tasks.text_to_image_synthesis, module_name=Models.ofa) class OfaForTextToImageSynthesis(Model): @@ -65,11 +146,23 @@ class OfaForTextToImageSynthesis(Model): vqgan_config, ckpt_path=os.path.join(model_dir, 'vqgan_model.ckpt'), is_gumbel=True).to(self._device) + + # Initialize OpenAI clip + + self.clip_tokenizer = ClipTokenizer(model_dir) + self.clip_model = build_clip_model( + os.path.join(model_dir, 'ViT-B-16.pt')) + self.clip_preprocess = build_clip_transform( + self.clip_model.visual.input_resolution) + + self.clip_model.to(self._device) + self.clip_model.eval() + # Initialize generator sampling = Sampling(self.tokenizer, sampling_topp=0.9) sg_args = { 'tokenizer': self.tokenizer, - 'beam_size': 1, + 'beam_size': 2, 'max_len_b': 1024, 'min_len': 1024, 'search_strategy': sampling, @@ -78,13 +171,68 @@ class OfaForTextToImageSynthesis(Model): } self.generator = sg.SequenceGenerator(**sg_args) + def clip_tokenize(self, texts, context_length=77, truncate=False): + + if isinstance(texts, str): + texts = [texts] + + sot_token = self.clip_tokenizer.encoder['<|startoftext|>'] + eot_token = self.clip_tokenizer.encoder['<|endoftext|>'] + all_tokens = [[sot_token] + self.clip_tokenizer.encode(text) + + [eot_token] for text in texts] + if packaging.version.parse( + torch.__version__) < packaging.version.parse('1.8.0'): + result = torch.zeros( + len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros( + len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError( + f'Input {texts[i]} is too long for context length {context_length}' + ) + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + def forward(self, input: Dict[str, Any]): + + text = input['samples'][0]['text'] input = move_to_device(input, self._device) + clip_text_input = self.clip_tokenize([text]).to(self._device) + gen_output = self.generator.generate([self.model], input) - gen_tokens = gen_output[0][0]['tokens'][:-1] - codes = gen_tokens.view(1, 32, 32) - 50265 + gen_tokens = torch.stack( + [item['tokens'][:-1] for item in gen_output[0]], dim=0) + codes = gen_tokens.view(-1, 32, 32) - 50265 + quant_b = self.vqgan_model.quantize.get_codebook_entry( codes.view(-1), list(codes.size()) + [self.vqgan_model.quantize.embedding_dim]) - dec = self.vqgan_model.decode(quant_b)[0] - return custom_to_pil(dec) + imgs = self.vqgan_model.decode(quant_b) + + sample_num = imgs.size()[0] + pil_imgs = [custom_to_pil(imgs[i]) for i in range(sample_num)] + + clip_image_input = torch.stack( + [self.clip_preprocess(img) for img in pil_imgs], + dim=0).to(self._device) + + with torch.no_grad(): + hyp_image_features = self.clip_model.encode_image(clip_image_input) + hyp_image_features /= hyp_image_features.norm(dim=-1, keepdim=True) + text_features = self.clip_model.encode_text(clip_text_input) + text_features /= text_features.norm(dim=-1, keepdim=True) + ti_similarity = hyp_image_features @ text_features.T + + sorted_score, ti_indices = torch.sort( + ti_similarity.view(-1), descending=True) + + pil_imgs_orderby_ti = [pil_imgs[index] for index in ti_indices] + return pil_imgs_orderby_ti[0] diff --git a/modelscope/preprocessors/ofa/text_to_image_synthesis.py b/modelscope/preprocessors/ofa/text_to_image_synthesis.py index 938f50de..e10de82c 100644 --- a/modelscope/preprocessors/ofa/text_to_image_synthesis.py +++ b/modelscope/preprocessors/ofa/text_to_image_synthesis.py @@ -19,7 +19,8 @@ class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor): self.max_src_length = 64 def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: - source = data['text'].lower().strip().split()[:self.max_src_length] + source = ' '.join( + data['text'].lower().strip().split()[:self.max_src_length]) source = 'what is the complete image? caption: {}'.format(source) inputs = self.get_inputs(source) sample = {