Browse Source

add default config and fix proprocess detokenizer

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10603232
master
yichang.zyc wenmeng.zwm 3 years ago
parent
commit
5f1b9a6218
4 changed files with 32 additions and 23 deletions
  1. +17
    -1
      modelscope/models/multi_modal/ofa_for_all_tasks.py
  2. +1
    -1
      modelscope/preprocessors/multi_modal.py
  3. +3
    -10
      modelscope/preprocessors/ofa/ocr_recognition.py
  4. +11
    -11
      modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py

+ 17
- 1
modelscope/models/multi_modal/ofa_for_all_tasks.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import math import math
import os import os
import re
import string import string
from functools import partial from functools import partial
from os import path as osp from os import path as osp
@@ -110,6 +111,8 @@ class OfaForAllTasks(TorchModel):
Tasks.text_classification: inference_d[self.gen_type], Tasks.text_classification: inference_d[self.gen_type],
Tasks.image_classification: inference_d[self.gen_type], Tasks.image_classification: inference_d[self.gen_type],
} }
pattern_str = '((?<=[^ a-zA-Z0-9.,:!?]) +| +(?=[^ a-zA-Z0-9.,:!?]))'
self.pattern = re.compile(pattern_str)


def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
input = move_to_device(input, self.model.device) input = move_to_device(input, self.model.device)
@@ -135,8 +138,18 @@ class OfaForAllTasks(TorchModel):
caption = input[OutputKeys.CAPTION] caption = input[OutputKeys.CAPTION]
result_l = list() result_l = list()
for cap in caption: for cap in caption:
result_l.append(cap.translate(self.transtab).strip())
if self.language == 'en':
result_l.append(cap.translate(self.transtab).strip())
else:
result_l.append(cap)
input[OutputKeys.CAPTION] = result_l input[OutputKeys.CAPTION] = result_l
if self.gen_type == 'generation' and self.language in [
'zh', 'cn'
] and self.cfg.task != Tasks.visual_grounding:
ret_l = list()
for text in input[OFA_TASK_KEY_MAPPING[self.cfg.task]]:
ret_l.append(self.detokenizer(text))
input[OFA_TASK_KEY_MAPPING[self.cfg.task]] = ret_l
return input return input


def _text_gen_inference(self, input): def _text_gen_inference(self, input):
@@ -314,3 +327,6 @@ class OfaForAllTasks(TorchModel):
save_function=partial(save_function, with_meta=False), save_function=partial(save_function, with_meta=False),
config=config, config=config,
**kwargs) **kwargs)

def detokenizer(self, text):
return self.pattern.sub('', text)

+ 1
- 1
modelscope/preprocessors/multi_modal.py View File

@@ -77,7 +77,7 @@ class OfaPreprocessor(Preprocessor):
data[key] = item data[key] = item
return data return data


def _ofa_input_compatibility_conversion(self, data):
def _ofa_input_compatibility_conversion(self, data): # fake
if 'image' in data and self.cfg.model.get('type', None) == 'ofa': if 'image' in data and self.cfg.model.get('type', None) == 'ofa':
if isinstance(data['image'], str): if isinstance(data['image'], str):
image = load_image(data['image']) image = load_image(data['image'])


+ 3
- 10
modelscope/preprocessors/ofa/ocr_recognition.py View File

@@ -73,21 +73,14 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
""" """
super(OfaOcrRecognitionPreprocessor, super(OfaOcrRecognitionPreprocessor,
self).__init__(cfg, model_dir, mode, *args, **kwargs) self).__init__(cfg, model_dir, mode, *args, **kwargs)
# Initialize transform
if self.cfg.model.imagenet_default_mean_and_std:
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
else:
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]


self.patch_resize_transform = transforms.Compose([ self.patch_resize_transform = transforms.Compose([
lambda image: ocr_resize( lambda image: ocr_resize(
image, image,
self.cfg.model.patch_image_size,
is_document=self.cfg.model.is_document),
self.patch_image_size,
is_document=self.cfg.model.get('is_document', False)),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
transforms.Normalize(mean=self.mean, std=self.std),
]) ])


def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:


+ 11
- 11
modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py View File

@@ -103,20 +103,20 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):


def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
self.sentence_avg = args.sentence_avg
self.eps = args.label_smoothing
self.ignore_prefix_size = args.ignore_prefix_size
self.ignore_eos = args.ignore_eos
self.report_accuracy = args.report_accuracy
self.drop_worst_ratio = args.drop_worst_ratio
self.drop_worst_after = args.drop_worst_after
self.use_rdrop = args.use_rdrop
self.reg_alpha = args.reg_alpha
self.sample_patch_num = args.sample_patch_num
self.sentence_avg = args.get('sentence_avg', False)
self.eps = args.get('label_smoothing', 0.1)
self.ignore_prefix_size = args.get('ignore_prefix_size', 0)
self.ignore_eos = args.get('ignore_eos', False)
self.report_accuracy = args.get('report_accuracy', False)
self.drop_worst_ratio = args.get('drop_worst_ratio', 0.0)
self.drop_worst_after = args.get('drop_worst_after', 0)
self.use_rdrop = args.get('use_rdrop', False)
self.reg_alpha = args.get('reg_alpha', 1.0)
self.sample_patch_num = args.get('sample_patch_num', 196)


self.constraint_start = None self.constraint_start = None
self.constraint_end = None self.constraint_end = None
if args.constraint_range:
if args.get('constraint_range', None):
constraint_start, constraint_end = args.constraint_range.split(',') constraint_start, constraint_end = args.constraint_range.split(',')
self.constraint_start = int(constraint_start) self.constraint_start = int(constraint_start)
self.constraint_end = int(constraint_end) self.constraint_end = int(constraint_end)


Loading…
Cancel
Save