Browse Source

[to #42322933]ofa文生图接入clip reranking后处理 & 修复预处理中的一个Bug

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9880918
master
menrui.mr yingda.chen 3 years ago
parent
commit
427f0e83ea
2 changed files with 155 additions and 6 deletions
  1. +153
    -5
      modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py
  2. +2
    -1
      modelscope/preprocessors/ofa/text_to_image_synthesis.py

+ 153
- 5
modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py View File

@@ -6,17 +6,30 @@ import numpy as np
import torch
import torch.cuda
from PIL import Image
from pkg_resources import packaging
from taming.models.vqgan import GumbelVQ, VQModel
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
ToTensor)

from modelscope.metainfo import Models
from modelscope.models.base import Model
from modelscope.models.builder import MODELS
from modelscope.models.multi_modal.mmr.models.module_clip import CLIP
from modelscope.models.multi_modal.mmr.models.tokenization_clip import \
SimpleTokenizer as ClipTokenizer
from modelscope.models.multi_modal.ofa import OFAModel, OFATokenizer
from modelscope.models.multi_modal.ofa.generate import sequence_generator as sg
from modelscope.models.multi_modal.ofa.generate.search import Sampling
from modelscope.models.multi_modal.ofa.generate.utils import move_to_device
from modelscope.utils.constant import Tasks

try:
from torchvision.transforms import InterpolationMode

BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC

__all__ = ['OfaForTextToImageSynthesis']


@@ -43,6 +56,74 @@ def load_vqgan(config, ckpt_path=None, is_gumbel=False):
return model.eval()


def build_clip_model(model_path):
state_dict = torch.load(model_path, map_location='cpu').state_dict()
vit = 'visual.proj' in state_dict
if vit:
vision_width = state_dict['visual.conv1.weight'].shape[0]
vision_layers = len([
k for k in state_dict.keys()
if k.startswith('visual.') and k.endswith('.attn.in_proj_weight')
])
vision_patch_size = state_dict['visual.conv1.weight'].shape[-1]
grid_size = round(
(state_dict['visual.positional_embedding'].shape[0] - 1)**0.5)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [
len(
set(
k.split('.')[2] for k in state_dict
if k.startswith(f'visual.layer{b}')))
for b in [1, 2, 3, 4]
]
vision_layers = tuple(counts)
vision_width = state_dict['visual.layer1.0.conv1.weight'].shape[0]
output_width = round(
(state_dict['visual.attnpool.positional_embedding'].shape[0]
- 1)**0.5)
vision_patch_size = None
assert output_width**2 + 1 == state_dict[
'visual.attnpool.positional_embedding'].shape[0]
image_resolution = output_width * 32

embed_dim = state_dict['text_projection'].shape[1]
context_length = state_dict['positional_embedding'].shape[0]
vocab_size = state_dict['token_embedding.weight'].shape[0]
transformer_width = state_dict['ln_final.weight'].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(
set(
k.split('.')[2] for k in state_dict
if k.startswith('transformer.resblocks')))

model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
vision_patch_size, context_length, vocab_size,
transformer_width, transformer_heads, transformer_layers)

for key in ['input_resolution', 'context_length', 'vocab_size']:
if key in state_dict:
del state_dict[key]

model.load_state_dict(state_dict)
return model.eval()


def _convert_image_to_rgb(image):
return image.convert('RGB')


def build_clip_transform(n_px):
return Compose([
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])


@MODELS.register_module(Tasks.text_to_image_synthesis, module_name=Models.ofa)
class OfaForTextToImageSynthesis(Model):

@@ -65,11 +146,23 @@ class OfaForTextToImageSynthesis(Model):
vqgan_config,
ckpt_path=os.path.join(model_dir, 'vqgan_model.ckpt'),
is_gumbel=True).to(self._device)

# Initialize OpenAI clip

self.clip_tokenizer = ClipTokenizer(model_dir)
self.clip_model = build_clip_model(
os.path.join(model_dir, 'ViT-B-16.pt'))
self.clip_preprocess = build_clip_transform(
self.clip_model.visual.input_resolution)

self.clip_model.to(self._device)
self.clip_model.eval()

# Initialize generator
sampling = Sampling(self.tokenizer, sampling_topp=0.9)
sg_args = {
'tokenizer': self.tokenizer,
'beam_size': 1,
'beam_size': 2,
'max_len_b': 1024,
'min_len': 1024,
'search_strategy': sampling,
@@ -78,13 +171,68 @@ class OfaForTextToImageSynthesis(Model):
}
self.generator = sg.SequenceGenerator(**sg_args)

def clip_tokenize(self, texts, context_length=77, truncate=False):

if isinstance(texts, str):
texts = [texts]

sot_token = self.clip_tokenizer.encoder['<|startoftext|>']
eot_token = self.clip_tokenizer.encoder['<|endoftext|>']
all_tokens = [[sot_token] + self.clip_tokenizer.encode(text)
+ [eot_token] for text in texts]
if packaging.version.parse(
torch.__version__) < packaging.version.parse('1.8.0'):
result = torch.zeros(
len(all_tokens), context_length, dtype=torch.long)
else:
result = torch.zeros(
len(all_tokens), context_length, dtype=torch.int)

for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(
f'Input {texts[i]} is too long for context length {context_length}'
)
result[i, :len(tokens)] = torch.tensor(tokens)

return result

def forward(self, input: Dict[str, Any]):

text = input['samples'][0]['text']
input = move_to_device(input, self._device)
clip_text_input = self.clip_tokenize([text]).to(self._device)

gen_output = self.generator.generate([self.model], input)
gen_tokens = gen_output[0][0]['tokens'][:-1]
codes = gen_tokens.view(1, 32, 32) - 50265
gen_tokens = torch.stack(
[item['tokens'][:-1] for item in gen_output[0]], dim=0)
codes = gen_tokens.view(-1, 32, 32) - 50265

quant_b = self.vqgan_model.quantize.get_codebook_entry(
codes.view(-1),
list(codes.size()) + [self.vqgan_model.quantize.embedding_dim])
dec = self.vqgan_model.decode(quant_b)[0]
return custom_to_pil(dec)
imgs = self.vqgan_model.decode(quant_b)

sample_num = imgs.size()[0]
pil_imgs = [custom_to_pil(imgs[i]) for i in range(sample_num)]

clip_image_input = torch.stack(
[self.clip_preprocess(img) for img in pil_imgs],
dim=0).to(self._device)

with torch.no_grad():
hyp_image_features = self.clip_model.encode_image(clip_image_input)
hyp_image_features /= hyp_image_features.norm(dim=-1, keepdim=True)
text_features = self.clip_model.encode_text(clip_text_input)
text_features /= text_features.norm(dim=-1, keepdim=True)
ti_similarity = hyp_image_features @ text_features.T

sorted_score, ti_indices = torch.sort(
ti_similarity.view(-1), descending=True)

pil_imgs_orderby_ti = [pil_imgs[index] for index in ti_indices]
return pil_imgs_orderby_ti[0]

+ 2
- 1
modelscope/preprocessors/ofa/text_to_image_synthesis.py View File

@@ -19,7 +19,8 @@ class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor):
self.max_src_length = 64

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
source = data['text'].lower().strip().split()[:self.max_src_length]
source = ' '.join(
data['text'].lower().strip().split()[:self.max_src_length])
source = 'what is the complete image? caption: {}'.format(source)
inputs = self.get_inputs(source)
sample = {


Loading…
Cancel
Save