| @@ -43,7 +43,7 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): | |||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| sample = self._build_infer_sample(data) | |||
| target = data[self.column_map['text']] | |||
| target = sample['label'] | |||
| target = target.translate(self.transtab).strip() | |||
| target_token_list = target.strip().split() | |||
| target = ' '.join(target_token_list[:self.max_tgt_length]) | |||
| @@ -85,11 +85,11 @@ class OfaImageClassificationPreprocessor(OfaBasePreprocessor): | |||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| sample = self._build_infer_sample(data) | |||
| target = ' {}'.format(data[self.column_map['text']]) | |||
| sample['ref_dict'] = {data[self.column_map['text']]: 1.0} | |||
| target = ' {}'.format(sample['label']) | |||
| sample['ref_dict'] = {sample['label']: 1.0} | |||
| sample['target'] = self.tokenize_text(target, add_bos=False) | |||
| sample['prev_output_tokens'] = torch.cat( | |||
| [self.bos_item, sample['target']]) | |||
| [self.bos_item, sample['target'][:-1]]) | |||
| if self.constraint_trie is not None: | |||
| constraint_mask = torch.zeros((len(sample['prev_output_tokens']), | |||
| @@ -109,6 +109,6 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||
| } | |||
| if 'text' in self.column_map and self.column_map['text'] in data: | |||
| target = data[self.column_map['text']] | |||
| target = unicodedata2.normalize('NFKC', convert(target, 'zh-hans')) | |||
| sample['label'] = target | |||
| sample['label'] = unicodedata2.normalize( | |||
| 'NFKC', convert(target, 'zh-hans')) | |||
| return sample | |||
| @@ -1,6 +1,8 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict | |||
| import torch | |||
| from modelscope.utils.constant import ModeKeys | |||
| from .base import OfaBasePreprocessor | |||
| @@ -24,9 +26,27 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor): | |||
| self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| if self.mode == ModeKeys.TRAIN: | |||
| return self._build_train_sample(data) | |||
| else: | |||
| return self._build_infer_sample(data) | |||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| sample = self._build_infer_sample(data) | |||
| target_str = sample['label'].lower() | |||
| target = super().pre_caption(target_str, max_words=self.max_tgt_length) | |||
| target = target.replace('[unk]', 'unk').replace('<unk>', 'unk') | |||
| sample['target'] = self.tokenize_text(target, add_bos=False) | |||
| noise_target_item = self.add_noise_to_tgt( | |||
| sample['target'][:-1].clone()) | |||
| sample['prev_output_tokens'] = torch.cat( | |||
| [self.bos_item, noise_target_item]) | |||
| return sample | |||
| def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| source = super().pre_caption( | |||
| data['text'], max_words=self.max_src_length) | |||
| source = source.strip()[:self.max_src_length] | |||
| data[self.column_map['text']], max_words=self.max_src_length) | |||
| # source = source.strip()[:self.max_src_length] | |||
| source = source.replace('[unk]', 'unk').replace('<unk>', 'unk') | |||
| prompt = self.cfg.model.get( | |||
| 'prompt', ' " {} " Summarize the article with a title: ') | |||
| @@ -42,4 +62,16 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor): | |||
| 'source': inputs, | |||
| 'decoder_prompt': decoder_prompt, | |||
| } | |||
| if 'summary' in self.column_map and self.column_map['summary'] in data: | |||
| sample['label'] = data[self.column_map['summary']] | |||
| return sample | |||
| def add_noise_to_tgt(self, target): | |||
| noise_indices = torch.FloatTensor( | |||
| target.size(0)).uniform_() < self.cfg.model.get( | |||
| 'noise_ratio', 0.0) | |||
| target[noise_indices] = torch.randint( | |||
| 4, | |||
| len(self.src_dict) - self.code_dict_size - self.num_bins, | |||
| size=(noise_indices.sum(), )) | |||
| return target | |||
| @@ -38,8 +38,51 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): | |||
| ]) | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| image = data['image'] if isinstance( | |||
| data['image'], Image.Image) else load_image(data['image']) | |||
| if self.mode == ModeKeys.TRAIN: | |||
| return self._build_train_sample(data) | |||
| else: | |||
| return self._build_infer_sample(data) | |||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| sample = self._build_infer_sample(data) | |||
| target = ' {}'.format(sample['label']) | |||
| sample['ref_dict'] = {sample['label']: 1.0} | |||
| tgt_item = self.tokenize_text(target, add_bos=False, add_eos=False) | |||
| if self.prompt_type == 'none': | |||
| prev_output_item = torch.cat([self.bos_item, tgt_item]) | |||
| target_item = torch.cat([prev_output_item[1:], self.eos_item]) | |||
| elif self.prompt_type == 'src': | |||
| prev_output_item = torch.cat([sample['source'], tgt_item]) | |||
| target_item = torch.cat([prev_output_item[1:], self.eos_item]) | |||
| elif self.prompt_type == 'prev_output': | |||
| prev_output_item = torch.cat([sample['source'][:-1], tgt_item]) | |||
| target_item = torch.cat([prev_output_item[1:], self.eos_item]) | |||
| else: | |||
| raise NotImplementedError | |||
| target_item[:-len(tgt_item) - 1] = self.tgt_dict.pad() | |||
| sample['target'] = target_item | |||
| sample['prev_output_tokens'] = prev_output_item | |||
| if self.constraint_trie is not None: | |||
| constraint_mask = torch.zeros( | |||
| (len(target_item), len(self.tgt_dict))).bool() | |||
| start_idx = len(target_item) - len(tgt_item) - 1 | |||
| for i in range( | |||
| len(target_item) - len(tgt_item) - 1, len(target_item)): | |||
| constraint_prefix_token = [ | |||
| self.tgt_dict.bos() | |||
| ] + target_item[start_idx:i].tolist() | |||
| constraint_nodes = self.constraint_trie.get_next_layer( | |||
| constraint_prefix_token) | |||
| constraint_mask[i][constraint_nodes] = True | |||
| sample['constraint_mask'] = constraint_mask | |||
| return sample | |||
| def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| image = self.get_img_pil(data[self.column_map['image']]) | |||
| patch_image = self.patch_resize_transform(image) | |||
| if 'text2' not in data: | |||
| hypothesis = self.pre_caption(data['text'], self.max_src_length) | |||
| @@ -68,4 +111,7 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): | |||
| 'patch_mask': torch.tensor([True]), | |||
| 'decoder_prompt': decoder_prompt, | |||
| } | |||
| if 'relation' in self.column_map and self.column_map[ | |||
| 'relation'] in data: | |||
| sample['label'] = data[self.column_map['relation']] | |||
| return sample | |||
| @@ -1,6 +1,7 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict | |||
| import numpy as np | |||
| import torch | |||
| from PIL import Image | |||
| from torchvision import transforms | |||
| @@ -27,24 +28,95 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): | |||
| """ | |||
| super(OfaVisualGroundingPreprocessor, | |||
| self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
| # Initialize transform | |||
| self.patch_resize_transform = transforms.Compose([ | |||
| lambda image: image.convert('RGB'), | |||
| transforms.Resize( | |||
| (self.patch_image_size, self.patch_image_size), | |||
| interpolation=transforms.InterpolationMode.BICUBIC), | |||
| transforms.ToTensor(), | |||
| transforms.Normalize(mean=self.mean, std=self.std), | |||
| ]) | |||
| if self.mode == ModeKeys.TRAIN: | |||
| # for positioning | |||
| self.positioning_transform = transforms.Compose([ | |||
| transforms.RandomResize([self.patch_image_size], | |||
| max_size=self.patch_image_size), | |||
| transforms.ToTensor(), | |||
| transforms.Normalize( | |||
| mean=self.mean, | |||
| std=self.std, | |||
| max_image_size=self.max_image_size) | |||
| ]) | |||
| else: | |||
| # Initialize transform | |||
| self.patch_resize_transform = transforms.Compose([ | |||
| lambda image: image.convert('RGB'), | |||
| transforms.Resize( | |||
| (self.patch_image_size, self.patch_image_size), | |||
| interpolation=transforms.InterpolationMode.BICUBIC), | |||
| transforms.ToTensor(), | |||
| transforms.Normalize(mean=self.mean, std=self.std), | |||
| ]) | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| image = data['image'] if isinstance( | |||
| data['image'], Image.Image) else load_image(data['image']) | |||
| if self.mode == ModeKeys.TRAIN: | |||
| return self._build_train_sample(data) | |||
| else: | |||
| return self._build_infer_sample(data) | |||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| image = self.get_img_pil(data[self.column_map['image']]) | |||
| w, h = image.size | |||
| b_tgt = { | |||
| 'boxes': [], | |||
| 'labels': [], | |||
| 'area': [], | |||
| 'size': torch.tensor([h, w]) | |||
| } | |||
| x0, y0, x1, y1 = data[self.column_map['region_coord']].strip().split( | |||
| ',') | |||
| region = torch.tensor([float(x0), float(y0), float(x1), float(y1)]) | |||
| b_tgt['boxes'] = torch.tensor( | |||
| [[float(x0), float(y0), float(x1), | |||
| float(y1)]]) | |||
| b_tgt['labels'] = np.array([0]) | |||
| b_tgt['area'] = [(float(x1) - float(x0)) * (float(y1) - float(y0))] | |||
| patch_image, patch_boxes = self.positioning_transform(image, b_tgt) | |||
| resize_h, resize_w = patch_boxes['size'][0], patch_boxes['size'][1] | |||
| quant_x0 = '<bin_{}>'.format( | |||
| int((patch_boxes['boxes'][0][0] * (self.num_bins - 1)).round())) | |||
| quant_y0 = '<bin_{}>'.format( | |||
| int((patch_boxes['boxes'][0][1] * (self.num_bins - 1)).round())) | |||
| quant_x1 = '<bin_{}>'.format( | |||
| int((patch_boxes['boxes'][0][2] * (self.num_bins - 1)).round())) | |||
| quant_y1 = '<bin_{}>'.format( | |||
| int((patch_boxes['boxes'][0][3] * (self.num_bins - 1)).round())) | |||
| region_coord = '{} {} {} {}'.format(quant_x0, quant_y0, quant_x1, | |||
| quant_y1) | |||
| src_caption = self.pre_caption(data[self.column_map['text']], | |||
| self.max_src_length) | |||
| prompt = self.cfg.model.get( | |||
| 'prompt', ' which region does the text " {} " describe?') | |||
| text = prompt.format(src_caption) | |||
| src_item = self.tokenize_text(text) | |||
| target_item = self.tokenize_text( | |||
| region_coord, add_bos=False) # !!! use_bpe=False | |||
| prev_output_item = torch.cat([self.bos_item, target_item[:-1]]) | |||
| sample = { | |||
| 'source': src_item, | |||
| 'patch_image': patch_image, | |||
| 'patch_mask': torch.tensor([True]), | |||
| 'target': target_item, | |||
| 'prev_output_tokens': prev_output_item, | |||
| 'w_resize_ratio': resize_w / w, | |||
| 'h_resize_ratio': resize_h / h, | |||
| 'region_coord': region | |||
| } | |||
| return sample | |||
| def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| image = self.get_img_pil(data[self.column_map['image']]) | |||
| w, h = image.size | |||
| patch_image = self.patch_resize_transform(image) | |||
| w_resize_ratio = torch.tensor(self.patch_image_size / w) | |||
| h_resize_ratio = torch.tensor(self.patch_image_size / h) | |||
| src_caption = self.pre_caption(data['text'], self.max_src_length) | |||
| src_caption = self.pre_caption(data[self.column_map['text']], | |||
| self.max_src_length) | |||
| prompt = self.cfg.model.get( | |||
| 'prompt', ' which region does the text " {} " describe?') | |||
| text = prompt.format(src_caption) | |||
| @@ -38,10 +38,70 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): | |||
| ]) | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| image = data['image'] if isinstance( | |||
| data['image'], Image.Image) else load_image(data['image']) | |||
| if self.mode == ModeKeys.TRAIN: | |||
| return self._build_train_sample(data) | |||
| else: | |||
| return self._build_infer_sample(data) | |||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| sample = self._build_infer_sample(data) | |||
| src_item = sample['source'] | |||
| ref = data[self.column_map['ref']] | |||
| predict_objects = data[self.column_map['predict_objects']] | |||
| ref_dict = { | |||
| item.split('|!+')[1]: float(item.split('|!+')[0]) | |||
| for item in ref.split('&&') | |||
| } | |||
| answer = max(ref_dict, key=ref_dict.get) | |||
| sample['conf'] = torch.tensor([ref_dict[answer]]) | |||
| tgt_item = self.tokenize_text( | |||
| ' {}'.format(answer), add_bos=False, add_eos=False) | |||
| if self.add_object and predict_objects is not None: | |||
| predict_object_seq = ' '.join( | |||
| predict_objects.strip().split('&&')[:self.max_object_length]) | |||
| predict_object_item = self.tokenize_text( | |||
| ' object: {}'.format(predict_object_seq), add_bos=False) | |||
| src_item = torch.cat([src_item, predict_object_item[:-1]]) | |||
| if self.prompt_type == 'none': | |||
| prev_output_item = torch.cat([self.bos_item, tgt_item]) | |||
| target_item = torch.cat([prev_output_item[1:], self.eos_item]) | |||
| elif self.prompt_type == 'src': | |||
| prev_output_item = torch.cat([src_item, tgt_item]) | |||
| target_item = torch.cat([prev_output_item[1:], self.eos_item]) | |||
| elif self.prompt_type == 'prev_output': | |||
| prev_output_item = torch.cat([src_item[:-1], tgt_item]) | |||
| target_item = torch.cat([prev_output_item[1:], self.eos_item]) | |||
| else: | |||
| raise NotImplementedError | |||
| target_item[:-len(tgt_item) - 1] = self.tgt_dict.pad() | |||
| sample['prev_output_tokens'] = prev_output_item | |||
| sample['target'] = target_item | |||
| sample['ref_dict'] = ref_dict | |||
| if self.constraint_trie is not None: | |||
| constraint_mask = torch.zeros( | |||
| (len(target_item), len(self.tgt_dict))).bool() | |||
| start_idx = len(target_item) - len(tgt_item) - 1 | |||
| for i in range( | |||
| len(target_item) - len(tgt_item) - 1, len(target_item)): | |||
| constraint_prefix_token = [ | |||
| self.tgt_dict.bos() | |||
| ] + target_item[start_idx:i].tolist() | |||
| constraint_nodes = self.constraint_trie.get_next_layer( | |||
| constraint_prefix_token) | |||
| constraint_mask[i][constraint_nodes] = True | |||
| sample['constraint_mask'] = constraint_mask | |||
| return sample | |||
| def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| image = self.get_img_pil(data[self.column_map['image']]) | |||
| patch_image = self.patch_resize_transform(image) | |||
| text = ' {}'.format(data['text']) | |||
| text = ' {}'.format(data[self.column_map['text']]) | |||
| inputs = self.tokenize_text(text) | |||
| if self.prompt_type == 'none': | |||
| decoder_prompt = self.bos_item | |||
| @@ -57,4 +117,6 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): | |||
| 'patch_mask': torch.tensor([True]), | |||
| 'decoder_prompt': decoder_prompt, | |||
| } | |||
| if 'answer' in self.column_map and self.column_map['answer'] in data: | |||
| sample['label'] = data[self.column_map['answer']] | |||
| return sample | |||