Browse Source

fix a bug

master
行嗔 3 years ago
parent
commit
46c3bdcfe8
2 changed files with 16 additions and 8 deletions
  1. +12
    -4
      modelscope/preprocessors/ofa/ocr_recognition.py
  2. +4
    -4
      tests/trainers/test_ofa_trainer.py

+ 12
- 4
modelscope/preprocessors/ofa/ocr_recognition.py View File

@@ -8,6 +8,7 @@ from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as F

from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
@@ -57,14 +58,21 @@ def ocr_resize(img, patch_image_size, is_document=False):

class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir):
def __init__(self,
cfg,
model_dir,
mode=ModeKeys.INFERENCE,
*args,
**kwargs):
"""preprocess the data

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path
model_dir (str): model path,
mode: preprocessor mode (model mode)
"""
super(OfaOcrRecognitionPreprocessor, self).__init__(cfg, model_dir)
super(OfaOcrRecognitionPreprocessor,
self).__init__(cfg, model_dir, mode, *args, **kwargs)
# Initialize transform
if self.cfg.model.imagenet_default_mean_and_std:
mean = IMAGENET_DEFAULT_MEAN
@@ -87,7 +95,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
data['image'], Image.Image) else load_image(data['image'])
patch_image = self.patch_resize_transform(image)
prompt = self.cfg.model.get('prompt', '图片上的文字是什么?')
inputs = self.get_inputs(prompt)
inputs = self.tokenize_text(prompt)

sample = {
'source': inputs,


+ 4
- 4
tests/trainers/test_ofa_trainer.py View File

@@ -36,10 +36,10 @@ class TestOfaTrainer(unittest.TestCase):
# 'launcher': 'pytorch',
'max_epochs': 1,
'use_fp16': True,
'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0},
'dataloader': {'batch_size_per_gpu': 1, 'workers_per_gpu': 0},
'lr_scheduler': {'name': 'polynomial_decay',
'warmup_proportion': 0.01,
'lr_end': 1e-07},
'lr_endo': 1e-07},
'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False},
'optimizer': {'type': 'AdamW', 'lr': 5e-05, 'weight_decay': 0.01},
'optimizer_hook': {'type': 'TorchAMPOptimizerHook',
@@ -86,11 +86,11 @@ class TestOfaTrainer(unittest.TestCase):
train_dataset=MsDataset.load(
'coco_2014_caption',
namespace='modelscope',
split='train[:100]'),
split='train[:20]'),
eval_dataset=MsDataset.load(
'coco_2014_caption',
namespace='modelscope',
split='validation[:20]'),
split='validation[:10]'),
metrics=[Metrics.BLEU],
cfg_file=config_file)
trainer = build_trainer(name=Trainers.ofa_tasks, default_args=args)


Loading…
Cancel
Save