|
|
|
@@ -8,6 +8,7 @@ from torchvision.transforms import InterpolationMode |
|
|
|
from torchvision.transforms import functional as F |
|
|
|
|
|
|
|
from modelscope.preprocessors.image import load_image |
|
|
|
from modelscope.utils.constant import ModeKeys |
|
|
|
from .base import OfaBasePreprocessor |
|
|
|
|
|
|
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
|
|
|
@@ -57,14 +58,21 @@ def ocr_resize(img, patch_image_size, is_document=False): |
|
|
|
|
|
|
|
class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): |
|
|
|
|
|
|
|
def __init__(self, cfg, model_dir): |
|
|
|
def __init__(self, |
|
|
|
cfg, |
|
|
|
model_dir, |
|
|
|
mode=ModeKeys.INFERENCE, |
|
|
|
*args, |
|
|
|
**kwargs): |
|
|
|
"""preprocess the data |
|
|
|
|
|
|
|
Args: |
|
|
|
cfg(modelscope.utils.config.ConfigDict) : model config |
|
|
|
model_dir (str): model path |
|
|
|
model_dir (str): model path, |
|
|
|
mode: preprocessor mode (model mode) |
|
|
|
""" |
|
|
|
super(OfaOcrRecognitionPreprocessor, self).__init__(cfg, model_dir) |
|
|
|
super(OfaOcrRecognitionPreprocessor, |
|
|
|
self).__init__(cfg, model_dir, mode, *args, **kwargs) |
|
|
|
# Initialize transform |
|
|
|
if self.cfg.model.imagenet_default_mean_and_std: |
|
|
|
mean = IMAGENET_DEFAULT_MEAN |
|
|
|
@@ -87,7 +95,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): |
|
|
|
data['image'], Image.Image) else load_image(data['image']) |
|
|
|
patch_image = self.patch_resize_transform(image) |
|
|
|
prompt = self.cfg.model.get('prompt', '图片上的文字是什么?') |
|
|
|
inputs = self.get_inputs(prompt) |
|
|
|
inputs = self.tokenize_text(prompt) |
|
|
|
|
|
|
|
sample = { |
|
|
|
'source': inputs, |
|
|
|
|