Title: [to #42322933]add finetune & merge master 新增ofa其它任务的finetune能力 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10656541master
| @@ -402,6 +402,7 @@ class Metrics(object): | |||||
| # accuracy | # accuracy | ||||
| accuracy = 'accuracy' | accuracy = 'accuracy' | ||||
| multi_average_precision = 'mAP' | |||||
| audio_noise_metric = 'audio-noise-metric' | audio_noise_metric = 'audio-noise-metric' | ||||
| # text gen | # text gen | ||||
| @@ -24,6 +24,7 @@ class MetricKeys(object): | |||||
| ROUGE_1 = 'rouge-1' | ROUGE_1 = 'rouge-1' | ||||
| ROUGE_L = 'rouge-l' | ROUGE_L = 'rouge-l' | ||||
| NED = 'ned' # ocr metric | NED = 'ned' # ocr metric | ||||
| mAP = 'mAP' | |||||
| BatchAcc = 'inbatch_t2i_recall_at_1' | BatchAcc = 'inbatch_t2i_recall_at_1' | ||||
| @@ -0,0 +1,67 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Dict | |||||
| import numpy as np | |||||
| from modelscope.metainfo import Metrics | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.registry import default_group | |||||
| from .base import Metric | |||||
| from .builder import METRICS, MetricKeys | |||||
| @METRICS.register_module( | |||||
| group_key=default_group, module_name=Metrics.multi_average_precision) | |||||
| class AveragePrecisionMetric(Metric): | |||||
| """The metric computation class for multi avarage precision classes. | |||||
| This metric class calculates multi avarage precision for the whole input batches. | |||||
| """ | |||||
| def __init__(self, *args, **kwargs): | |||||
| super().__init__(*args, **kwargs) | |||||
| self.preds = [] | |||||
| self.labels = [] | |||||
| self.thresh = kwargs.get('threshold', 0.5) | |||||
| def add(self, outputs: Dict, inputs: Dict): | |||||
| label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS | |||||
| ground_truths = inputs[label_name] | |||||
| eval_results = outputs[label_name] | |||||
| for key in [ | |||||
| OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, | |||||
| OutputKeys.LABELS, OutputKeys.SCORES | |||||
| ]: | |||||
| if key in outputs and outputs[key] is not None: | |||||
| eval_results = outputs[key] | |||||
| break | |||||
| assert type(ground_truths) == type(eval_results) | |||||
| for truth in ground_truths: | |||||
| self.labels.append(truth) | |||||
| for result in eval_results: | |||||
| if isinstance(truth, str): | |||||
| self.preds.append(result.strip().replace(' ', '')) | |||||
| else: | |||||
| self.preds.append(result) | |||||
| def evaluate(self): | |||||
| assert len(self.preds) == len(self.labels) | |||||
| scores = self._calculate_ap_score(self.preds, self.labels, self.thresh) | |||||
| return {MetricKeys.mAP: scores.mean().item()} | |||||
| def _calculate_ap_score(self, preds, labels, thresh=0.5): | |||||
| hyps = np.array(preds) | |||||
| refs = np.array(labels) | |||||
| a = np.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2]) | |||||
| b = np.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:]) | |||||
| interacts = np.concatenate([a, b], axis=1) | |||||
| area_predictions = (hyps[:, 2] - hyps[:, 0]) * ( | |||||
| hyps[:, 3] - hyps[:, 1]) | |||||
| area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1]) | |||||
| interacts_w = interacts[:, 2] - interacts[:, 0] | |||||
| interacts_h = interacts[:, 3] - interacts[:, 1] | |||||
| area_interacts = interacts_w * interacts_h | |||||
| ious = area_interacts / ( | |||||
| area_predictions + area_targets - area_interacts + 1e-6) | |||||
| return (ious >= thresh) & (interacts_w > 0) & (interacts_h > 0) | |||||
| @@ -43,7 +43,7 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): | |||||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
| sample = self._build_infer_sample(data) | sample = self._build_infer_sample(data) | ||||
| target = data[self.column_map['text']] | |||||
| target = sample['label'] | |||||
| target = target.translate(self.transtab).strip() | target = target.translate(self.transtab).strip() | ||||
| target_token_list = target.strip().split() | target_token_list = target.strip().split() | ||||
| target = ' '.join(target_token_list[:self.max_tgt_length]) | target = ' '.join(target_token_list[:self.max_tgt_length]) | ||||
| @@ -1,13 +1,20 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import functools | |||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| import torch | import torch | ||||
| from PIL import Image | |||||
| from PIL import Image, ImageFile | |||||
| from timm.data import create_transform | |||||
| from torchvision import transforms | from torchvision import transforms | ||||
| from modelscope.preprocessors.image import load_image | from modelscope.preprocessors.image import load_image | ||||
| from modelscope.utils.constant import ModeKeys | from modelscope.utils.constant import ModeKeys | ||||
| from .base import OfaBasePreprocessor | from .base import OfaBasePreprocessor | ||||
| from .utils.vision_helper import RandomAugment | |||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |||||
| ImageFile.MAX_IMAGE_PIXELS = None | |||||
| Image.MAX_IMAGE_PIXELS = None | |||||
| class OfaImageClassificationPreprocessor(OfaBasePreprocessor): | class OfaImageClassificationPreprocessor(OfaBasePreprocessor): | ||||
| @@ -28,18 +35,77 @@ class OfaImageClassificationPreprocessor(OfaBasePreprocessor): | |||||
| super(OfaImageClassificationPreprocessor, | super(OfaImageClassificationPreprocessor, | ||||
| self).__init__(cfg, model_dir, mode, *args, **kwargs) | self).__init__(cfg, model_dir, mode, *args, **kwargs) | ||||
| # Initialize transform | # Initialize transform | ||||
| self.patch_resize_transform = transforms.Compose([ | |||||
| lambda image: image.convert('RGB'), | |||||
| transforms.Resize( | |||||
| (self.patch_image_size, self.patch_image_size), | |||||
| interpolation=transforms.InterpolationMode.BICUBIC), | |||||
| transforms.ToTensor(), | |||||
| transforms.Normalize(mean=self.mean, std=self.std), | |||||
| ]) | |||||
| if self.mode != ModeKeys.TRAIN: | |||||
| self.patch_resize_transform = transforms.Compose([ | |||||
| lambda image: image.convert('RGB'), | |||||
| transforms.Resize( | |||||
| (self.patch_image_size, self.patch_image_size), | |||||
| interpolation=transforms.InterpolationMode.BICUBIC), | |||||
| transforms.ToTensor(), | |||||
| transforms.Normalize(mean=self.mean, std=self.std), | |||||
| ]) | |||||
| else: | |||||
| self.patch_resize_transform = create_transform( | |||||
| input_size=self.patch_image_size, | |||||
| is_training=True, | |||||
| color_jitter=0.4, | |||||
| auto_augment='rand-m9-mstd0.5-inc1', | |||||
| interpolation='bicubic', | |||||
| re_prob=0.25, | |||||
| re_mode='pixel', | |||||
| re_count=1, | |||||
| mean=self.mean, | |||||
| std=self.std) | |||||
| self.patch_resize_transform = transforms.Compose( | |||||
| functools.reduce(lambda x, y: x + y, [ | |||||
| [ | |||||
| lambda image: image.convert('RGB'), | |||||
| ], | |||||
| self.patch_resize_transform.transforms[:2], | |||||
| [self.patch_resize_transform.transforms[2]], | |||||
| [ | |||||
| RandomAugment( | |||||
| 2, | |||||
| 7, | |||||
| isPIL=True, | |||||
| augs=[ | |||||
| 'Identity', 'AutoContrast', 'Equalize', | |||||
| 'Brightness', 'Sharpness', 'ShearX', 'ShearY', | |||||
| 'TranslateX', 'TranslateY', 'Rotate' | |||||
| ]), | |||||
| ], | |||||
| self.patch_resize_transform.transforms[3:], | |||||
| ])) | |||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
| image = data['image'] if isinstance( | |||||
| data['image'], Image.Image) else load_image(data['image']) | |||||
| if self.mode == ModeKeys.TRAIN: | |||||
| return self._build_train_sample(data) | |||||
| else: | |||||
| return self._build_infer_sample(data) | |||||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| sample = self._build_infer_sample(data) | |||||
| target = ' {}'.format(sample['label']) | |||||
| sample['ref_dict'] = {sample['label']: 1.0} | |||||
| sample['target'] = self.tokenize_text(target, add_bos=False) | |||||
| sample['prev_output_tokens'] = torch.cat( | |||||
| [self.bos_item, sample['target'][:-1]]) | |||||
| if self.constraint_trie is not None: | |||||
| constraint_mask = torch.zeros((len(sample['prev_output_tokens']), | |||||
| len(self.tgt_dict))).bool() | |||||
| for i in range(len(sample['prev_output_tokens'])): | |||||
| constraint_prefix_token = sample[ | |||||
| 'prev_output_tokens'][:i + 1].tolist() | |||||
| constraint_nodes = self.constraint_trie.get_next_layer( | |||||
| constraint_prefix_token) | |||||
| constraint_mask[i][constraint_nodes] = True | |||||
| sample['constraint_mask'] = constraint_mask | |||||
| return sample | |||||
| def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| image = self.get_img_pil(data[self.column_map['image']]) | |||||
| patch_image = self.patch_resize_transform(image) | patch_image = self.patch_resize_transform(image) | ||||
| prompt = self.cfg.model.get('prompt', ' what does the image describe?') | prompt = self.cfg.model.get('prompt', ' what does the image describe?') | ||||
| inputs = self.tokenize_text(prompt) | inputs = self.tokenize_text(prompt) | ||||
| @@ -48,4 +114,6 @@ class OfaImageClassificationPreprocessor(OfaBasePreprocessor): | |||||
| 'patch_image': patch_image, | 'patch_image': patch_image, | ||||
| 'patch_mask': torch.tensor([True]) | 'patch_mask': torch.tensor([True]) | ||||
| } | } | ||||
| if 'text' in self.column_map and self.column_map['text'] in data: | |||||
| sample['label'] = data[self.column_map['text']] | |||||
| return sample | return sample | ||||
| @@ -11,9 +11,6 @@ from zhconv import convert | |||||
| from modelscope.utils.constant import ModeKeys | from modelscope.utils.constant import ModeKeys | ||||
| from .base import OfaBasePreprocessor | from .base import OfaBasePreprocessor | ||||
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |||||
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |||||
| def ocr_resize(img, patch_image_size, is_document=False): | def ocr_resize(img, patch_image_size, is_document=False): | ||||
| img = img.convert('RGB') | img = img.convert('RGB') | ||||
| @@ -112,6 +109,6 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||||
| } | } | ||||
| if 'text' in self.column_map and self.column_map['text'] in data: | if 'text' in self.column_map and self.column_map['text'] in data: | ||||
| target = data[self.column_map['text']] | target = data[self.column_map['text']] | ||||
| target = unicodedata2.normalize('NFKC', convert(target, 'zh-hans')) | |||||
| sample['label'] = target | |||||
| sample['label'] = unicodedata2.normalize( | |||||
| 'NFKC', convert(target, 'zh-hans')) | |||||
| return sample | return sample | ||||
| @@ -1,6 +1,8 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| import torch | |||||
| from modelscope.utils.constant import ModeKeys | from modelscope.utils.constant import ModeKeys | ||||
| from .base import OfaBasePreprocessor | from .base import OfaBasePreprocessor | ||||
| @@ -24,9 +26,26 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor): | |||||
| self).__init__(cfg, model_dir, mode, *args, **kwargs) | self).__init__(cfg, model_dir, mode, *args, **kwargs) | ||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
| if self.mode == ModeKeys.TRAIN: | |||||
| return self._build_train_sample(data) | |||||
| else: | |||||
| return self._build_infer_sample(data) | |||||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| sample = self._build_infer_sample(data) | |||||
| target_str = sample['label'].lower() | |||||
| target = super().pre_caption(target_str, max_words=self.max_tgt_length) | |||||
| target = target.replace('[unk]', 'unk').replace('<unk>', 'unk') | |||||
| sample['target'] = self.tokenize_text(target, add_bos=False) | |||||
| noise_target_item = self.add_noise_to_tgt( | |||||
| sample['target'][:-1].clone()) | |||||
| sample['prev_output_tokens'] = torch.cat( | |||||
| [self.bos_item, noise_target_item]) | |||||
| return sample | |||||
| def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| source = super().pre_caption( | source = super().pre_caption( | ||||
| data['text'], max_words=self.max_src_length) | |||||
| source = source.strip()[:self.max_src_length] | |||||
| data[self.column_map['text']], max_words=self.max_src_length) | |||||
| source = source.replace('[unk]', 'unk').replace('<unk>', 'unk') | source = source.replace('[unk]', 'unk').replace('<unk>', 'unk') | ||||
| prompt = self.cfg.model.get( | prompt = self.cfg.model.get( | ||||
| 'prompt', ' " {} " Summarize the article with a title: ') | 'prompt', ' " {} " Summarize the article with a title: ') | ||||
| @@ -42,4 +61,17 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor): | |||||
| 'source': inputs, | 'source': inputs, | ||||
| 'decoder_prompt': decoder_prompt, | 'decoder_prompt': decoder_prompt, | ||||
| } | } | ||||
| if 'summary' in self.column_map and self.column_map['summary'] in data: | |||||
| sample['label'] = data[self.column_map['summary']] | |||||
| return sample | return sample | ||||
| def add_noise_to_tgt(self, target): | |||||
| noise_indices = torch.FloatTensor( | |||||
| target.size(0)).uniform_() < self.cfg.model.get( | |||||
| 'noise_ratio', 0.0) | |||||
| target[noise_indices] = torch.randint( | |||||
| 4, | |||||
| len(self.src_dict) - self.cfg.model.get('num_codes', 8192) | |||||
| - self.cfg.model.get('num_bins', 1000), | |||||
| size=(noise_indices.sum(), )) | |||||
| return target | |||||
| @@ -38,18 +38,64 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): | |||||
| ]) | ]) | ||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
| image = data['image'] if isinstance( | |||||
| data['image'], Image.Image) else load_image(data['image']) | |||||
| if self.mode == ModeKeys.TRAIN: | |||||
| return self._build_train_sample(data) | |||||
| else: | |||||
| return self._build_infer_sample(data) | |||||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| sample = self._build_infer_sample(data) | |||||
| target = ' {}'.format(sample['label']) | |||||
| sample['ref_dict'] = {sample['label']: 1.0} | |||||
| tgt_item = self.tokenize_text(target, add_bos=False, add_eos=False) | |||||
| if self.prompt_type == 'none': | |||||
| prev_output_item = torch.cat([self.bos_item, tgt_item]) | |||||
| target_item = torch.cat([prev_output_item[1:], self.eos_item]) | |||||
| elif self.prompt_type == 'src': | |||||
| prev_output_item = torch.cat([sample['source'], tgt_item]) | |||||
| target_item = torch.cat([prev_output_item[1:], self.eos_item]) | |||||
| elif self.prompt_type == 'prev_output': | |||||
| prev_output_item = torch.cat([sample['source'][:-1], tgt_item]) | |||||
| target_item = torch.cat([prev_output_item[1:], self.eos_item]) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id | |||||
| sample['target'] = target_item | |||||
| sample['prev_output_tokens'] = prev_output_item | |||||
| if self.constraint_trie is not None: | |||||
| constraint_mask = torch.zeros( | |||||
| (len(target_item), len(self.tgt_dict))).bool() | |||||
| start_idx = len(target_item) - len(tgt_item) - 1 | |||||
| for i in range( | |||||
| len(target_item) - len(tgt_item) - 1, len(target_item)): | |||||
| constraint_prefix_token = [ | |||||
| self.tgt_dict.bos() | |||||
| ] + target_item[start_idx:i].tolist() | |||||
| constraint_nodes = self.constraint_trie.get_next_layer( | |||||
| constraint_prefix_token) | |||||
| constraint_mask[i][constraint_nodes] = True | |||||
| sample['constraint_mask'] = constraint_mask | |||||
| return sample | |||||
| def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| image = self.get_img_pil(data[self.column_map['image']]) | |||||
| patch_image = self.patch_resize_transform(image) | patch_image = self.patch_resize_transform(image) | ||||
| if 'text2' not in data: | if 'text2' not in data: | ||||
| hypothesis = self.pre_caption(data['text'], self.max_src_length) | |||||
| hypothesis = self.pre_caption(data[self.column_map['text']], | |||||
| self.max_src_length) | |||||
| prompt = self.cfg.model.get('prompt', | prompt = self.cfg.model.get('prompt', | ||||
| ' does the image describe " {} "?') | ' does the image describe " {} "?') | ||||
| text = prompt.format(hypothesis) | text = prompt.format(hypothesis) | ||||
| else: | else: | ||||
| assert 'text' in data, f'text must be in the input {data.keys()}' | assert 'text' in data, f'text must be in the input {data.keys()}' | ||||
| caption = self.pre_caption(data['text2'], self.max_src_length) | |||||
| hypothesis = self.pre_caption(data['text'], self.max_src_length) | |||||
| caption = self.pre_caption(data[self.column_map['text2']], | |||||
| self.max_src_length) | |||||
| hypothesis = self.pre_caption(data[self.column_map['text']], | |||||
| self.max_src_length) | |||||
| prompt = self.cfg.model.get( | prompt = self.cfg.model.get( | ||||
| 'prompt', ' can image and text1 " {} " imply text2 " {} "?') | 'prompt', ' can image and text1 " {} " imply text2 " {} "?') | ||||
| text = prompt.format(caption, hypothesis) | text = prompt.format(caption, hypothesis) | ||||
| @@ -68,4 +114,7 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): | |||||
| 'patch_mask': torch.tensor([True]), | 'patch_mask': torch.tensor([True]), | ||||
| 'decoder_prompt': decoder_prompt, | 'decoder_prompt': decoder_prompt, | ||||
| } | } | ||||
| if 'relation' in self.column_map and self.column_map[ | |||||
| 'relation'] in data: | |||||
| sample['label'] = data[self.column_map['relation']] | |||||
| return sample | return sample | ||||
| @@ -1,6 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| import numpy as np | |||||
| import torch | import torch | ||||
| from PIL import Image | from PIL import Image | ||||
| from torchvision import transforms | from torchvision import transforms | ||||
| @@ -8,6 +9,7 @@ from torchvision import transforms | |||||
| from modelscope.preprocessors.image import load_image | from modelscope.preprocessors.image import load_image | ||||
| from modelscope.utils.constant import ModeKeys | from modelscope.utils.constant import ModeKeys | ||||
| from .base import OfaBasePreprocessor | from .base import OfaBasePreprocessor | ||||
| from .utils import transforms as T | |||||
| class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): | class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): | ||||
| @@ -27,24 +29,98 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): | |||||
| """ | """ | ||||
| super(OfaVisualGroundingPreprocessor, | super(OfaVisualGroundingPreprocessor, | ||||
| self).__init__(cfg, model_dir, mode, *args, **kwargs) | self).__init__(cfg, model_dir, mode, *args, **kwargs) | ||||
| # Initialize transform | |||||
| self.patch_resize_transform = transforms.Compose([ | |||||
| lambda image: image.convert('RGB'), | |||||
| transforms.Resize( | |||||
| (self.patch_image_size, self.patch_image_size), | |||||
| interpolation=transforms.InterpolationMode.BICUBIC), | |||||
| transforms.ToTensor(), | |||||
| transforms.Normalize(mean=self.mean, std=self.std), | |||||
| ]) | |||||
| self.num_bins = self.cfg.model.get('num_bins', 1000) | |||||
| if self.mode == ModeKeys.TRAIN: | |||||
| # for positioning | |||||
| self.positioning_transform = T.Compose([ | |||||
| T.RandomResize([self.patch_image_size], | |||||
| max_size=self.patch_image_size), | |||||
| T.ToTensor(), | |||||
| T.Normalize( | |||||
| mean=self.mean, | |||||
| std=self.std, | |||||
| max_image_size=self.max_image_size) | |||||
| ]) | |||||
| else: | |||||
| # Initialize transform | |||||
| self.patch_resize_transform = transforms.Compose([ | |||||
| lambda image: image.convert('RGB'), | |||||
| transforms.Resize( | |||||
| (self.patch_image_size, self.patch_image_size), | |||||
| interpolation=transforms.InterpolationMode.BICUBIC), | |||||
| transforms.ToTensor(), | |||||
| transforms.Normalize(mean=self.mean, std=self.std), | |||||
| ]) | |||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
| image = data['image'] if isinstance( | |||||
| data['image'], Image.Image) else load_image(data['image']) | |||||
| if self.mode == ModeKeys.TRAIN: | |||||
| return self._build_train_sample(data) | |||||
| else: | |||||
| return self._build_infer_sample(data) | |||||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| image = self.get_img_pil(data[self.column_map['image']]) | |||||
| w, h = image.size | |||||
| boxes_target = { | |||||
| 'boxes': [], | |||||
| 'labels': [], | |||||
| 'area': [], | |||||
| 'size': torch.tensor([h, w]) | |||||
| } | |||||
| x0, y0, x1, y1 = data[self.column_map['region_coord']].strip().split( | |||||
| ',') | |||||
| region = torch.tensor([float(x0), float(y0), float(x1), float(y1)]) | |||||
| boxes_target['boxes'] = torch.tensor( | |||||
| [[float(x0), float(y0), float(x1), | |||||
| float(y1)]]) | |||||
| boxes_target['labels'] = np.array([0]) | |||||
| area = [(float(x1) - float(x0)) * (float(y1) - float(y0))] | |||||
| boxes_target['area'] = torch.tensor(area) | |||||
| patch_image, patch_boxes = self.positioning_transform( | |||||
| image, boxes_target) | |||||
| resize_h, resize_w = patch_boxes['size'][0], patch_boxes['size'][1] | |||||
| quant_x0 = '<bin_{}>'.format( | |||||
| int((patch_boxes['boxes'][0][0] * (self.num_bins - 1)).round())) | |||||
| quant_y0 = '<bin_{}>'.format( | |||||
| int((patch_boxes['boxes'][0][1] * (self.num_bins - 1)).round())) | |||||
| quant_x1 = '<bin_{}>'.format( | |||||
| int((patch_boxes['boxes'][0][2] * (self.num_bins - 1)).round())) | |||||
| quant_y1 = '<bin_{}>'.format( | |||||
| int((patch_boxes['boxes'][0][3] * (self.num_bins - 1)).round())) | |||||
| region_coord = '{} {} {} {}'.format(quant_x0, quant_y0, quant_x1, | |||||
| quant_y1) | |||||
| src_caption = self.pre_caption(data[self.column_map['text']], | |||||
| self.max_src_length) | |||||
| prompt = self.cfg.model.get( | |||||
| 'prompt', ' which region does the text " {} " describe?') | |||||
| text = prompt.format(src_caption) | |||||
| src_item = self.tokenize_text(text) | |||||
| target_item = self.tokenize_text( | |||||
| region_coord, add_bos=False) # !!! use_bpe=False | |||||
| prev_output_item = torch.cat([self.bos_item, target_item[:-1]]) | |||||
| sample = { | |||||
| 'source': src_item, | |||||
| 'patch_image': patch_image, | |||||
| 'patch_mask': torch.tensor([True]), | |||||
| 'target': target_item, | |||||
| 'prev_output_tokens': prev_output_item, | |||||
| 'w_resize_ratio': resize_w / w, | |||||
| 'h_resize_ratio': resize_h / h, | |||||
| 'region_coord': region | |||||
| } | |||||
| return sample | |||||
| def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| image = self.get_img_pil(data[self.column_map['image']]) | |||||
| w, h = image.size | w, h = image.size | ||||
| patch_image = self.patch_resize_transform(image) | patch_image = self.patch_resize_transform(image) | ||||
| w_resize_ratio = torch.tensor(self.patch_image_size / w) | w_resize_ratio = torch.tensor(self.patch_image_size / w) | ||||
| h_resize_ratio = torch.tensor(self.patch_image_size / h) | h_resize_ratio = torch.tensor(self.patch_image_size / h) | ||||
| src_caption = self.pre_caption(data['text'], self.max_src_length) | |||||
| src_caption = self.pre_caption(data[self.column_map['text']], | |||||
| self.max_src_length) | |||||
| prompt = self.cfg.model.get( | prompt = self.cfg.model.get( | ||||
| 'prompt', ' which region does the text " {} " describe?') | 'prompt', ' which region does the text " {} " describe?') | ||||
| text = prompt.format(src_caption) | text = prompt.format(src_caption) | ||||
| @@ -56,4 +132,10 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): | |||||
| 'w_resize_ratio': w_resize_ratio, | 'w_resize_ratio': w_resize_ratio, | ||||
| 'h_resize_ratio': h_resize_ratio, | 'h_resize_ratio': h_resize_ratio, | ||||
| } | } | ||||
| if 'region_coord' in self.column_map and self.column_map[ | |||||
| 'region_coord'] in data: | |||||
| x0, y0, x1, y1 = data[ | |||||
| self.column_map['region_coord']].strip().split(',') | |||||
| sample['label'] = [float(x0), float(y0), float(x1), float(y1)] | |||||
| return sample | return sample | ||||
| @@ -38,10 +38,52 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): | |||||
| ]) | ]) | ||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
| image = data['image'] if isinstance( | |||||
| data['image'], Image.Image) else load_image(data['image']) | |||||
| if self.mode == ModeKeys.TRAIN: | |||||
| return self._build_train_sample(data) | |||||
| else: | |||||
| return self._build_infer_sample(data) | |||||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| sample = self._build_infer_sample(data) | |||||
| tgt_item = self.tokenize_text( | |||||
| ' {}'.format(sample['label']), add_bos=False, add_eos=False) | |||||
| if self.prompt_type == 'none': | |||||
| prev_output_item = torch.cat([self.bos_item, tgt_item]) | |||||
| target_item = torch.cat([prev_output_item[1:], self.eos_item]) | |||||
| elif self.prompt_type == 'src': | |||||
| prev_output_item = torch.cat([sample['source'], tgt_item]) | |||||
| target_item = torch.cat([prev_output_item[1:], self.eos_item]) | |||||
| elif self.prompt_type == 'prev_output': | |||||
| prev_output_item = torch.cat([sample['source'][:-1], tgt_item]) | |||||
| target_item = torch.cat([prev_output_item[1:], self.eos_item]) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id | |||||
| sample['prev_output_tokens'] = prev_output_item | |||||
| sample['target'] = target_item | |||||
| if self.constraint_trie is not None: | |||||
| constraint_mask = torch.zeros( | |||||
| (len(target_item), len(self.tgt_dict))).bool() | |||||
| start_idx = len(target_item) - len(tgt_item) - 1 | |||||
| for i in range( | |||||
| len(target_item) - len(tgt_item) - 1, len(target_item)): | |||||
| constraint_prefix_token = [ | |||||
| self.tgt_dict.bos() | |||||
| ] + target_item[start_idx:i].tolist() | |||||
| constraint_nodes = self.constraint_trie.get_next_layer( | |||||
| constraint_prefix_token) | |||||
| constraint_mask[i][constraint_nodes] = True | |||||
| sample['constraint_mask'] = constraint_mask | |||||
| return sample | |||||
| def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| image = self.get_img_pil(data[self.column_map['image']]) | |||||
| patch_image = self.patch_resize_transform(image) | patch_image = self.patch_resize_transform(image) | ||||
| text = ' {}'.format(data['text']) | |||||
| text = ' {}'.format(data[self.column_map['text']]) | |||||
| inputs = self.tokenize_text(text) | inputs = self.tokenize_text(text) | ||||
| if self.prompt_type == 'none': | if self.prompt_type == 'none': | ||||
| decoder_prompt = self.bos_item | decoder_prompt = self.bos_item | ||||
| @@ -57,4 +99,6 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): | |||||
| 'patch_mask': torch.tensor([True]), | 'patch_mask': torch.tensor([True]), | ||||
| 'decoder_prompt': decoder_prompt, | 'decoder_prompt': decoder_prompt, | ||||
| } | } | ||||
| if 'answer' in self.column_map and self.column_map['answer'] in data: | |||||
| sample['label'] = data[self.column_map['answer']] | |||||
| return sample | return sample | ||||
| @@ -34,6 +34,7 @@ class OFATrainer(EpochBasedTrainer): | |||||
| self, | self, | ||||
| model: Optional[Union[TorchModel, nn.Module, str]] = None, | model: Optional[Union[TorchModel, nn.Module, str]] = None, | ||||
| cfg_file: Optional[str] = None, | cfg_file: Optional[str] = None, | ||||
| cfg_modify_fn: Optional[Callable] = None, | |||||
| arg_parse_fn: Optional[Callable] = None, | arg_parse_fn: Optional[Callable] = None, | ||||
| data_collator: Optional[Union[Callable, Dict[str, | data_collator: Optional[Union[Callable, Dict[str, | ||||
| Callable]]] = None, | Callable]]] = None, | ||||
| @@ -49,7 +50,8 @@ class OFATrainer(EpochBasedTrainer): | |||||
| **kwargs): | **kwargs): | ||||
| model = Model.from_pretrained(model, revision=model_revision) | model = Model.from_pretrained(model, revision=model_revision) | ||||
| model_dir = model.model_dir | model_dir = model.model_dir | ||||
| cfg = Config.from_file(cfg_file) | |||||
| self.cfg_modify_fn = cfg_modify_fn | |||||
| cfg = self.rebuild_config(Config.from_file(cfg_file)) | |||||
| if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0: | if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0: | ||||
| work_dir = cfg.train.work_dir | work_dir = cfg.train.work_dir | ||||
| else: | else: | ||||
| @@ -57,10 +59,12 @@ class OFATrainer(EpochBasedTrainer): | |||||
| tokenizer_files = { | tokenizer_files = { | ||||
| 'zh': [ | 'zh': [ | ||||
| 'tokenizer.json', 'tokenizer_config.json', 'vocab.txt', | 'tokenizer.json', 'tokenizer_config.json', 'vocab.txt', | ||||
| 'config.json' | |||||
| 'config.json', 'ans2label.json' | |||||
| ], | |||||
| 'en': [ | |||||
| 'tokenizer.json', 'vocab.json', 'merges.txt', 'config.json', | |||||
| 'ans2label.json' | |||||
| ], | ], | ||||
| 'en': | |||||
| ['tokenizer.json', 'vocab.json', 'merges.txt', 'config.json'], | |||||
| } | } | ||||
| for filename in tokenizer_files[cfg.model.get('language', 'en')]: | for filename in tokenizer_files[cfg.model.get('language', 'en')]: | ||||
| finetune_file = os.path.join(work_dir, filename) | finetune_file = os.path.join(work_dir, filename) | ||||
| @@ -127,6 +131,11 @@ class OFATrainer(EpochBasedTrainer): | |||||
| **kwargs, | **kwargs, | ||||
| ) | ) | ||||
| def rebuild_config(self, cfg: Config): | |||||
| if self.cfg_modify_fn is not None: | |||||
| cfg = self.cfg_modify_fn(cfg) | |||||
| return cfg | |||||
| def train_step(self, model, inputs): | def train_step(self, model, inputs): | ||||
| model.train() | model.train() | ||||
| loss, sample_size, logging_output = self.criterion(model, inputs) | loss, sample_size, logging_output = self.criterion(model, inputs) | ||||
| @@ -9,6 +9,7 @@ from modelscope.metainfo import Trainers | |||||
| from modelscope.msdatasets import MsDataset | from modelscope.msdatasets import MsDataset | ||||
| from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
| from modelscope.utils.constant import DownloadMode, ModelFile | from modelscope.utils.constant import DownloadMode, ModelFile | ||||
| from modelscope.utils.hub import read_config | |||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| @@ -78,6 +79,7 @@ class TestOfaTrainer(unittest.TestCase): | |||||
| json.dump(self.finetune_cfg, writer) | json.dump(self.finetune_cfg, writer) | ||||
| pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh' | pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh' | ||||
| args = dict( | args = dict( | ||||
| model=pretrained_model, | model=pretrained_model, | ||||
| work_dir=WORKSPACE, | work_dir=WORKSPACE, | ||||