Browse Source

[to #42322933]ofa文生图接入clip reranking后处理 & 修复预处理中的一个Bug

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9880918
master
menrui.mr yingda.chen 3 years ago
parent
commit
427f0e83ea
2 changed files with 155 additions and 6 deletions
  1. +153
    -5
      modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py
  2. +2
    -1
      modelscope/preprocessors/ofa/text_to_image_synthesis.py

+ 153
- 5
modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py View File

@@ -6,17 +6,30 @@ import numpy as np
import torch import torch
import torch.cuda import torch.cuda
from PIL import Image from PIL import Image
from pkg_resources import packaging
from taming.models.vqgan import GumbelVQ, VQModel from taming.models.vqgan import GumbelVQ, VQModel
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
ToTensor)


from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.models.base import Model from modelscope.models.base import Model
from modelscope.models.builder import MODELS from modelscope.models.builder import MODELS
from modelscope.models.multi_modal.mmr.models.module_clip import CLIP
from modelscope.models.multi_modal.mmr.models.tokenization_clip import \
SimpleTokenizer as ClipTokenizer
from modelscope.models.multi_modal.ofa import OFAModel, OFATokenizer from modelscope.models.multi_modal.ofa import OFAModel, OFATokenizer
from modelscope.models.multi_modal.ofa.generate import sequence_generator as sg from modelscope.models.multi_modal.ofa.generate import sequence_generator as sg
from modelscope.models.multi_modal.ofa.generate.search import Sampling from modelscope.models.multi_modal.ofa.generate.search import Sampling
from modelscope.models.multi_modal.ofa.generate.utils import move_to_device from modelscope.models.multi_modal.ofa.generate.utils import move_to_device
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks


try:
from torchvision.transforms import InterpolationMode

BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC

__all__ = ['OfaForTextToImageSynthesis'] __all__ = ['OfaForTextToImageSynthesis']




@@ -43,6 +56,74 @@ def load_vqgan(config, ckpt_path=None, is_gumbel=False):
return model.eval() return model.eval()




def build_clip_model(model_path):
state_dict = torch.load(model_path, map_location='cpu').state_dict()
vit = 'visual.proj' in state_dict
if vit:
vision_width = state_dict['visual.conv1.weight'].shape[0]
vision_layers = len([
k for k in state_dict.keys()
if k.startswith('visual.') and k.endswith('.attn.in_proj_weight')
])
vision_patch_size = state_dict['visual.conv1.weight'].shape[-1]
grid_size = round(
(state_dict['visual.positional_embedding'].shape[0] - 1)**0.5)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [
len(
set(
k.split('.')[2] for k in state_dict
if k.startswith(f'visual.layer{b}')))
for b in [1, 2, 3, 4]
]
vision_layers = tuple(counts)
vision_width = state_dict['visual.layer1.0.conv1.weight'].shape[0]
output_width = round(
(state_dict['visual.attnpool.positional_embedding'].shape[0]
- 1)**0.5)
vision_patch_size = None
assert output_width**2 + 1 == state_dict[
'visual.attnpool.positional_embedding'].shape[0]
image_resolution = output_width * 32

embed_dim = state_dict['text_projection'].shape[1]
context_length = state_dict['positional_embedding'].shape[0]
vocab_size = state_dict['token_embedding.weight'].shape[0]
transformer_width = state_dict['ln_final.weight'].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(
set(
k.split('.')[2] for k in state_dict
if k.startswith('transformer.resblocks')))

model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
vision_patch_size, context_length, vocab_size,
transformer_width, transformer_heads, transformer_layers)

for key in ['input_resolution', 'context_length', 'vocab_size']:
if key in state_dict:
del state_dict[key]

model.load_state_dict(state_dict)
return model.eval()


def _convert_image_to_rgb(image):
return image.convert('RGB')


def build_clip_transform(n_px):
return Compose([
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])


@MODELS.register_module(Tasks.text_to_image_synthesis, module_name=Models.ofa) @MODELS.register_module(Tasks.text_to_image_synthesis, module_name=Models.ofa)
class OfaForTextToImageSynthesis(Model): class OfaForTextToImageSynthesis(Model):


@@ -65,11 +146,23 @@ class OfaForTextToImageSynthesis(Model):
vqgan_config, vqgan_config,
ckpt_path=os.path.join(model_dir, 'vqgan_model.ckpt'), ckpt_path=os.path.join(model_dir, 'vqgan_model.ckpt'),
is_gumbel=True).to(self._device) is_gumbel=True).to(self._device)

