diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 680fe2e8..9fad45e2 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -43,6 +43,7 @@ class Pipelines(object): person_image_cartoon = 'unet-person-image-cartoon' ocr_detection = 'resnet18-ocr-detection' action_recognition = 'TAdaConv_action-recognition' + animal_recognation = 'resnet101-animal_recog' # nlp tasks sentence_similarity = 'sentence-similarity' diff --git a/modelscope/models/cv/animal_recognition/__init__.py b/modelscope/models/cv/animal_recognition/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/animal_recognition/resnet.py b/modelscope/models/cv/animal_recognition/resnet.py new file mode 100644 index 00000000..1fd4b93e --- /dev/null +++ b/modelscope/models/cv/animal_recognition/resnet.py @@ -0,0 +1,430 @@ +import math + +import torch +import torch.nn as nn + +from .splat import SplAtConv2d + +__all__ = ['ResNet', 'Bottleneck'] + + +class DropBlock2D(object): + + def __init__(self, *args, **kwargs): + raise NotImplementedError + + +class GlobalAvgPool2d(nn.Module): + + def __init__(self): + """Global average pooling over the input's spatial dimensions""" + super(GlobalAvgPool2d, self).__init__() + + def forward(self, inputs): + return nn.functional.adaptive_avg_pool2d(inputs, + 1).view(inputs.size(0), -1) + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + radix=1, + cardinality=1, + bottleneck_width=64, + avd=False, + avd_first=False, + dilation=1, + is_first=False, + rectified_conv=False, + rectify_avg=False, + norm_layer=None, + dropblock_prob=0.0, + last_gamma=False): + super(Bottleneck, self).__init__() + group_width = int(planes * (bottleneck_width / 64.)) * cardinality + self.conv1 = nn.Conv2d( + inplanes, group_width, kernel_size=1, bias=False) + self.bn1 = norm_layer(group_width) + self.dropblock_prob = dropblock_prob + self.radix = radix + self.avd = avd and (stride > 1 or is_first) + self.avd_first = avd_first + + if self.avd: + self.avd_layer = nn.AvgPool2d(3, stride, padding=1) + stride = 1 + + if dropblock_prob > 0.0: + self.dropblock1 = DropBlock2D(dropblock_prob, 3) + if radix == 1: + self.dropblock2 = DropBlock2D(dropblock_prob, 3) + self.dropblock3 = DropBlock2D(dropblock_prob, 3) + + if radix >= 1: + self.conv2 = SplAtConv2d( + group_width, + group_width, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=cardinality, + bias=False, + radix=radix, + rectify=rectified_conv, + rectify_avg=rectify_avg, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + elif rectified_conv: + from rfconv import RFConv2d + self.conv2 = RFConv2d( + group_width, + group_width, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=cardinality, + bias=False, + average_mode=rectify_avg) + self.bn2 = norm_layer(group_width) + else: + self.conv2 = nn.Conv2d( + group_width, + group_width, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=cardinality, + bias=False) + self.bn2 = norm_layer(group_width) + + self.conv3 = nn.Conv2d( + group_width, planes * 4, kernel_size=1, bias=False) + self.bn3 = norm_layer(planes * 4) + + if last_gamma: + from torch.nn.init import zeros_ + zeros_(self.bn3.weight) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.dilation = dilation + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + if self.dropblock_prob > 0.0: + out = self.dropblock1(out) + out = self.relu(out) + + if self.avd and self.avd_first: + out = self.avd_layer(out) + + out = self.conv2(out) + if self.radix == 0: + out = self.bn2(out) + if self.dropblock_prob > 0.0: + out = self.dropblock2(out) + out = self.relu(out) + + if self.avd and not self.avd_first: + out = self.avd_layer(out) + + out = self.conv3(out) + out = self.bn3(out) + if self.dropblock_prob > 0.0: + out = self.dropblock3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, + block, + layers, + radix=1, + groups=1, + bottleneck_width=64, + num_classes=1000, + dilated=False, + dilation=1, + deep_stem=False, + stem_width=64, + avg_down=False, + rectified_conv=False, + rectify_avg=False, + avd=False, + avd_first=False, + final_drop=0.0, + dropblock_prob=0, + last_gamma=False, + norm_layer=nn.BatchNorm2d): + self.cardinality = groups + self.bottleneck_width = bottleneck_width + # ResNet-D params + self.inplanes = stem_width * 2 if deep_stem else 64 + self.avg_down = avg_down + self.last_gamma = last_gamma + # ResNeSt params + self.radix = radix + self.avd = avd + self.avd_first = avd_first + + super(ResNet, self).__init__() + self.rectified_conv = rectified_conv + self.rectify_avg = rectify_avg + if rectified_conv: + from rfconv import RFConv2d + conv_layer = RFConv2d + else: + conv_layer = nn.Conv2d + conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {} + if deep_stem: + self.conv1 = nn.Sequential( + conv_layer( + 3, + stem_width, + kernel_size=3, + stride=2, + padding=1, + bias=False, + **conv_kwargs), + norm_layer(stem_width), + nn.ReLU(inplace=True), + conv_layer( + stem_width, + stem_width, + kernel_size=3, + stride=1, + padding=1, + bias=False, + **conv_kwargs), + norm_layer(stem_width), + nn.ReLU(inplace=True), + conv_layer( + stem_width, + stem_width * 2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + **conv_kwargs), + ) + else: + self.conv1 = conv_layer( + 3, + 64, + kernel_size=7, + stride=2, + padding=3, + bias=False, + **conv_kwargs) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer( + block, 64, layers[0], norm_layer=norm_layer, is_first=False) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, norm_layer=norm_layer) + if dilated or dilation == 4: + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=1, + dilation=2, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=1, + dilation=4, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + elif dilation == 2: + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=2, + dilation=1, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=1, + dilation=2, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + else: + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=2, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=2, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + self.avgpool = GlobalAvgPool2d() + self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, norm_layer): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, + block, + planes, + blocks, + stride=1, + dilation=1, + norm_layer=None, + dropblock_prob=0.0, + is_first=True): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + down_layers = [] + if self.avg_down: + if dilation == 1: + down_layers.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + else: + down_layers.append( + nn.AvgPool2d( + kernel_size=1, + stride=1, + ceil_mode=True, + count_include_pad=False)) + down_layers.append( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=1, + bias=False)) + else: + down_layers.append( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False)) + down_layers.append(norm_layer(planes * block.expansion)) + downsample = nn.Sequential(*down_layers) + + layers = [] + if dilation == 1 or dilation == 2: + layers.append( + block( + self.inplanes, + planes, + stride, + downsample=downsample, + radix=self.radix, + cardinality=self.cardinality, + bottleneck_width=self.bottleneck_width, + avd=self.avd, + avd_first=self.avd_first, + dilation=1, + is_first=is_first, + rectified_conv=self.rectified_conv, + rectify_avg=self.rectify_avg, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob, + last_gamma=self.last_gamma)) + elif dilation == 4: + layers.append( + block( + self.inplanes, + planes, + stride, + downsample=downsample, + radix=self.radix, + cardinality=self.cardinality, + bottleneck_width=self.bottleneck_width, + avd=self.avd, + avd_first=self.avd_first, + dilation=2, + is_first=is_first, + rectified_conv=self.rectified_conv, + rectify_avg=self.rectify_avg, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob, + last_gamma=self.last_gamma)) + else: + raise RuntimeError('=> unknown dilation size: {}'.format(dilation)) + + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + radix=self.radix, + cardinality=self.cardinality, + bottleneck_width=self.bottleneck_width, + avd=self.avd, + avd_first=self.avd_first, + dilation=dilation, + rectified_conv=self.rectified_conv, + rectify_avg=self.rectify_avg, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob, + last_gamma=self.last_gamma)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + if self.drop: + x = self.drop(x) + x = self.fc(x) + + return x diff --git a/modelscope/models/cv/animal_recognition/splat.py b/modelscope/models/cv/animal_recognition/splat.py new file mode 100644 index 00000000..b12bf154 --- /dev/null +++ b/modelscope/models/cv/animal_recognition/splat.py @@ -0,0 +1,125 @@ +"""Split-Attention""" + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import BatchNorm2d, Conv2d, Linear, Module, ReLU +from torch.nn.modules.utils import _pair + +__all__ = ['SplAtConv2d'] + + +class SplAtConv2d(Module): + """Split-Attention Conv2d + """ + + def __init__(self, + in_channels, + channels, + kernel_size, + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + groups=1, + bias=True, + radix=2, + reduction_factor=4, + rectify=False, + rectify_avg=False, + norm_layer=None, + dropblock_prob=0.0, + **kwargs): + super(SplAtConv2d, self).__init__() + padding = _pair(padding) + self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) + self.rectify_avg = rectify_avg + inter_channels = max(in_channels * radix // reduction_factor, 32) + self.radix = radix + self.cardinality = groups + self.channels = channels + self.dropblock_prob = dropblock_prob + if self.rectify: + from rfconv import RFConv2d + self.conv = RFConv2d( + in_channels, + channels * radix, + kernel_size, + stride, + padding, + dilation, + groups=groups * radix, + bias=bias, + average_mode=rectify_avg, + **kwargs) + else: + self.conv = Conv2d( + in_channels, + channels * radix, + kernel_size, + stride, + padding, + dilation, + groups=groups * radix, + bias=bias, + **kwargs) + self.use_bn = norm_layer is not None + if self.use_bn: + self.bn0 = norm_layer(channels * radix) + self.relu = ReLU(inplace=True) + self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) + if self.use_bn: + self.bn1 = norm_layer(inter_channels) + self.fc2 = Conv2d( + inter_channels, channels * radix, 1, groups=self.cardinality) + if dropblock_prob > 0.0: + self.dropblock = DropBlock2D(dropblock_prob, 3) + self.rsoftmax = rSoftMax(radix, groups) + + def forward(self, x): + x = self.conv(x) + if self.use_bn: + x = self.bn0(x) + if self.dropblock_prob > 0.0: + x = self.dropblock(x) + x = self.relu(x) + + batch, rchannel = x.shape[:2] + if self.radix > 1: + splited = torch.split(x, rchannel // self.radix, dim=1) + gap = sum(splited) + else: + gap = x + gap = F.adaptive_avg_pool2d(gap, 1) + gap = self.fc1(gap) + + if self.use_bn: + gap = self.bn1(gap) + gap = self.relu(gap) + + atten = self.fc2(gap) + atten = self.rsoftmax(atten).view(batch, -1, 1, 1) + + if self.radix > 1: + attens = torch.split(atten, rchannel // self.radix, dim=1) + out = sum([att * split for (att, split) in zip(attens, splited)]) + else: + out = atten * x + return out.contiguous() + + +class rSoftMax(nn.Module): + + def __init__(self, radix, cardinality): + super().__init__() + self.radix = radix + self.cardinality = cardinality + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 68d875ec..b046e076 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -1,4 +1,5 @@ from .action_recognition_pipeline import ActionRecognitionPipeline +from .animal_recog_pipeline import AnimalRecogPipeline from .image_cartoon_pipeline import ImageCartoonPipeline from .image_matting_pipeline import ImageMattingPipeline from .ocr_detection_pipeline import OCRDetectionPipeline diff --git a/modelscope/pipelines/cv/animal_recog_pipeline.py b/modelscope/pipelines/cv/animal_recog_pipeline.py new file mode 100644 index 00000000..eee9e844 --- /dev/null +++ b/modelscope/pipelines/cv/animal_recog_pipeline.py @@ -0,0 +1,127 @@ +import os.path as osp +import tempfile +from typing import Any, Dict + +import cv2 +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.fileio import File +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Pipelines +from modelscope.models.cv.animal_recognition import resnet +from modelscope.pipelines.base import Input +from modelscope.preprocessors import load_image +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from ..base import Pipeline +from ..builder import PIPELINES + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_classification, module_name=Pipelines.animal_recognation) +class AnimalRecogPipeline(Pipeline): + + def __init__(self, model: str): + super().__init__(model=model) + import torch + + def resnest101(**kwargs): + model = resnet.ResNet( + resnet.Bottleneck, [3, 4, 23, 3], + radix=2, + groups=1, + bottleneck_width=64, + deep_stem=True, + stem_width=64, + avg_down=True, + avd=True, + avd_first=False, + **kwargs) + return model + + def filter_param(src_params, own_state): + copied_keys = [] + for name, param in src_params.items(): + if 'module.' == name[0:7]: + name = name[7:] + if '.module.' not in list(own_state.keys())[0]: + name = name.replace('.module.', '.') + if (name in own_state) and (own_state[name].shape + == param.shape): + own_state[name].copy_(param) + copied_keys.append(name) + + def load_pretrained(model, src_params): + if 'state_dict' in src_params: + src_params = src_params['state_dict'] + own_state = model.state_dict() + filter_param(src_params, own_state) + model.load_state_dict(own_state) + + self.model = resnest101(num_classes=8288) + local_model_dir = model + if osp.exists(model): + local_model_dir = model + else: + local_model_dir = snapshot_download(model) + self.local_path = local_model_dir + src_params = torch.load( + osp.join(local_model_dir, 'pytorch_model.pt'), 'cpu') + load_pretrained(self.model, src_params) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + img = load_image(input) + elif isinstance(input, PIL.Image.Image): + img = input.convert('RGB') + elif isinstance(input, np.ndarray): + if len(input.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + img = input[:, :, ::-1] + img = Image.fromarray(img.astype('uint8')).convert('RGB') + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + test_transforms = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), normalize + ]) + img = test_transforms(img) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + def set_phase(model, is_train): + if is_train: + model.train() + else: + model.eval() + + is_train = False + set_phase(self.model, is_train) + img = input['img'] + input_img = torch.unsqueeze(img, 0) + outputs = self.model(input_img) + return {'outputs': outputs} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + label_mapping_path = osp.join(self.local_path, 'label_mapping.txt') + with open(label_mapping_path, 'r') as f: + label_mapping = f.readlines() + score = torch.max(inputs['outputs']) + inputs = { + 'scores': score.item(), + 'labels': label_mapping[inputs['outputs'].argmax()].split('\t')[1] + } + return inputs diff --git a/tests/pipelines/test_animal_recognation.py b/tests/pipelines/test_animal_recognation.py new file mode 100644 index 00000000..d0f42dc3 --- /dev/null +++ b/tests/pipelines/test_animal_recognation.py @@ -0,0 +1,20 @@ +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class MultiModalFeatureTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run(self): + animal_recog = pipeline( + Tasks.image_classification, + model='damo/cv_resnest101_animal_recognation') + result = animal_recog('data/test/images/image1.jpg') + print(result) + + +if __name__ == '__main__': + unittest.main()