Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10603232master
| @@ -1,6 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import math | import math | ||||
| import os | import os | ||||
| import re | |||||
| import string | import string | ||||
| from functools import partial | from functools import partial | ||||
| from os import path as osp | from os import path as osp | ||||
| @@ -110,6 +111,8 @@ class OfaForAllTasks(TorchModel): | |||||
| Tasks.text_classification: inference_d[self.gen_type], | Tasks.text_classification: inference_d[self.gen_type], | ||||
| Tasks.image_classification: inference_d[self.gen_type], | Tasks.image_classification: inference_d[self.gen_type], | ||||
| } | } | ||||
| pattern_str = '((?<=[^ a-zA-Z0-9.,:!?]) +| +(?=[^ a-zA-Z0-9.,:!?]))' | |||||
| self.pattern = re.compile(pattern_str) | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | ||||
| input = move_to_device(input, self.model.device) | input = move_to_device(input, self.model.device) | ||||
| @@ -135,8 +138,18 @@ class OfaForAllTasks(TorchModel): | |||||
| caption = input[OutputKeys.CAPTION] | caption = input[OutputKeys.CAPTION] | ||||
| result_l = list() | result_l = list() | ||||
| for cap in caption: | for cap in caption: | ||||
| result_l.append(cap.translate(self.transtab).strip()) | |||||
| if self.language == 'en': | |||||
| result_l.append(cap.translate(self.transtab).strip()) | |||||
| else: | |||||
| result_l.append(cap) | |||||
| input[OutputKeys.CAPTION] = result_l | input[OutputKeys.CAPTION] = result_l | ||||
| if self.gen_type == 'generation' and self.language in [ | |||||
| 'zh', 'cn' | |||||
| ] and self.cfg.task != Tasks.visual_grounding: | |||||
| ret_l = list() | |||||
| for text in input[OFA_TASK_KEY_MAPPING[self.cfg.task]]: | |||||
| ret_l.append(self.detokenizer(text)) | |||||
| input[OFA_TASK_KEY_MAPPING[self.cfg.task]] = ret_l | |||||
| return input | return input | ||||
| def _text_gen_inference(self, input): | def _text_gen_inference(self, input): | ||||
| @@ -314,3 +327,6 @@ class OfaForAllTasks(TorchModel): | |||||
| save_function=partial(save_function, with_meta=False), | save_function=partial(save_function, with_meta=False), | ||||
| config=config, | config=config, | ||||
| **kwargs) | **kwargs) | ||||
| def detokenizer(self, text): | |||||
| return self.pattern.sub('', text) | |||||
| @@ -77,7 +77,7 @@ class OfaPreprocessor(Preprocessor): | |||||
| data[key] = item | data[key] = item | ||||
| return data | return data | ||||
| def _ofa_input_compatibility_conversion(self, data): | |||||
| def _ofa_input_compatibility_conversion(self, data): # fake | |||||
| if 'image' in data and self.cfg.model.get('type', None) == 'ofa': | if 'image' in data and self.cfg.model.get('type', None) == 'ofa': | ||||
| if isinstance(data['image'], str): | if isinstance(data['image'], str): | ||||
| image = load_image(data['image']) | image = load_image(data['image']) | ||||
| @@ -73,21 +73,14 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||||
| """ | """ | ||||
| super(OfaOcrRecognitionPreprocessor, | super(OfaOcrRecognitionPreprocessor, | ||||
| self).__init__(cfg, model_dir, mode, *args, **kwargs) | self).__init__(cfg, model_dir, mode, *args, **kwargs) | ||||
| # Initialize transform | |||||
| if self.cfg.model.imagenet_default_mean_and_std: | |||||
| mean = IMAGENET_DEFAULT_MEAN | |||||
| std = IMAGENET_DEFAULT_STD | |||||
| else: | |||||
| mean = [0.5, 0.5, 0.5] | |||||
| std = [0.5, 0.5, 0.5] | |||||
| self.patch_resize_transform = transforms.Compose([ | self.patch_resize_transform = transforms.Compose([ | ||||
| lambda image: ocr_resize( | lambda image: ocr_resize( | ||||
| image, | image, | ||||
| self.cfg.model.patch_image_size, | |||||
| is_document=self.cfg.model.is_document), | |||||
| self.patch_image_size, | |||||
| is_document=self.cfg.model.get('is_document', False)), | |||||
| transforms.ToTensor(), | transforms.ToTensor(), | ||||
| transforms.Normalize(mean=mean, std=std), | |||||
| transforms.Normalize(mean=self.mean, std=self.std), | |||||
| ]) | ]) | ||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
| @@ -103,20 +103,20 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||||
| def __init__(self, args): | def __init__(self, args): | ||||
| super().__init__() | super().__init__() | ||||
| self.sentence_avg = args.sentence_avg | |||||
| self.eps = args.label_smoothing | |||||
| self.ignore_prefix_size = args.ignore_prefix_size | |||||
| self.ignore_eos = args.ignore_eos | |||||
| self.report_accuracy = args.report_accuracy | |||||
| self.drop_worst_ratio = args.drop_worst_ratio | |||||
| self.drop_worst_after = args.drop_worst_after | |||||
| self.use_rdrop = args.use_rdrop | |||||
| self.reg_alpha = args.reg_alpha | |||||
| self.sample_patch_num = args.sample_patch_num | |||||
| self.sentence_avg = args.get('sentence_avg', False) | |||||
| self.eps = args.get('label_smoothing', 0.1) | |||||
| self.ignore_prefix_size = args.get('ignore_prefix_size', 0) | |||||
| self.ignore_eos = args.get('ignore_eos', False) | |||||
| self.report_accuracy = args.get('report_accuracy', False) | |||||
| self.drop_worst_ratio = args.get('drop_worst_ratio', 0.0) | |||||
| self.drop_worst_after = args.get('drop_worst_after', 0) | |||||
| self.use_rdrop = args.get('use_rdrop', False) | |||||
| self.reg_alpha = args.get('reg_alpha', 1.0) | |||||
| self.sample_patch_num = args.get('sample_patch_num', 196) | |||||
| self.constraint_start = None | self.constraint_start = None | ||||
| self.constraint_end = None | self.constraint_end = None | ||||
| if args.constraint_range: | |||||
| if args.get('constraint_range', None): | |||||
| constraint_start, constraint_end = args.constraint_range.split(',') | constraint_start, constraint_end = args.constraint_range.split(',') | ||||
| self.constraint_start = int(constraint_start) | self.constraint_start = int(constraint_start) | ||||
| self.constraint_end = int(constraint_end) | self.constraint_end = int(constraint_end) | ||||