Browse Source

add classification prepreocess

master
翎航 3 years ago
parent
commit
fd3679b547
2 changed files with 80 additions and 22 deletions
  1. +79
    -11
      modelscope/preprocessors/ofa/image_classification.py
  2. +1
    -11
      modelscope/preprocessors/ofa/ocr_recognition.py

+ 79
- 11
modelscope/preprocessors/ofa/image_classification.py View File

@@ -1,13 +1,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import functools
from typing import Any, Dict

import torch
from PIL import Image
from PIL import Image, ImageFile
from timm.data import create_transform
from torchvision import transforms

from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor
from .utils.vision_helper import RandomAugment

ImageFile.LOAD_TRUNCATED_IMAGES = True
ImageFile.MAX_IMAGE_PIXELS = None
Image.MAX_IMAGE_PIXELS = None


class OfaImageClassificationPreprocessor(OfaBasePreprocessor):
@@ -28,18 +35,77 @@ class OfaImageClassificationPreprocessor(OfaBasePreprocessor):
super(OfaImageClassificationPreprocessor,
self).__init__(cfg, model_dir, mode, *args, **kwargs)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize(
(self.patch_image_size, self.patch_image_size),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])
if self.mode != ModeKeys.TRAIN:
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize(
(self.patch_image_size, self.patch_image_size),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])
else:
self.patch_resize_transform = create_transform(
input_size=self.patch_image_size,
is_training=True,
color_jitter=0.4,
auto_augment='rand-m9-mstd0.5-inc1',
interpolation='bicubic',
re_prob=0.25,
re_mode='pixel',
re_count=1,
mean=self.mean,
std=self.std)
self.patch_resize_transform = transforms.Compose(
functools.reduce(lambda x, y: x + y, [
[
lambda image: image.convert('RGB'),
],
self.patch_resize_transform.transforms[:2],
[self.patch_resize_transform.transforms[2]],
[
RandomAugment(
2,
7,
isPIL=True,
augs=[
'Identity', 'AutoContrast', 'Equalize',
'Brightness', 'Sharpness', 'ShearX', 'ShearY',
'TranslateX', 'TranslateY', 'Rotate'
]),
],
self.patch_resize_transform.transforms[3:],
]))

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)

def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = ' {}'.format(data[self.column_map['text']])
sample['ref_dict'] = {data[self.column_map['text']]: 1.0}
sample['target'] = self.tokenize_text(target, add_bos=False)
sample['prev_output_tokens'] = torch.cat(
[self.bos_item, sample['target']])

if self.constraint_trie is not None:
constraint_mask = torch.zeros((len(sample['prev_output_tokens']),
len(self.tgt_dict))).bool()
for i in range(len(sample['prev_output_tokens'])):
constraint_prefix_token = sample[
'prev_output_tokens'][:i + 1].tolist()
constraint_nodes = self.constraint_trie.get_next_layer(
constraint_prefix_token)
constraint_mask[i][constraint_nodes] = True
sample['constraint_mask'] = constraint_mask

return sample

def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = self.get_img_pil(data[self.column_map['image']])
patch_image = self.patch_resize_transform(image)
prompt = self.cfg.model.get('prompt', ' what does the image describe?')
inputs = self.tokenize_text(prompt)
@@ -48,4 +114,6 @@ class OfaImageClassificationPreprocessor(OfaBasePreprocessor):
'patch_image': patch_image,
'patch_mask': torch.tensor([True])
}
if 'text' in self.column_map and self.column_map['text'] in data:
sample['label'] = data[self.column_map['text']]
return sample

+ 1
- 11
modelscope/preprocessors/ofa/ocr_recognition.py View File

@@ -11,9 +11,6 @@ from zhconv import convert
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)


def ocr_resize(img, patch_image_size, is_document=False):
img = img.convert('RGB')
@@ -73,13 +70,6 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
"""
super(OfaOcrRecognitionPreprocessor,
self).__init__(cfg, model_dir, mode, *args, **kwargs)
# Initialize transform
if self.cfg.model.imagenet_default_mean_and_std:
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
else:
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

self.patch_resize_transform = transforms.Compose([
lambda image: ocr_resize(
@@ -87,7 +77,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
self.cfg.model.patch_image_size,
is_document=self.cfg.model.is_document),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
transforms.Normalize(mean=self.mean, std=self.std),
])

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:


Loading…
Cancel
Save