Browse Source

add default config and fix proprocess detokenizer

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10603232
master
yichang.zyc wenmeng.zwm 3 years ago
parent
commit
5f1b9a6218
4 changed files with 32 additions and 23 deletions
  1. +17
    -1
      modelscope/models/multi_modal/ofa_for_all_tasks.py
  2. +1
    -1
      modelscope/preprocessors/multi_modal.py
  3. +3
    -10
      modelscope/preprocessors/ofa/ocr_recognition.py
  4. +11
    -11
      modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py

+ 17
- 1
modelscope/models/multi_modal/ofa_for_all_tasks.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
import re
import string
from functools import partial
from os import path as osp
@@ -110,6 +111,8 @@ class OfaForAllTasks(TorchModel):
Tasks.text_classification: inference_d[self.gen_type],
Tasks.image_classification: inference_d[self.gen_type],
}
pattern_str = '((?<=[^ a-zA-Z0-9.,:!?]) +| +(?=[^ a-zA-Z0-9.,:!?]))'
self.pattern = re.compile(pattern_str)

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
input = move_to_device(input, self.model.device)
@@ -135,8 +138,18 @@ class OfaForAllTasks(TorchModel):
caption = input[OutputKeys.CAPTION]
result_l = list()
for cap in caption:
result_l.append(cap.translate(self.transtab).strip())
if self.language == 'en':
result_l.append(cap.translate(self.transtab).strip())
else:
result_l.append(cap)
input[OutputKeys.CAPTION] = result_l
if self.gen_type == 'generation' and self.language in [
'zh', 'cn'
] and self.cfg.task != Tasks.visual_grounding:
ret_l = list()
for text in input[OFA_TASK_KEY_MAPPING[self.cfg.task]]:
ret_l.append(self.detokenizer(text))
input[OFA_TASK_KEY_MAPPING[self.cfg.task]] = ret_l
return input

def _text_gen_inference(self, input):
@@ -314,3 +327,6 @@ class OfaForAllTasks(TorchModel):
save_function=partial(save_function, with_meta=False),
config=config,
**kwargs)

def detokenizer(self, text):
return self.pattern.sub('', text)

+ 1
- 1
modelscope/preprocessors/multi_modal.py View File

@@ -77,7 +77,7 @@ class OfaPreprocessor(Preprocessor):
data[key] = item
return data

def _ofa_input_compatibility_conversion(self, data):
def _ofa_input_compatibility_conversion(self, data): # fake
if 'image' in data and self.cfg.model.get('type', None) == 'ofa':
if isinstance(data['image'], str):
image = load_image(data['image'])


+ 3
- 10
modelscope/preprocessors/ofa/ocr_recognition.py View File

@@ -73,21 +73,14 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
"""
super(OfaOcrRecognitionPreprocessor,
self).__init__(cfg, model_dir, mode, *args, **kwargs)
# Initialize transform
if self.cfg.model.imagenet_default_mean_and_std:
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
else:
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

self.patch_resize_transform = transforms.Compose([
lambda image: ocr_resize(
image,
self.cfg.model.patch_image_size,
is_document=self.cfg.model.is_document),
self.patch_image_size,
is_document=self.cfg.model.get('is_document', False)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
transforms.Normalize(mean=self.mean, std=self.std),
])

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:


+ 11
- 11
modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py View File

@@ -103,20 +103,20 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):

def __init__(self, args):
super().__init__()
self.sentence_avg = args.sentence_avg
self.eps = args.label_smoothing
self.ignore_prefix_size = args.ignore_prefix_size
self.ignore_eos = args.ignore_eos
self.report_accuracy = args.report_accuracy
self.drop_worst_ratio = args.drop_worst_ratio
self.drop_worst_after = args.drop_worst_after
self.use_rdrop = args.use_rdrop
self.reg_alpha = args.reg_alpha
self.sample_patch_num = args.sample_patch_num
self.sentence_avg = args.get('sentence_avg', False)
self.eps = args.get('label_smoothing', 0.1)
self.ignore_prefix_size = args.get('ignore_prefix_size', 0)
self.ignore_eos = args.get('ignore_eos', False)
self.report_accuracy = args.get('report_accuracy', False)
self.drop_worst_ratio = args.get('drop_worst_ratio', 0.0)
self.drop_worst_after = args.get('drop_worst_after', 0)
self.use_rdrop = args.get('use_rdrop', False)
self.reg_alpha = args.get('reg_alpha', 1.0)
self.sample_patch_num = args.get('sample_patch_num', 196)

self.constraint_start = None
self.constraint_end = None
if args.constraint_range:
if args.get('constraint_range', None):
constraint_start, constraint_end = args.constraint_range.split(',')
self.constraint_start = int(constraint_start)
self.constraint_end = int(constraint_end)


Loading…
Cancel
Save