From 5f1b9a621871f310ee44138c62b588bbc7d83c73 Mon Sep 17 00:00:00 2001 From: "yichang.zyc" Date: Wed, 2 Nov 2022 14:23:26 +0800 Subject: [PATCH] add default config and fix proprocess detokenizer Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10603232 --- .../models/multi_modal/ofa_for_all_tasks.py | 18 ++++++++++++++- modelscope/preprocessors/multi_modal.py | 2 +- .../preprocessors/ofa/ocr_recognition.py | 13 +++-------- .../multi_modal/ofa/ofa_trainer_utils.py | 22 +++++++++---------- 4 files changed, 32 insertions(+), 23 deletions(-) diff --git a/modelscope/models/multi_modal/ofa_for_all_tasks.py b/modelscope/models/multi_modal/ofa_for_all_tasks.py index 2c6034e8..fc578b25 100644 --- a/modelscope/models/multi_modal/ofa_for_all_tasks.py +++ b/modelscope/models/multi_modal/ofa_for_all_tasks.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import math import os +import re import string from functools import partial from os import path as osp @@ -110,6 +111,8 @@ class OfaForAllTasks(TorchModel): Tasks.text_classification: inference_d[self.gen_type], Tasks.image_classification: inference_d[self.gen_type], } + pattern_str = '((?<=[^ a-zA-Z0-9.,:!?]) +| +(?=[^ a-zA-Z0-9.,:!?]))' + self.pattern = re.compile(pattern_str) def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: input = move_to_device(input, self.model.device) @@ -135,8 +138,18 @@ class OfaForAllTasks(TorchModel): caption = input[OutputKeys.CAPTION] result_l = list() for cap in caption: - result_l.append(cap.translate(self.transtab).strip()) + if self.language == 'en': + result_l.append(cap.translate(self.transtab).strip()) + else: + result_l.append(cap) input[OutputKeys.CAPTION] = result_l + if self.gen_type == 'generation' and self.language in [ + 'zh', 'cn' + ] and self.cfg.task != Tasks.visual_grounding: + ret_l = list() + for text in input[OFA_TASK_KEY_MAPPING[self.cfg.task]]: + ret_l.append(self.detokenizer(text)) + input[OFA_TASK_KEY_MAPPING[self.cfg.task]] = ret_l return input def _text_gen_inference(self, input): @@ -314,3 +327,6 @@ class OfaForAllTasks(TorchModel): save_function=partial(save_function, with_meta=False), config=config, **kwargs) + + def detokenizer(self, text): + return self.pattern.sub('', text) diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 13876058..3a3ae820 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -77,7 +77,7 @@ class OfaPreprocessor(Preprocessor): data[key] = item return data - def _ofa_input_compatibility_conversion(self, data): + def _ofa_input_compatibility_conversion(self, data): # fake if 'image' in data and self.cfg.model.get('type', None) == 'ofa': if isinstance(data['image'], str): image = load_image(data['image']) diff --git a/modelscope/preprocessors/ofa/ocr_recognition.py b/modelscope/preprocessors/ofa/ocr_recognition.py index a0342c14..58e3ea6e 100644 --- a/modelscope/preprocessors/ofa/ocr_recognition.py +++ b/modelscope/preprocessors/ofa/ocr_recognition.py @@ -73,21 +73,14 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): """ super(OfaOcrRecognitionPreprocessor, self).__init__(cfg, model_dir, mode, *args, **kwargs) - # Initialize transform - if self.cfg.model.imagenet_default_mean_and_std: - mean = IMAGENET_DEFAULT_MEAN - std = IMAGENET_DEFAULT_STD - else: - mean = [0.5, 0.5, 0.5] - std = [0.5, 0.5, 0.5] self.patch_resize_transform = transforms.Compose([ lambda image: ocr_resize( image, - self.cfg.model.patch_image_size, - is_document=self.cfg.model.is_document), + self.patch_image_size, + is_document=self.cfg.model.get('is_document', False)), transforms.ToTensor(), - transforms.Normalize(mean=mean, std=std), + transforms.Normalize(mean=self.mean, std=self.std), ]) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py index 3c38884c..3930febb 100644 --- a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py @@ -103,20 +103,20 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): def __init__(self, args): super().__init__() - self.sentence_avg = args.sentence_avg - self.eps = args.label_smoothing - self.ignore_prefix_size = args.ignore_prefix_size - self.ignore_eos = args.ignore_eos - self.report_accuracy = args.report_accuracy - self.drop_worst_ratio = args.drop_worst_ratio - self.drop_worst_after = args.drop_worst_after - self.use_rdrop = args.use_rdrop - self.reg_alpha = args.reg_alpha - self.sample_patch_num = args.sample_patch_num + self.sentence_avg = args.get('sentence_avg', False) + self.eps = args.get('label_smoothing', 0.1) + self.ignore_prefix_size = args.get('ignore_prefix_size', 0) + self.ignore_eos = args.get('ignore_eos', False) + self.report_accuracy = args.get('report_accuracy', False) + self.drop_worst_ratio = args.get('drop_worst_ratio', 0.0) + self.drop_worst_after = args.get('drop_worst_after', 0) + self.use_rdrop = args.get('use_rdrop', False) + self.reg_alpha = args.get('reg_alpha', 1.0) + self.sample_patch_num = args.get('sample_patch_num', 196) self.constraint_start = None self.constraint_end = None - if args.constraint_range: + if args.get('constraint_range', None): constraint_start, constraint_end = args.constraint_range.split(',') self.constraint_start = int(constraint_start) self.constraint_end = int(constraint_end)