diff --git a/data/test/images/ocr_recognition.jpg b/data/test/images/ocr_recognition.jpg new file mode 100644 index 00000000..069ac03d --- /dev/null +++ b/data/test/images/ocr_recognition.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d68cfcaa7cc7b8276877c2dfa022deebe82076bc178ece1bfe7fd5423cd5b99 +size 60009 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 0baa7444..cbab0e0b 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -101,6 +101,7 @@ class Pipelines(object): image2image_translation = 'image-to-image-translation' live_category = 'live-category' video_category = 'video-category' + ocr_recognition = 'convnextTiny-ocr-recognition' image_portrait_enhancement = 'gpen-image-portrait-enhancement' image_to_image_generation = 'image-to-image-generation' skin_retouching = 'unet-skin-retouching' diff --git a/modelscope/outputs.py b/modelscope/outputs.py index e95640dc..3dc3cc44 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -47,6 +47,12 @@ TASK_OUTPUTS = { # } Tasks.ocr_detection: [OutputKeys.POLYGONS], + # ocr recognition result for single sample + # { + # "text": "电子元器件提供BOM配单" + # } + Tasks.ocr_recognition: [OutputKeys.TEXT], + # face detection result for single sample # { # "scores": [0.9, 0.1, 0.05, 0.05] diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 5f662b5f..5c87bac5 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -119,6 +119,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.image_classification: (Pipelines.daily_image_classification, 'damo/cv_vit-base_image-classification_Dailylife-labels'), + Tasks.ocr_recognition: (Pipelines.ocr_recognition, + 'damo/cv_convnextTiny_ocr-recognition_damo'), Tasks.skin_retouching: (Pipelines.skin_retouching, 'damo/cv_unet_skin-retouching'), } diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 76d0d575..c424818b 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -29,6 +29,7 @@ if TYPE_CHECKING: from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline from .live_category_pipeline import LiveCategoryPipeline from .ocr_detection_pipeline import OCRDetectionPipeline + from .ocr_recognition_pipeline import OCRRecognitionPipeline from .skin_retouching_pipeline import SkinRetouchingPipeline from .tinynas_classification_pipeline import TinynasClassificationPipeline from .video_category_pipeline import VideoCategoryPipeline @@ -65,6 +66,7 @@ else: 'image_to_image_generation_pipeline': ['Image2ImageGenerationPipeline'], 'ocr_detection_pipeline': ['OCRDetectionPipeline'], + 'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], 'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], 'video_category_pipeline': ['VideoCategoryPipeline'], diff --git a/modelscope/pipelines/cv/ocr_recognition_pipeline.py b/modelscope/pipelines/cv/ocr_recognition_pipeline.py new file mode 100644 index 00000000..4b095042 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_recognition_pipeline.py @@ -0,0 +1,131 @@ +import math +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.pipelines.cv.ocr_utils.model_convnext_transformer import \ + OCRRecModel +from modelscope.preprocessors import load_image +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +# constant +NUM_CLASSES = 7644 +IMG_HEIGHT = 32 +IMG_WIDTH = 300 +PRED_LENTH = 75 +PRED_PAD = 6 + + +@PIPELINES.register_module( + Tasks.ocr_recognition, module_name=Pipelines.ocr_recognition) +class OCRRecognitionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) + label_path = osp.join(self.model, 'label_dict.txt') + logger.info(f'loading model from {model_path}') + + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + self.infer_model = OCRRecModel(NUM_CLASSES).to(self.device) + self.infer_model.eval() + self.infer_model.load_state_dict( + torch.load(model_path, map_location=self.device)) + self.labelMapping = dict() + with open(label_path, 'r') as f: + lines = f.readlines() + cnt = 2 + for line in lines: + line = line.strip('\n') + self.labelMapping[cnt] = line + cnt += 1 + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + img = np.array(load_image(input).convert('L')) + elif isinstance(input, PIL.Image.Image): + img = np.array(input.convert('L')) + elif isinstance(input, np.ndarray): + if len(input.shape) == 3: + img = cv2.cvtColor(input, cv2.COLOR_RGB2GRAY) + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + data = [] + img_h, img_w = img.shape + wh_ratio = img_w / img_h + true_w = int(IMG_HEIGHT * wh_ratio) + split_batch_cnt = 1 + if true_w < IMG_WIDTH * 1.2: + img = cv2.resize(img, (min(true_w, IMG_WIDTH), IMG_HEIGHT)) + else: + split_batch_cnt = math.ceil((true_w - 48) * 1.0 / 252) + img = cv2.resize(img, (true_w, IMG_HEIGHT)) + + if split_batch_cnt == 1: + mask = np.zeros((IMG_HEIGHT, IMG_WIDTH)) + mask[:, :img.shape[1]] = img + data.append(mask) + else: + for idx in range(split_batch_cnt): + mask = np.zeros((IMG_HEIGHT, IMG_WIDTH)) + left = (PRED_LENTH * 4 - PRED_PAD * 4) * idx + trunk_img = img[:, left:min(left + PRED_LENTH * 4, true_w)] + mask[:, :trunk_img.shape[1]] = trunk_img + data.append(mask) + + data = torch.FloatTensor(data).view( + len(data), 1, IMG_HEIGHT, IMG_WIDTH).cuda() / 255. + + result = {'img': data} + + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + pred = self.infer_model(input['img']) + return {'results': pred} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + preds = inputs['results'] + batchSize, length = preds.shape + pred_idx = [] + if batchSize == 1: + pred_idx = preds[0].cpu().data.tolist() + else: + for idx in range(batchSize): + if idx == 0: + pred_idx.extend(preds[idx].cpu().data[:PRED_LENTH + - PRED_PAD].tolist()) + elif idx == batchSize - 1: + pred_idx.extend(preds[idx].cpu().data[PRED_PAD:].tolist()) + else: + pred_idx.extend(preds[idx].cpu().data[PRED_PAD:PRED_LENTH + - PRED_PAD].tolist()) + + # ctc decoder + last_p = 0 + str_pred = [] + for p in pred_idx: + if p != last_p and p != 0: + str_pred.append(self.labelMapping[p]) + last_p = p + + final_str = ''.join(str_pred) + result = {OutputKeys.TEXT: final_str} + return result diff --git a/modelscope/pipelines/cv/ocr_utils/model_convnext_transformer.py b/modelscope/pipelines/cv/ocr_utils/model_convnext_transformer.py new file mode 100644 index 00000000..cf5e2fe1 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/model_convnext_transformer.py @@ -0,0 +1,23 @@ +import torch +import torch.nn as nn + +from .ocr_modules.convnext import convnext_tiny +from .ocr_modules.vitstr import vitstr_tiny + + +class OCRRecModel(nn.Module): + + def __init__(self, num_classes): + super(OCRRecModel, self).__init__() + self.cnn_model = convnext_tiny() + self.num_classes = num_classes + self.vitstr = vitstr_tiny(num_tokens=num_classes) + + def forward(self, input): + """ Transformation stage """ + features = self.cnn_model(input) + prediction = self.vitstr(features) + prediction = torch.nn.functional.softmax(prediction, dim=-1) + + output = torch.argmax(prediction, -1) + return output diff --git a/modelscope/pipelines/cv/ocr_utils/ocr_modules/__init__.py b/modelscope/pipelines/cv/ocr_utils/ocr_modules/__init__.py new file mode 100644 index 00000000..7799c34f --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/ocr_modules/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .convnext import convnext_tiny + from .vitstr import vitstr_tiny +else: + _import_structure = { + 'convnext': ['convnext_tiny'], + 'vitstr': ['vitstr_tiny'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/pipelines/cv/ocr_utils/ocr_modules/convnext.py b/modelscope/pipelines/cv/ocr_utils/ocr_modules/convnext.py new file mode 100644 index 00000000..c2059107 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/ocr_modules/convnext.py @@ -0,0 +1,169 @@ +""" Contains various versions of ConvNext Networks. +ConvNext Networks (ConvNext) were proposed in: + Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell and Saining Xie + A ConvNet for the 2020s. CVPR 2022. +Compared to https://github.com/facebookresearch/ConvNeXt, +we obtain different ConvNext variants by changing the network depth, width, +feature number, and downsample ratio. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .timm_tinyc import DropPath + + +class Block(nn.Module): + r""" ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): + super().__init__() + self.dwconv = nn.Conv2d( + dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, + 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), + requires_grad=True) if layer_scale_init_value > 0 else None + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class ConvNeXt(nn.Module): + r""" ConvNeXt + A PyTorch impl of : `A ConvNet for the 2020s` - + https://arxiv.org/pdf/2201.03545.pdf + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate (float): Stochastic depth rate. Default: 0. + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. + """ + + def __init__( + self, + in_chans=1, + num_classes=1000, + depths=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + drop_path_rate=0., + layer_scale_init_value=1e-6, + head_init_scale=1., + ): + super().__init__() + + self.downsample_layers = nn.ModuleList( + ) # stem and 3 intermediate downsampling conv layers + stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + LayerNorm(dims[0], eps=1e-6, data_format='channels_first')) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format='channels_first'), + nn.Conv2d( + dims[i], dims[i + 1], kernel_size=(2, 1), stride=(2, 1)), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = nn.ModuleList( + ) # 4 feature resolution stages, each consisting of multiple residual blocks + dp_rates = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] + cur = 0 + for i in range(4): + stage = nn.Sequential(*[ + Block( + dim=dims[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_value) + for j in range(depths[i]) + ]) + self.stages.append(stage) + cur += depths[i] + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + nn.init.constant_(m.bias, 0) + + def forward_features(self, x): + for i in range(4): + x = self.downsample_layers[i](x.contiguous()) + x = self.stages[i](x.contiguous()) + return x # global average pooling, (N, C, H, W) -> (N, C) + + def forward(self, x): + x = self.forward_features(x.contiguous()) + + return x.contiguous() + + +class LayerNorm(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, + normalized_shape, + eps=1e-6, + data_format='channels_last'): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ['channels_last', 'channels_first']: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == 'channels_last': + return F.layer_norm(x, self.normalized_shape, self.weight, + self.bias, self.eps) + elif self.data_format == 'channels_first': + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +def convnext_tiny(): + model = ConvNeXt(depths=[3, 3, 8, 3], dims=[96, 192, 256, 512]) + return model diff --git a/modelscope/pipelines/cv/ocr_utils/ocr_modules/timm_tinyc.py b/modelscope/pipelines/cv/ocr_utils/ocr_modules/timm_tinyc.py new file mode 100644 index 00000000..f54c0e78 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/ocr_modules/timm_tinyc.py @@ -0,0 +1,334 @@ +'''Referenced from rwightman's pytorch-image-models(timm). +Github: https://github.com/rwightman/pytorch-image-models +We use some modules and modify the parameters according to our network. +''' +import collections.abc +import logging +import math +from collections import OrderedDict +from copy import deepcopy +from functools import partial +from itertools import repeat + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True): + super().__init__() + img_size = (1, 75) + to_2tuple = _ntuple(2) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], + img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + x = x.permute(0, 1, 3, 2) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0], ) + (1, ) * ( + x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Attention(nn.Module): + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + attn_drop=0.1, + proj_drop=0.1): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + representation_size=None, + distilled=False, + drop_rate=0.1, + attn_drop_rate=0.1, + drop_path_rate=0., + embed_layer=PatchEmbed, + norm_layer=None, + act_layer=None, + weight_init=''): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + distilled (bool): model includes a distillation token and head as in DeiT models + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + weight_init: (str): weight init scheme + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 2 if distilled else 1 + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.dist_token = nn.Parameter(torch.zeros( + 1, 1, embed_dim)) if distilled else None + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer) for i in range(depth) + ]) + self.norm = norm_layer(embed_dim) + + # Representation layer + if representation_size and not distilled: + self.num_features = representation_size + self.pre_logits = nn.Sequential( + OrderedDict([('fc', nn.Linear(embed_dim, representation_size)), + ('act', nn.Tanh())])) + else: + self.pre_logits = nn.Identity() + + # Classifier head(s) + self.head = nn.Linear( + self.num_features, + num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = None + if distilled: + self.head_dist = nn.Linear( + self.embed_dim, + self.num_classes) if num_classes > 0 else nn.Identity() + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear( + self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + if self.num_tokens == 2: + self.head_dist = nn.Linear( + self.embed_dim, + self.num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + cls_token = self.cls_token.expand( + x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + if self.dist_token is None: + x = torch.cat((cls_token, x), dim=1) + else: + x = torch.cat( + (cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), + dim=1) + x = self.pos_drop(x + self.pos_embed) + x = self.blocks(x) + x = self.norm(x) + if self.dist_token is None: + return self.pre_logits(x[:, 0]) + else: + return x[:, 0], x[:, 1] + + def forward(self, x): + x = self.forward_features(x) + if self.head_dist is not None: + x, x_dist = self.head(x[0]), self.head_dist( + x[1]) # x must be a tuple + if self.training and not torch.jit.is_scripting(): + # during inference, return the average of both classifier predictions + return x, x_dist + else: + return (x + x_dist) / 2 + else: + x = self.head(x) + return x diff --git a/modelscope/pipelines/cv/ocr_utils/ocr_modules/vitstr.py b/modelscope/pipelines/cv/ocr_utils/ocr_modules/vitstr.py new file mode 100644 index 00000000..e7d96574 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/ocr_modules/vitstr.py @@ -0,0 +1,63 @@ +""" Contains various versions of ViTSTR. +ViTSTR were proposed in: + Rowel Atienza + Vision transformer for fast and efficient scene text recognition. ICDAR 2021. +Compared to https://github.com/roatienza/deep-text-recognition-benchmark, +we obtain different ViTSTR variants by changing the network patch_size and in_chans. +""" +from __future__ import absolute_import, division, print_function +import logging +from copy import deepcopy +from functools import partial + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + +from .timm_tinyc import VisionTransformer + + +class ViTSTR(VisionTransformer): + ''' + ViTSTR is basically a ViT that uses DeiT weights. + Modified head to support a sequence of characters prediction for STR. + ''' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def reset_classifier(self, num_classes): + self.num_classes = num_classes + self.head = nn.Linear( + self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + + x = x + self.pos_embed + x = self.pos_drop(x) + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + b, s, e = x.size() + x = x.reshape(b * s, e) + x = self.head(x).view(b, s, self.num_classes) + return x + + +def vitstr_tiny(num_tokens): + vitstr = ViTSTR( + patch_size=1, + in_chans=512, + embed_dim=192, + depth=12, + num_heads=3, + mlp_ratio=4, + qkv_bias=True) + vitstr.reset_classifier(num_classes=num_tokens) + return vitstr diff --git a/tests/pipelines/test_ocr_recognition.py b/tests/pipelines/test_ocr_recognition.py new file mode 100644 index 00000000..d86c2266 --- /dev/null +++ b/tests/pipelines/test_ocr_recognition.py @@ -0,0 +1,47 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import shutil +import sys +import tempfile +import unittest +from typing import Any, Dict, List, Tuple, Union + +import cv2 +import numpy as np +import PIL + +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class OCRRecognitionTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_convnextTiny_ocr-recognition_damo' + self.test_image = 'data/test/images/ocr_recognition.jpg' + + def pipeline_inference(self, pipeline: Pipeline, input_location: str): + result = pipeline(input_location) + print('ocr recognition results: ', result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + ocr_recognition = pipeline(Tasks.ocr_recognition, model=self.model_id) + self.pipeline_inference(ocr_recognition, self.test_image) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub_PILinput(self): + ocr_recognition = pipeline(Tasks.ocr_recognition, model=self.model_id) + imagePIL = PIL.Image.open(self.test_image) + self.pipeline_inference(ocr_recognition, imagePIL) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + ocr_recognition = pipeline(Tasks.ocr_recognition) + self.pipeline_inference(ocr_recognition, self.test_image) + + +if __name__ == '__main__': + unittest.main()