# Initialize OpenAI clip

self.clip_tokenizer = ClipTokenizer(model_dir)
self.clip_model = build_clip_model(
os.path.join(model_dir, 'ViT-B-16.pt'))
self.clip_preprocess = build_clip_transform(
self.clip_model.visual.input_resolution)

self.clip_model.to(self._device)
self.clip_model.eval()

# Initialize generator # Initialize generator
sampling = Sampling(self.tokenizer, sampling_topp=0.9) sampling = Sampling(self.tokenizer, sampling_topp=0.9)
sg_args = { sg_args = {
'tokenizer': self.tokenizer, 'tokenizer': self.tokenizer,
'beam_size': 1,
'beam_size': 2,
'max_len_b': 1024, 'max_len_b': 1024,
'min_len': 1024, 'min_len': 1024,
'search_strategy': sampling, 'search_strategy': sampling,
@@ -78,13 +171,68 @@ class OfaForTextToImageSynthesis(Model):
} }
self.generator = sg.SequenceGenerator(**sg_args) self.generator = sg.SequenceGenerator(**sg_args)


def clip_tokenize(self, texts, context_length=77, truncate=False):

if isinstance(texts, str):
texts = [texts]

sot_token = self.clip_tokenizer.encoder['<|startoftext|>']
eot_token = self.clip_tokenizer.encoder['<|endoftext|>']
all_tokens = [[sot_token] + self.clip_tokenizer.encode(text)
+ [eot_token] for text in texts]
if packaging.version.parse(
torch.__version__) < packaging.version.parse('1.8.0'):
result = torch.zeros(
len(all_tokens), context_length, dtype=torch.long)
else:
result = torch.zeros(
len(all_tokens), context_length, dtype=torch.int)

for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(
f'Input {texts[i]} is too long for context length {context_length}'
)
result[i, :len(tokens)] = torch.tensor(tokens)

return result

def forward(self, input: Dict[str, Any]): def forward(self, input: Dict[str, Any]):

text = input['samples'][0]['text']
input = move_to_device(input, self._device) input = move_to_device(input, self._device)
clip_text_input = self.clip_tokenize([text]).to(self._device)

gen_output = self.generator.generate([self.model], input) gen_output = self.generator.generate([self.model], input)
gen_tokens = gen_output[0][0]['tokens'][:-1]
codes = gen_tokens.view(1, 32, 32) - 50265
gen_tokens = torch.stack(
[item['tokens'][:-1] for item in gen_output[0]], dim=0)
codes = gen_tokens.view(-1, 32, 32) - 50265

quant_b = self.vqgan_model.quantize.get_codebook_entry( quant_b = self.vqgan_model.quantize.get_codebook_entry(
codes.view(-1), codes.view(-1),
list(codes.size()) + [self.vqgan_model.quantize.embedding_dim]) list(codes.size()) + [self.vqgan_model.quantize.embedding_dim])
dec = self.vqgan_model.decode(quant_b)[0]
return custom_to_pil(dec)
imgs = self.vqgan_model.decode(quant_b)

sample_num = imgs.size()[0]
pil_imgs = [custom_to_pil(imgs[i]) for i in range(sample_num)]

clip_image_input = torch.stack(
[self.clip_preprocess(img) for img in pil_imgs],
dim=0).to(self._device)

with torch.no_grad():
hyp_image_features = self.clip_model.encode_image(clip_image_input)
hyp_image_features /= hyp_image_features.norm(dim=-1, keepdim=True)
text_features = self.clip_model.encode_text(clip_text_input)
text_features /= text_features.norm(dim=-1, keepdim=True)
ti_similarity = hyp_image_features @ text_features.T

sorted_score, ti_indices = torch.sort(
ti_similarity.view(-1), descending=True)

pil_imgs_orderby_ti = [pil_imgs[index] for index in ti_indices]
return pil_imgs_orderby_ti[0]

+ 2
- 1
modelscope/preprocessors/ofa/text_to_image_synthesis.py View File

@@ -19,7 +19,8 @@ class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor):
self.max_src_length = 64 self.max_src_length = 64


def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
source = data['text'].lower().strip().split()[:self.max_src_length]
source = ' '.join(
data['text'].lower().strip().split()[:self.max_src_length])
source = 'what is the complete image? caption: {}'.format(source) source = 'what is the complete image? caption: {}'.format(source)
inputs = self.get_inputs(source) inputs = self.get_inputs(source)
sample = { sample = {


Loading…
Cancel
Save